In [None]:
%load_ext autoreload
%autoreload 2

from dotenv import load_dotenv

load_dotenv("../.env")

In [None]:
import os
import random

import matplotlib.pyplot as plt
import mlflow
import mlflow.pytorch
import torch
from prisma import Prisma
from torch.utils.data import DataLoader, Dataset, random_split
from tqdm.notebook import tqdm
from transformers import AutoTokenizer

from such_toxic.t5 import T5Toxic

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

device

In [None]:
model_name = "t5-small"
max_input_length = 128
num_classes = 6
batch_size = 16
epochs = 1
lr = 0.0001

dataset_sample_percent = 1.0
thresholds = {
    "toxic": 0.8,
    "severe_toxic": 0.8,
    "obscene": 0.8,
    "threat": 0.8,
    "insult": 0.8,
    "identity_hate": 0.6,
}

In [None]:
if os.getenv("MLFLOW_ENABLED") == "true":
    print("Starting run")
    mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI"))
    mlflow.set_experiment("t5-toxic")
    mlflow.pytorch.autolog()
    mlflow.start_run()

    mlflow.log_param("model_name", model_name)
    mlflow.log_param("max_input_length", max_input_length)
    mlflow.log_param("num_classes", num_classes)
    mlflow.log_param("batch_size", batch_size)
    mlflow.log_param("epochs", epochs)
    mlflow.log_param("learning_rate", lr)

In [None]:
db = Prisma(
    http={
        "timeout": None,
    },
)
db.connect()

dataset_ids = db.query_raw(query="select id from comments")

sample_size = int(dataset_sample_percent * len(dataset_ids))
dataset_ids = random.sample(dataset_ids, sample_size)

train_ids, validation_ids, test_ids = random_split(dataset_ids, [0.7, 0.05, 0.25])

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)


class ToxicDataset(Dataset):
    def __init__(self, ids):
        self.ids = ids

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        id = self.ids[idx]
        row = db.comments.find_first(where=id)
        tokens = tokenizer(
            row.comment_text,
            return_tensors="pt",
            truncation=True,
            return_length=True,
            max_length=max_input_length,
            padding="max_length",
        )

        labels = [
            1 if row.toxic >= thresholds["toxic"] else 0,
            1 if row.severe_toxic >= thresholds["severe_toxic"] else 0,
            1 if row.obscene >= thresholds["obscene"] else 0,
            1 if row.threat >= thresholds["threat"] else 0,
            1 if row.insult >= thresholds["insult"] else 0,
            1 if row.identity_hate >= thresholds["identity_hate"] else 0,
        ]

        return (
            tokens["input_ids"],
            tokens["attention_mask"],
            torch.tensor(labels, dtype=torch.float32),
        )

In [None]:
train_dataset = ToxicDataset(list(train_ids))
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)

validation_dataset = ToxicDataset(list(validation_ids))
validation_dataloader = DataLoader(validation_dataset, batch_size, shuffle=True)

test_dataset = ToxicDataset(list(test_ids))
test_dataloader = DataLoader(test_dataset, batch_size, shuffle=True)

for i, (input, mask, output) in enumerate(train_dataloader):
    print("Input: ", input.shape)
    print("Mask: ", mask.shape)
    print("Output: ", output.shape)
    break

In [None]:
model = T5Toxic(model_name, num_classes).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

In [None]:
def run_prediction(input, attention_mask, labels):
    input = input.squeeze(1).to(device)
    attention_mask = attention_mask.squeeze(1).to(device)
    labels = labels.to(device)

    model_output = model(
        input_ids=input,
        attention_mask=attention_mask,
        labels=labels,
    )

    probabilities = torch.sigmoid(model_output.logits)
    loss = model_output.loss

    return probabilities, loss


def train(model, dataloader, optimizer):
    model.train()
    total_loss = 0
    train_pbar = tqdm(
        enumerate(dataloader), desc="Batch", leave=False, total=len(dataloader)
    )
    for i, (input, attn_mask, output) in train_pbar:
        optimizer.zero_grad()
        _, loss = run_prediction(input, attn_mask, output)
        total_loss += loss.item()
        if os.getenv("MLFLOW_ENABLED") == "true":
            mlflow.log_metric("batch_training_loss", loss.item(), step=i)
        train_pbar.set_postfix({"Loss": loss.item()})
        loss.backward()
        optimizer.step()

    return total_loss / len(dataloader)


def validate(model, dataloader):
    model.eval()
    total_loss = 0
    validation_pbar = tqdm(dataloader, desc="Batch", leave=False)
    for input, attn_mask, output in validation_pbar:
        _, loss = run_prediction(input, attn_mask, output)
        validation_pbar.set_postfix({"Loss": loss.item()})
        total_loss += loss.item()

    return total_loss / len(dataloader)


def plot_losses(training_losses, validation_losses):
    plt.plot(training_losses, label="Training Loss")
    plt.plot(validation_losses, label="Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training/Validation Loss")
    plt.legend()
    plt.savefig("../target/training_loss.png")
    plt.close()


def save_model(model, training_loss, validation_loss, epoch):
    if os.getenv("MLFLOW_ENABLED") == "true":
        mlflow.log_metric("training_loss", training_loss, step=epoch)
        mlflow.log_metric("validation_loss", validation_loss, step=epoch)
        mlflow.pytorch.log_model(model, f"{model_name}-toxic")

    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "training_loss": training_loss,
            "validation_loss": validation_loss,
        },
        f"../target/{model_name}-toxic.pt",
    )

In [None]:
training_losses = []
validation_losses = []
pbar = tqdm(range(epochs), desc="Epoch", position=0)
for epoch in pbar:
    training_loss = train(model, train_dataloader, optimizer)
    validation_loss = validate(model, validation_dataloader)

    training_losses.append(training_loss)
    validation_losses.append(validation_loss)

    pbar.set_postfix(
        training_loss=training_loss,
        validation_loss=validation_loss,
    )

    plot_losses(training_losses, validation_losses)
    save_model(model, training_loss, validation_loss, epoch)

In [None]:
save_model(model, training_loss, validation_loss, epoch)

In [None]:
model.eval()


def accuracy(y_true, y_pred):
    labels = 0
    for true, pred in zip(y_true, y_pred):
        if true == pred:
            labels += 1
    return labels / len(y_true)


def true_positives(y_true, y_pred):
    total_p = 0
    tp = 0
    for true, pred in zip(y_true, y_pred):
        if true == 1:
            total_p += 1
        if true == 1 and pred == 1:
            tp += 1
    if total_p == 0:
        return total_p
    return tp / total_p


def predict(input_ids, attn_mask, raw=False):
    # Forward pass
    with torch.no_grad():
        predictions, _ = run_prediction(
            input_ids,
            attn_mask,
            labels=torch.tensor([[0, 0, 0, 0, 0, 0]], dtype=torch.float32),
        )
        predictions = predictions.cpu().numpy()
        predictions = predictions[0]
    if raw:
        return predictions

    return [
        1 if predictions[0] >= thresholds["toxic"] else 0,
        1 if predictions[1] >= thresholds["severe_toxic"] else 0,
        1 if predictions[2] >= thresholds["obscene"] else 0,
        1 if predictions[3] >= thresholds["threat"] else 0,
        1 if predictions[4] >= thresholds["insult"] else 0,
        1 if predictions[5] >= thresholds["identity_hate"] else 0,
    ]

In [None]:
y_true = []
y_pred = []
for ids, mask, labels in tqdm(list(test_dataset)[:1000]):
    y_true.append(labels)
    y_pred.append(predict(ids, mask))

toxic_y_true = [x[0] for x in y_true]
toxic_y_pred = [x[0] for x in y_pred]

severe_toxic_y_true = [x[1] for x in y_true]
severe_toxic_y_pred = [x[1] for x in y_pred]

obscene_y_true = [x[2] for x in y_true]
obscene_y_pred = [x[2] for x in y_pred]

threat_y_true = [x[3] for x in y_true]
threat_y_pred = [x[3] for x in y_pred]

insult_y_true = [x[4] for x in y_true]
insult_y_pred = [x[4] for x in y_pred]

identity_hate_y_true = [x[5] for x in y_true]
identity_hate_y_pred = [x[5] for x in y_pred]

accuracy_metrics = {
    "toxic_accuracy": accuracy(
        toxic_y_true,
        toxic_y_pred,
    ),
    "toxic_true_positives": true_positives(
        toxic_y_true,
        toxic_y_pred,
    ),
    "severe_toxic_accuracy": accuracy(
        severe_toxic_y_true,
        severe_toxic_y_pred,
    ),
    "severe_toxic_true_positives": true_positives(
        severe_toxic_y_true,
        severe_toxic_y_pred,
    ),
    "obscene_accuracy": accuracy(
        obscene_y_true,
        obscene_y_pred,
    ),
    "obscene_true_positives": true_positives(
        obscene_y_true,
        obscene_y_pred,
    ),
    "threat_accuracy": accuracy(
        threat_y_true,
        threat_y_pred,
    ),
    "threat_true_positives": true_positives(
        threat_y_true,
        threat_y_pred,
    ),
    "insult_accuracy": accuracy(
        insult_y_true,
        insult_y_pred,
    ),
    "insult_true_positives": true_positives(
        insult_y_true,
        insult_y_pred,
    ),
    "identity_hate_accuracy": accuracy(
        identity_hate_y_true,
        identity_hate_y_pred,
    ),
    "identity_hate_true_positives": true_positives(
        identity_hate_y_true,
        identity_hate_y_pred,
    ),
}

if os.getenv("MLFLOW_ENABLED") == "true":
    mlflow.log_metrics(accuracy_metrics)

print(accuracy_metrics)

In [None]:
db.disconnect()

if os.getenv("MLFLOW_ENABLED") == "true":
    mlflow.end_run()

In [40]:
model.eval()


def run_inference(text):
    tokens = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        return_length=True,
        max_length=max_input_length,
        padding="max_length",
    )
    with torch.no_grad():
        predictions = run_prediction(
            tokens["input_ids"],
            tokens["attention_mask"],
            labels=torch.tensor([[0, 0, 0, 0, 0, 0]], dtype=torch.float32),
        )

        print("Toxic: ", round(predictions[0][0][0].item(), 4))
        print("Severe Toxic: ", round(predictions[0][0][1].item(), 4))
        print("Obscene: ", round(predictions[0][0][2].item(), 4))
        print("Threat: ", round(predictions[0][0][3].item(), 4))
        print("Insult: ", round(predictions[0][0][4].item(), 4))
        print("Identity Hate: ", round(predictions[0][0][5].item(), 4))


run_inference("Ok twinkle toes that sounds fine to me")

Toxic:  0.0001
Severe Toxic:  0.0
Obscene:  0.0
Threat:  0.0
Insult:  0.0
Identity Hate:  0.0
