In [1]:
import json
from pathlib import Path

import numpy as np
import torch
from IPython.core.display import display
from ipywidgets import Output
from livelossplot import PlotLosses
from sklearn.metrics import f1_score
from torch import nn
from torch.utils.data import SubsetRandomSampler, DataLoader
from tqdm import tqdm
from transformers import AutoTokenizer

from classification.ClassificationDataset import ClassificationDataset
from classification.model.SentenceTransformerAndClassifier import SentenceTransformerAndClassifier
from utils import PROJECT_ROOT

In [2]:
MAX_LEN = 64
BATCH_SIZE = 512
SHUFFLE = True
SEED = 42
VALIDATION_SPLIT = 0.05
EPOCHS = 4
MAX_DATASET_SIZE = None

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device: ", device)

Using device:  cuda


In [4]:
def prepare_dataloaders(dataset: ClassificationDataset, validation_split: float):
    dataset_size = len(dataset)
    print("Full dataset size: ", dataset_size)
    print(dataset.index2label)

    indices = list(range(dataset_size))
    if SHUFFLE:
        np.random.seed(SEED)
        np.random.shuffle(indices)
    if MAX_DATASET_SIZE is not None:
        dataset_size = MAX_DATASET_SIZE
        indices = indices[:MAX_DATASET_SIZE]
        print("Used dataset size: ", dataset_size)
    split = int(np.floor(validation_split * dataset_size))
    train_indices, val_indices = indices[split:], indices[:split]
    print("Train size: {}, Val size: {}".format(len(train_indices), len(val_indices)))

    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(val_indices)

    train_loader = DataLoader(dataset, batch_size=BATCH_SIZE,
                              sampler=train_sampler)
    validation_loader = DataLoader(dataset, batch_size=BATCH_SIZE,
                                   sampler=valid_sampler)
    return train_loader, validation_loader

In [5]:
def run_epoch(model: SentenceTransformerAndClassifier, dataloaders, epoch: int, loss_fn):
    logs = {}
    for phase in ["train", "val"]:
        with tqdm(dataloaders[phase], unit="batch") as tepoch:
            if phase == "train":
                model.train()
            else:
                model.eval()

            total_epoch_loss = 0
            total_correct_classified_samples = 0
            total_processed_samples = 0
            total_f1_score = 0

            for step, data in enumerate(tepoch):
                description_prefix = "Epoch" if phase == "train" else "Validation"
                tepoch.set_description(f"{description_prefix} {epoch + 1}")

                input_ids = data["batch_encoding"]["input_ids"].to(device)
                attention_mask = data["batch_encoding"]["attention_mask"].to(device)
                input_ids = torch.squeeze(input_ids)
                attention_mask = torch.squeeze(attention_mask)
                targets = data["class_label"].to(device)
                actual_batch_size = len(targets)

                with torch.set_grad_enabled(phase == "train"):
                    outputs = model(input_ids, attention_mask)
                    predictions = torch.argmax(outputs, dim=1)

                    mean_batch_loss = loss_fn(outputs, targets)
                    total_batch_loss = mean_batch_loss.item() * actual_batch_size
                    total_epoch_loss += total_batch_loss

                    correct_classified_samples = torch.sum(predictions == targets).item()
                    batch_accuracy = correct_classified_samples / actual_batch_size
                    total_correct_classified_samples += correct_classified_samples

                    batch_f1_score = f1_score(targets.cpu(), predictions.cpu(), average="macro")
                    total_f1_score += batch_f1_score * actual_batch_size

                    optimizer.zero_grad()
                    if phase == "train":
                        mean_batch_loss.backward()
                        optimizer.step()

                    total_processed_samples += actual_batch_size
                    tepoch.set_postfix(loss=mean_batch_loss.item(), accuracy=batch_accuracy, f1_score=batch_f1_score)

            mean_epoch_loss = total_epoch_loss / total_processed_samples
            epoch_accuracy = total_correct_classified_samples / total_processed_samples
            mean_f1_score = total_f1_score / total_processed_samples

            logs[f"{phase}_loss"] = mean_epoch_loss
            logs[f"{phase}_acc"] = epoch_accuracy
            logs[f"{phase}_f1_score"] = mean_f1_score
    return logs

In [6]:
base_model = "sentence-transformers/paraphrase-mpnet-base-v2"
model = SentenceTransformerAndClassifier(base_model, n_classes=5)
tokenizer = AutoTokenizer.from_pretrained(base_model)
model.to(device)
model.describe_parameters()

+---------------------+------------+
|       Modules       | Parameters |
+---------------------+------------+
| linear_layer.weight |   393216   |
|  linear_layer.bias  |    512     |
|  classifier.weight  |    2560    |
|   classifier.bias   |     5      |
+---------------------+------------+
Total trainable parameters: 396293


In [7]:
dataset = ClassificationDataset(Path.joinpath(PROJECT_ROOT, "data/processed"), tokenizer, MAX_LEN)
with open(Path.joinpath(PROJECT_ROOT, "class2label.json"), "w") as file:
    json.dump(dataset.index2label, file)
train_loader, validation_loader = prepare_dataloaders(dataset, VALIDATION_SPLIT)
dataloaders = {
    "train": train_loader,
    "val": validation_loader
}

Full dataset size:  134423
{0: 'address', 1: 'company_name', 2: 'location', 3: 'physical_good', 4: 'serial_number'}
Train size: 127702, Val size: 6721


In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters())

In [9]:
groups = {
    "acccuracy": ["train_acc", "val_acc"],
    "loss": ["train_loss", "val_loss"],
    "f1_score": ["train_f1_score", "val_f1_score"]
}

plotlosses = PlotLosses(groups=groups)

GRAPHS = Output()
display(GRAPHS)

Output()

In [10]:
for epoch in range(EPOCHS):
    logs = run_epoch(model, dataloaders, epoch, criterion)
    print("Loss: {:.3f}, Acc: {:.3f}, F1 score: {:.3f}".format(
        logs["train_loss"],
        logs["train_acc"],
        logs["train_f1_score"]
    ))
    print("Val Loss: {:.3f}, Val Acc: {:.3f}, Val F1 score: {:.3f}".format(
        logs["val_loss"],
        logs["val_acc"],
        logs["val_f1_score"],
    ))
    print()

    with GRAPHS:
        plotlosses.update(logs)
        plotlosses.send()

Epoch 1: 100%|██████████| 250/250 [26:01<00:00,  6.25s/batch, accuracy=0.963, f1_score=0.962, loss=0.124] 
Validation 1: 100%|██████████| 14/14 [01:21<00:00,  5.79s/batch, accuracy=1, f1_score=1, loss=0.00625]       

Loss: 0.150, Acc: 0.966, F1 score: 0.966
Val Loss: 0.037, Val Acc: 0.988, Val F1 score: 0.988




Epoch 2: 100%|██████████| 250/250 [26:11<00:00,  6.29s/batch, accuracy=0.981, f1_score=0.981, loss=0.0884]
Validation 2: 100%|██████████| 14/14 [01:20<00:00,  5.78s/batch, accuracy=1, f1_score=1, loss=0.0104]        

Loss: 0.053, Acc: 0.984, F1 score: 0.984
Val Loss: 0.030, Val Acc: 0.989, Val F1 score: 0.989




Epoch 3: 100%|██████████| 250/250 [26:08<00:00,  6.27s/batch, accuracy=0.981, f1_score=0.981, loss=0.0942]
Validation 3: 100%|██████████| 14/14 [01:21<00:00,  5.79s/batch, accuracy=0.985, f1_score=0.983, loss=0.0417]

Loss: 0.046, Acc: 0.985, F1 score: 0.985
Val Loss: 0.025, Val Acc: 0.992, Val F1 score: 0.992




Epoch 4: 100%|██████████| 250/250 [26:11<00:00,  6.28s/batch, accuracy=1, f1_score=1, loss=0.0127]        
Validation 4: 100%|██████████| 14/14 [01:21<00:00,  5.80s/batch, accuracy=0.985, f1_score=0.987, loss=0.0641]

Loss: 0.041, Acc: 0.987, F1 score: 0.987
Val Loss: 0.024, Val Acc: 0.992, Val F1 score: 0.992






In [11]:
torch.save(model.state_dict(), Path.joinpath(PROJECT_ROOT, "classification_model.pt").absolute())