# Train Model

In [None]:
import math, tqdm
import numpy as np
import import_ipynb # used to import the modules notebook
import modules # impor the notebook
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

Set Model Configuration

In [None]:
# num_tokens: the number of different tokens in the corpus
# t: the length of the sequences as input to the model
# depth: depth of the network (number of transformer blocks)
# heads: number of attention heads in the multi-head attention mechanism
# k: embedding dimension (needs to be a multiple of heads)

k = 6 # x * heads
num_tokens = 10 # integers from 0 to 9
heads = 3
depth = 2
t = 5

In [None]:
# define vocab as integers
vocab = np.arange(num_tokens)
print(vocab)

Generate Training Data

- vocabulary = 1, 2, 3, 4, 5, 6, 7, 8, 9, 0
- 0,1,2 are used as class labels 
- sequence length = 5

- sequence class 0:
    - increasing sequence
    - e.g. 3,4,5,6,7
- sequence class 1:
    - decresing  sequence
    - e.g. 9,8,7,6,5
- sequence class 2:
    - pairwise sequence
    - e.g. 3,5,3,5,3

In [None]:
def generate_synthetic_vector_data_increasing(vocab, vector_dim):
    questions = []
    answers = []
    
    for i in range(len(vocab)):
        sequence = np.roll(vocab, i)[:vector_dim]
        questions.append(sequence)

        sequence_roll = np.roll(sequence, -1)
        sequence_roll[-1] = 0
        answers.append(sequence_roll)
            
    questions = np.array(questions)
    answers = np.array(answers)
    
    return questions, answers

In [None]:
# example
questions, answers = generate_synthetic_vector_data_increasing(vocab[3:], t)
print("There are {} questions and {} answers for the increasing style".format(questions.shape[0], answers.shape[0]))
for i,question in enumerate(questions):
    print(str(question) + " -> " + str(answers[i]))

In [None]:
def generate_synthetic_vector_data_decreasing(vocab, vector_dim):
    questions = []
    answers = []
    
    for i in range(len(vocab)):
        reverse_vocab = np.flip(vocab)
        sequence = np.roll(reverse_vocab, i)[:vector_dim]
        questions.append(sequence)
        
        sequence_roll = np.roll(sequence, -1)
        sequence_roll[-1] = 1
        answers.append(sequence_roll)
            
    
    questions = np.array(questions)
    answers = np.array(answers)
    
    return questions, answers

In [None]:
# example
questions, answers = generate_synthetic_vector_data_decreasing(vocab[3:], t)
print("There are {} questions and {} answers for the decreasing style".format(questions.shape[0], answers.shape[0]))
for i,question in enumerate(questions):
    print(str(question) + " -> " + str(answers[i]))
    

In [None]:
def generate_synthetic_vector_data_recurring(vocab, vector_dim):
    questions = []
    answers = []
    for first in vocab:        
        for second in vocab:
            if second == first:
                continue
            sequence = [first, second] * math.ceil(vector_dim/2)
            questions.append(sequence[0:vector_dim])
            sequence_roll = np.roll(sequence[0:vector_dim], -1)
            sequence_roll[-1] = 2
            answers.append(sequence_roll)
        
    questions = np.array(questions)
    answers = np.array(answers)

    return questions, answers

In [None]:
# example
questions, answers = generate_synthetic_vector_data_recurring(vocab[3:], t)
print("There are {} questions and {} answers for the recurring style".format(questions.shape[0], answers.shape[0]))
print("Example: {} --> {}".format(questions[0], answers[0]))
print("Example: {} --> {}".format(questions[1], answers[1]))


In [None]:
# Generate synthetic data 
questions_class0, answers_class0 = generate_synthetic_vector_data_increasing(vocab[3:], t)
questions_class1, answers_class1 = generate_synthetic_vector_data_decreasing(vocab[3:], t)
questions_class2, answers_class2 = generate_synthetic_vector_data_recurring(vocab[3:], t)

# # create train and test sets for each class
# train_test_split = 0.8
# train_test_split_index = int(train_test_split * len(questions_class0))
# train_questions_class0 = questions_class0[:train_test_split_index]
# train_answers_class0 = answers_class0[:train_test_split_index]
# test_questions_class0 = questions_class0[train_test_split_index:]
# test_answers_class0 = answers_class0[train_test_split_index:]

# train_test_split_index = int(train_test_split * len(questions_class1))
# train_questions_class1 = questions_class1[:train_test_split_index]
# train_answers_class1 = answers_class1[:train_test_split_index]
# test_questions_class1 = questions_class1[train_test_split_index:]
# test_answers_class1 = answers_class1[train_test_split_index:]

# train_test_split_index = int(train_test_split * len(questions_class2))
# train_questions_class2 = questions_class2[:train_test_split_index]
# train_answers_class2 = answers_class2[:train_test_split_index]
# test_questions_class2 = questions_class2[train_test_split_index:]
# test_answers_class2 = answers_class2[train_test_split_index:]



# # create tensor datasets for each class
# data_train_class0 = TensorDataset(torch.tensor(train_questions_class0), torch.tensor(train_answers_class0))
# data_test_class0 = TensorDataset(torch.tensor(test_questions_class0), torch.tensor(test_answers_class0))
# data_train_class1 = TensorDataset(torch.tensor(train_questions_class1), torch.tensor(train_answers_class1))
# data_test_class1 = TensorDataset(torch.tensor(test_questions_class1), torch.tensor(test_answers_class1))
# data_train_class2 = TensorDataset(torch.tensor(train_questions_class2), torch.tensor(train_answers_class2))
# data_test_class2 = TensorDataset(torch.tensor(test_questions_class2), torch.tensor(test_answers_class2))

# create full datasets without train and test split
data_class0 = TensorDataset(torch.tensor(questions_class0), torch.tensor(answers_class0))
data_class1 = TensorDataset(torch.tensor(questions_class1), torch.tensor(answers_class1))
data_class2 = TensorDataset(torch.tensor(questions_class2), torch.tensor(answers_class2))


# # save tensor datasets to ./data
# torch.save(data_train_class0, './data/data_train_class0.pt')
# torch.save(data_test_class0, './data/data_test_class0.pt')
# torch.save(data_train_class1, './data/data_train_class1.pt')
# torch.save(data_test_class1, './data/data_test_class1.pt')
# torch.save(data_train_class2, './data/data_train_class2.pt')
# torch.save(data_test_class2, './data/data_test_class2.pt')

# save full datasets to ./data
torch.save(data_class0, './data/data_class0.pt')
torch.save(data_class1, './data/data_class1.pt')
torch.save(data_class2, './data/data_class2.pt')


In [None]:
# # print specs of train and test data
# print("Train data class 0: ", len(data_train_class0))
# print("Test data class 0: ", len(data_test_class0))
# print("example: ", data_train_class0[0])

# print("Train data class 1: ", len(data_train_class1))
# print("Test data class 1: ", len(data_test_class1))
# print("example: ", data_train_class1[0])

# print("Train data class 2: ", len(data_train_class2))
# print("Test data class 2: ", len(data_test_class2))
# print("example: ", data_train_class2[0])

# print specs of full data
print("Full data class 0: ", len(data_class0))
print("example: ", data_class0[0])

print("Full data class 1: ", len(data_class1))
print("example: ", data_class1[0])

print("Full data class 2: ", len(data_class2))
print("example: ", data_class2[0])

Initialize the Model

In [None]:
print(hasattr(modules, 'GTransformer'))

In [None]:
# create the model
model = modules.GTransformer(k=k, heads=heads, depth=depth, t=t, num_tokens=num_tokens)
if torch.cuda.is_available():
    model.cuda()
print(torch.cuda.is_available())


Train the Model

In [None]:
# training
num_batches = 10000
batch_size = 9  # batch size needs to be a multiple of the class number
lr=0.001
gradient_clipping = 1.0
test_every = 100
test_batchsize = 9
patience = 5


# tensorboard writer
tensorboard_writer_gt = SummaryWriter("torchlogs/gtransformer/")

# define optimization for learning
opt = torch.optim.Adam(lr=lr, params=model.parameters())

# define the learning rate scheduler
# scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=1000, gamma=0.1)

for i in tqdm(range(num_batches)):
    # learning rate warmup
    opt.zero_grad()
    # sample source and target from the three classes for batch 
    number_data = batch_size // 3
    # source_class0, target_class0 = data_train_class0[torch.randint(0, len(data_train_class0), (number_data,))]
    # source_class1, target_class1 = data_train_class1[torch.randint(0, len(data_train_class1), (number_data,))]
    # source_class2, target_class2 = data_train_class2[torch.randint(0, len(data_train_class2), (number_data,))]
    source_class0, target_class0 = data_class0[torch.randint(0, len(data_class0), (number_data,))]
    source_class1, target_class1 = data_class1[torch.randint(0, len(data_class1), (number_data,))]
    source_class2, target_class2 = data_class2[torch.randint(0, len(data_class2), (number_data,))]
    # concatenate the sources and targets
    source = torch.cat([source_class0, source_class1, source_class2])
    target = torch.cat([target_class0, target_class1, target_class2])

    if torch.cuda.is_available():
        source, target = source.cuda(), target.cuda() 

    output, _, _ = model(source)
    # print("output: ", output)
    # print("output", output.shape)

    loss = F.nll_loss(output.transpose(2, 1), target, reduction='mean')
    tensorboard_writer_gt.add_scalar('gtransformer/train-loss', float(loss.item()), i * batch_size)
    loss.backward()

    # clip gradients: if the total gradient vector has a length > 1, we clip it back down to 1.
    if gradient_clipping > 0.0:
       nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)

    # update the model weights and the learning rate
    opt.step()
    # scheduler.step()

    # track the best model
    best_eval_loss = float('inf')
    best_model_state = None
    eval_no_improve = 0

    
    if i != 0 and (i % test_every == 0 or i == num_batches - 1):
        upto = data_class0.tensors[0].size(0) if i == num_batches - 1 else test_batchsize // 3
        # sample source and target from the three classes for batch 
        # source_class0, target_class0 = data_test_class0[torch.randint(0, len(data_test_class0), (upto,))]
        # source_class1, target_class1 = data_test_class1[torch.randint(0, len(data_test_class1), (upto,))]
        # source_class2, target_class2 = data_test_class2[torch.randint(0, len(data_test_class2), (upto,))]
        source_class0, target_class0 = data_class0[torch.randint(0, len(data_class0), (upto,))]
        source_class1, target_class1 = data_class1[torch.randint(0, len(data_class1), (upto,))]
        source_class2, target_class2 = data_class2[torch.randint(0, len(data_class2), (upto,))]
        # concatenate the sources and targets
        source = torch.cat([source_class0, source_class1, source_class2])
        target = torch.cat([target_class0, target_class1, target_class2])


        model.eval()
        with torch.no_grad():

            if torch.cuda.is_available():
                source, target = source.cuda(), target.cuda()

            output, _ , _ = model(source)
            eval_loss = F.nll_loss(output.transpose(2, 1), target, reduction='mean')
            tensorboard_writer_gt.add_scalar('gtransformer/eval-loss', float(eval_loss.item()), i * batch_size)
            print(f'epoch{i}: {loss:.4} loss')

            if eval_loss < best_eval_loss:
                best_eval_loss = eval_loss
                best_model_state = model.state_dict()  
                torch.save(best_model_state, 'best_model.pth')  
                eval_no_improve = 0
            else:
                eval_no_improve += 1

            if eval_no_improve >= patience:
                print(f'Early stopping at epoch {i}')
                break

        model.train()

In [None]:
# Load the best model state after training
model.load_state_dict(torch.load('best_model.pth'))

In [None]:
# Start TensorBoard (or use the command line: tensorboard --logdir=./src/model-basic/torchlogs/gtransformer/)
%load_ext tensorboard
%tensorboard --logdir ./torchlogs/gtransformer/

Save the Model

In [None]:
# save model and load it
torch.save(model.state_dict(), 'gtransformer.pth')