In [None]:
from local.torch_basics import *
from local.test import *
from local.core import *
from local.layers import *
from local.data.all import *
from local.notebook.showdoc import show_doc
from local.optimizer import *
from local.learner import *
from local.metrics import *
from local.text.core import *
from local.text.data import *
from local.text.models.core import *
from local.text.models.awdlstm import *
from local.callback.rnn import *
from local.callback.all import *

# Integration test on Wikitext-2

> Training a Language Model on WT2

## Data

In [None]:
path = untar_data(URLs.WIKITEXT_TINY)

The dataset comes with all the wrticles concatenated. We split them to be able to shuffle at the beginning of each epoch.

In [None]:
def istitle(line):
    return len(re.findall(r'^ = [^=]* = $', line)) != 0

def read_file(filename):
    articles = L()
    with open(filename, encoding='utf8') as f:
        lines = f.readlines()
    current_article = ''
    for i,line in enumerate(lines):
        current_article += line.replace('<unk>', UNK)
        if i < len(lines)-2 and lines[i+1] == ' \n' and istitle(lines[i+2]):
            articles.append(current_article.split(' '))
            current_article = ''
    articles.append(current_article.split(' '))
    return articles

Then we put our list of tokenized texts together in an `LM_Dataset`. It will return tuples of sequences of `seq_len`, with the second sequence between the first one shifted by one on the right.

In [None]:
trn_txt = read_file(path/'train.txt')
val_txt = read_file(path/'valid.txt')

In [None]:
count = Counter([p for t in trn_txt for p in t])
vocab = make_vocab(count)

In [None]:
splits = [list(range(len(val_txt), len(val_txt)+len(trn_txt))), list(range(len(val_txt)))]
tfm = Numericalize(make_vocab(count))

In [None]:
dsrc = DataSource(val_txt+trn_txt, [tfm], filts=splits)

In [None]:
bs,sl = 104,72
train_dl = LMDataLoader(dsrc.train, bs=bs,   seq_len=sl, after_batch=[Cuda()], shuffle=True)
valid_dl = LMDataLoader(dsrc.valid, bs=2*bs, seq_len=sl, after_batch=[Cuda()])

In [None]:
dbch = DataBunch(train_dl, valid_dl)
dbch.show_batch()

Unnamed: 0,text
0,"\n = Patriarchal Cathedral of the Holy Ascension of God = \n \n The Patriarchal Cathedral of the Holy Ascension of God ( Bulgarian : xxunk xxunk „ xxunk xxunk xxunk “ , xxunk xxunk „ xxunk xxunk xxunk “ ) is a former Eastern Orthodox cathedral in the city of xxunk Tarnovo , in north central Bulgaria . Located on top of the fortified Tsarevets hill in the former capital"
1,"\n \n Some of Balliett 's "" real @-@ world ideas "" in Chasing Vermeer were "" Do coincidences mean anything ? "" and "" What is art and what makes it valuable ? "" Balliett says her "" central message "" is "" kids are powerful thinkers , and their ideas are valuable , and that adults don 't have all the answers . "" \n A book by Rita xxunk"
2,"allowing his friends time to use the collective consciousness to rebuild the guardian that had kept the beast trapped . However , in this process , Swamp Thing has his human soul removed , setting up the fourth run of the comic , relaunched shortly afterward . In the process John loses his memory , setting up the events leading up to the 200th issue . Leading up to the landmark issue"
3,"tracks were included with other background music in the Snow Original Soundtrack released on April 25 , 2003 . Before the visual novel 's release , Snow Image Album was released at xxunk 63 on December 28 , 2002 . \n Three drama CDs based on Snow have been published , the first CD volume was released by xxunk on August 22 , 2003 , focusing on Sumino Yukizuki . xxunk released"
4,"trapped 1 @.@ 7 million birds , the largest number of any nuisance species to be destroyed . In 2005 , the population in the United States was estimated at 140 million birds , around 45 % of the global total of 310 million . \n \n = = = In science and culture = = = \n \n Common starlings may be kept as pets or as laboratory animals . Austrian"
5,"battalions now involved ( the 5th Battalion , Royal West xxunk had by now been tasked on the south east side of the village ) supported by tanks , Villa Grande was finally cleared by the end of 26 December . The troops of the 8th Indian Division entered the village to find a xxunk . One correspondent described the scene "" as though a giant had xxunk on a child 's"
6,"at Lincoln 's Inn for three years . \n \n = = Marriage and family = = \n \n In November 1604 , he married Anne xxunk in a Protestant , Church of England ceremony at St Peter 's , Cornhill , where his address was registered as St Martin in the Fields . His children , including his eldest son and heir , Cecil , who was born in the winter"
7,"Emmy Award for "" Most Outstanding Personality "" . The network 's other notable programs include : \n Ted Mack 's The Original Amateur Hour , which began on radio in the 1930s under original host Edward Bowes \n The Morey Amsterdam Show , a comedy / variety show hosted by Morey Amsterdam , which started on CBS before moving to DuMont in 1949 \n Captain Video and His Video Rangers ,"
8,". By 1964 , APF was the UK 's largest commercial user of colour film , consuming more than three million feet ( 570 miles or 910 kilometres ) of stock per year . \n Alan Pattillo , a veteran xxunk and director for APF , was appointed the company 's first official script editor in late 1964 . This move was aimed to reduce the burden on Gerry Anderson who ,"
9,"\n = = Etymology = = \n \n The earliest named settlement within the domain of modern @-@ day Haifa was a city known as xxunk . Tel Shikmona Hebrew meaning "" mound of the xxunk xxunk "" ( Arabic Tell el @-@ xxunk or Tell es @-@ Samak , meaning "" mound of the fish "" ) preserved and transformed this ancient name and is mentioned once in the xxunk ("


In [None]:
%%time
for x,y in dbch.train_dl: pass

CPU times: user 4.04 s, sys: 16.3 ms, total: 4.06 s
Wall time: 4.06 s


In [None]:
nn.

## Model

In [None]:
config = awd_lstm_lm_config.copy()
config.update({'input_p': 0.6, 'output_p': 0.4, 'weight_p': 0.5, 'embed_p': 0.1, 'hidden_p': 0.2})
model = get_language_model(AWD_LSTM, len(vocab), config=config)

In [None]:
class HP():
    def __init__(self, opt, i): 
        self.opt,self.i = opt,i
    def __getitem__(self, k): 
        if k!= 'mom': return self.opt.param_groups[self.i][k]
        else: return self.opt.param_groups[self.i]['betas'][0]
    def __setitem__(self, k, v): 
        if k != 'mom': self.opt.param_groups[self.i][k] = v
        else: self.opt.param_groups[self.i]['betas'] = (v,0.99)

class AdamOpt():
    def __init__(self, params, lr, wd=0., eps=1e-7):
        self.opt = torch.optim.Adam(params, lr=lr, weight_decay=wd, eps=eps, betas=(0.9, 0.99)) 
        self.hypers = [HP(self.opt, i) for i in range(len(self.opt.param_groups))]
        
    def step(self): self.opt.step()
    def zero_grad(self): self.opt.zero_grad()

In [None]:
opt = AdamOpt(model.parameters(), lr=5e-4, wd=0.1, eps=1e-7)

In [None]:
opt_func = partial(Adam, wd=0.1, eps=1e-7)
cb_funcs = [partial(MixedPrecision, clip=0.1), partial(RNNTrainer, alpha=2, beta=1)]

In [None]:
learn = Learner(model, dbch, loss_func=CrossEntropyLossFlat(), opt_func=opt_func, cb_funcs=cb_funcs, metrics=[accuracy, Perplexity()])

In [None]:
learn.fit_one_cycle(1, 5e-3, moms=(0.8,0.7,0.8), div=10)

epoch,train_loss,valid_loss,accuracy,perplexity,time
0,7.702378,6.909925,0.058589,1002.172546,01:03
