In [None]:
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from tqdm.notebook import tqdm

from transformers import AutoModelForSequenceClassification, AutoTokenizer

device = 'cuda' if torch.cuda.is_available() else 'mps'
print('device:', device)

In [None]:
class JokesDataset(Dataset):
    def __init__(self, path):
        self.df = pd.read_csv(path, sep='\t', header=None, names=['score', 'joke'], usecols=[0, 1])
        
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, idx):
        return self.df.iloc[idx, 1], self.df.iloc[idx, 0]
    
train_dataset = JokesDataset('data/train.tsv')
dev_dataset = JokesDataset('data/dev.tsv')
test_dataset = JokesDataset('data/test.tsv')

In [None]:
checkpoint = 'bert-base-uncased'

model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=1).to(device)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [None]:
learning_rate = 3e-5
batch_size = 32
epochs = 4

loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
dev_dataloader = DataLoader(dev_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

In [None]:
def train_loop(dataloader, model, loss_fn, optimizer):
    model.train()
    size = len(dataloader.dataset)
    progress_bar = tqdm(range(size))
    for batch, (X, y) in enumerate(dataloader):
        encoded_input = tokenizer(list(X), padding='max_length', truncation=True, return_tensors='pt').to(device)
        y = y.type(torch.float).to(device)
        # Compute prediction and loss
        pred = model(**encoded_input)
        loss = loss_fn(pred.logits.squeeze(dim=1), y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
        progress_bar.update(len(X))
            

def test_loop(dataloader, model, loss_fn):
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        progress_bar = tqdm(range(size))
        for X, y in dataloader:
            encoded_input = tokenizer(list(X), padding='max_length', truncation=True, return_tensors='pt').to(device)
            y = y.type(torch.float).to(device)
            
            pred = model(**encoded_input)
            test_loss += loss_fn(pred, y).item()
            progress_bar.update(len(X))

    test_loss /= num_batches
    correct /= size
    print(f"Avg dev loss: {test_loss:>8f} \n")

In [None]:
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(dev_dataloader, model, loss_fn)
print("Done!")

In [None]:
torch.save(model, 'models/model.pth')