In [None]:
import torch
import torch.nn as nn
from torch.functional import F
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import re
from collections import Counter
import random
from IPython.display import clear_output
import math

In [None]:
%load_ext autoreload
%autoreload 2

# Get Data

### The Critique of Pure Reason

In [None]:
with open('reason.txt') as f:
    fulltext = f.read()

In [None]:
start = fulltext.index('PREFACE TO THE FIRST EDITION 1781')
start

In [None]:
end = fulltext.index('End of the Project Gutenberg EBook')
end

In [None]:
fulltext = fulltext[start:end]

In [None]:
fulltext = fulltext.lower()

In [None]:
len(fulltext)

# Tokenization

In [None]:
def tokenize(text):
    text = re.sub('_', '', text)
    text = re.sub('\d+', '', text)
    fulltoks = re.findall(r"[\w']+|[,.?;:]", text)
    return fulltoks

In [None]:
fulltoks = tokenize(fulltext)

In [None]:
vocab = [tok for tok,count in Counter(fulltoks).most_common()]

In [None]:
len(vocab)

# Transforms

In [None]:
from helpers import BaseTransform, TokTransform

In [None]:
tok_tfm = TokTransform(vocab)

In [None]:
data = tok_tfm.encode(fulltoks)
x = data[:-1]
y = data[1:]
assert len(x) == len(y)
assert len(data) == len(x)+1

# Chunking

In [None]:
chunk_sz = 50

In [None]:
n_chunks = math.floor(len(x) / chunk_sz)
n_chunks

In [None]:
x = x[:n_chunks*chunk_sz]; x = x.chunk(n_chunks)
y = y[:n_chunks*chunk_sz]; y = y.chunk(n_chunks)

In [None]:
assert(len(x) == len(y))

# Train / Val split

In [None]:
def shuffle_same(x_set, y_set):
        "Shuffle both x_set and y_set, but keep them lined up with each other"
        zipped = list(zip(x_set, y_set))
        random.shuffle(zipped)
        return list(zip(*zipped))

In [None]:
x,y = shuffle_same(x,y)

In [None]:
len(x)

In [None]:
cut = int(len(x) * .8)
cut

In [None]:
range(cut), range(cut,len(x))

In [None]:
x_train, x_val = x[:cut], x[cut:]
y_train, y_val = y[:cut], y[cut:]

In [None]:
assert len(x_train) == len(y_train)
assert len(x_val) == len(y_val)

# Dataloading

In [None]:
from helpers import DataLoader, DataLoaders

In [None]:
bs = 16
dl_train = DataLoader(x_train, y_train, bs)
dl_val = DataLoader(x_val, y_val, bs)
dls = DataLoaders(dl_train, dl_val)

# Model

In [None]:
from model import LangModel

In [None]:
voc_sz = tok_tfm.count
emb_sz = 200
hid_sz = 300

In [None]:
model = LangModel(voc_sz, emb_sz, hid_sz)

# Loss

In [None]:
def LM_loss(preds, targ):
    preds = preds.view(-1, tok_tfm.count)
    targ = targ.view(-1)
    # don't include xxunk indices
    preds = preds[targ!=0]
    targ = targ[targ!=0]
    return F.cross_entropy(preds, targ)

# Optimizer

In [None]:
opt = torch.optim.Adam(model.parameters(), lr=.001)

# Metrics (accuracy)

In [None]:
def accuracy(preds, targ):
    # don't include xxunk indices
    preds = preds[targ!=0]
    targ = targ[targ!=0]
    nums = preds.argmax(dim=-1)
    return (nums == targ).float().mean()

# Learner

In [None]:
from helpers import Learner

In [None]:
learn = Learner(dls, model, LM_loss, opt, accuracy)

In [None]:
learn.fit(3)

In [None]:
learn.print_logs()

# Saving

In [None]:
torch.save(learn.model, 'saves/model')

In [None]:
import pickle

In [None]:
pickle.dump(tok_tfm, open('saves/tok_tfm.p', 'wb'))