In [1]:
# %load data_preparation/trigram/v0.py
import torch
import torch.optim as optim
import torch.nn.functional as F

# create a classifier class that inherits from nn.Module
class TrigramClassifier(torch.nn.Module):
    def __init__(self):
        super(TrigramClassifier, self).__init__()
        self.W = torch.nn.Parameter(torch.randn((27,27,27), generator=g, requires_grad=True))

    # x here is no longer a one-hot encoded vector, instead we must select the row of W that corresponds to the index of the letter
    def forward(self, x):
        return self.W[x]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

g = torch.Generator().manual_seed(42)

words = open('names.txt').read().splitlines()
letters = sorted(list(set(''.join(words))))
letter_to_index = {letter: index for index, letter in enumerate(letters)}
letter_to_index['.'] = 0
index_to_letter = {i: letter for letter, i in letter_to_index.items()}

trigram_xs_train, trigram_ys_train = [], []

trigram_validation_words, trigram_test_words = [], []

trigram_trainValTestSplit = [0.8, 0.1, 0.1]

indices = torch.randperm(len(words), generator=g)

for w in words:
    chs = ['.'] + list(w) + ['.']
    # create trigrams from the words
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        xs_train.append((letter_to_index[ch1], letter_to_index[ch2]))
        ys_train.append(letter_to_index[ch3])

trigram_xs = torch.as_tensor(xs)
trigram_ys = torch.as_tensor(ys)

trigram_dataset = TensorDataset(xs, ys)

trigram_train_ratio = .8
trigram_validation_ratio = .1

trigram_n_total = len(dataset)
trigram_n_train = int(n_total * train_ratio)
trigram_n_train_batch=n_train
trigram_n_validation = int(n_total * validation_ratio)
trigram_n_validation_batch=n_validation
trigram_n_test = n_total - n_train - n_validation

trigram_train_data, trigram_validation_data, trigram_test_data = random_split(dataset, [n_train, n_validation, n_test])

trigram_train_loader = DataLoader(train_data, batch_size=n_train_batch, shuffle=True)
trigram_validation_loader = DataLoader(validation_data, batch_size=n_validation_batch, shuffle=True)
trigram_test_loader = DataLoader(test_data, batch_size=n_test, shuffle=True)

In [2]:
# %load model_configuration/trigram/v0.py
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

def smoothing():
    return + 0.01*(trigram_model.W**2).mean()

def make_trigram_train_step_fn(model, loss_fn, optimizer):
    def train_step(x, y):
        model.train()
        yhat = model(x)
        loss = loss_fn(yhat, y) + smoothing()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        return loss.item()
    return train_step

# make a validation step function
def make_trigram_validation_step_fn(model, loss_fn):
    def validation_step(x, y):
        model.eval()
        # NO GRADIENTS IN VALIDATION
        with torch.no_grad():
            yhat = model(x)
            loss = loss_fn(yhat, y)
            return loss.item()
    return validation_step

trigram_lr = 50
trigram_momentum = 0.9
trigram_model = TrigramClassifier()

trigram_optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
trigram_loss_fn = torch.nn.CrossEntropyLoss()

trigram_train_step_fn = make_train_step_fn(model, loss_fn, optimizer)
trigram_validation_step_fn = make_validation_step_fn(model, loss_fn)

trigram_writer = SummaryWriter('runs/trigram_classifier')
trigram_x_dummy, trigram_y_dummy = next(iter(trigram_train_loader))
writer.add_graph(trigram_model, trigram_x_dummy.to(device))

In [3]:
# %load model_training/trigram/v0.py
import torch
import torch.nn.functional as F
import numpy

trigram_epochs = 10
trigram_losses = []

def trigram_mini_batch(device, data_loader, stepn_fn):
    mini_batch_losses = []
    for x_batch, y_batch in data_loader:
        loss = stepn_fn(x_batch, y_batch)
        mini_batch_losses.append(loss)
    return numpy.mean(mini_batch_losses)

for epoch in range(trigram_epochs):
    loss = trigram_mini_batch(device, train_loader, train_step_fn)
    trigram_losses.append(loss)

    trigram_writer.add_scalars(main_tag=f'TRIGRAM&lr={trigram_lr}&momentum={trigram_momentum}&epochs={trigram_epochs}&batch={trigram_n_train_batch}&smoothing=0.01', tag_scalar_dict={'training': loss}, global_step=epoch)

writer.close()

checkpoint = {
    'epoch': trigram_epochs,
    'lr': trigram_lr,
    'momentum': trigram_momentum,
    'smoothing': '0.01+W**2.mean()',
    'model_state_dict': trigram_model.state_dict(),
    'optimizer_state_dict': trigram_optimizer.state_dict(),
    'loss': trigram_losses
}

torch.save(checkpoint, 'trigram_checkpoint.pth')
# print last losses value
print(f'Final trigram training loss: {trigram_losses[-1]}')