In [2]:
import torch
from torch import nn
import pandas as pd
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import copy
from torch.utils.data import Dataset, DataLoader, TensorDataset
import gc
import random
# import wandb

In [3]:
def dataLoading(src):
    #path = "/kaggle/input/roman-to-telgu/tel_{}.csv".format(data_type)
    df = pd.read_csv(src,header=None)
    return df[0].to_numpy(), df[1].to_numpy()

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
MAX_LENGTH = 30
BATCH_SIZE = 32
EOS_token = 1
SOS_token = 0
PAD_token = 2
TEACHER_FORCING_RATIO = 0.5

train_csv = "/kaggle/input/assignment3/tel_train.csv"
test_csv = "/kaggle/input/assignment3/tel_test.csv"
val_csv = "/kaggle/input/assignment3/tel_valid.csv"

train_input , train_output = dataLoading(train_csv)
valid_input, valid_output = dataLoading(val_csv)
test_input, test_output = dataLoading(test_csv)

cuda


In [5]:
def characterFetching(x):
    characters = 3
    ind2ch ={SOS_token:'<',EOS_token:'>',PAD_token:'_'}
    ch2ind ={'<':SOS_token,'>':EOS_token,'_':PAD_token}
    for word in x:
        for letter in word:
            if letter not in ch2ind:
                ch2ind[letter] = characters
                ind2ch[characters] = letter
                characters+=1
    return [ch2ind,ind2ch,characters]

In [6]:
meta_data_input , meta_data_output = characterFetching(train_input) , characterFetching(train_output)

In [7]:
def addEosPadding(x,meta_data):
    indexed_data = []
    for word in x:
        l =[]
        word += '>'
        word += (MAX_LENGTH-len(word))*'_'
        for char in word:
            l.append(meta_data[0][char])
        indexed_data.append(l)
    return torch.tensor(indexed_data)

In [8]:
train_data_tensor = DataLoader(TensorDataset(addEosPadding(train_input,meta_data_input), addEosPadding(train_output,meta_data_output)),BATCH_SIZE, shuffle = True)
valid_data_tensor = DataLoader(TensorDataset(addEosPadding(valid_input,meta_data_input), addEosPadding(valid_output,meta_data_output)),BATCH_SIZE, shuffle = True)

In [9]:
class Hyperparameters:
    def __init__(self,input_dim:int,output_dim:int,
                 encoder_layers =1,decoder_layers=1,hidden_size=64,embed_dim=512
                 ,cell_type:str='rnn',bidirectional:bool=False,dropout:float=0,beam_search:int=0,
                 learning_rate=0.001):
        self.encoder_layers = encoder_layers
        self.decoder_layers = decoder_layers
        self.hidden_size = hidden_size
        #input_dim is size of vocabulary of input language
        self.input_dim = input_dim
        self.embed_dim = embed_dim
        #output_dim is size of vocabulary of output language
        self.output_dim = output_dim
    
        cell_dict = {'rnn':nn.RNN,'gru':nn.GRU,'lstm':nn.LSTM}
        self.cell = cell_dict[cell_type]
        self.cell_name = cell_type
        self.bidirectional = bidirectional
        self.dropout = dropout
        self.beam_search = beam_search
        self.learning_rate = learning_rate

In [50]:
class Attention(nn.Module):
    def __init__(self,parameters:Hyperparameters):
        super(Attention,self).__init__()
        hidden_size = parameters.hidden_size
        self.Wa = nn.Linear(hidden_size,hidden_size)
        self.Ua = nn.Linear(hidden_size,hidden_size)
        self.Va = nn.Linear(hidden_size,1)
        
    def forward(self,queries,keys):
        scores = self.Va(torch.tanh(self.Wa(queries)+self.Ua(keys)))
        scores = scores.squeeze().unsqueeze(1)
        weights = F.softmax(scores, dim =0)
        weights = weights.permute(2,1,0)
        keys = keys.permute(1,0,2)
        context = torch.bmm(weights, keys)
        return context, weights
    
class Encoder(nn.Module):
    def __init__(self,parameters:Hyperparameters):
        super(Encoder,self).__init__()
        self.hidden_size = parameters.hidden_size
        self.num_layers = parameters.encoder_layers
        self.embedding = nn.Embedding(parameters.input_dim,parameters.embed_dim)
        self.cell = parameters.cell(parameters.embed_dim,self.hidden_size,self.num_layers,batch_first=True)
        self.max_length = MAX_LENGTH
        self.batch_size = BATCH_SIZE
        self.cell_name = parameters.cell_name
    
    def forward(self, input_t, current_state):
        encoder_states = torch.zeros(self.max_length, self.num_layers, self.batch_size, self.hidden_size, device = device)
        
        for i in range(self.max_length):
            current_input = input_t[:, i].view(self.batch_size,1)
            _, current_state = self.forwardStep(current_input, current_state)
            if self.cell_name == 'lstm':
                encoder_states[i] = current_state[1]
            else:
                encoder_states[i] = current_state
        return encoder_states, current_state

    def forwardStep(self, current_input, prev_state):
        embd_input = self.embedding(current_input)
        output, prev_state = self.cell(embd_input, prev_state)
        return output, prev_state
        
    def getInitialState(self):
        return torch.zeros(self.num_layers, self.batch_size, self.hidden_size, device=device)
        

class Decoder(nn.Module):
    def __init__(self,parameters:Hyperparameters):
        super(Decoder,self).__init__()
        self.hidden_size = parameters.hidden_size
        self.num_layers = parameters.decoder_layers
        self.batch_size = BATCH_SIZE
        self.max_length = MAX_LENGTH
        self.cell_name = parameters.cell_name
        self.attention = Attention(parameters)
        self.embedding = nn.Embedding(parameters.output_dim,parameters.embed_dim)
        self.cell = parameters.cell(parameters.embed_dim+self.hidden_size, self.hidden_size, self.num_layers, batch_first = True)
        self.fc = nn.Linear(self.hidden_size,parameters.output_dim)
        self.softmax = nn.LogSoftmax(dim=2)
        
        
    def forward(self, current_state, encoder_final_layers, output_batch, loss_fun):

        use_teacher_forcing = True if random.random() < TEACHER_FORCING_RATIO else False
        
        decoder_input = torch.full((self.batch_size,1),SOS_token, device=device)
        embedding = self.embedding(decoder_input)
        soft_embed = F.relu(embedding)
        
        decoder_actual_output = []
        attentions = []
        loss = 0
        

        for i in range(self.max_length):
            decoder_output, current_state, attn_weights = self.forwardStep(decoder_input, current_state, encoder_final_layers)
            
            topv, topi = decoder_output.topk(1)
            
            decoder_input = topi.squeeze().detach().view(self.batch_size, 1)
            decoder_actual_output.append(decoder_input)

            attentions.append(attn_weights)
            
            if(output_batch==None):
                decoder_input = decoder_input.view(self.batch_size, 1)
            else:
                if(i<self.max_length-1):
                    if use_teacher_forcing:
                        decoder_current_input = output_batch[:, i+1].view(self.batch_size, 1)
                decoder_output = decoder_output[:, -1, :]
                loss+=(loss_fun(decoder_output, output_batch[:, i]))

        decoder_actual_output = torch.cat(decoder_actual_output,dim=0).view(self.max_length, self.batch_size).transpose(0,1)

        correct = (decoder_actual_output == output_batch).all(dim=1).sum().item()
        return decoder_actual_output, attentions, loss, correct
    
    def forwardStep(self, current_input, prev_state, encoder_final_layers):
        embedding = self.embedding(current_input)
        if self.cell_name == "lstm":
            context , attn_weights = self.attention(prev_state[1][-1,:,:], encoder_final_layers)
        else:
            context , attn_weights = self.attention(prev_state[-1,:,:], encoder_final_layers)
        activation = F.relu(embedding)
        
        input_gru = torch.cat((activation, context), dim=2)
        output, prev_state = self.cell(input_gru, prev_state)
        output = self.softmax(self.fc(output))
        return output, prev_state, attn_weights 

In [67]:
def evaluate(encoder, decoder, data_t, loss_fun, parameters):
    encoder.eval()
    decoder.eval()
    correct_predictions = 0
    total_loss = 0
    total_predictions = len(data_t.dataset)
    number_of_batches = len(data_t)
    with torch.no_grad():
        for ind, (input_tensor, output_tensor) in enumerate(data_t):
            input_tensor  = input_tensor.to(device)
            output_tensor = output_tensor.to(device)
            encoder_initial = encoder.getInitialState()
            if parameters.cell_name == "lstm":
                encoder_initial = (encoder_initial, encoder.getInitialState())
            encoder_states, encoder_final_state = encoder(input_tensor,encoder_initial)

            current_state = encoder_final_state
            encoder_final_layer_states = encoder_states[:, -1, :, :]

            loss = 0
            correct = 0

            decoder_output, attentions, loss, correct = decoder(current_state, encoder_final_layer_states, output_tensor, loss_fun)

            correct_predictions+=correct
            total_loss +=loss

        accuracy = correct_predictions / total_predictions
        total_loss /= number_of_batches

        return  total_loss, accuracy

In [69]:
def train(parameters, encoder, decoder, train_data, valid_data, epochs ):

    encoder_opt = optim.Adam(encoder.parameters(), lr = parameters.learning_rate)
    decoder_opt = optim.Adam(decoder.parameters(), lr = parameters.learning_rate)
    
    loss_fun = nn.NLLLoss()
    
    total_predictions = len(train_data.dataset)
    total_batches = len(train_data)
    
    for epoch in range(epochs):
        encoder.train()
        decoder.train()
        total_correct = 0
        total_loss = 0
        for ind , (input_tensor, output_tensor) in enumerate(train_data):
            input_tensor  = input_tensor.to(device)
            output_tensor = output_tensor.to(device)
            encoder_initial = encoder.getInitialState()
            
            if parameters.cell_name == 'lstm':
                encoder_initial = (encoder_initial, encoder.getInitialState())
            
            encoder_states, encoder_final_state = encoder(input_tensor,encoder_initial)
            
            decoder_state = encoder_final_state
            encoder_final_layer_states = encoder_states[:,-1,:,:]
            
            loss =0
            correct =0
            
            decoder_output, attentions, loss, correct = decoder(decoder_state,encoder_final_layer_states,output_tensor,loss_fun)
            total_correct +=correct
            total_loss += loss.item()/MAX_LENGTH
            
            if(ind%30==0):
                print("epoch-  ",epoch,"batch number - ",ind, loss.item()/MAX_LENGTH,"ACC - ",correct/BATCH_SIZE)
            encoder_opt.zero_grad()
            decoder_opt.zero_grad()
            loss.backward()
            encoder_opt.step()
            decoder_opt.step()
        
        train_acc = total_correct/total_predictions
        train_loss = total_loss/total_predictions
        valid_loss, valid_acc = evaluate(encoder,decoder,valid_data,loss_fun,parameters)
        print("Training Accuracy - ",train_acc, "Train_loss - ",train_loss, "Valid_acc - ", valid_acc, "Valid_loss - ", valid_loss)
        

In [70]:
encoder_layers = 5
decoder_layers = 5
hidden_size = 64
embed_dim =256
cell_type = 'lstm'
bidirectional = True
dropout = 0
learning_rate =0.001
input_dim = meta_data_input[2]
output_dim = meta_data_output[2]

In [71]:
parameters = Hyperparameters(input_dim, output_dim, encoder_layers, decoder_layers, hidden_size, embed_dim, cell_type, bidirectional, dropout )
encoder = Encoder(parameters).to(device)
decoder = Decoder(parameters).to(device)

In [72]:
train(parameters,encoder,decoder,train_data_tensor,valid_data_tensor,15)

epoch-   0 batch number -  0 4.092297871907552 ACC -  0.0
epoch-   0 batch number -  30 1.663475545247396 ACC -  0.0
epoch-   0 batch number -  60 1.4596532185872395 ACC -  0.0
epoch-   0 batch number -  90 1.5295500437418619 ACC -  0.0
epoch-   0 batch number -  120 1.4103167215983072 ACC -  0.0
epoch-   0 batch number -  150 1.3071956634521484 ACC -  0.0
epoch-   0 batch number -  180 1.2571194966634114 ACC -  0.0
epoch-   0 batch number -  210 1.2961006164550781 ACC -  0.0
epoch-   0 batch number -  240 1.2250829060872397 ACC -  0.0
epoch-   0 batch number -  270 1.3197568257649739 ACC -  0.0
epoch-   0 batch number -  300 1.1100274403889974 ACC -  0.0
epoch-   0 batch number -  330 1.1388284047444661 ACC -  0.0
epoch-   0 batch number -  360 1.277779261271159 ACC -  0.0
epoch-   0 batch number -  390 1.2282318115234374 ACC -  0.0
epoch-   0 batch number -  420 1.1285836537679037 ACC -  0.0
epoch-   0 batch number -  450 1.2026607513427734 ACC -  0.0
epoch-   0 batch number -  480 1

KeyboardInterrupt: 