In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from fastai2.text.all import *
from fastai2.text.core import *
from fastai2.text.core import _join_texts
from fastai2.basics import *
from fastai.text.models.qrnn import QRNN, QRNNLayer
from util import *
import fasttext as ft

In [None]:
path = Config().data_path/'giga-fren'
path.ls()

# Create databunch

In [None]:
def tokenize_df(df, text_cols, n_workers=defaults.cpus, rules=None, mark_fields=None, out_col='text',
                tok_func=SpacyTokenizer, **tok_kwargs):
    "Tokenize texts in `df[text_cols]` in parallel using `n_workers`"
    text_cols = L(text_cols)
    #mark_fields defaults to False if there is one column of texts, True if there are multiple
    if mark_fields is None: mark_fields = len(text_cols)>1
    rules = L(ifnone(rules, defaults.text_proc_rules.copy()))
    texts = _join_texts(df[text_cols], mark_fields=mark_fields)
    outputs = L(parallel_tokenize(texts, tok_func, rules, n_workers=n_workers, **tok_kwargs)
               ).sorted().itemgot(1)

    other_cols = df.columns[~df.columns.isin(text_cols)]
    res = df[other_cols].copy()
    res[out_col] = outputs
    return res,Counter(outputs.concat())

In [None]:
df = pd.read_csv(path/'questions_easy.csv')
df.head()
df=df[:100]

In [None]:
df_tok,count    = tokenize_df(df,     "en", out_col="en")
df_tok,count_fr = tokenize_df(df_tok, "fr", out_col="fr")

In [None]:
df_tok

In [None]:
splits = RandomSplitter()(range_of(df_tok))
dsrc   = DataSource(df_tok,
                    splits=splits, tfms=[[attrgetter("en"), Numericalize(make_vocab(count))],
                                         [attrgetter("fr"), Numericalize(make_vocab(count_fr))]],
                    dl_type=SortedDL)

# TODO: change sortedDL to sortishDL
# TODO: create s2sdatabunch class
dbch   = dsrc.databunch(before_batch=lambda items: pad_input(items, pad_fields=[0,1]))

In [None]:
dbch.show_batch(max_n=2)

In [None]:
class Seq2SeqQRNN(nn.Module):
    def __init__(self, emb_enc, emb_dec, n_hid, max_len, n_layers=2, p_inp:float=0.15, p_enc:float=0.25, 
                 p_dec:float=0.1, p_out:float=0.35, p_hid:float=0.05, bos_idx:int=0, pad_idx:int=1):
        super().__init__()
        self.n_layers,self.n_hid,self.max_len,self.bos_idx,self.pad_idx = n_layers,n_hid,max_len,bos_idx,pad_idx
        self.emb_enc  = emb_enc
        self.emb_enc_drop = nn.Dropout(p_inp)
        self.encoder  = QRNN(emb_enc.weight.size(1), n_hid, n_layers=n_layers, dropout=p_enc)
        self.out_enc  = nn.Linear(n_hid, emb_enc.weight.size(1), bias=False)
        self.hid_dp   = nn.Dropout(p_hid)
        self.emb_dec  = emb_dec
        self.decoder  = QRNN(emb_dec.weight.size(1), emb_dec.weight.size(1), n_layers=n_layers, dropout=p_dec)
        self.out_drop = nn.Dropout(p_out)
        self.out      = nn.Linear(emb_dec.weight.size(1), emb_dec.weight.size(0))
        self.out.weight.data = self.emb_dec.weight.data
        
    def forward(self, inp):
        self.encoder.reset()
        self.decoder.reset()
        bs,sl = inp.size()
        hid   = self.initHidden(bs)
        emb   = self.emb_enc_drop(self.emb_enc(inp))
        enc_out, hid = self.encoder(emb, hid)
        hid   = self.out_enc(self.hid_dp(hid))

        dec_inp = inp.new_zeros(bs).long() + self.bos_idx
        outs = []
        for i in range(self.max_len):
            emb      = self.emb_dec(dec_inp).unsqueeze(1)
            out, hid = self.decoder(emb, hid)
            out      = self.out(self.out_drop(out[:,0]))
            dec_inp  = out.max(1)[1]
            outs.append(out)
            if (dec_inp==self.pad_idx).all(): break
        return torch.stack(outs, dim=1)
    
    def initHidden(self, bs): return one_param(self).new_zeros(self.n_layers, bs, self.n_hid)

In [None]:
# run once
# en_vecs = ft.load_model(str((path/'cc.en.300.bin')))
# emb_enc = create_emb(en_vecs, dsrc.vocab[0])
# del en_vecs
# torch.save(emb_enc, path/'models'/'en_enc_emb.pth')

In [None]:
# run once
# fr_vecs = ft.load_model(str((path/'cc.fr.300.bin')))
# emb_dec = create_emb(fr_vecs, dsrc.vocab[1])
# del fr_vecs
# torch.save(emb_dec, path/'models'/'fr_dec_emb.pth')

In [None]:
emb_enc = torch.load(path/'models'/'en_enc_emb.pth')
emb_dec = torch.load(path/'models'/'fr_dec_emb.pth')

In [None]:
model = Seq2SeqQRNN(emb_enc, emb_dec, 256, 30, n_layers=2)
# learn = Learner(dbch, model, loss_func=seq2seq_loss,  metrics=[seq2seq_acc, CorpusBLEU(len(dbch.vocab[1]))])
learn = Learner(dbch, model, loss_func=seq2seq_loss,  metrics=[seq2seq_acc])

In [None]:
learn.fit(10,1e-2)