In [1]:
%load_ext autoreload
%autoreload 2

from dotenv import load_dotenv

load_dotenv("../.env")

True

In [2]:
import os

import matplotlib.pyplot as plt
import mlflow
import mlflow.pytorch
import numpy as np
import torch
from prisma import Prisma
from torch.utils.data import DataLoader, Dataset, random_split
from tqdm import tqdm
from transformers import AutoTokenizer, T5ForSequenceClassification

In [3]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

device

device(type='cuda')

In [4]:
np.set_printoptions(suppress=True)

model_name = "t5-small"
max_input_length = 128
num_classes = 6
batch_size = 16
epochs = 10
lr = 0.001

In [5]:
if os.getenv("MLFLOW_ENABLED") == "true":
    print("Starting run")
    mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI"))
    mlflow.set_experiment("t5-toxic")
    mlflow.pytorch.autolog()
    mlflow.start_run()

    mlflow.log_param("model_name", model_name)
    mlflow.log_param("max_input_length", max_input_length)
    mlflow.log_param("num_classes", num_classes)
    mlflow.log_param("batch_size", batch_size)
    mlflow.log_param("epochs", epochs)
    mlflow.log_param("learning_rate", lr)

In [6]:
model = T5ForSequenceClassification.from_pretrained(model_name, num_labels=num_classes)
model.to(device)

tokenizer = AutoTokenizer.from_pretrained(model_name)

Some weights of T5ForSequenceClassification were not initialized from the model checkpoint at t5-small and are newly initialized: ['classification_head.dense.bias', 'classification_head.dense.weight', 'classification_head.out_proj.bias', 'classification_head.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
db = Prisma(
    http={
        "timeout": None,
    },
)
db.connect()

dataset_ids = db.query_raw(query="select id from comments")
train_ids, validation_ids, test_ids = random_split(dataset_ids, [0.7, 0.05, 0.25])

In [8]:
thresholds = {
    "toxic": 0.7,
    "severe_toxic": 0.7,
    "obscene": 0.7,
    "threat": 0.7,
    "insult": 0.7,
    "identity_hate": 0.6,
}

In [9]:
class ToxicDataset(Dataset):
    def __init__(self, ids):
        self.ids = ids

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        id = self.ids[idx]
        row = db.comments.find_first(where=id)
        tokens = tokenizer(
            row.comment_text,
            return_tensors="pt",
            truncation=True,
            return_length=True,
            max_length=max_input_length,
            padding="max_length",
        )

        labels = [
            1 if row.toxic >= thresholds["toxic"] else 0,
            1 if row.severe_toxic >= thresholds["severe_toxic"] else 0,
            1 if row.obscene >= thresholds["obscene"] else 0,
            1 if row.threat >= thresholds["threat"] else 0,
            1 if row.insult >= thresholds["insult"] else 0,
            1 if row.identity_hate >= thresholds["identity_hate"] else 0,
        ]

        return (
            tokens["input_ids"],
            tokens["attention_mask"],
            torch.tensor(labels, dtype=torch.float32),
        )


train_dataset = ToxicDataset(list(train_ids))
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)

validation_dataset = ToxicDataset(list(validation_ids))
validation_dataloader = DataLoader(validation_dataset, batch_size, shuffle=True)

test_dataset = ToxicDataset(list(test_ids))
test_dataloader = DataLoader(test_dataset, batch_size, shuffle=True)

for i, (input, mask, output) in enumerate(train_dataloader):
    print("Input: ", input.shape)
    print("Mask: ", mask.shape)
    print("Output: ", output.shape)
    break

Input:  torch.Size([16, 1, 128])
Mask:  torch.Size([16, 1, 128])
Output:  torch.Size([16, 6])


In [10]:
loss_fn = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [11]:
training_losses = []
validation_losses = []


def run_prediction(input, attn_mask, output):
    input = input.squeeze(1).to(device)
    attn_mask = attn_mask.squeeze(1).to(device)
    predictions = model(input_ids=input, attention_mask=attn_mask)
    probabilities = torch.softmax(predictions.logits, dim=-1)

    loss = 0
    if output is not None:
        output = output.to(device)
        loss = loss_fn(probabilities, output)

    return probabilities, loss


total_training_loss = 0
total_validation_loss = 0
pbar = tqdm(range(epochs), desc="Epoch", position=0)
for epoch in pbar:
    model.train()

    for input, attn_mask, output in tqdm(
        train_dataloader,
        desc="Training Batch",
        position=1,
        leave=False,
    ):
        optimizer.zero_grad()
        _, loss = run_prediction(input, attn_mask, output)
        total_training_loss += loss.item()
        loss.backward()
        optimizer.step()

    avg_training_loss = total_training_loss / len(train_dataloader)
    training_losses.append(avg_training_loss)

    model.eval()
    for input, attn_mask, output in tqdm(
        validation_dataloader,
        desc="Validation Batch",
        position=1,
        leave=False,
    ):
        _, loss = run_prediction(input, attn_mask, output)
        total_validation_loss += loss.item()

    avg_validation_loss = total_validation_loss / len(validation_dataloader)
    validation_losses.append(avg_validation_loss)
    pbar.set_postfix(
        training_loss=avg_training_loss,
        validation_loss=avg_validation_loss,
    )

    plt.plot(training_losses, label="Training Loss")
    plt.plot(validation_losses, label="Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training/Validation Loss")
    plt.legend()
    plt.savefig("../target/training_loss.png")
    plt.close()

    if os.getenv("MLFLOW_ENABLED") == "true":
        mlflow.log_metric("training_loss", avg_training_loss, step=epoch)
        mlflow.pytorch.log_model(model, "t5-small-toxic")

    model.save_pretrained("../target/t5-small-toxicity")
    tokenizer.save_pretrained("../target/t5-small-toxicity")

Epoch: 100%|██████████| 10/10 [00:27<00:00,  2.73s/it, training_loss=1.91, validation_loss=1.84] 


In [12]:
model.eval()


def accuracy(y_true, y_pred):
    labels = 0
    for true, pred in zip(y_true, y_pred):
        if true == pred:
            labels += 1
    return labels / len(y_true)


def true_positives(y_true, y_pred):
    total_p = 0
    tp = 0
    for true, pred in zip(y_true, y_pred):
        if true == 1:
            total_p += 1
        if true == 1 and pred == 1:
            tp += 1
    if total_p == 0:
        return total_p
    return tp / total_p


def predict(input_ids, attn_mask, raw=False):
    # Forward pass
    with torch.no_grad():
        predictions, _ = run_prediction(input_ids, attn_mask, output=None)
        predictions = predictions.cpu().numpy()
        predictions = predictions[0]
    if raw:
        return predictions

    return [
        1 if predictions[0] >= thresholds["toxic"] else 0,
        1 if predictions[1] >= thresholds["severe_toxic"] else 0,
        1 if predictions[2] >= thresholds["obscene"] else 0,
        1 if predictions[3] >= thresholds["threat"] else 0,
        1 if predictions[4] >= thresholds["insult"] else 0,
        1 if predictions[5] >= thresholds["identity_hate"] else 0,
    ]

In [13]:
y_true = []
y_pred = []
for ids, mask, labels in list(test_dataset):
    y_true.append(labels)
    y_pred.append(predict(ids, mask))

toxic_y_true = [x[0] for x in y_true]
toxic_y_pred = [x[0] for x in y_pred]

severe_toxic_y_true = [x[1] for x in y_true]
severe_toxic_y_pred = [x[1] for x in y_pred]

obscene_y_true = [x[2] for x in y_true]
obscene_y_pred = [x[2] for x in y_pred]

threat_y_true = [x[3] for x in y_true]
threat_y_pred = [x[3] for x in y_pred]

insult_y_true = [x[4] for x in y_true]
insult_y_pred = [x[4] for x in y_pred]

identity_hate_y_true = [x[5] for x in y_true]
identity_hate_y_pred = [x[5] for x in y_pred]

accuracy_metrics = {
    "toxic_accuracy": accuracy(
        toxic_y_true,
        toxic_y_pred,
    ),
    "toxic_true_positives": true_positives(
        toxic_y_true,
        toxic_y_pred,
    ),
    "severe_toxic_accuracy": accuracy(
        severe_toxic_y_true,
        severe_toxic_y_pred,
    ),
    "severe_toxic_true_positives": true_positives(
        severe_toxic_y_true,
        severe_toxic_y_pred,
    ),
    "obscene_accuracy": accuracy(
        obscene_y_true,
        obscene_y_pred,
    ),
    "obscene_true_positives": true_positives(
        obscene_y_true,
        obscene_y_pred,
    ),
    "threat_accuracy": accuracy(
        threat_y_true,
        threat_y_pred,
    ),
    "threat_true_positives": true_positives(
        threat_y_true,
        threat_y_pred,
    ),
    "insult_accuracy": accuracy(
        insult_y_true,
        insult_y_pred,
    ),
    "insult_true_positives": true_positives(
        insult_y_true,
        insult_y_pred,
    ),
    "identity_hate_accuracy": accuracy(
        identity_hate_y_true,
        identity_hate_y_pred,
    ),
    "identity_hate_true_positives": true_positives(
        identity_hate_y_true,
        identity_hate_y_pred,
    ),
}
if os.getenv("MLFLOW_ENABLED") == "true":
    mlflow.log_metrics(accuracy_metrics)

print(accuracy_metrics)

{'toxic_accuracy': 1.0, 'toxic_true_positives': 0, 'severe_toxic_accuracy': 1.0, 'severe_toxic_true_positives': 0, 'obscene_accuracy': 1.0, 'obscene_true_positives': 0, 'threat_accuracy': 1.0, 'threat_true_positives': 0, 'insult_accuracy': 1.0, 'insult_true_positives': 0, 'identity_hate_accuracy': 1.0, 'identity_hate_true_positives': 0}


In [14]:
db.disconnect()

if os.getenv("MLFLOW_ENABLED") == "true":
    mlflow.end_run()

In [15]:
def run_inference(text):
    tokens = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        return_length=True,
        max_length=max_input_length,
        padding="max_length",
    )
    with torch.no_grad():
        predictions = run_prediction(
            tokens["input_ids"],
            tokens["attention_mask"],
            output=None,
        )
        predictions = predictions[0]
        return predictions.cpu().numpy().tolist()[0]


run_inference("Hello World")

[0.17869669198989868,
 0.14727136492729187,
 0.2174721211194992,
 0.1398240476846695,
 0.17705754935741425,
 0.1396782398223877]