# Pretraining on WikiText103

In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [None]:
#export
from exp.nb_12a import *

## Data

In [None]:
#path = datasets.Config().data_path()
#version = '103' #2

In [None]:
#! wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-{version}-v1.zip -P {path}
#! unzip -q -n {path}/wikitext-{version}-v1.zip  -d {path}
#! mv {path}/wikitext-{version}/wiki.train.tokens {path}/wikitext-{version}/train.txt
#! mv {path}/wikitext-{version}/wiki.valid.tokens {path}/wikitext-{version}/valid.txt
#! mv {path}/wikitext-{version}/wiki.test.tokens {path}/wikitext-{version}/test.txt

WT103 comes in one big text file. If we want to shuffle at the beginning of each epoch, we have to cut it into different chunks.

In [None]:
path = datasets.Config().data_path()/'wikitext-103'

In [None]:
def istitle(line):
    return len(re.findall(r'^ = [^=]* = $', line)) != 0

In [None]:
def read_wiki(filename):
    articles = []
    with open(filename, encoding='utf8') as f:
        lines = f.readlines()
    current_article = ''
    for i,line in enumerate(lines):
        current_article += line
        if i < len(lines)-2 and lines[i+1] == ' \n' and istitle(lines[i+2]):
            current_article = current_article.replace('<unk>', UNK)
            articles.append(current_article)
            current_article = ''
    current_article = current_article.replace('<unk>', UNK)
    articles.append(current_article)
    return articles

In [None]:
train = TextList(read_wiki(path/'train.txt'), path=path)
valid = TextList(read_wiki(path/'valid.txt'), path=path)

In [None]:
len(train), len(valid)

In [None]:
sd = SplitData(train, valid)

In [None]:
proc_tok, proc_num = TokenizeProcessor(), NumericalizeProcessor()

The `lambda x: 0` means that all x get a 0 label.

In [None]:
ll = label_by_func(sd, lambda x: 0, proc_x = [proc_tok, proc_num])

In [None]:
pickle.dump(ll, open(path/'ld.pkl', 'wb'))

In [None]:
ll = pickle.load(open(path/'ld.pkl', 'rb'))

In [None]:
bs, bptt = 64, 70
data = lm_databunchify(ll, bs, bptt)

In [None]:
vocab = ll.train.proc_x[-1].vocab

In [None]:
len(vocab)

## Model

In [None]:
dps = np.array([0.1, 0.15, 0.25, 0.02, 0.2]) * 0.2

In [None]:
tok_pad = vocab.index(PAD)

In [None]:
tok_pad

In [None]:
emb_sz, nh, nl = 300, 300, 2

In [None]:
model = get_language_model(len(vocab), emb_sz, nh, nl, tok_pad, *dps)

In [None]:
cbs = [
    partial(AvgStatsCallback, accuracy_flat),
    CudaCallback, Recorder,
    partial(GradientClipping, clip=0.1),
    partial(RNNTrainer, alpha=2., beta=1.),
    ProgressCallback
]

In [None]:
learn = Learner(model, data, cross_entropy_flat, lr=5e-3, cb_funcs=cbs, opt_func=adam_opt())

In [None]:
lr = 5e-3

In [None]:
sched_cos??

In [None]:
def cos_1cycle_anneal(start, high, end):
    return [sched_cos(start, high), sched_cos(high, end)]

In [None]:
sched_lr  = combine_scheds([0.3,0.7], cos_1cycle_anneal(lr/10., lr, lr/1e5))
sched_mom = combine_scheds([0.3,0.7], cos_1cycle_anneal(0.8, 0.7, 0.8))

In [None]:
ll = np.arange(0, 1, .01)

In [None]:
plt.plot(ll, [sched_lr(l) for l in ll])

In [None]:
plt.plot(ll, [sched_mom(l) for l in ll])

In [None]:
cbsched = [ParamScheduler('lr', sched_lr), ParamScheduler('mom', sched_mom)]

In [None]:
learn.fit(10, cbs=cbsched)

In [None]:
torch.save(learn.model.state_dict(), path/'pretrained.pth')
pickle.dump(vocab, open(path/'vocab.pkl', 'wb'))