In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

import numpy as np
np.random.seed(0)
from tqdm import tqdm

import pickle
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
import re
from nltk import tokenize
import nltk
nltk.download('punkt')
import string

Mounted at /content/drive


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [3]:
embedding_model = pickle.load(open('/content/drive/MyDrive/NLP/glove-wiki-gigaword-300-2.p','rb'))

In [None]:
len(embedding_model.index_to_key)

400000

In [5]:
# Hyper-parameter
hidden_size = 256
num_classes = len(embedding_model.index_to_key)
num_epochs = 50
batch_size = 50
learning_rate = 0.001

vector_length = 300
input_size = vector_length
sequence_length = 5
num_layers = 2


In [4]:
with open("/content/drive/MyDrive/NLP/alice_in_wonderland.txt", 'r') as file:
    text_file = file.read().lower()

In [6]:
class MyDataset(Dataset):
  def __init__(self, text_file):
    text = tokenize.word_tokenize(text_file)
    fil_doc_w2index = []
    x = []
    y = []
    for word in text:
        try:
          fil_doc_w2index.append(embedding_model.key_to_index[word])
        except Exception:
          pass
    for i in range(sequence_length,len(fil_doc_w2index)):
        x.append(torch.tensor(fil_doc_w2index[i-sequence_length:i]))
        y.append(torch.tensor(fil_doc_w2index[i]))
    self.data = x
    self.labels = y

  def __len__(self):
    return len(self.labels)
  
  def __getitem__(self, index):
    return self.data[index], self.labels[index]

In [7]:
dataset = MyDataset(text_file)

In [None]:
#lengths = [round(len(dataset)*0.8), round(len(dataset)*0.2)]
#train_dataset, test_dataset = torch.utils.data.random_split(dataset, lengths)

In [8]:
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
#val_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class LSTM(nn.Module):
    def __init__(self, pretrained_model, input_size, hidden_size, num_layers, num_classes):
        super(LSTM, self).__init__()
        weights = torch.FloatTensor(pretrained_model.vectors)
        self.embedding = nn.Embedding.from_pretrained(weights)


        self.num_layers =  num_layers
        self.hidden_size = hidden_size

        #self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        #self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)

        # x -> (batch_size, sequence_length, input_size) because batch_size = true
        self.fc = nn.Linear(hidden_size, num_classes)
        self.dropout1 = nn.Dropout(0.40)
        self.dropout2 = nn.Dropout(0.40)

    def forward(self, x):
        out = self.dropout1(self.embedding(x))
        # initial hidden state size is always (num_layer, batch_size, hidden_size)
        # h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        # c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        h = None

        #out, _ = self.rnn(x, h0)
        #out, _ = self.gru(x, h0)
        out, _ = self.lstm(out, h)
        
        # out -> (batch_size, sequence_length, hidden_size) because batch_size = true
        # if out.dim() == 2:
        #     out = out[-1, :] # only the last time step
        # else:
        #     out = out[:, -1, :] # only the last time step
        out = out[..., -1, :] # only the last time step
        out = self.fc(self.dropout2(out))
        return out

net = LSTM(embedding_model, input_size, hidden_size, num_layers, num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=learning_rate)
#net.embedding.weight.requires_grad = False
# optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.5)

In [11]:
#net = net.float()
net.train()
for epoch in range(num_epochs):  # loop over the dataset multiple times
    print("\nStarting epoch {}".format(epoch+1))
    
    total = 0
    running_loss = 0.0

    # to make a beautiful progress bar
    loader = tqdm(enumerate(train_loader), total=len(train_loader))
    for i, data in loader:
        # get the data points
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        # zero the parameter gradients (else, they are accumulated)
        optimizer.zero_grad()

        # forward the data through the network
        outputs = net(inputs)
        # calculate the loss given the output of the network and the target labels
        loss = criterion(outputs, labels)
        # calculate the gradients of the network w.r.t. its parameters
        loss.backward()
        # Let the optimiser take an optimization step using the calculated gradients
        optimizer.step()
        
        running_loss += loss
        total += outputs.size(0)

        loader.set_description("loss: {:.5f}".format(running_loss/total))

print('Finished Training')


Starting epoch 1


loss: 0.13317: 100%|██████████| 683/683 [00:43<00:00, 15.73it/s]


Starting epoch 2



loss: 0.11234: 100%|██████████| 683/683 [00:41<00:00, 16.33it/s]


Starting epoch 3



loss: 0.10423: 100%|██████████| 683/683 [00:41<00:00, 16.32it/s]


Starting epoch 4



loss: 0.09808: 100%|██████████| 683/683 [00:41<00:00, 16.28it/s]


Starting epoch 5



loss: 0.09341: 100%|██████████| 683/683 [00:41<00:00, 16.35it/s]


Starting epoch 6



loss: 0.08963: 100%|██████████| 683/683 [00:41<00:00, 16.27it/s]


Starting epoch 7



loss: 0.08634: 100%|██████████| 683/683 [00:41<00:00, 16.27it/s]


Starting epoch 8



loss: 0.08320: 100%|██████████| 683/683 [00:41<00:00, 16.38it/s]


Starting epoch 9



loss: 0.08025: 100%|██████████| 683/683 [00:41<00:00, 16.31it/s]


Starting epoch 10



loss: 0.07769: 100%|██████████| 683/683 [00:42<00:00, 16.23it/s]


Starting epoch 11



loss: 0.07497: 100%|██████████| 683/683 [00:41<00:00, 16.36it/s]


Starting epoch 12



loss: 0.07247: 100%|██████████| 683/683 [00:41<00:00, 16.39it/s]


Starting epoch 13



loss: 0.07000: 100%|██████████| 683/683 [00:41<00:00, 16.35it/s]


Starting epoch 14



loss: 0.06747: 100%|██████████| 683/683 [00:41<00:00, 16.31it/s]


Starting epoch 15



loss: 0.06526: 100%|██████████| 683/683 [00:41<00:00, 16.37it/s]


Starting epoch 16



loss: 0.06306: 100%|██████████| 683/683 [00:41<00:00, 16.36it/s]


Starting epoch 17



loss: 0.06086: 100%|██████████| 683/683 [00:41<00:00, 16.26it/s]


Starting epoch 18



loss: 0.05900: 100%|██████████| 683/683 [00:41<00:00, 16.33it/s]


Starting epoch 19



loss: 0.05689: 100%|██████████| 683/683 [00:41<00:00, 16.35it/s]


Starting epoch 20



loss: 0.05521: 100%|██████████| 683/683 [00:41<00:00, 16.34it/s]


Starting epoch 21



loss: 0.05360: 100%|██████████| 683/683 [00:41<00:00, 16.33it/s]


Starting epoch 22



loss: 0.05188: 100%|██████████| 683/683 [00:41<00:00, 16.33it/s]


Starting epoch 23



loss: 0.05034: 100%|██████████| 683/683 [00:41<00:00, 16.35it/s]


Starting epoch 24



loss: 0.04881: 100%|██████████| 683/683 [00:41<00:00, 16.30it/s]


Starting epoch 25



loss: 0.04729: 100%|██████████| 683/683 [00:41<00:00, 16.37it/s]


Starting epoch 26



loss: 0.04609: 100%|██████████| 683/683 [00:41<00:00, 16.33it/s]


Starting epoch 27



loss: 0.04460: 100%|██████████| 683/683 [00:41<00:00, 16.36it/s]


Starting epoch 28



loss: 0.04362: 100%|██████████| 683/683 [00:41<00:00, 16.28it/s]


Starting epoch 29



loss: 0.04238: 100%|██████████| 683/683 [00:41<00:00, 16.30it/s]


Starting epoch 30



loss: 0.04131: 100%|██████████| 683/683 [00:41<00:00, 16.34it/s]


Starting epoch 31



loss: 0.04008: 100%|██████████| 683/683 [00:41<00:00, 16.41it/s]


Starting epoch 32



loss: 0.03904: 100%|██████████| 683/683 [00:41<00:00, 16.41it/s]


Starting epoch 33



loss: 0.03801: 100%|██████████| 683/683 [00:41<00:00, 16.34it/s]


Starting epoch 34



loss: 0.03705: 100%|██████████| 683/683 [00:41<00:00, 16.39it/s]


Starting epoch 35



loss: 0.03622: 100%|██████████| 683/683 [00:41<00:00, 16.34it/s]


Starting epoch 36



loss: 0.03515: 100%|██████████| 683/683 [00:41<00:00, 16.34it/s]


Starting epoch 37



loss: 0.03455: 100%|██████████| 683/683 [00:41<00:00, 16.42it/s]


Starting epoch 38



loss: 0.03340: 100%|██████████| 683/683 [00:41<00:00, 16.41it/s]


Starting epoch 39



loss: 0.03280: 100%|██████████| 683/683 [00:41<00:00, 16.31it/s]


Starting epoch 40



loss: 0.03217: 100%|██████████| 683/683 [00:41<00:00, 16.38it/s]


Starting epoch 41



loss: 0.03131: 100%|██████████| 683/683 [00:41<00:00, 16.38it/s]


Starting epoch 42



loss: 0.03063: 100%|██████████| 683/683 [00:41<00:00, 16.40it/s]


Starting epoch 43



loss: 0.02993: 100%|██████████| 683/683 [00:41<00:00, 16.40it/s]


Starting epoch 44



loss: 0.02950: 100%|██████████| 683/683 [00:41<00:00, 16.34it/s]


Starting epoch 45



loss: 0.02876: 100%|██████████| 683/683 [00:41<00:00, 16.33it/s]


Starting epoch 46



loss: 0.02809: 100%|██████████| 683/683 [00:41<00:00, 16.37it/s]


Starting epoch 47



loss: 0.02768: 100%|██████████| 683/683 [00:41<00:00, 16.40it/s]


Starting epoch 48



loss: 0.02706: 100%|██████████| 683/683 [00:41<00:00, 16.38it/s]


Starting epoch 49



loss: 0.02677: 100%|██████████| 683/683 [00:41<00:00, 16.40it/s]


Starting epoch 50



loss: 0.02619: 100%|██████████| 683/683 [00:41<00:00, 16.42it/s]

Finished Training





In [None]:
net.eval()
class Accuracy:
    """A class to keep track of the accuracy while training"""
    def __init__(self):
        self.correct = 0
        self.total = 0
        
    def reset(self):
        """Resets the internal state"""
        self.correct = 0
        self.total = 0
        
    def update(self, output, labels):
        """
        Updates the internal state to later compute the overall accuracy
        
        output: the output of the network for a batch
        labels: the target labels
        """
        _, predicted = torch.max(output.data, 1) # predicted now contains the predicted class index/label
        
        self.total += labels.size(0)
        self.correct += (predicted == labels).sum().item() # .item() gets the number, not the tensor

    def compute(self):
        return self.correct/self.total

accuracy = Accuracy()

accuracy.reset()
# Gradients are calculated on the forward pass for every iteration.
# As we do not need gradients now, we can disable the calculation.
with torch.no_grad():
    for data in tqdm(train_loader):
        # get the data points
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        # forward the data through the network
        outputs = net(inputs)
        
        accuracy.update(outputs, labels)

print("Accuracy: {:.2f}%".format(100 * accuracy.compute()))

# accuracy.reset()        
# with torch.no_grad():
#     for data in tqdm(val_loader):
#         # get the data points
#         inputs, labels = data
#         inputs, labels = inputs.to(device), labels.to(device)
#         # forward the data through the network
#         outputs = net(inputs)
        
#         accuracy.update(outputs, labels)
        
# print("\nTesting Accuracy: {:.2f}%".format(100 * accuracy.compute()))

100%|██████████| 682/682 [00:04<00:00, 153.93it/s]

Accuracy: 20.56%





In [13]:
pickle.dump(net,open('/content/drive/MyDrive/NLP/text_generation.p','wb'))

In [None]:
net = pickle.load(open('/content/drive/MyDrive/NLP/text_generation.p','rb'))

In [14]:
net.eval()
#starting_words = 'Alice goes to the forest'
starting_words = 'The cat screams to Alice'
tokenized_starting_words = tokenize.word_tokenize(starting_words.lower())
inputs = torch.tensor([embedding_model.key_to_index[word] for word in tokenized_starting_words])
for i in range(100):
    next_word = torch.argmax(net(inputs.to(device).long()))
    if i == 0:
        print(starting_words, end = ' ')
    print(embedding_model.index_to_key[next_word], end = ' ')
    new_inputs = torch.zeros(5+i+1)
    new_inputs[0:5+i] = inputs
    new_inputs[5+i] = next_word
    inputs = new_inputs

The cat screams to Alice . ` i 'm a hatter . ' and she 's such a capital one of the cakes , and then the other -- the was just , ' the gryphon replied in a low , low , with a eyes eyes , and the baby and the blades of the , and the bright were all getting at the mushroom with his knuckles . ` i 'm not a bit , it 's worth the game , ' the gryphon whispered in reply . ` i do n't know what a mock turtle , you know . ' ` 