# Neural text generation

In [1]:
from seq2seq import *

In [2]:
path = Config().data_path()/'giga-fren'/'giga-fren'
data = load_data(path)
model_path = Config().model_path()
emb_enc = torch.load(model_path/'fr_emb.pth')
emb_dec = torch.load(model_path/'en_emb.pth')

In [3]:
class Seq2SeqRNN_attn(nn.Module):
    def __init__(self, emb_enc, emb_dec, nh, out_sl, nl=2, bos_idx=0, pad_idx=1):
        super().__init__()
        self.nl,self.nh,self.out_sl,self.pr_force = nl,nh,out_sl,1
        self.bos_idx,self.pad_idx = bos_idx,pad_idx
        self.emb_enc,self.emb_dec = emb_enc,emb_dec
        self.emb_sz_enc,self.emb_sz_dec = emb_enc.embedding_dim,emb_dec.embedding_dim
        self.voc_sz_dec = emb_dec.num_embeddings
                 
        self.emb_enc_drop = nn.Dropout(0.15)
        self.gru_enc = nn.GRU(self.emb_sz_enc, nh, num_layers=nl, dropout=0.25, 
                              batch_first=True, bidirectional=True)
        self.out_enc = nn.Linear(2*nh, self.emb_sz_dec, bias=False)
        self.gru_dec = nn.GRU(self.emb_sz_dec + 2*nh, self.emb_sz_dec, num_layers=nl,
                              dropout=0.1, batch_first=True)
        self.out_drop = nn.Dropout(0.35)
        self.out = nn.Linear(self.emb_sz_dec, self.voc_sz_dec)
        self.out.weight.data = self.emb_dec.weight.data
        
        self.enc_att = nn.Linear(2*nh, self.emb_sz_dec, bias=False)
        self.hid_att = nn.Linear(self.emb_sz_dec, self.emb_sz_dec)
        self.V =  self.init_param(self.emb_sz_dec)
        
    def encoder(self, bs, inp):
        h = self.initHidden(bs)
        emb = self.emb_enc_drop(self.emb_enc(inp))
        enc_out, hid = self.gru_enc(emb, 2*h)
        pre_hid = hid.view(2, self.nl, bs, self.nh).permute(1,2,0,3).contiguous()
        pre_hid = pre_hid.view(self.nl, bs, 2*self.nh)
        hid = self.out_enc(pre_hid)
        return hid,enc_out
    
    def decoder(self, dec_inp, hid, enc_att, enc_out):
        hid_att = self.hid_att(hid[-1])
        u = torch.tanh(enc_att + hid_att[:,None])
        attn_wgts = F.softmax(u @ self.V, 1)
        ctx = (attn_wgts[...,None] * enc_out).sum(1)
        emb = self.emb_dec(dec_inp)
        outp, hid = self.gru_dec(torch.cat([emb, ctx], 1)[:,None], hid)
        outp = self.out(self.out_drop(outp[:,0]))
        return hid, outp
        
    def forward(self, inp, targ=None):
        bs, sl = inp.size()
        hid,enc_out = self.encoder(bs, inp)
        dec_inp = inp.new_zeros(bs).long() + self.bos_idx
        enc_att = self.enc_att(enc_out)
        
        res = []
        for i in range(self.out_sl):
            hid, outp = self.decoder(dec_inp, hid, enc_att, enc_out)
            res.append(outp)
            dec_inp = outp.max(1)[1]
            if (dec_inp==self.pad_idx).all(): break
            if (targ is not None) and (random.random()<self.pr_force):
                if i>=targ.shape[1]: continue
                dec_inp = targ[:,i]
        return torch.stack(res, dim=1)

    def initHidden(self, bs): return one_param(self).new_zeros(2*self.nl, bs, self.nh)
    def init_param(self, *sz): return nn.Parameter(torch.randn(sz)/math.sqrt(sz[0]))

In [4]:
model = Seq2SeqRNN_attn(emb_enc, emb_dec, 256, 30)
learn = Learner(data, model, loss_func=seq2seq_loss, metrics=seq2seq_acc,
                callback_fns=partial(TeacherForcing, end_epoch=30))

In [5]:
learn.fit_one_cycle(5, 3e-3)

epoch,train_loss,valid_loss,seq2seq_acc,time
0,2.353132,3.401374,0.591976,01:13
1,1.92795,3.190605,0.580701,01:07
2,1.717998,3.750251,0.509902,01:05
3,1.527029,3.694459,0.50927,01:06
4,1.32362,3.821543,0.49899,01:06


In [36]:
#learn.save('52')

In [5]:
learn.load('52');

In [6]:
def preds_acts(learn, ds_type=DatasetType.Valid):
    "Same as `get_predictions` but also returns non-reconstructed activations"
    learn.model.eval()
    ds = learn.data.train_ds
    rxs,rys,rzs,xs,ys,zs = [],[],[],[],[],[] # 'r' == 'reconstructed'
    with torch.no_grad():
        for xb,yb in progress_bar(learn.dl(ds_type)):
            out = learn.model(xb)
            for x,y,z in zip(xb.cpu(),yb.cpu(),out.cpu()):
                rxs.append(ds.x.reconstruct(x))
                rys.append(ds.y.reconstruct(y))
                preds = z.argmax(1)
                rzs.append(ds.y.reconstruct(preds))
                for a,b in zip([xs,ys,zs],[x,y,z]): a.append(b)
    return rxs,rys,rzs,xs,ys,zs

In [7]:
learn.model=learn.model.cuda()
rxs,rys,rzs,xs,ys,zs = preds_acts(learn)

In [43]:
idx=709
rx,ry,rz = rxs[idx],rys[idx],rzs[idx]
x,y,z = xs[idx],ys[idx],zs[idx]
rx,ry,rz

(Text [   2   24   20   12  266   15   12 1584  726   10  220   26  274   12 1984   16  800  186   17   13   87   21   14
   130    9],
 Text [   2   11   16   10  578   14   10  283   20    0   10 1038   12   84  825   13    0  108    9],
 Text [  2  11  16  10 283  14 578  13  10 279  12  84 825  51  87  20  87   9])

In [44]:
def weird_print(*strs): print('-------',*strs,'-------',sep='\n--')

In [45]:
weird_print(rx,ry,rz)

-------
--xxbos quels sont les critères et les démarches susceptibles de mettre en valeur les contributions des chercheurs canadiens à la recherche pour le développement ?
--xxbos what are the approaches and the criteria for xxunk the contributions of canadian researchers to xxunk development ?
--xxbos what are the criteria and approaches to the value of canadian researchers from research for research ?
---------


In [46]:
def select_topk(outp, k=5):
    probs = F.softmax(outp,dim=-1)
    vals,idxs = probs.topk(k, dim=-1)
    return idxs[torch.randint(k, (1,))]

From [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751).

In [47]:
from random import choice

def select_nucleus(outp, p=0.5):
    probs = F.softmax(outp,dim=-1)
    idxs = torch.argsort(probs, descending=True)
    res,cumsum = [],0.
    for idx in idxs:
        res.append(idx)
        cumsum += probs[idx]
        if cumsum>p: return idxs.new_tensor([choice(res)])

In [88]:
def decode(self, inp,p=0.5,nucleus=True):
    inp = inp[None]
    bs, sl = inp.size()
    hid,enc_out = self.encoder(bs, inp)
    dec_inp = inp.new_zeros(bs).long() + self.bos_idx
    enc_att = self.enc_att(enc_out)

    res = []
    for i in range(self.out_sl):
        hid, outp = self.decoder(dec_inp, hid, enc_att, enc_out)
        dec_inp = select_nucleus(outp[0], p=p) if nucleus else select_topk(outp[0], k=int(20*p)+1)
#         dec_inp = select_topk(outp[0], k=2)
        res.append(dec_inp)
        if (dec_inp==self.pad_idx).all(): break
    return torch.cat(res)

In [89]:
def predict_with_decode(learn, x, y,p=0.5,nucleus=True):
    learn.model.eval()
    ds = learn.data.train_ds
    with torch.no_grad():
        out = decode(learn.model, x, p,nucleus)
        rx = ds.x.reconstruct(x)
        ry = ds.y.reconstruct(y)
        rz = ds.y.reconstruct(out)
    return rx,ry,rz

In [90]:
learn.model=learn.model.cpu()

In [91]:
ps=[0.00001,0.1,0.3,0.5,0.99] #for nuclues is what to sum to, for topk is p*20

In [94]:
for p in ps: 
    rx,ry,rz = predict_with_decode(learn, x, y,p=p)
    rx2,ry2,rz2 = predict_with_decode(learn, x, y, nucleus=False,p=p)
    weird_print(rz)
    weird_print(rz2)

-------
--xxbos what are the criteria and approaches to the value of canadian researchers from research for research ?
---------
-------
--xxbos what are the criteria and approaches to the value of canadian researchers from research for research ?
---------
-------
--xxbos what are the criteria and approaches to the future of researchers from researchers to research development ?
---------
-------
--are are xxbos what processes , and processes can make to the of the of aboriginal communities from the research to the future development ?
---------
-------
--xxbos what are the criteria and indicators for the value of canadian researchers from research for the development ?
---------
-------
--, when are these costs associated by the research of canada research to developing developing development development ?
---------
-------
--xxbos what are the criteria and factors that to measure our activity of the researchers for research ?
---------
-------
--and what were the results ? to the de

Notice how we lose 'xxbos' at the beginning of the sentence very early for topk, but never lose it for nucleus. 