In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision 
from torchvision import transforms

from torch.utils.data import Dataset, DataLoader

from tqdm.autonotebook import tqdm

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow

import pandas as pd

from sklearn.metrics import accuracy_score

import time
import os

In [2]:
lesson_folder='/Users/msarica/Desktop/DS606/DS_Capstone/Deliverable3/'
# lesson_folder = 'drive/My Drive/Colab Notebooks/606/'
def delete_file(file_name):
    file = lesson_folder + file_name
    if os.path.exists(file) == False:
        print ("file doesn't exist")
    else:
        os.remove(file)

# delete_file('606_model2')

def get_filename(file_name):
    return  lesson_folder + file_name


In [3]:
from io import BytesIO
from zipfile import ZipFile
from urllib.request import urlopen
import re

def get_data_from_url(url):
    resp = urlopen(url) # snow white
    text = resp.read()
    text = text.decode('utf-8').lower()
    return text

# def read_file(filename):
#     with open(filename, 'r') as f:
#         text = f.read().lower()
#         return text

urls = [
    "https://www.cs.cmu.edu/~spok/grimmtmp/042.txt",
#     'http://www.umich.edu/~umfandsf/other/ebooks/alice30.txt'
]

def get_data_as_text(story_urls=urls):
    all_data = []
    
    for url in story_urls: 
        text = get_data_from_url(url)
        # print(text[:100])
        all_data.append(text)
    
    return " ".join(all_data)

def prepare_data():
    # text = read_file()
    text = get_data_as_text()
    words = text.split()
    unique_words = set(words)
    int_to_vocab = {(key+1): word for key, word in enumerate(unique_words)}
    vocab_to_int = {word: key for key, word in int_to_vocab.items()}
    vocab_to_int['_unknown'] = 0
    int_to_vocab[0] = '_unknown'
    vocabulary_size = len(int_to_vocab)

    print('Vocabulary size', vocabulary_size)
    return words, int_to_vocab, vocab_to_int

def get_word_index(w):
    if w not in vocab_to_int: 
        return 0
    return vocab_to_int[w]

words, int_to_vocab, vocab_to_int = prepare_data()

Vocabulary size 790


In [4]:
class AutoRegressiveDataset(Dataset):

    def __init__(self, words, word_to_int, seq_size):
        self.words = words
        self.word_to_int = word_to_int
        self.seq_size = seq_size
        self.int_text = [word_to_int[w] for w in words]

    def __len__(self):
        return len(self.words) // self.seq_size
  
    def __getitem__(self, i):
        seq_size = self.seq_size

        x = self.int_text[i:i+seq_size]
        y = self.int_text[i+1: i+seq_size+1]

        return torch.tensor(x, dtype=torch.int64), torch.tensor(y, dtype=torch.int64)

class AutoRegressive(nn.Module):
    def __init__(self, vocabulary_length, sequence_size, embedding_size, lstm_size):
        super(AutoRegressive, self).__init__()
        self.seq_size = sequence_size
        self.lstm_size = lstm_size
        self.embedding = nn.Embedding(
            vocabulary_length, 
            embedding_size
            )
        self.lstm = nn.LSTM(embedding_size,
                            lstm_size,
                            batch_first=False)
        self.dense = nn.Linear(lstm_size, vocabulary_length)
    
    def forward(self, input, previous_state ): 
        embed = self.embedding(input)
        output, state = self.lstm(embed, previous_state)
        logits = self.dense(output)

        return logits, state

    def zero_state(self, batch_size):
        return (torch.zeros(1, batch_size, self.lstm_size),
                torch.zeros(1, batch_size, self.lstm_size))


In [5]:
def get_loss_and_train_op(model, lr=0.001):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    return criterion, optimizer

In [6]:
def train(
    model,
    train_loader,
    optimizer,
    loss_func,
    device,
    lr_schedule=None,
    epochs=50
):
    batch_size = train_loader.batch_size
    gradients_norm=5

    for epoch in tqdm(range(epochs), desc="Epoch", disable=False):
        state_h, state_c = model.zero_state(batch_size)
        state_h = state_h.to(device)
        state_c = state_c.to(device)

        running_loss = 0.0
        for x, y in tqdm(train_loader, desc="Train Batch", leave=False, disable=False):
            model.train()

            optimizer.zero_grad()

            x = x.to(device)
            y = y.to(device)

            logits, (state_h, state_c) = model(x, (state_h, state_c))
            loss = loss_func(logits.transpose(1, 2), y)

            loss.backward()

            running_loss += loss.item() * batch_size

            state_h = state_h.detach()
            state_c = state_c.detach()

            # to prevent gradient becoming too large and Nan
            _ = torch.nn.utils.clip_grad_norm_(
                model.parameters(), gradients_norm)

            optimizer.step()

            # #In PyTorch, the convention is to update the learning rate after every epoch
            if not lr_schedule is None:
                if isinstance(lr_schedule, torch.optim.lr_scheduler.ReduceLROnPlateau):
                    lr_schedule.step(running_loss)
                else:
                    lr_schedule.step()

In [30]:
def predict_sentence(device, model, initial_seed, vocab_to_int, int_to_vocab, top_k=5, length=50):
    def get_word_index(w):
        if w not in vocab_to_int: 
            return 0
        return vocab_to_int[w]

    model.eval()
    words = initial_seed.split()

    state_h, state_c = model.zero_state(1)
    state_h = state_h.to(device)
    state_c = state_c.to(device)
    
    for w in words:
        ix = torch.tensor([[get_word_index(w)]]).to(device)
        output, (state_h, state_c) = model(ix, (state_h, state_c))
    
    _, top_ix = torch.topk(output[0], k=top_k)
    choices = top_ix.tolist()
    choice = np.random.choice(choices[0])

    words.append(int_to_vocab[choice])
    
    while (len(words) > length and (len(words) < 200 or words[len(words)-1].strip().endswith('.'))) == False:
        ix = torch.tensor([[choice]]).to(device)
        output, (state_h, state_c) = model(ix, (state_h, state_c))

        _, top_ix = torch.topk(output[0], k=top_k)
        choices = top_ix.tolist()
        choice = np.random.choice(choices[0])
        word = int_to_vocab[choice]
        words.append(word)

    return (' '.join(words))

# predict_sentence(
#     device, 
#     model, 
#     'queen said', 
#     vocab_to_int, 
#     int_to_vocab
#     )

In [25]:
def main():
    batch_size=5
    embedding_size=64
    lstm_size=64
    epochs = 50
    
    autoRegData = AutoRegressiveDataset( words,vocab_to_int, seq_size=5)
    train_loader = DataLoader(autoRegData, batch_size=batch_size, shuffle=False)
    model = AutoRegressive(len(words), 32, 64, 16)
    model = model.to(device)

    criterion, optimizer = get_loss_and_train_op(model, 0.01)
    schedule = torch.optim.lr_scheduler.StepLR(optimizer, epochs//4, gamma=0.3)

    train(
        model=model, 
        train_loader=train_loader, 
        optimizer=optimizer, 
        loss_func=criterion,
        device=device,
        lr_schedule=schedule,
        epochs=50)

    return model, optimizer


In [9]:
def save_model(file_name, model, optimizer, epoch=None):
    torch.save({
                # 'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'words': words,
                'int_to_vocab': int_to_vocab,
                'vocab_to_int': vocab_to_int
                # 'results' : results,
                # 'early_stop': early_stop_flag
                }, file_name)


In [26]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [27]:
model, optimizer = main()


HBox(children=(IntProgress(value=0, description='Epoch', max=50, style=ProgressStyle(description_width='initia…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Train Batch', max=125, style=ProgressStyle(description_width=…




In [28]:
predict_sentence(
    device, 
    model, 
    'queen said', 
    vocab_to_int, 
    int_to_vocab
    )

KeyError: 1241

In [13]:
file_name = lesson_folder + 'story/story.pt'
save_model(file_name, model, optimizer)