In [1]:
import torch
import numpy as np
import torchvision
from torch.utils.data import Dataset, Subset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from torch import nn
from torch.utils.data import DataLoader, TensorDataset, random_split
from torchvision import datasets, transforms
import transformer as transformer
import pickle

import matplotlib.pyplot as plt


SEQ_LENGTH = 20
PRINT_EVERY = 10
N_EPOCHS = 10
LR = 1e-3
ES_PATIENCE = 10
BATCH_SIZE = 2**7

  Referenced from: '/Users/hollymandel/miniconda3/lib/python3.10/site-packages/torchvision/image.so'
  warn(


In [2]:
with open("shakespeare_indices.pkl", 'rb') as file:
    ind = np.asarray(pickle.load(file))
with open("shakespeare_token_dict.pkl", 'rb') as file:
    token_dict = pickle.load(file)  
vocab_size = len(token_dict.keys()) + 1

In [3]:
import pickle

class shakes_data(Dataset):
    def __init__(self, ind, N = SEQ_LENGTH):
        overhang = ind.shape[0] % N
        resid_len = ind.shape[0] - overhang
        data = ind[overhang:]
        data = data.reshape(int(resid_len/N),N)
        self.data = data[:,:-1]
        self.labels = data[:,1:]
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        x = self.data[index]
        y = self.labels[index]
        return x, y
    
sh = shakes_data(ind, SEQ_LENGTH)

In [4]:
def one_hotify(vec: np.ndarray, vocab_size):
    oh = np.zeros([len(vec), vocab_size])
    for i in range(vec.size):
        oh[i, vec[i]] = 1
    return oh
    
class EarlyStopping:
    def __init__(self, patience=ES_PATIENCE, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_loss = np.infty
        self.early_stop = False
        self.delta = delta
        self.best_model = None

    def __call__(self, loss, model):
        if loss > self.best_loss + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.counter = 0
            if loss < self.best_loss:
                self.best_model = copy.deepcopy(model)
                self.best_loss = loss

In [7]:
# !pip install --upgrade notebook ipython jupyterlab
# %matplotlib notebook
import time
import copy

def format_input(data):
    return torch.from_numpy(
        np.apply_along_axis(
            one_hotify, 
            axis=1, 
            arr=data, 
            vocab_size = vocab_size
        ).astype(np.float32))

def this_word(inputs):
    return [ index_dict[x] for x in inputs[0,:] ]

def next_word(model, inputs):
    model_probs = model(format_input(inputs))
    max_prob = np.argmax(model_probs.detach().numpy(),2)
    return [ index_dict[x] for x in max_prob[0,:] ]

def get_grad_sum(model):
    all_grad_norms = [ np.linalg.norm(param.grad.detach().numpy()) for param in model.parameters() ]
    # plt.plot(all_grad_norms)
    return np.sum(all_grad_norms) 

def test_loss(model, testloader, criterion):
    total_loss = 0.0
    total_grad = 0.0
    for i, dt in enumerate(testloader, 0):
            inputs, labels = dt
            inputs = format_input(inputs.detach().numpy())
            outputs = model(inputs)
            loss = criterion(outputs.permute(0, 2, 1), labels)
            total_loss += loss.item()
            total_grad += get_grad_sum(model)
    return total_loss, total_grad

def train_model(
    model, 
    trainloader,
    testloader,
    save_out = None,
    optimizer_type = torch.optim.Adam, 
    criterion = nn.CrossEntropyLoss(), 
    lr = LR,
    n_epochs = N_EPOCHS, 
    print_every = PRINT_EVERY,
    es_patience = ES_PATIENCE,
    verbose=True
):
    save_out = save_out or []
    optimizer = optimizer_type(model.parameters(), lr)
    es = EarlyStopping(patience=es_patience, verbose=False)
    for epoch in range(n_epochs):
        model.train()
        running_loss = 0.0
        for i, dt in enumerate(trainloader, 0):
            inputs, labels = dt
            inputs = format_input(inputs.detach().numpy())
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs.permute(0, 2, 1), labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            if i % print_every == print_every-1:
                if verbose:
                    model.eval()
                    this_train, this_train_grad = running_loss / print_every, get_grad_sum(model)
                    this_test, this_test_grad = test_loss(model, testloader, criterion)
                    save_out.append([this_train, this_train_grad, this_test, this_test_grad])
                    print(f"train gradients: {this_train_grad:.0f}, test gradients: {this_test_grad:.0f}")
                    print(f"[epoch {epoch+1},batch {i+1}] train loss: {this_train:.2f}, test loss: {this_test:.2f}")
                es(running_loss, model)
                if es.early_stop:
                    return es.best_model, save_out
                running_loss = 0.0
                
    return es.best_model, save_out

In [11]:
train_fraction = 0.8
train_size = int(train_fraction * len(sh))
test_size = len(sh) - train_size
trainset, testset = random_split(sh, [train_size, test_size])
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)

save_out = []
model = transformer.Transformer(n_blocks=2, n_heads = 8, N = SEQ_LENGTH-1, M = 50, embed_dim = vocab_size, d_target = vocab_size,)
model, save_out = train_model(
    model = model,
    trainloader = trainloader,
    testloader = testloader,
    save_out = save_out,
    lr = 1e-3,
    verbose=True,
    print_every = 10
)

train gradients: 5, test gradients: 427
[epoch 1,batch 10] train loss: 10.68, test loss: 731.87
train gradients: 4, test gradients: 326
[epoch 1,batch 20] train loss: 8.83, test loss: 638.74
train gradients: 3, test gradients: 253
[epoch 1,batch 30] train loss: 7.79, test loss: 573.44
train gradients: 3, test gradients: 204
[epoch 1,batch 40] train loss: 7.19, test loss: 544.64
train gradients: 2, test gradients: 191
[epoch 1,batch 50] train loss: 6.92, test loss: 533.94
train gradients: 2, test gradients: 187
[epoch 1,batch 60] train loss: 6.80, test loss: 528.32
train gradients: 2, test gradients: 172
[epoch 1,batch 70] train loss: 6.74, test loss: 524.78
train gradients: 2, test gradients: 169
[epoch 1,batch 80] train loss: 6.70, test loss: 522.11
train gradients: 2, test gradients: 168
[epoch 1,batch 90] train loss: 6.64, test loss: 519.97
train gradients: 2, test gradients: 155
[epoch 1,batch 100] train loss: 6.67, test loss: 518.01
train gradients: 2, test gradients: 165
[epoch 1

In [None]:
# def compute_loss(model, data_loader, criterion):
#     loss = 0
#     with torch.no_grad():
#         for images, labels in data_loader:
#             images = torch.from_numpy(
#                 np.apply_along_axis(
#                     one_hotify, 
#                     axis=1, 
#                     arr=images.detach().numpy(), 
#                     vocab_size = vocab_size
#                 ).astype(np.float32))
#             outputs = model(images)
#             loss += criterion(outputs.permute(0, 2, 1), labels.to(torch.long))
#     return loss

# def grid_search(learning_rates = [1e-3, 5e-4, 1e-4], n_heads = [4,8,16], n_blocks = [2,4,8], batch_sizes = [64,256,1024], models_dict = None, save_file = None, save_every = 1):
#     models_dict = models_dict or {}
#     shared_tf_params = {"N": SEQ_LENGTH-1, "M": 50, "embed_dim": vocab_size, "d_target": vocab_size}
#     losses = np.zeros([len(batch_sizes), len(n_heads), len(learning_rates), len(n_blocks)])
#     save_ix = 0
#     for i in range(len(batch_sizes)):
#         bs = batch_sizes[i]
#         data_loader = DataLoader(sh, batch_size=bs, shuffle=True)
#         for j in range(len(n_heads)):
#             for k in range(len(learning_rates)):
#                 for l in range(len(n_blocks)):
#                     nh, lr, nb = n_heads[j], learning_rates[k], n_blocks[l]
#                     this_model = transformer.Transformer(n_heads=nh, n_blocks=l, **shared_tf_params)
#                     train_model(this_model , data_loader, lr = lr)
#                     losses[i,j,k,l] = compute_loss(this_model, data_loader, nn.CrossEntropyLoss())
#                     models_dict[(bs, nh, lr, nb)] = this_model
#                     print(f"completed batch size: {bs}, n_heads: {nh}, learning_rate: {lr}")
#                     if save_file and save_ix % save_every == save_every-1:
#                         with open(f"{save_file}.pkl", 'wb') as f:
#                             pickle.dump(losses, f)
#                             print(f"saved losses at model index {save_ix}")
#                     save_ix += 1
                
#     return losses

In [48]:
def format_input(data):
    return torch.from_numpy(
        np.apply_along_axis(
            one_hotify, 
            axis=1, 
            arr=data, 
            vocab_size = vocab_size
        ).astype(np.float32))

def this_word(inputs):
    return [ index_dict[x] for x in inputs[0,:] ]

def next_word(model, inputs):
    model_probs = model(format_input(inputs))
    max_prob = np.argmax(model_probs.detach().numpy(),2)
    return [ index_dict[x] for x in max_prob[0,:] ]


In [67]:
random_sample = np.random.randint(len(sh))

next_word(model, sh[random_sample][0][None,:])

['to',
 'the',
 'i',
 'and',
 'to',
 'and',
 'of',
 'i',
 'the',
 'lord',
 'of',
 'a',
 'have',
 'and',
 'i',
 'that',
 'lord',
 'the',
 'in']