In [None]:
import torch
from torch import nn
from torchtext.datasets import TranslationDataset
from torchtext.data import Field
from torchtext.data import Iterator
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from src.model import TransformerModel
from src.helper import save_state, load_state

In [None]:

class Trainer():
    def __init__(self, data_constants : dict, parameters : dict):
        self.device = torch.device("cuda:0" if torch.cuda.is_available () else "cpu")
        self.source_field, self.target_field = self._create_fields(batch_first = True)
        train_dataset, valid_dataset, self.train_iterator, self.valid_iterator = self._create_dataset(data_constants)
        self.source_field.build_vocab(train_dataset)
        self.target_field.build_vocab(train_dataset)
        self.train_iterator.create_batches
        self.valid_iterator.create_batches
        self._validation_accuracy = []
        self._training_accuracy = []
        self._train_loss = []
        self._validation_loss = []
        self._total_count = []
        self.model = TransformerModel(**parameters, src_tokens = len(self.source_field.vocab), tgt_tokens = len(self.target_field.vocab)).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.0001)
    def greedy_search(self,model,bs,src,seq_length=None):
        """
        We take the source sentence, encode it, and then decode it one word at a time, using the previous
        word as the input to the decoder
        
        :param model: the model to use
        :param bs: batch size
        :param src: The source sentence
        :param seq_length: The length of the sequence to be generated. If None, the sequence will be
        generated until the model predicts the end of sentence token
        :return: The logits and the argmax of the logits
        """
        model.eval()
        trg = torch.tensor([self.target_field.vocab.stoi[self.target_field.init_token]]*bs).long().to(self.device)
        trg = trg.view(bs,-1)
        src_key_padding_mask,memory_key_padding_mask,_,src_pos_encoder= model.generate_masks_and_encoding(src,src_embedding=True)
        memory_tensor = model.transformer.encoder(src_pos_encoder, src_key_padding_mask = src_key_padding_mask).to(self.device)
        count = 1
        while True:
            tgt_key_padding_mask,tgt_mask,tgt_pos_encoder=model.generate_masks_and_encoding(trg,src_embedding=False)
            prediction = model.transformer.decoder(tgt_pos_encoder,memory_tensor, tgt_key_padding_mask = tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, tgt_mask = tgt_mask)
            logits = model.fc(prediction.transpose(1,0))
            ix = torch.argmax(logits,dim=2)[:,-1]
            trg = torch.cat([trg,ix.view(bs, -1)],dim=1)
            if seq_length == None and (ix == self.target_field.vocab.stoi[self.target_field.eos_token]).all():
                break
            elif seq_length != None and count >= seq_length:
                break
            count+=1
        return logits, torch.argmax(logits,dim=2)

    def train_model(self, total_batches_processed, hyperparameter_tuning = False):
        """
        This function takes in a model, optimizer, total number of batches processed, batch size, and a
        boolean value for hyperparameter tuning. It then trains the model for 20 epochs, and saves the model
        every 500 steps. It also prints out the training and validation accuracy and loss every 500 steps.

        :param model: The model that we want to train
        :param optimizer: The optimizer used to train the model
        :param total_batches_processed: This is the number of batches that have been processed so far. This
        is used to save the model after every 500 batches
        :param batch_size: The number of examples in each batch
        :param hyperparameter_tuning: If you want to tune the hyperparameters, set this to True, defaults to
        False (optional)
        """
        loss_fn = nn.CrossEntropyLoss(ignore_index=self.source_field.vocab.stoi[self.source_field.pad_token])
        model = self.model.to(self.device)
        cost = []
        train_accuracy=0
        total_train = 0
        valid_accuracy = 0
        correct_train = 0
        for epoch in range(20):
            if hyperparameter_tuning == True and valid_accuracy == 100.0: break
            print()
            print("__________________________________________________________________________________________________________________________________________")
            print(f'Epoch Number: {epoch}')
            print("__________________________________________________________________________________________________________________________________________")
            print()
            num_batch_processed = 0
            for batch in self.train_iterator:
                model.train()
                src = batch.src.to(self.device)
                trg = batch.trg.to(self.device)
                trg_labels = trg[:,1:]
                prediction = model(src,trg[:,:-1])
                loss = loss_fn(prediction.permute(0,2,1), trg_labels)
                cost.append(loss.item())
                loss.backward()
                num_batch_processed += 1
                total_batches_processed += 1
                pred_max = torch.argmax(prediction, dim=2)        
                total_train += trg_labels.size(0)
                correct_train += torch.sum(((pred_max == trg_labels) | (trg_labels == self.source_field.vocab.stoi[self.source_field.pad_token])).all(1))
                if num_batch_processed%5==0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                if total_batches_processed%500==0:
                    train_accuracy = 100* (correct_train/total_train)
                    self._training_accuracy.append(train_accuracy)
                    total_train = 0
                    correct_train = 0
                    with torch.no_grad():
                        model.eval()
                        valid_cost = []
                        correct = 0
                        total = 0
                        for batch in self.valid_iterator:
                            src = batch.src.to(self.device)
                            trg = batch.trg.to(self.device)
                            valid_labels = trg[:,1:]
                            prediction, _= self.greedy_search(model, src.size(0),src, valid_labels.size(1))
                            loss = loss_fn(prediction.permute(0,2,1),valid_labels)
                            valid_cost.append(loss.item())
                            total += valid_labels.size(0)
                            pred_max = torch.argmax(prediction, dim=2)
                            correct += torch.sum(((pred_max == valid_labels) | (valid_labels == self.source_field.vocab.stoi[self.source_field.pad_token])).all(1))
                            valid_accuracy = 100 * (correct / total)
                            self._validation_accuracy.append(valid_accuracy)
                print("Steps:", total_batches_processed)
                # sentences = ["What is the tens digit of 93283843?", "What is the units digit of 93215897?", "What is the thousands digit of 58179700?"]
                # sentence_targets = ["4", "7", "9"]
                # sentences = ["Put 0.4, 5, 30, 50, -2, 16 in descending order.","Sort -25/127, -2/13, 0.2.","Sort 3, -0.2, 927897, 3/7 in ascending order."]
                # sentence_targets = ["50, 30, 16, 5, 0.4, -2","-25/127, -2/13, 0.2","-0.2, 3/7, 3, 927897"]
                sentences = ["Solve -282*d + 929 - 178 = -1223 for d.", "Solve 0 = -i - 91*i - 1598*i - 64220 for i.", "Solve -25*m - 2084 = -2559 for m."]
                sentence_targets = ["7", "-38", "19"]
                for sentence in sentences:
                    src = self.source_field.process([sentence]).to(self.device)
                    _,decoded = self.greedy_search(model, 1, src.view(1,-1))
                    pred = [self.target_field.vocab.itos[ind] for ind in decoded[0]]
                    print(f"Example Question: {sentence} | Expected Answer: {sentence_targets[sentences.index(sentence)]} | Generated Answer: {''.join(pred)}")
                self._train_loss.append(np.mean(cost))
                self._validation_loss.append(np.mean(valid_cost))
                print()
                print(f'Train Accuracy {train_accuracy}%')
                print(f'Valid Accuracy {valid_accuracy} %')
                print(f'Train Loss: {np.mean(cost)}')
                print("Val Loss:",np.mean(valid_cost))
                print("__________________________________________________________________________________________________________________________________________")
                print()
                cost = []
                self._total_count.append(total_batches_processed)
                save_state("Results/model.pt",total_batches_processed)
                if hyperparameter_tuning == True and valid_accuracy == 100.0:
                    print("Final Accuracy for Validation:", valid_accuracy)
                break

    def _create_fields(self, batch_first : bool = True):
        """
        It creates two fields, one for the source language and one for the target language
        
        :param batch_first: If True, the data will be returned in the form of (batch, seq_len, feature),
        defaults to True
        :type batch_first: bool (optional)
        :return: A tuple of two fields, one for the source and one for the target.
        """
        _source_field = Field(tokenize = lambda x: list(x),
                            init_token = '<sos>',
                            eos_token = '<eos>',
                            pad_token = '<pad>',
                            batch_first = batch_first)
        _target_field = Field(tokenize = lambda x: list(x),
                            init_token = '<sos>',
                            eos_token = '<eos>',
                            pad_token = '<pad>',
                            batch_first = batch_first)
        return _source_field, _target_field

    def _create_dataset(self, data_constants : dict):
        """
        It creates a train and validation dataset from the train and validation files, and then creates an
        iterator for each dataset
        
        :param train_name: the name of the file that contains the training data
        :type train_name: str
        :param valid_name: the name of the validation file
        :type valid_name: str
        :param inputs_ending: the ending of the input files, e.g. ".en"
        :type inputs_ending: str
        :param targets_ending: the file extension of the target language
        :type targets_ending: str
        :param self.device: the self.device to run the model on (CPU or GPU)
        :type self.device: torch.self.device
        """
        _train_dataset, _valid_dataset, _ = TranslationDataset.splits(
            path = data_constants['folder'],
            root = data_constants['folder'],
            exts = (data_constants['inputs_ending'], data_constants['targets_ending']),
            fields = (self.source_field, self.target_field),
            train = data_constants['train_name'],
            validation = data_constants['valid_name'],
            test = data_constants['valid_name']
        )
        _train_iterator = Iterator(dataset=_train_dataset, batch_size = data_constants['train_batch_size'], train=True, repeat=False, shuffle=True, device=self.device) 
        _valid_iterator = Iterator(dataset=_valid_dataset, batch_size = data_constants['valid_batch_size'], train=False, repeat=False, shuffle=True, device=self.device) 

        return _train_dataset, _valid_dataset, _train_iterator, _valid_iterator

    

In [None]:
TRAIN_BS = 128
VALID_BS = 64
DATA_FOLDER = 'data/numbers__place_value'
TRAIN_FILE_NAME = "train"
VALID_FILE_NAME = "interpolate"
INPUTS_FILE_ENDING = ".x"
TARGETS_FILE_ENDING = ".y"

DIM_MODEL = 512
DIM_FEEDFORWARD = 2048
NUM_ENCODER_LAYERS = 6
NUM_DECODER_LAYERS = 6
NUM_HEADS = 8

data_constants = {
    'folder' : DATA_FOLDER,
    'train_name' : TRAIN_FILE_NAME,
    'valid_name' : VALID_FILE_NAME,
    'inputs_ending' : INPUTS_FILE_ENDING,
    'targets_ending' : TARGETS_FILE_ENDING,
    'train_batch_size' : TRAIN_BS,
    'valid_batch_size' : VALID_BS
    }
parameters = {
    'd_model' : DIM_MODEL,
    'nhead' : NUM_HEADS,
    'dim_feedforward' : DIM_FEEDFORWARD,
    'num_encoder_layers' : NUM_ENCODER_LAYERS,
    'num_decoder_layers' : NUM_DECODER_LAYERS
}
total_batches = 0

trainer = Trainer(data_constants=data_constants, parameters=parameters)
trainer.train_model(total_batches_processed=total_batches)

In [None]:
trainer = Trainer(data_constants=data_constants, parameters=parameters)
model, optimizer, total_batches, total_count, validation_accuracy, training_accuracy, train_loss, validation_loss = load_state("Results/model.pt",trainer.model,trainer.optimizer, total_batches,trainer._total_count,trainer._validation_accuracy,trainer._training_accuracy,trainer._train_loss,trainer._validation_loss)
print(optimizer)
state = torch.load("drive/MyDrive/model.pt")
print(state["validation_acc"])
fig, ax1 = plt.subplots()
fig, ax2 = plt.subplots()
ax1.plot(total_count,train_loss, 'deepskyblue')
ax1.plot(total_count,validation_loss, 'coral')
ax1.set_xlabel('Steps')
ax1.set_ylabel('Loss')
ax1.legend(["Train Loss", "Validation Loss"])
ax1.title.set_text("Training/Validation Loss for calculus - differentiate")
ax2.plot(total_count,training_accuracy,'deepskyblue')
ax2.plot(total_count,validation_accuracy,'coral')
ax2.set_xlabel('Steps')
ax2.set_ylabel('Accuracy')
ax2.legend(["Train Accuracy", "Validation Accuracy"])
ax2.title.set_text("Training/Validation Accuracy for calculus - differentiate")
plt.show()
