In [1]:
from fastai.nlp import *
from fastai.lm_rnn import *
from fastai import sgdr
from torchtext import vocab, data

import pdb

In [30]:
class CharSeqStatefulLSTM(nn.Module):
    def __init__(self, vocab_size, n_fac, bs, nl):
        super().__init__()
        self.vocab_size,self.nl = vocab_size,nl
        self.e = nn.Embedding(vocab_size, n_fac)
        self.rnn = nn.LSTM(n_fac, n_hidden, nl, dropout=0.5)
        self.l_out = nn.Linear(n_hidden, vocab_size)
        self.init_hidden(bs)
        
    def forward(self, cs, **kwargs):
        bs = cs[0].size(0)
        if self.h[0].size(1) != bs: self.init_hidden(bs)
        self.rnn.flatten_parameters()
        self.h = (self.h[0].cpu(), self.h[1].cpu())
        ecs = self.e(cs)
        outp,h = self.rnn(ecs, self.h)
        #pdb.set_trace()
        #self.h = repackage_var(h)
        return F.log_softmax(self.l_out(outp), dim=-1).view(-1, self.vocab_size)
    
    def init_hidden(self, bs):
        self.h = (V(torch.zeros(self.nl, bs, n_hidden)),
                  V(torch.zeros(self.nl, bs, n_hidden)))

In [31]:
class CharSeqStatefulLSTM512(nn.Module):
    def __init__(self, vocab_size, n_fac, bs, nl):
        super().__init__()
        self.vocab_size,self.nl = vocab_size,nl
        self.e = nn.Embedding(vocab_size, n_fac)
        self.rnn = nn.LSTM(n_fac, n_hidden2, nl, dropout=0.5)
        self.l_out = nn.Linear(n_hidden2, vocab_size)
        self.init_hidden(bs)
        
    def forward(self, cs, **kwargs):
        bs = cs[0].size(0)
        if self.h[0].size(1) != bs: self.init_hidden(bs)
        self.rnn.flatten_parameters()
        self.h = (self.h[0].cpu(), self.h[1].cpu())
        ecs = self.e(cs)
        outp,h = self.rnn(ecs, self.h)
        #pdb.set_trace()
        #self.h = repackage_var(h)
        return F.log_softmax(self.l_out(outp), dim=-1).view(-1, self.vocab_size)
    
    def init_hidden(self, bs):
        self.h = (V(torch.zeros(self.nl, bs, n_hidden2)),
                  V(torch.zeros(self.nl, bs, n_hidden2)))

In [32]:
PATH='data/proverbs/'
PATH2='data/proverbs2/'
PATH3='data/proverbs3/'
TRN_PATH = 'train/'
VAL_PATH = 'valid/'
TRN = PATH + TRN_PATH
VAL = PATH + VAL_PATH
TRN2 = PATH2 + TRN_PATH
VAL2 = PATH2 + VAL_PATH
TRN3 = PATH3 + TRN_PATH
VAL3 = PATH3 + VAL_PATH

In [33]:
PATH, TRN, VAL

('data/proverbs/', 'data/proverbs/train/', 'data/proverbs/valid/')

In [34]:
TEXT = data.Field(lower=True, tokenize=list)
bs=64; bptt=8; n_fac=42; n_hidden=128

TEXT

<torchtext.data.field.Field at 0x1adeb1e22b0>

In [35]:
TEXT3 = data.Field(lower=True, tokenize=list)
bs=64; bptt=8; n_fac=42; n_hidden2=512

TEXT3

<torchtext.data.field.Field at 0x1adeb1e2668>

In [36]:
FILES = dict(train=TRN_PATH, validation=VAL_PATH, test=VAL_PATH)
md = LanguageModelData.from_text_files(PATH, TEXT, **FILES, bs=bs, bptt=bptt, min_freq=3)

In [37]:
md

<fastai.nlp.LanguageModelData at 0x1adeb5c2080>

In [38]:
m = CharSeqStatefulLSTM(md.nt, n_fac, 256, 2)


In [39]:
m.load_state_dict(torch.load(f'{PATH}models/gen_0_dict', map_location=lambda storage, loc: storage))


In [40]:
m = m.cpu()


In [41]:
m.eval()

CharSeqStatefulLSTM(
  (e): Embedding(37, 42)
  (rnn): LSTM(42, 128, num_layers=2, dropout=0.5)
  (l_out): Linear(in_features=128, out_features=37, bias=True)
)

In [42]:
FILES2 = dict(train=TRN_PATH, validation=VAL_PATH, test=VAL_PATH)
md2 = LanguageModelData.from_text_files(PATH2, TEXT, **FILES, bs=bs, bptt=bptt, min_freq=3)

m2 = CharSeqStatefulLSTM(md2.nt, n_fac, 256, 2)
m2.load_state_dict(torch.load(PATH2 + 'models/gen_1_dict', map_location=lambda storage, loc: storage))
m2.eval()

CharSeqStatefulLSTM(
  (e): Embedding(37, 42)
  (rnn): LSTM(42, 128, num_layers=2, dropout=0.5)
  (l_out): Linear(in_features=128, out_features=37, bias=True)
)

In [43]:
FILES3 = dict(train=TRN_PATH, validation=VAL_PATH, test=VAL_PATH)
md3 = LanguageModelData.from_text_files(PATH3, TEXT3, **FILES, bs=bs, bptt=bptt, min_freq=3)

m3 = CharSeqStatefulLSTM512(md3.nt, n_fac, 256, 2)
m3.load_state_dict(torch.load(PATH3 + 'models/gen_2_dict', map_location=lambda storage, loc: storage))
m3.eval()

CharSeqStatefulLSTM512(
  (e): Embedding(59, 42)
  (rnn): LSTM(42, 512, num_layers=2, dropout=0.5)
  (l_out): Linear(in_features=512, out_features=59, bias=True)
)

In [44]:
def get_next(inp, gen):
    new_TEXT = ''
    if gen == 1:
        sel_m = m2
        new_TEXT = TEXT
    elif gen == 2:
        sel_m = m3
        new_TEXT = TEXT3
    else: 
        sel_m = m
        new_TEXT = TEXT
    idxs = new_TEXT.numericalize(inp, device=-1)
    pid = idxs.transpose(0,1)
    pid = pid.cpu()
    vpid = VV(pid)
    vpid = vpid.cpu()
    p = sel_m(vpid)
    r = torch.multinomial(p[-1].exp(), 1)
    return new_TEXT.vocab.itos[to_np(r)[0]]

In [45]:
def get_next_n(inp, n, gen):
    res = inp
    for i in range(n):
        c = get_next(inp, gen)
        res += c
        inp = inp[1:]+c
        if c == '.': break
    return res

In [51]:
get_next_n('People ', 1000, 2)

'People only the consists.'