In [58]:
import numpy as np
import spacy
import torch
from datasets import load_dataset, Dataset, concatenate_datasets
from torch import nn

In [59]:
def determine_device():
    if torch.cuda.is_available():
        return 'cuda'
    elif torch.backends.mps.is_available():
        return 'mps'
    else:
        return 'cpu'


device = determine_device()
print(f'Device is {device}')

Device is mps


In [61]:
from torch.utils.data import DataLoader
from util import BucketBatchSampler, BucketDataset

nlp = spacy.load('en_core_web_sm')
dataset = load_dataset('Sp1786/multiclass-sentiment-analysis-dataset')
train_dataset: Dataset = dataset['train']
validation_dataset: Dataset = dataset['validation']
test_dataset: Dataset = dataset['test']

corpus = concatenate_datasets([train_dataset, validation_dataset])['text']
vocabulary = sorted(set(''.join(corpus)))

char_to_i = {u: i for i, u in enumerate(vocabulary)}


def encode_x(char_to_i, message):
    return np.array([char_to_i[char] for char in message])


def encode_y(label):
    vector = torch.zeros(3)
    vector[label] = 1
    return vector


train_messages = [encode_x(char_to_i, message) for message in train_dataset['text']]
train_labels = [encode_y(label) for label in train_dataset['label']]

train_bucket_batch_sampler = BucketBatchSampler(train_messages, 128)  # <-- does not store X
train_bucket_dataset = BucketDataset(train_messages, train_labels)
train_dataloader = DataLoader(train_bucket_dataset, batch_size=1, batch_sampler=train_bucket_batch_sampler, num_workers=8, drop_last=False)


In [69]:
# See https://docs.pytorch.org/docs/stable/generated/torch.nn.GRU.html.
class Model(nn.Module):

    def __init__(self, vocabulary_size, embedding_dim, hidden_size):
        super(Model, self).__init__()
        self.embedding = nn.Embedding(num_embeddings=vocabulary_size, embedding_dim=embedding_dim)
        self.rnn = nn.GRU(input_size=embedding_dim, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, 3)

    def forward(self, sequence, hidden):
        # (batch_size, sequence_length) -> (batch_size, sequence_length, embedding_dim)
        embedded = self.embedding(sequence)
        # (batch_size, sequence_length, embedding_dim)
        # -> (batch_size, sequence_length, hidden_size), (num_layers, batch_size, hidden_size)
        prediction, hidden = self.rnn(embedded, hidden)
        # See https://docs.pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html#creating-the-network.
        return self.linear(hidden[0]), hidden

In [70]:
embedding_dim = 256
hidden_size = 1024
num_layers = 1
model = Model(len(vocabulary), embedding_dim, hidden_size).to(device)
loss_fn = nn.CrossEntropyLoss(reduction="mean")
optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)

# See https://docs.pytorch.org/tutorials/beginner/introyt/trainingyt.html#the-training-loop.
for i, data in enumerate(train_dataloader):
    x, y = data
    batch_size = x.shape[0]
    h0 = torch.zeros(num_layers, batch_size, hidden_size).to(device)
    # See https://discuss.pytorch.org/t/gru-and-padded-sequences-tipps-and-tricks/90729.
    prediction, _ = model.forward(x.to(device), h0)

    loss = loss_fn(prediction, y.to(device))
    print(f'Batch {i} of size {batch_size}, loss {loss.item()}')

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Batch 0 of size 128, loss 1.1001335382461548
Batch 1 of size 8, loss 3.8926305770874023
Batch 2 of size 90, loss 1.6588584184646606
Batch 3 of size 26, loss 1.4633713960647583
Batch 4 of size 128, loss 1.1869010925292969
Batch 5 of size 103, loss 1.0807524919509888
Batch 6 of size 1, loss 0.6781896948814392
Batch 7 of size 16, loss 1.9421859979629517
Batch 8 of size 128, loss 1.5011545419692993
Batch 9 of size 4, loss 1.7138030529022217
Batch 10 of size 7, loss 1.9737110137939453
Batch 11 of size 128, loss 1.7998613119125366
Batch 12 of size 1, loss 0.8564717769622803
Batch 13 of size 128, loss 1.4678044319152832
Batch 14 of size 5, loss 3.0721302032470703
Batch 15 of size 19, loss 1.5426909923553467
Batch 16 of size 128, loss 1.4457134008407593
Batch 17 of size 8, loss 1.453172206878662
Batch 18 of size 8, loss 1.4024332761764526
Batch 19 of size 1, loss 2.130967140197754
Batch 20 of size 36, loss 1.6565011739730835
Batch 21 of size 1, loss 1.4566657543182373
Batch 22 of size 128, los

KeyboardInterrupt: 