In [1]:
import torch

print(torch.cuda.is_available())

if torch.cuda.is_available():
    dev = "cuda:0"
else:
    dev = "cpu"

print(dev)

True
cuda:0


In [68]:
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch
import torch.nn.functional as F
import nltk
from global_variables import *
from preprocessing import *

class LSTM_Language_Model(nn.Module):
    def __init__(self, vocab_size=27597, embedding_dim=100,
                 hidden_dim=100, lstm_layers=2, dropout=0.2):
        super(LSTM_Language_Model, self).__init__()

        self.lstm = nn.LSTM(embedding_dim, hidden_dim, lstm_layers, dropout=dropout)
        self.hl = nn.Linear(hidden_dim, hidden_dim)
        self.fc1 = nn.Linear(hidden_dim, vocab_size)
        self.activation = nn.Tanh()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.batchNorm = nn.BatchNorm1d(hidden_dim)
    def embedd(self, word_indexes):
        return self.fc1.weight.index_select(0, word_indexes)

    def forward(self, packed_sents):
                
        
        emb_seq = nn.utils.rnn.PackedSequence(
             self.embedding(packed_sents.data), packed_sents.batch_sizes)
        # embeded_data = self.embedding(packed_sents.data)
        result, _ = self.lstm(emb_seq)
       
        out = self.fc1(result.data)
        return F.log_softmax(out, dim=1)


def divider(data, size=BATCH_SIZE, time=30, window=30):
    batch = []
    count = 0
    for i in range(1, len(data) + 1, window + 1):
        count += 1
        sequence = data[i - 1:i - 1 + time + 1]
        batch.append(sequence)
        if count != 0 and count % size == 0:
            tmp_batch = batch
            tmp_batch.sort(key=lambda l: len(l), reverse=True)
            batch = []
            yield tmp_batch


def pre_process_train_data_LSTM_upgrade(name='wiki.train.txt', is_LSTM=True):
    setup_nltk()
    sliding_window_value = 30
    text = to_number(lists_to_tokens(splitting_tokens(string_to_lower(load_text(name)))))
    unique_n = unique_words(text)
    print('unique_words----->' + str(unique_n))
    mapping = create_integers(text)
    reverse_mapping = {i: k for k, i in mapping.items()}
    integers_texts = words_to_integers(text, mapping)
    ytm_batch = divider(integers_texts, 20, 30, 30)
    net = LSTM_Language_Model(27597, 100, 16, 2, 0.4)
    net.to(dev)
    optimizer = optim.Adam(net.parameters(), lr=0.01)
    train_LSTM(integers_texts, net, optimizer, 100, train=True)
    return net


def pre_process_valid_test_data_LSTM_upgrade(model, name='wiki.valid.txt', is_LSTM=True):
    setup_nltk()
    sliding_window_value = 30
    text = to_number(lists_to_tokens(splitting_tokens(string_to_lower(load_text(name)))))
    unique_n = unique_words(text)
    print('unique_words----->' + str(unique_n))
    mapping = create_integers(text)
    reverse_mapping = {i: k for k, i in mapping.items()}
    integers_texts = words_to_integers(text, mapping)
    model.to(dev)
    loss = valid(integers_texts, model)
    return loss


def train_LSTM(data, model, optimizer, clip_grads, epoch_size=3, train=False):
    train_per = []
    vali_per = []
    if train:
        for i in range(20):
            model.train()
            losses = []
            for index, sequence in enumerate(divider(data, 20)):

                x = nn.utils.rnn.pack_sequence([torch.tensor(token[:-1]) for token in sequence])
                y = nn.utils.rnn.pack_sequence([torch.tensor(token[1:]) for token in sequence])
                x,y =x.to(dev), y.to(dev)
                model.zero_grad()
                out = model(x)
                loss = F.nll_loss(out, y.data)
                losses.append(loss.item())
                loss.backward()
                if clip_grads:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
                optimizer.step()
            mean_loss = sum(losses) / len(losses)
            print("Epoch->" + ' ' + str(i))
            perplexity = np.exp(mean_loss)
            print(perplexity)
            train_per.append(perplexity)
            vali_loss = pre_process_valid_test_data_LSTM_upgrade(model, 'wiki.valid.txt')
            vali_perplexity = np.exp(vali_loss)
            vali_per.append(vali_perplexity)
            print(vali_loss)
            if vali_loss < 6.9:
                print(mean_loss, np.exp(mean_loss))
                break
        print(train_per)
        print(vali_per)
        plt.plot(train_per)
        plt.title("Train Perplexity")
        plt.xlabel("Epoch")
        plt.ylabel("perplexity")
        plt.show()
        plt.plot(vali_per)
        plt.title("Validation Perplexity")
        plt.xlabel("Epoch")
        plt.ylabel("perplexity")
        plt.show()




def valid(data, model):
            model.eval()
            losses = []
            for index, sequence in enumerate(divider(data, 20)):
                x = nn.utils.rnn.pack_sequence([torch.tensor(token[:-1]) for token in sequence])
                y = nn.utils.rnn.pack_sequence([torch.tensor(token[1:]) for token in sequence])
                x,y =x.to(dev), y.to(dev)
                out = model(x)
                loss = F.nll_loss(out, y.data)
                losses.append(loss.item())
            mean_loss = sum(losses) / len(losses)
                # if index % 150 == 0:
                #     perplexity = np.exp(loss.item())
                #     print("Batch" + ' ' + str(index))
                #     print("loss" + ' ' + str(mean_loss))
                #     print('perplexity' + ' ' + str(perplexity))
            print(mean_loss, np.exp(mean_loss))
            # print('final_perplexity_valid' + ' ' + str(perplexity))
            # print("Batch" + ' ' + str(index))
            # print("final_loss_valid" + ' ' + str(loss.item()))
            return mean_loss



In [None]:
import time
import matplotlib.pyplot as plt

def main(is_LSTM=False, use_custom_loss=False, use_valid=False, upgraded= False):
    'if upgraded equal to false run the old models, otherwise run new pipeline with upgraded LSTM '
    if upgraded == False:


        dataset = 'wiki.valid.txt' if use_valid else 'wiki.test.txt'
        to_print = "Training on an LSTM Neural Network Model" if is_LSTM else "Training on a simple Feed Forward Neural Network Model"
        print(to_print)

        train_dataset = pre_process_train_data(name='wiki.train.txt', is_LSTM=is_LSTM)
        valid_dataset = pre_process_val_train_data(name=dataset, is_LSTM=is_LSTM)
        print("Starting the timer")
        start_time = time.time()
        run_nn_model(train_dataset, valid_dataset, is_LSTM=is_LSTM, epoch=1, use_custom_loss=use_custom_loss)
        end_time = time.time() - start_time
        print("Trained in -> " + str(end_time / 60) + " minutes.")
    else:
        'pre_process_train_data_LSTM_upgrade function trains the model, the second function tests it on a valid dataset'
        model = pre_process_train_data_LSTM_upgrade()
      #  pre_process_valid_test_data_LSTM_upgrade(model, 'wiki.valid.txt')

if __name__ == '__main__':
    main(is_LSTM=False, use_custom_loss=False, use_valid=False, upgraded =True)


[nltk_data] Downloading package stopwords to /home/studio-lab-
[nltk_data]     user/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


unique_words----->27597
Epoch-> 0
1043.2250001930624


[nltk_data] Downloading package stopwords to /home/studio-lab-
[nltk_data]     user/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


unique_words----->11410
8.510806253682013 4968.167076990682
8.510806253682013
Epoch-> 1
833.089937859559


[nltk_data] Downloading package stopwords to /home/studio-lab-
[nltk_data]     user/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


unique_words----->11410
8.63985378984092 5652.503310606885
8.63985378984092
Epoch-> 2
807.0681313780915


[nltk_data] Downloading package stopwords to /home/studio-lab-
[nltk_data]     user/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


unique_words----->11410
8.69257853825887 5958.526740187342
8.69257853825887
Epoch-> 3
780.1889025010731


[nltk_data] Downloading package stopwords to /home/studio-lab-
[nltk_data]     user/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


unique_words----->11410
8.722722318898077 6140.873770393606
8.722722318898077
