In [None]:
%load_ext autoreload
%autoreload 2

from dotenv import load_dotenv

load_dotenv("../.env")

In [None]:
import json
import os

import matplotlib.pyplot as plt
import mlflow
import torch
from torch.utils.data import DataLoader, Dataset, random_split
from tqdm.notebook import tqdm
from transformers import AutoTokenizer

from such_toxic.text_classifier import TextClassifier

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

device

In [None]:
embedded_dataset_path = "../target/wiki_data_all_human.json"

sentence_transformer_model = "sentence-transformers/all-MiniLM-L6-v2"
experiment_name = "such-toxic"
model_name = "such_toxic"

max_length = 512
num_classes = 6
batch_size = 32
epochs = 15
lr = 0.001

In [None]:
if os.getenv("MLFLOW_ENABLED") == "true":
    print("Starting run")
    mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI"))
    mlflow.set_experiment("such-toxic")
    mlflow.pytorch.autolog()
    mlflow.start_run()

    mlflow.log_param("model_name", model_name)
    mlflow.log_param("num_classes", num_classes)
    mlflow.log_param("batch_size", batch_size)
    mlflow.log_param("epochs", epochs)
    mlflow.log_param("learning_rate", lr)

In [None]:
with open(embedded_dataset_path) as f:
    dataset = json.loads(f.read())

if os.getenv("MLFLOW_ENABLED") == "true":
    mlflow.log_artifact(embedded_dataset_path)

In [None]:
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")


class SuchToxicDataset(Dataset):
    def __init__(self, rows):
        self.rows = rows

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, idx):
        row = self.rows[idx]
        encoded_input = tokenizer(
            row["comment_text"],
            padding="max_length",
            truncation=True,
            return_tensors="pt",
            max_length=max_length,
        )
        return (
            encoded_input["input_ids"],
            encoded_input["attention_mask"],
            torch.tensor(
                [
                    row["toxic"],
                    row["severe_toxic"],
                    row["obscene"],
                    row["threat"],
                    row["insult"],
                    row["identity_hate"],
                ],
                dtype=torch.float32,
            ),
        )


train, test = random_split(dataset, [0.8, 0.2])

train_dataset = SuchToxicDataset(list(train))
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

for batch in train_loader:
    print("input_ids: ", batch[0].shape)
    print("attention_mask: ", batch[1].shape)
    print("labels: ", batch[2].shape)
    break

In [None]:
loss_fn = torch.nn.BCELoss()
model = TextClassifier(num_classes=num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

model

In [None]:
losses = []

pbar = tqdm(range(epochs), desc="Epoch")
pbar_trainloader = tqdm(
    train_loader,
    desc="Train Loader",
    leave=False,
    total=len(train_loader),
)
for epoch in pbar:
    model.train()

    total_loss = 0.0
    for input_ids, attn_mask, labels in pbar_trainloader:
        optimizer.zero_grad()

        input_ids = input_ids.squeeze(1).to(device)
        attn_mask = attn_mask.squeeze(1).to(device)
        labels = labels.to(device)

        prediction = model(input_ids, attn_mask)
        loss = loss_fn(prediction, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    losses.append(avg_loss)
    pbar.set_postfix(loss=avg_loss)

    if os.getenv("MLFLOW_ENABLED") == "true":
        mlflow.log_metric("loss", avg_loss, step=epoch)

In [None]:
plt.plot(losses, marker="o")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss Curve")
plt.savefig("../target/training_loss.png")

In [None]:
torch.save(
    {
        "epoch": epochs,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    },
    f"../checkpoints/{model_name}.pth",
)

if os.getenv("MLFLOW_ENABLED") == "true":
    mlflow.pytorch.log_model(model, "such-toxic")

In [None]:
def accuracy(y_true, y_pred, threshold=0.8):
    labels = 0
    for true, pred in zip(y_true, y_pred):
        pred = 1 if pred >= threshold else 0
        if true == pred:
            labels += 1
    return labels / len(y_true)

In [None]:
model.eval()

total_accuracy = 0
test_dataset = SuchToxicDataset(list(test)[:10])

for input_ids, attn_mask, label in tqdm(test_dataset, desc="Test Dataset"):
    input_ids = input_ids.to(device)
    attn_mask = attn_mask.to(device)
    label = label.to(device)

    prediction = model(input_ids, attn_mask)
    output = prediction.cpu().detach().numpy().tolist()[0]
    total_accuracy += accuracy(label, output)

print(total_accuracy / len(test))

In [None]:
test_str = "What a bunch of fucking nerds..."


tokens = tokenizer(
    test_str,
    padding="max_length",
    truncation=True,
    return_tensors="pt",
    max_length=max_length,
)

input_ids = tokens["input_ids"].squeeze(1).to(device)
attn_mask = tokens["attention_mask"].squeeze(1).to(device)

output = model(input_ids, attn_mask)
output = output.cpu().detach().numpy().tolist()[0]

print("Toxic: ", output[0])
print("Severe Toxic: ", output[1])
print("Obscene: ", output[2])
print("Threat: ", output[3])
print("Insult: ", output[4])
print("Identity Hate: ", output[5])

In [None]:
if os.getenv("MLFLOW_ENABLED") == "true":
    mlflow.end_run()