In [58]:
import pandas as pd
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import DistilBertModel, DistilBertTokenizer
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import sys
from model import DistilBertClassifier
from dataset import NewsDataset
from sklearn import metrics
import matplotlib.pyplot as plt
import json

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
is_colab = 'google.colab' in sys.modules
is_train = False

In [3]:
if is_colab:
    from google.colab import drive
    drive.mount('./mnt')
    base_dir = './mnt/My Drive/coursework'
else:
    base_dir = '.'

In [4]:
dataframe = pd.read_csv(f'{base_dir}/data/findata.csv', encoding='cp1252', header=None)
dataframe.columns = ['sentiment', 'title']
if is_train:
    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', truncation=True, do_lower_case=True)
else:
    tokenizer = DistilBertTokenizer.from_pretrained(f'{base_dir}/savepoints/vocab.pt', truncation=True, do_lower_case=True)

In [5]:
train_df, test_df = train_test_split(dataframe, test_size=.2)
train_ds = NewsDataset(train_df, tokenizer, 128)
test_ds = NewsDataset(test_df, tokenizer, 128)

train_loader = DataLoader(train_ds, 16, shuffle=True)
test_loader = DataLoader(test_ds, 16, shuffle=True)

In [6]:
model = DistilBertClassifier().to(device)
model

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


DistilBertClassifier(
  (l1): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Line

In [7]:
if is_train:
    loss_fn = nn.BCEWithLogitsLoss()
    optim = torch.optim.Adam(model.parameters(), lr=2e-5)

    for epoch in range(1, 3):
        print(f'epoch {epoch:1}')
        epoch_loss = 0
        for bn, data in tqdm(enumerate(train_loader)):
            out = model(data['ids'].to(device), data['mask'].to(device), data['token_type_ids'].to(device))
            optim.zero_grad()
            
            loss = loss_fn(out, data['targets'].to(device))
            epoch_loss = loss.item()

            loss.backward()
            optim.step()
            if bn%50 == 0:
                print(f'\tloss: {epoch_loss/(bn+1):.5f}')
        print(f'epoch {epoch:1} loss {epoch_loss/bn:.5f}')
        torch.save(model.state_dict(), f'{base_dir}/savepoints/model.pt')
else:
    model.load_state_dict(torch.load(f'{base_dir}/savepoints/model.pt', map_location=device))

In [83]:
scores = {
    'r2': lambda y_true, y_pred: metrics.r2_score(y_true, y_pred).item(),
    'f1': lambda y_true, y_pred: metrics.f1_score(np.argmax(y_true, axis=1), np.argmax(y_pred, axis=1), average=None).tolist(),
    'accuracy': lambda y_true, y_pred: metrics.accuracy_score(np.argmax(y_true, axis=1), np.argmax(y_pred, axis=1)).item(),
    'prec-recall-perclass': lambda y_true, y_pred: {i: [x.tolist() for x in metrics.precision_recall_curve(y_true[:, i], y_pred[:, i])[:2]] for i in range(3)},
    'avgprecision-perclass': lambda y_true, y_pred: {i: metrics.average_precision_score(y_true[:, i], y_pred[:, i]) for i in range(3)},
    'prec-recall-micro': lambda y_true, y_pred: [x.tolist() for x in metrics.precision_recall_curve(y_true.ravel(), y_pred.ravel())[:2]],
    'avgprecision-micro': lambda y_true, y_pred: metrics.average_precision_score(y_true, y_pred, average='micro').item(),
}

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def write_precision_data(model, prefix, test_loader, device):
    model.eval()
    fin_targets=np.array([]).reshape(0, 3)
    fin_outputs=np.array([]).reshape(0, 3)
    with torch.no_grad():
        for _, data in tqdm(enumerate(test_loader)):
            ids = data['ids'].to(device)
            mask = data['mask'].to(device)
            token_type_ids = data['token_type_ids'].to(device)
            targets = data['targets'].to(device)
            outputs = model(ids, mask, token_type_ids)
            fin_targets = np.concatenate((fin_targets, targets.cpu().detach().numpy()))
            fin_outputs = np.concatenate((fin_outputs, outputs.cpu().detach().numpy()))

    results = {}
    for k, v in scores.items():
        results[k] = v(fin_targets, fin_outputs)

    with open(f'./savepoints/{prefix}-acc-scores.json', 'w') as f:
        json.dump(results, f)
        
    precision = {}
    recall = {}
    average_precision = {}

    for i in range(3):
        y_true = fin_targets[:, i]
        y_pred = sigmoid(fin_outputs[:, i])
        precision[i], recall[i], _ = metrics.precision_recall_curve(y_true, y_pred)
        average_precision[i] = metrics.average_precision_score(y_true, y_pred)

    precision["micro"], recall["micro"], _ = metrics.precision_recall_curve(
        fin_targets.ravel(), fin_outputs.ravel()
    )
    average_precision["micro"] = metrics.average_precision_score(fin_targets, fin_outputs, average="micro")
    
    display = metrics.PrecisionRecallDisplay(
        recall=recall["micro"],
        precision=precision["micro"],
        average_precision=average_precision["micro"],
    )
    display.plot()
    plt.savefig(f'./illustrations/precision-recall/{prefix}-microavg.svg')
    
    ax = plt.subplot()

    display = metrics.PrecisionRecallDisplay(
        recall=recall["micro"],
        precision=precision["micro"],
        average_precision=average_precision["micro"],
    )
    display.plot(ax=ax, name="Micro-average precision-recall", color="gold")

    for i in range(3):
        display = metrics.PrecisionRecallDisplay(
            recall=recall[i],
            precision=precision[i],
            average_precision=average_precision[i],
        )
        display.plot(ax=ax, name=f"Precision-recall for class {i}")

    handles, labels = display.ax_.get_legend_handles_labels()
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.legend(handles=handles, labels=labels, loc="best")
    ax.set_title("Extension of Precision-Recall curve to multi-class")
    plt.savefig(f'./illustrations/precision-recall/{prefix}-multiclass.svg')