## Imports

In [46]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [47]:
import numpy as np
import sys
LIBS_AI = '/home/ubuntu/fastai/'
sys.path.append(LIBS_AI)

In [48]:
from fastai.io import *
from fastai.conv_learner import *
from fastai.column_data import *

## Load Data

In [49]:
DATA_PATH_EN = '../data/comtrans_en.txt'
DATA_PATH_DE = '../data/comtrans_de.txt'

In [50]:
raw_data = open(DATA_PATH_DE).read()

In [51]:
print('corpus length:', len(raw_data))

corpus length: 4245799


In [52]:
raw_data[:400]

'Wiederaufnahme der Sitzungsperiode\nIch erklre die am Freitag , dem 17. Dezember unterbrochene Sitzungsperiode des Europischen Parlaments fr wiederaufgenommen , wnsche Ihnen nochmals alles Gute zum Jahreswechsel und hoffe , da Sie schne Ferien hatten .\nWie Sie feststellen konnten , ist der gefrchtete " Millenium-Bug " nicht eingetreten . Doch sind Brger einiger unserer Mitgliedstaaten Opfer von sch'

## Preprocessing

In [53]:
chars = sorted(list(set(raw_data)))
vocab_size = len(chars)+1
print('total chars:', vocab_size)

total chars: 82


In [54]:
''.join(chars)

'\n !"$%\'()+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]abcdefghijklmnopqrstuvwxyz'

In [55]:
char_indices = {c: i for i, c in enumerate(chars)}
indices_char = {i: c for i, c in enumerate(chars)}

In [56]:
idx = [char_indices[c] for c in raw_data]
idx[:15]

[49, 63, 59, 58, 59, 72, 55, 75, 60, 68, 55, 62, 67, 59, 1]

In [57]:
''.join(indices_char[i] for i in idx[:15])

'Wiederaufnahme '

In [58]:
cs=8

In [59]:
c_in_dat = [[idx[i+j] for i in range(cs)] for j in range(len(idx)-cs)]
c_out_dat = [idx[j+cs] for j in range(len(idx)-cs)]

In [60]:
xs = np.stack(c_in_dat, axis=0)

In [61]:
xs.shape

(4245791, 8)

In [62]:
y = np.stack(c_out_dat)

In [63]:
val_idx = get_cv_idxs(len(idx)-cs-1)

## RNN Model

In [64]:
n_fac = 30; n_hidden = 10;

In [65]:
class CharRnn(nn.Module):
    def __init__(self, vocab_size, n_fac):
        super().__init__()
        self.e = nn.Embedding(vocab_size, n_fac)
        self.rnn = nn.RNN(n_fac, n_hidden)
        self.l_out = nn.Linear(n_hidden, vocab_size)
        
    def forward(self, *cs):
        bs = cs[0].size(0)
        h = V(torch.zeros(1, bs, n_hidden))
        inp = self.e(torch.stack(cs))
        outp,h = self.rnn(inp, h)
        
        return F.log_softmax(self.l_out(outp[-1]), dim=-1)

In [66]:
md = ColumnarModelData.from_arrays('.', val_idx, xs, y, bs=512)

In [67]:
m = CharRnn(vocab_size, n_fac).cuda()
opt = optim.Adam(m.parameters(), 1e-2)

In [68]:
fit(m, md, 5, opt, F.nll_loss)

HBox(children=(IntProgress(value=0, description='Epoch', max=5), HTML(value='')))

epoch      trn_loss   val_loss                                
    0      2.143929   2.162634  
    1      2.146127   2.154928                                
    2      2.136336   2.156804                                
    3      2.127147   2.14134                                 
    4      2.132018   2.151597                                



[array([2.1516])]

In [69]:
def get_next(inp):
    idxs = T(np.array([char_indices[c] for c in inp]))
    p = m(*VV(idxs))
    i = np.argmax(to_np(p))
    return chars[i]

In [70]:
get_next('Ic')

'h'

In [71]:
def get_next_n(inp, n):
    res = inp
    for i in range(n):
        c = get_next(inp)
        res += c
        inp = inp[1:]+c
    return res

In [72]:
get_next_n('art', n=5)

'art der '

## Tuning RNN 

In [73]:
set_lrs(opt, 1e-3)
fit(m, md, 5, opt, F.nll_loss)

HBox(children=(IntProgress(value=0, description='Epoch', max=5), HTML(value='')))

epoch      trn_loss   val_loss                                
    0      2.115914   2.122613  
    1      2.132142   2.12201                                 
    2      2.116861   2.121286                                
    3      2.121428   2.120786                                
    4      2.139819   2.120255                                



[array([2.12026])]

In [74]:
get_next_n('Ich ', n=5)

'Ich der d'