# Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import os
import numpy as np
from torchtext.legacy.data import Field
from torchtext.legacy.datasets import LanguageModelingDataset
from torchtext.legacy.data import BPTTIterator
from src.model import LSTMModel
from src.helper import counter, get_fables
import matplotlib.pyplot as plt

# Training Class

In [None]:
class Trainer():
    def __init__(self, model_parameters : dict, path : str, bptt_len : int, samples : list):
        self.samples = samples
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        _split_chars = lambda x: list(x) 
        self.train_field = Field(tokenize=_split_chars ,init_token ='<sos>',eos_token ='<eos>')
        train_dataset = LanguageModelingDataset(
            path = path,
            text_field=self.train_field
        )
        self.train_field.build_vocab(train_dataset)
        self.bptt_iterator = BPTTIterator(
            dataset= train_dataset,
            batch_size = model_parameters['batch_size'],
            bptt_len = bptt_len,
            shuffle = False
        )
        self.model = LSTMModel(**model_parameters, vocab_size = len(self.train_field.vocab)).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)

        
    def predict(self, model : LSTMModel, prompt : str,sequence_length : int, method : str = 'greedy') -> str:
        """
        The function takes in a model, a prompt, a sequence length, and a method. It then generates a
        sentence of the specified length using the specified method
        
        :param model: the model to use for prediction
        :type model: LSTMModel
        :param prompt: The prompt to start the sentence with
        :type prompt: str
        :param sequence_length: The length of the generated sequence
        :type sequence_length: int
        :param method: 'greedy' or 'random', defaults to greedy
        :type method: str (optional)
        :return: A string of the generated sentence
        """
        model.eval()
        generated_sentence=[]
        prompt = torch.tensor([self.train_field.vocab.stoi[t] for t in self.train_field.tokenize(prompt)]).long().to(self.device)
        hidden = None
        Softmax1D = nn.Softmax(dim=1)
        if method == 'greedy':
            out,hidden=model(prompt.view(-1,1),hidden)
            print(hidden[0].shape)
            ix = torch.argmax(Softmax1D(out), dim=1)[len(prompt)-1]
            for i in range(sequence_length):
                out,hidden=model(ix.view(-1,1),hidden)
                ix = torch.argmax(Softmax1D(out), dim=1)
                generated_sentence.append(self.train_field.vocab.itos[ix])
        if method == 'random':
            out,hidden=model(prompt.view(-1,1),hidden)
            ix = torch.multinomial(Softmax1D(out),1)[len(prompt)-1]
            for i in range(sequence_length):
                out,hidden=model(ix.view(-1,1),hidden)
                ix = torch.multinomial(Softmax1D(out),1)
                generated_sentence.append(self.train_field.vocab.itos[ix])
        return ''.join(generated_sentence)

    def train_model(self, num_epochs : int) -> LSTMModel:
        """
        The function takes in the number of epochs and the model and trains the model for the given
        number of epochs
        
        :param num_epochs: Number of epochs to train for
        :return: The model is being returned.
        """
        # vocab_size = 
        loss_fn = nn.CrossEntropyLoss()
        total_steps = 0
        loss_plot =[]
        perp_plot=[]
        for epoch in range(1, num_epochs+1):
            cost = 0
            num_steps =0
            hidden=None
            print('Total Steps: ',total_steps)
            for batch in self.bptt_iterator:
                self.model.train()
                self.optimizer.zero_grad()
                output, hidden = self.model(batch.text.to(self.device),hidden)
                hidden = (hidden[0].detach(), hidden[1].detach())
                targets = batch.target
                targets = targets.view(targets.shape[0]*targets.shape[1]).to(device)
                out = output.view(-1, self.model.vocab_size)
                loss = loss_fn(out,targets)
                cost += loss.item()
                num_steps += 1
                total_steps+=1
                loss.backward()
                self.optimizer.step()
            if epoch%1==0:
                self.model.eval()
                for prompt in self.samples:
                    print('Greedy decoding')
                    gen_text = self.predict(self.model, prompt, 100)
                    print(f'Sample prompt: {prompt} | generated text: {gen_text}')
                    print('Random decoding')
                    gen_text = self.predict(self.model, prompt, 100, method='random')
                    print(f'Sample prompt: {prompt} | generated text: {gen_text}')
            perplexity = np.exp(cost/num_steps)
            loss_plot.append(cost/num_steps)
            perp_plot.append(perplexity)
            print('epoch:',epoch,'loss: ', cost/num_steps, 'perplexity:',perplexity)
        _, ax1 = plt.subplots()
        _, ax2 = plt.subplots()
        ax1.plot(loss_plot,'coral')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.title.set_text('Train Loss Plot')
        ax1.legend(["Train Loss"])
        ax2.plot(perp_plot,'deepskyblue')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Loss')
        ax2.title.set_text('Train Perplexity Plot')
        ax2.legend(["Train Perplexity"])
        plt.show()
        return self.model

# Train the fable and the trump models

In [None]:
fable_samples = ['Dogs like best to', 'THERE were once some Frogs who lived together', 'THE WOMAN AND HER HEN']
get_fables()
book_path = os.path.join(os.path.join("data","books"), "AesopsFables.txt")
BATCH_SIZE = 64
EMBEDDING_DIM = 1024
HIDDEN_SIZE = 1024
LAYERS = 2
PATH = book_path
BPTT_LEN = 256
model_parameters = {
    'batch_size' : BATCH_SIZE,
    'embedding_dim': EMBEDDING_DIM,
    'hidden_size': HIDDEN_SIZE,
    'num_layers': LAYERS
}
trainer = Trainer(model_parameters = model_parameters, path = PATH, bptt_len = BPTT_LEN, samples = fable_samples)
counter(PATH)
print("Vocab Count: ", len(trainer.train_field.vocab))
print("Count: ", len(trainer.bptt_iterator))
fable_model = trainer.train_model(epoch_num=100)
torch.save(fable_model.state_dict(), 'fable_model.pt')

In [None]:
trump_samples = ['Good morning America', 'Very good', 'Donald Trump:']
bonus_path = os.path.join('data', "donaldtrump.txt")
BATCH_SIZE = 64
EMBEDDING_DIM = 1024
HIDDEN_SIZE = 1024
LAYERS = 2
PATH = bonus_path
BPTT_LEN = 256
model_parameters = {
    'batch_size' : BATCH_SIZE,
    'embedding_dim': EMBEDDING_DIM,
    'hidden_size': HIDDEN_SIZE,
    'num_layers': LAYERS
}
trainer = Trainer(model_parameters = model_parameters, path = PATH, bptt_len = BPTT_LEN, samples = trump_samples)
counter(PATH)
print("Vocab Count: ", len(trainer.train_field.vocab))
print("Count: ", len(trainer.bptt_iterator))
trump_model = trainer.train_model(epoch_num=100)
torch.save(trump_model.state_dict(), 'trump_model.pt')

# Greedy decoding for the fable model

In [None]:
#Greedy

prompt = 'THE FOX AND THE LION'
gen_text = trainer.predict(model,prompt,300)
print("Greedy Decoding")
print("A title in the book")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')

#A title which you invent, which is not in the book, but similar in the style.
prompt = 'THE TURTLE AND THE BIRD'
gen_text = trainer.predict(model,prompt,300)
print("A title in similar style")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
# Some texts in a similar style.
print("Some texts in similar style")
prompt = 'Back in my day'
gen_text = trainer.predict(model,prompt,300)
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
#Anything you might find interesting
print("Anything Interesting")
prompt = 'Dallmayr to go'
gen_text = trainer.predict(model,prompt,300)
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
prompt = 'Covid-19 is'
gen_text = trainer.predict(model,prompt,300)
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
#Multinomial
print("Random Decoding")
print("A title in the book")
prompt = 'THE FOX AND THE LION'
gen_text = trainer.predict(model,prompt,300,method="random")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
#A title which you invent, which is not in the book, but similar in the style.
print("A title in similar style")
prompt = 'THE TURTLE AND THE BIRD'
gen_text = trainer.predict(model,prompt,300,method="random")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
# Some texts in a similar style.
print("Some texts in similar style")
prompt = 'Back in my day'
gen_text = trainer.predict(model,prompt,300,method="random")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
#Anything you might find interesting
print("Anything Interesting")
prompt = 'Dallmayr to go'
gen_text = trainer.predict(model,prompt,300,method="random")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
prompt = 'Covid-19 is'
gen_text = trainer.predict(model,prompt,300,method="random")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
# prompt = 'The Angry Turtle was'
# gen_text = predict(model,prompt,300,method="random")
# print(f'Sample prompt: {prompt} | generated text: {gen_text}')

#Greedy
print("Greedy Decoding")

prompt = ' '
gen_text = trainer.predict(model,prompt,300)
print("A title in the book")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
#A title which you invent, which is not in the book, but similar in the style.
prompt = 'A very nice day'
gen_text = trainer.predict(model,prompt,300)
print("A title in similar style")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
# Some texts in a similar style.
prompt = 'I once was'
gen_text = trainer.predict(model,prompt,300)
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
#Anything you might find interesting
print("Anything Interesting")
prompt = 'Birds are flying'
gen_text = trainer.predict(model,prompt,300)
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
prompt = 'Coca Cola'
gen_text = trainer.predict(model,prompt,300)
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
#Multinomial





# Random decoding for the fable model

In [None]:
print("Random Decoding")
prompt = ' '
gen_text = trainer.predict(model,prompt,300,method="random")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
#A title which you invent, which is not in the book, but similar in the style.
prompt = 'A very nice day'
gen_text = trainer.predict(model,prompt,300,method="random")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
# Some texts in a similar style.
prompt = 'I once was'
gen_text = trainer.predict(model,prompt,300,method="random")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
#Anything you might find interesting
print("Anything Interesting")
prompt = 'Birds are flying'
gen_text = trainer.predict(model,prompt,300,method="random")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
prompt = 'Coca Cola'
gen_text = trainer.predict(model,prompt,300,method="random")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
prompt = 'The Fox'
gen_text = trainer.predict(model,prompt,300,method="random")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')

prompt = 'The Angry Turtle was'
gen_text = trainer.predict(model,prompt,300,method="random")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')

prompt = '<eos>'
gen_text = trainer.predict(model,prompt,600,method="random")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
gen_text = trainer.predict(model,prompt,600,method="greedy")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')

prompt = 'THE WOLF AND THE LAMB'
gen_text = trainer.predict(model,prompt,1000,method="random")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
prompt = 'THE WOLF AND THE LAMB'
gen_text = trainer.predict(model,prompt,1000,method="greedy")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')

# Greedy decoding for the trump model

In [None]:
#Greedy
prompt = 'Thank You'
gen_text = trainer.predict(model,prompt,300)
print("Greedy Decoding")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
#A title which you invent, which is not in the book, but similar in the style.
prompt = 'Good'
gen_text = trainer.predict(model,prompt,300)
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
# Some texts in a similar style.
prompt = 'China'
gen_text = trainer.predict(model,prompt,300)
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
#Anything you might find interesting
prompt = 'We have to'
gen_text = trainer.predict(model,prompt,300)
print(f'Sample prompt: {prompt} | generated text: {gen_text}')


# Random decoding for the trump model

In [None]:
print("Random Decoding")
prompt = 'Thank You'
gen_text = trainer.predict(model,prompt,300,method="random")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
#A title which you invent, which is not in the book, but similar in the style.
prompt = 'Good'
gen_text = trainer.predict(model,prompt,300,method="random")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
# Some texts in a similar style.
prompt = 'China'
gen_text = trainer.predict(model,prompt,300,method="random")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
#Anything you might find interesting
prompt = 'We have to'
gen_text = trainer.predict(model,prompt,300,method="random")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
prompt = ' '
gen_text = trainer.predict(model,prompt,1000,method="greedy")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
prompt = ' '
gen_text = trainer.predict(model,prompt,1000,method="random")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
print("Something not in the text")
prompt = 'Birds fly high'

# Comparing random and greedy for the trump model

In [None]:

print("Random Decoding")
gen_text = trainer.predict(model,prompt,300,method="random")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
prompt = 'Birds fly high'
print("Greedy Decoding")
gen_text = trainer.predict(model,prompt,300,method="greedy")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')

print("Greedy Decoding")
prompt = "President Donald J. Trump: "
gen_text = trainer.predict(model,prompt,600,method="greedy")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')
print("Random Decoding")
prompt = "President Donald J. Trump: "
gen_text = trainer.predict(model,prompt,600,method="random")
print(f'Sample prompt: {prompt} | generated text: {gen_text}')