In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoTokenizer

from rate_severity_of_toxic_comments.utilities import *
from rate_severity_of_toxic_comments.model import *
from rate_severity_of_toxic_comments.dataset import *
from rate_severity_of_toxic_comments.training import * 

# For descriptive error messages
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [None]:
import wandb

In [None]:
CONFIG = {
    "seed": 42,
    "epochs": 10,
    "model_name": "roberta-base",
    "train_batch_size": 32,
    "valid_batch_size": 64,
    "learning_rate": 1e-4,
    "max_length": 128,
    "margin": 0.5,
    "use_gpu": False,
    "run_mode": "test"
}

CONFIG["tokenizer"] = AutoTokenizer.from_pretrained(CONFIG['model_name'])
# CONFIG['group'] = f'{HASH_NAME}-Baseline'

In [None]:
df = pd.read_csv("res/data/validation_data.csv")
df = df.sample(20)
data_size = len(df.index)

In [None]:
train_split = 0.7
threshold_index = int(train_split * data_size)

df_train = df[:threshold_index].reset_index(drop=True)
df_valid = df[threshold_index:].reset_index(drop=True)
training_data, val_data = build_datasets([df_train, df_valid], CONFIG)

In [None]:
run = wandb.init(project="rate-comments",
    entity="toxicity",
    config=CONFIG,
    job_type='Train',
    # group="", TODO?
    tags=[CONFIG["run_mode"]])

wandb.run.name = CONFIG["run_mode"] + "-" + wandb.run.id
wandb.run.save()

device = torch.device("cuda" if torch.cuda.is_available() and CONFIG["use_gpu"] else "cpu")
loss_fn = nn.MarginRankingLoss(margin=CONFIG['margin'])

train_loader, valid_loader = build_dataloaders([training_data, val_data], batch_sizes=(CONFIG["train_batch_size"], CONFIG["valid_batch_size"]))

model = create_model(CONFIG)
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=CONFIG["learning_rate"])

stats = run_training(train_loader, valid_loader, model, loss_fn, optimizer, device, 
    CONFIG["epochs"], log_interval=10, verbose=True)
    
run.finish()