In [1]:
from pathlib import Path

import numpy as np
import torch
from torch import nn
from torch.utils.data import SubsetRandomSampler, DataLoader
from transformers import AutoTokenizer

from classification.ClassificationDataset import ClassificationDataset
from classification.model.SentenceTransformerAndClassifier import SentenceTransformerAndClassifier
from utils import PROJECT_ROOT

In [2]:
MAX_LEN = 64
BATCH_SIZE = 512
SHUFFLE = True
SEED = 42
VALIDATION_SPLIT = 0.05
EPOCHS = 3

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device: ", device)

Using device:  cuda


In [4]:
def prepare_dataloaders(dataset: ClassificationDataset, validation_split: float):
    dataset_size = len(dataset)
    print("Dataset size: ", dataset_size)
    print(dataset.index2label)

    indices = list(range(dataset_size))
    split = int(np.floor(validation_split * dataset_size))
    if SHUFFLE:
        np.random.seed(SEED)
        np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]
    print("Train size: {}, Val size: {}".format(len(train_indices), len(val_indices)))

    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(val_indices)

    train_loader = DataLoader(dataset, batch_size=BATCH_SIZE,
                              sampler=train_sampler)
    validation_loader = DataLoader(dataset, batch_size=BATCH_SIZE,
                                   sampler=valid_sampler)
    return train_loader, validation_loader

In [5]:
def fit(model: SentenceTransformerAndClassifier, train_loader: DataLoader):
    epoch_loss = 0
    processed_samples = 0
    correct_classified_samples = 0

    model.train()
    for step, data in enumerate(train_loader, 0):
        input_ids = data["batch_encoding"]["input_ids"].to(device)
        attention_mask = data["batch_encoding"]["attention_mask"].to(device)
        input_ids = torch.squeeze(input_ids)
        attention_mask = torch.squeeze(attention_mask)
        targets = data["class_label"].to(device)

        outputs = model(input_ids, attention_mask)
        predictions = torch.argmax(outputs, dim=1)

        loss = criterion(outputs, targets)
        epoch_loss += loss.item()
        correct_classified_samples += torch.sum(predictions == targets).item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        processed_samples += len(data["class_label"])

    epoch_loss /= processed_samples
    epoch_accuracy = correct_classified_samples / processed_samples
    return {
        "loss": epoch_loss,
        "acc": epoch_accuracy
    }

In [6]:
def validate(model: SentenceTransformerAndClassifier, validation_loader: DataLoader):
    model.eval()
    with torch.no_grad():
        val_loss = 0
        processed_samples = 0
        correct_classified_samples = 0

        for step, data in enumerate(validation_loader, 0):
            input_ids = data["batch_encoding"]["input_ids"].to(device)
            attention_mask = data["batch_encoding"]["attention_mask"].to(device)
            input_ids = torch.squeeze(input_ids)
            attention_mask = torch.squeeze(attention_mask)
            targets = data["class_label"].to(device)

            outputs = model(input_ids, attention_mask)
            predictions = torch.argmax(outputs, dim=1)

            loss = criterion(outputs, targets)
            val_loss += loss.item()
            correct_classified_samples += torch.sum(predictions == targets).item()

            processed_samples += len(data["class_label"])

        val_loss /= processed_samples
        val_accuracy = correct_classified_samples / processed_samples
    return {
        "val_loss": val_loss,
        "val_acc": val_accuracy
    }

In [7]:
base_model = "sentence-transformers/paraphrase-mpnet-base-v2"
model = SentenceTransformerAndClassifier(base_model, n_classes=5)
tokenizer = AutoTokenizer.from_pretrained(base_model)
model.to(device)
model.describe_parameters()

+---------------------+------------+
|       Modules       | Parameters |
+---------------------+------------+
| linear_layer.weight |   393216   |
|  linear_layer.bias  |    512     |
|  classifier.weight  |    2560    |
|   classifier.bias   |     5      |
+---------------------+------------+
Total trainable parameters: 396293


396293

In [8]:

dataset = ClassificationDataset(Path.joinpath(PROJECT_ROOT, "data/processed"), tokenizer, MAX_LEN)
train_loader, validation_loader = prepare_dataloaders(dataset, VALIDATION_SPLIT)

Dataset size:  24995
{0: 'address', 1: 'company_name', 2: 'location', 3: 'physical_good', 4: 'serial_number'}
Train size: 23746, Val size: 1249


In [9]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters())

In [10]:
for epoch in range(EPOCHS):
    print("Epoch {}/{}".format(epoch + 1, EPOCHS))
    print("-" * 10)

    logs = fit(model, train_loader)

    print("Loss: {}, Acc: {}".format(
        logs["loss"],
        logs["acc"]
    ))

    val_logs = validate(model, validation_loader)
    print("Val Loss: {}, Val Acc: {}".format(
        val_logs["val_loss"],
        val_logs["val_acc"],
    ))

Epoch 1/3
----------
Loss: 0.0009428589104110064, Acc: 0.9145961425082119
Val Loss: 0.00017952515458087142, Val Acc: 0.9847878302642114
Epoch 2/3
----------
Loss: 0.00014538020519938975, Acc: 0.9801229680788344
Val Loss: 9.983928323841935e-05, Val Acc: 0.9879903923138511
Epoch 3/3
----------
Loss: 0.00011213524644375651, Acc: 0.984923776636065
Val Loss: 9.242845050471605e-05, Val Acc: 0.9895916733386709


In [11]:
torch.save(model.state_dict(), Path.joinpath(PROJECT_ROOT, "save_dict_model.pt").absolute())