# Setup

In [None]:
%load_ext autoreload
%autoreload 2

import sys 
import os
import torch
import numpy as np
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

In [None]:
current_save_model_path = "../models/A"
current_load_model_path = "../models/B"

# Loading Jigsaw Dataset & Fine Tuning BERT

In [None]:
from transformers import BertForSequenceClassification, BertTokenizer, DistilBertTokenizer, DistilBertForSequenceClassification
bert_fn = DistilBertForSequenceClassification
tokenizer_fn = DistilBertTokenizer
model_name = "distilbert-base-uncased"

## Load Pre-trained Fine-Tuned BERT

In [None]:
from bert import FinetunedBert
finetuned_bert = FinetunedBert(model_name=model_name, device=device, bert_fn=bert_fn, tokenizer_fn=tokenizer_fn)
tokenizer = finetuned_bert.load(path=current_load_model_path+"_bert")
finetuned_bert_model = finetuned_bert.model

## Train a new BERT

In [None]:
tokenizer = tokenizer_fn.from_pretrained(model_name, num_labels=2, ignore_mismatched_sizes=True)

In [None]:
from datasets.jigsaw import GetDataLoader, GetTestDataLoader

jigsaw_train_dataloader, total_length = GetDataLoader(tokenizer, device=device, n=5000)
jigsaw_test_dataloader, total_length_2 = GetTestDataLoader(tokenizer, device=device, n=100)

print(f"train dataset length: {total_length}\ntest dataset length: {total_length_2}")

In [None]:
from bert import FinetunedBert
finetuned_bert = FinetunedBert(model_name=model_name, lr=0.00005, device=device, bert_fn=bert_fn, tokenizer_fn=tokenizer_fn)

In [None]:
finetuned_bert.train(jigsaw_train_dataloader, epochs=3)

In [None]:
print(f"validation Accuracy: {finetuned_bert.accuracy(jigsaw_test_dataloader) * 100:.2f}%")

In [None]:
finetuned_bert_model = finetuned_bert.model

In [None]:
try:
    del jigsaw_train_dataloader
except:
    pass
try:
    del jigsaw_test_dataloader
except:
    pass
finetuned_bert.model.to("cpu")
torch.cuda.empty_cache()

print("emptied")

# Prototypical Network

In [None]:
from utils import EmbedUserInput, GetFewShotDataLoader
from datasets.jigsaw import GetInputAndLabels, GetInputAndLabelsByClass

EmbedWrapper = lambda X: EmbedUserInput(tokenizer, finetuned_bert_model, X, device=device)

In [None]:
n_way = 2 # Number of classes, we want biclassification (toxic / non-toxic)

## Training input

In [None]:
k_shot = 4
q_queries = 9

In [None]:
training_inputs, training_labels, _ =  GetInputAndLabels(n=400) # GetInputAndLabels(n=400, toxicity_level=3)

training_embedding = EmbedWrapper(training_inputs)
print(training_embedding.shape, len(training_labels))

training_loader = GetFewShotDataLoader(training_embedding, training_labels, n_way=n_way, k_shot=k_shot, q_queries=q_queries, device=device)
print("made loader")

## Training loop

In [None]:
from model import PrototypicalNetwork
from solver import Solver

In [None]:
proto_net = PrototypicalNetwork()
solver = Solver(proto_net, n_way=n_way, lr=0.0001, device=device)

In [None]:
solver.load(path=current_load_model_path+"_proto_net.pth")

In [None]:
solver.train(training_loader, n_epochs=25, output_file="proto_loss_0001.csv")

In [None]:
try:
    del training_inputs
except:
    pass
try:
    del training_labels
except:
    pass
torch.cuda.empty_cache()

print("emptied")

## Evaluation

In [None]:
test_inputs, test_labels, _ = GetInputAndLabels(type="test", n=1000)
test_embedding = EmbedWrapper(test_inputs)

k_shot_test = 50
q_queries_test = 9
testing_loader = GetFewShotDataLoader(test_embedding, test_labels, n_way=n_way, k_shot=k_shot_test, q_queries=q_queries_test)
print("made loader")

In [None]:
solver.evaluate(testing_loader)

In [None]:
try:
    del test_inputs
    del test_embedding
except:
    pass
try:
    del test_labels
except:
    pass
torch.cuda.empty_cache()

## Similarity evaluation

In [None]:
similarity_inputs = GetInputAndLabelsByClass(type="test")
n_input_per_toxic_class = 14
non_toxic = similarity_inputs["non_toxic"]

k_shot_s_test = 4
q_queries_s_test = 9

for category, inputs in similarity_inputs.items():
    if category == "non_toxic":
        continue

    if len(inputs) == 0:
        print(f"skipping category '{category}' because it has no inputs")
        continue

    le = min(len(inputs), n_input_per_toxic_class)
    toxic_inputs = inputs[:le]
    if len(non_toxic) < le:
        raise ValueError(f"not enough non-toxic samples: needed {le}, found {len(non_toxic)}")

    combined_inputs = toxic_inputs + non_toxic[:le]
    combined_labels = ([1] * le) + ([0] * le)

    if not combined_inputs or le < k_shot_s_test + q_queries_s_test:
        print(f"skipping category '{category}' because combined inputs are empty or too small")
        continue
    
    inputs_embeddings = EmbedWrapper(combined_inputs)

    testing_loader = GetFewShotDataLoader(
            inputs_embeddings,
            combined_labels,
            n_way=n_way,
            k_shot=k_shot_s_test,
            q_queries=q_queries_s_test,
        )

    print(f"made loader for {category}, length is {le}")
    solver.evaluate(testing_loader)

del similarity_inputs

# User-Testing

In [None]:
support = ["circles are bad", "squares are good"]
labels = torch.tensor([1, 0]) # 1 = offensive, 0 = not offensive
query = ["circles are whack"]

pred_labels = solver.predict(EmbedWrapper(support), labels, EmbedWrapper(query))
outcome = "sensitive" if pred_labels.item() == 1 else "not sensitive"
print(f"predicted: {outcome}")



# Save the model

In [None]:
finetuned_bert.save(tokenizer, optimizers=True, path=current_save_model_path+"_bert")

In [None]:
solver.save(optimizers=True, path=current_save_model_path+"_proto_net.pth")