In [43]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from collections import Counter
from tqdm import tqdm
import pickle
import matplotlib.pyplot as plt

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
# change to your path in google drive
YOUR_PATH_TO_WORKSPACE = "/content/drive/MyDrive/MSAI/MSAI437 - Deep Learning/"

In [4]:
os.chdir(os.path.join(YOUR_PATH_TO_WORKSPACE, "HW3"))

In [5]:
def process_data(path):
    lines = open(path, 'r').readlines()
    data = []
    for line in lines:
        line = line.strip()
        if line.startswith("=") and line.endswith("="):
            continue
        if line == "": continue
        data.append(line)
    return data

def prepare_train_data(path, n=-1):
    data = process_data(path)
    tokens = []
    for text in data: tokens += text.split(" ")
    counter = Counter(tokens)
    if n==-1: return tokens, counter, []

    inversed_rank = counter.most_common()[::-1]
    need_to_replace = set([key for key, freq in inversed_rank if freq <= n])
    replaced_tokens = []
    for token in tokens:
        replaced_tokens.append("<unk>" if token in need_to_replace else token)
    return replaced_tokens, Counter(replaced_tokens), need_to_replace

def prepare_val_test_data(path, need_to_replace=[]):
    data = process_data(path)
    tokens = []
    for text in data: tokens += text.split(" ")
    if len(need_to_replace) == 0:
      return tokens
    replaced_tokens = []
    for token in tokens:
        replaced_tokens.append("<unk>" if token in need_to_replace else token)
    return replaced_tokens

In [30]:
class Wiki2(Dataset):
    def __init__(self, data_tokens: list, vocab: dict, seq_size: int=30) -> None:
        super().__init__()

        self.data = self._index_all(data_tokens, vocab)
        self.seq_size = seq_size

    def __getitem__(self, index):
        indexed_text = torch.tensor(self.data[index:index+self.seq_size]).long()
        indexed_target = torch.tensor(self.data[index+self.seq_size]).long()
        return indexed_text, indexed_target

    def __len__(self):
        return len(self.data) - self.seq_size

    def _index_all(self, data_tokens, vocab):
        set_keys = set(vocab.keys())
        return [vocab[token] if token in set_keys else vocab['<unk>'] for token in data_tokens]

In [20]:
class RNNModel(nn.Module):
    def __init__(self, vocab_size, embed_size=100, hidden_size=100, num_layers=1):
        super().__init__()

        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.RNN(embed_size, hidden_size, num_layers, batch_first=True, dropout=0.5)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, text, hidden):
        # text: [batch_size, sequence_length]
        embedded = self.embed(text)  # [batch_size, sequence_length, embed_size]
        output, hidden = self.rnn(embedded, hidden)  # output: [batch_size, sequence_length, hidden_size]
        out = self.fc(output[:, -1, :])  # out: [batch_size, vocab_size], only feed final output of rnn to linear layer
        return out, hidden

    def init_hidden(self, batch_size):
        return torch.zeros(self.num_layers, batch_size, self.hidden_size)

In [36]:
train_tokens, train_counter, need_to_replace = prepare_train_data("./dataset/wiki2.train.txt", n=10)  # freq threshold: n
val_tokens = prepare_val_test_data("./dataset/wiki2.valid.txt", need_to_replace)
test_tokens = prepare_val_test_data("./dataset/wiki2.test.txt", need_to_replace)
print("length of vocab:", len(train_counter))
print("number of train tokens:", len(train_tokens))
print("number of valid tokens:", len(val_tokens))
print("number of test tokens:", len(test_tokens))
vocab = {key:i for i, key in enumerate(sorted(list(train_counter.keys())))}

length of vocab: 13354
number of train tokens: 2007146
number of valid tokens: 209338
number of test tokens: 235854


In [34]:
epochs = 20
batchsize = 512


train_dataset = Wiki2(train_tokens, vocab, 30)
train_loader = DataLoader(train_dataset, batchsize, shuffle=True, drop_last=True)
val_dataset = Wiki2(val_tokens, vocab, 30)
val_loader = DataLoader(val_dataset, batchsize, shuffle=False, drop_last=True)

model = RNNModel(len(train_counter), 100, 100, 1).to('cuda')
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=1e-4, weight_decay=4e-5)

train_perplexity, val_perplexity = [], []

for e in range(epochs):
    model.train()
    mean_train_perplexity = 0
    with tqdm(total=len(train_loader), desc=f'Epoch {e+1}/{epochs}', unit='batch') as pbar:
        for i, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            hidden = model.init_hidden(batchsize).cuda()

            optimizer.zero_grad()
            output, hidden = model(inputs, hidden)
            loss = criterion(output, targets)
            loss.backward()
            optimizer.step()

            mean_train_perplexity += float(torch.exp(loss))/len(train_loader)

            pbar.set_description(f'Epoch {e+1}/{epochs}, Iter {i+1}/{len(train_loader)} - perplexity: {float(torch.exp(loss)):.4f}')
            pbar.update(1)
        pbar.set_description(f'Epoch {e+1}/{epochs}, Iter {i+1}/{len(train_loader)} - perplexity: {mean_train_perplexity:.4f}')

    model.eval()
    mean_val_perplexity = 0
    with tqdm(total=len(val_loader), desc=f'Epoch (Val) {e+1}/{epochs}', unit='batch') as pbar:
        for i, (inputs, targets) in enumerate(val_loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            hidden = model.init_hidden(batchsize).cuda()

            output, hidden = model(inputs, hidden)
            loss = criterion(output, targets)

            mean_val_perplexity += float(torch.exp(loss))/len(val_loader)

            pbar.set_description(f'Epoch (Val) {e+1}/{epochs}, Iter {i+1}/{len(train_loader)} - perplexity: {float(torch.exp(loss)):.4f}')
            pbar.update(1)
        pbar.set_description(f'Epoch (Val) {e+1}/{epochs}, Iter {i+1}/{len(train_loader)} - perplexity: {mean_val_perplexity:.4f}')

    train_perplexity.append(mean_train_perplexity)
    val_perplexity.append(mean_val_perplexity)

Epoch 1/20, Iter 3920/3920 - perplexity: 1090.8049: 100%|██████████| 3920/3920 [01:27<00:00, 44.71batch/s]
Epoch (Val) 1/20, Iter 408/3920 - perplexity: 455.0215: 100%|██████████| 408/408 [00:06<00:00, 64.06batch/s]
Epoch 2/20, Iter 3920/3920 - perplexity: 463.5960: 100%|██████████| 3920/3920 [01:27<00:00, 44.87batch/s]
Epoch (Val) 2/20, Iter 408/3920 - perplexity: 364.8809: 100%|██████████| 408/408 [00:06<00:00, 60.17batch/s]
Epoch 3/20, Iter 3920/3920 - perplexity: 375.0421: 100%|██████████| 3920/3920 [01:27<00:00, 44.91batch/s]
Epoch (Val) 3/20, Iter 408/3920 - perplexity: 305.8921: 100%|██████████| 408/408 [00:06<00:00, 59.75batch/s]
Epoch 4/20, Iter 3920/3920 - perplexity: 319.8277: 100%|██████████| 3920/3920 [01:27<00:00, 44.88batch/s]
Epoch (Val) 4/20, Iter 408/3920 - perplexity: 270.6893: 100%|██████████| 408/408 [00:07<00:00, 58.03batch/s]
Epoch 5/20, Iter 3920/3920 - perplexity: 283.8253: 100%|██████████| 3920/3920 [01:27<00:00, 44.67batch/s]
Epoch (Val) 5/20, Iter 408/3920 -

In [None]:
plt.plot(train_perplexity, label='train-perplexity')
plt.plot(val_perplexity, label='val-perplexity')
plt.legend()
plt.xlabel("Epoch")
plt.ylabel("Perplexity")

In [38]:
test_dataset = Wiki2(test_tokens, vocab, 30)
test_loader = DataLoader(test_dataset, batchsize, shuffle=False, drop_last=True)

model.eval()
mean_test_perplexity = 0
with tqdm(total=len(val_loader), desc=f'Test', unit='batch') as pbar:
    for i, (inputs, targets) in enumerate(val_loader):
        inputs, targets = inputs.cuda(), targets.cuda()
        hidden = model.init_hidden(batchsize).cuda()

        output, hidden = model(inputs, hidden)
        loss = criterion(output, targets)

        mean_test_perplexity += float(torch.exp(loss))/len(test_loader)

        pbar.set_description(f'Test, Iter {i+1}/{len(test_loader)} - perplexity: {float(torch.exp(loss)):.4f}')
        pbar.update(1)
    pbar.set_description(f'Test, Iter {i+1}/{len(test_loader)} - perplexity: {mean_test_perplexity:.4f}')

Test, Iter 408/460 - perplexity: 141.4453: 100%|██████████| 408/408 [00:07<00:00, 56.54batch/s]


In [42]:
torch.save(model.state_dict(), "./checkpoint.pt")
f = open("./vocab.pkl", 'wb')
pickle.dump(vocab, f)
f.close()