In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import torch
torch.cuda.set_device(0)
from fastai import *
from fastai.text import *
from pathlib import Path
import pickle
import fastText as ft
from collections import defaultdict

In [3]:
PATH=Path('/mnt/data/extracts/translate')
DATA=PATH/'data/'
TMP=PATH/'tmp/'
TMP.mkdir(exist_ok=True)
MODELS=Path('/mnt/models/translate')

## Load Data

In [4]:
qs = pickle.load((DATA/'fr-en-qs.pkl').open('rb'))
en_qs, fr_qs = zip(*qs)

In [5]:
en_tok = pickle.load((DATA/'en_tok.pkl').open('rb'))
fr_tok = pickle.load((DATA/'fr_tok.pkl').open('rb'))

In [7]:
def tok2ids(tok, lim=40000):
    flat = [w for l in tok for w in l]
    toi = Counter(flat).most_common(lim)
    itos = [o for o, c in toi]
    itos.insert(0, '_bos_')
    itos.insert(1, '_pad_')
    itos.insert(2, '_eos_')
    itos.insert(3, '_unk')
    stoi = {t:i for i, t in enumerate(itos)}
    ids = [([0] + [stoi[t] for t in q] + [2]) for q in tok]
    return ids, itos, stoi

In [8]:
fr_ids, fr_itos, fr_stoi = tok2ids(fr_tok)
en_ids, en_itos, en_stoi = tok2ids(en_tok)

## Load Embeddings

In [194]:
en_ft = ft.load_model(str((DATA/'wiki.en.bin')))
fr_ft = ft.load_model(str((DATA/'wiki.fr.bin')))

def get_vecs(ftxt):
    return {w: ftxt.get_word_vector(w) for w in ftxt.get_words()}

en_vecs = get_vecs(en_ft)
fr_vecs = get_vecs(fr_ft)

In [10]:
en_vecs = pickle.load(open(DATA/'wiki.en.pkl','rb'))
fr_vecs = pickle.load(open(DATA/'wiki.fr.pkl','rb'))

## Create Data Loaders

In [11]:
class Seq2SeqDataset(Dataset):
    def __init__(self, x, y): 
        super(Seq2SeqDataset, self).__init__()
        self.x = x
        self.y = y
        
    def __getitem__(self, i):
        return A(self.x[i], self.y[i])
        
    def __len__(self):
        return len(self.x)

In [12]:
en_cutoff = int(np.percentile([len(q) for q in en_ids], q=97))
fr_cutoff = int(np.percentile([len(q) for q in fr_ids], q=95))

In [13]:
en_arr = np.array([q[:en_cutoff] for q in en_ids])
fr_arr = np.array([q[:fr_cutoff] for q in fr_ids])

In [14]:
val_idx = np.random.rand(len(en_arr)) < 0.2
en_tr, fr_tr = en_arr[~val_idx], fr_arr[~val_idx]
en_val, fr_val = en_arr[val_idx], fr_arr[val_idx]

In [15]:
trn_ds = Seq2SeqDataset(fr_tr, en_tr)
val_ds = Seq2SeqDataset(fr_val, en_val)

In [16]:
bs = 5
tr_samp = SortishSampler(en_tr, key=lambda x: len(en_tr[x]), bs=bs)
val_samp = SortSampler(en_val, key=lambda x: len(en_val[x]))

In [17]:
trn_dl = DataLoader(trn_ds, batch_size=bs, sampler=tr_samp, transpose=True, transpose_y=True,
                    pre_pad=False, pad_idx=1, num_workers=1)
val_dl = DataLoader(val_ds, batch_size=bs, sampler=val_samp, transpose=True, transpose_y=True, 
                   pre_pad=False, pad_idx=1, num_workers=1)
md = ModelData(TMP, trn_dl, val_dl)

In [18]:
len(fr_itos)

24793

## Create Seq2SeqRNN

In [19]:
def create_embedding(itos, ft_vecs):
    vs = len(itos)
    emb_sz = list(en_vecs.values())[0].shape[0]
    miss = []
    emb = nn.Embedding(vs, emb_sz)
    wgt = emb.weight.data
    for i, w in enumerate(itos):
        try:
            wgt[i] = torch.from_numpy(ft_vecs[w])
        except KeyError:
            miss.append(w)
    print(len(miss))
    return emb

In [26]:
class Seq2SeqRNN(nn.Module):
    def __init__(self, en_itos, fr_itos, enc_hidden, dec_hidden, en_cutoff, nl=2):
        super(Seq2SeqRNN, self).__init__()
        en_vs = len(en_itos)
        fr_vs = len(fr_itos)
        emb_sz = list(en_vecs.values())[0].shape[0]
        self.enc_emb = create_embedding(fr_itos, fr_vecs)
        self.enc_drop = nn.Dropout(0.15)
        self.enc_gru = nn.GRU(emb_sz, enc_hidden, num_layers=nl, dropout=0.25)
        
        self.dec_emb = create_embedding(en_itos, en_vecs)
        self.dec_drop = nn.Dropout(0.15)
        self.dec_gru = nn.GRU(emb_sz, dec_hidden, num_layers=nl, dropout=0.25)
        self.out = nn.Linear(dec_hidden, fr_vs)
        
        self.enc_hidden = enc_hidden
        self.dec_hidden = dec_hidden
        self.en_cutoff = en_cutoff
        
    def forward(self, x):
        seq_len = x.size()[0]
        bs = x.size()[1]
        x = self.enc_emb(x)
        x = self.enc_drop(x)
        enc_out, enc_h = self.enc_gru(x)
        pdb.set_trace()
        
        dec_h = to_gpu(torch.zeros(seq_len, bs, self.dec_hidden)) # Check how Jeremy deals with this
        for i in range(len(en_cutoff)):
            dec_inp = enc_out[i]
            
        return x

In [None]:
nh = 256
s2s = to_gpu(Seq2SeqRNN(en_itos, fr_itos, nh, nh, en_cutoff))
opt_fn = optim.Adam(s2s.parameters(), lr=0.001)

In [30]:
fit(s2s, md, 1, opt=opt_fn, crit=F.cross_entropy)

HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))

  0%|          | 0/8042 [00:00<?, ?it/s]> <ipython-input-26-2cee549ffcf2>(28)forward()
-> dec_h = to_gpu(torch.zeros(seq_len, bs, self.dec_hidden)) # Check how Jeremy deals with this
(Pdb) l
 23  	        x = self.enc_emb(x)
 24  	        x = self.enc_drop(x)
 25  	        enc_out, enc_h = self.enc_gru(x)
 26  	        pdb.set_trace()
 27  	
 28  ->	        dec_h = to_gpu(torch.zeros(seq_len, bs, self.dec_hidden)) # Check how Jeremy deals with this
 29  	        for i in range(len(en_cutoff)):
 30  	            dec_inp = enc_out[i]
 31  	
 32  	        return x
[EOF]

(Pdb) n
> <ipython-input-26-2cee549ffcf2>(29)forward()
-> for i in range(len(en_cutoff)):
(Pdb) enc_out.size()
torch.Size([31, 5, 256])
(Pdb) enc_h.size()
torch.Size([2, 5, 256])
(Pdb) q


BdbQuit: 