In [59]:
from fastai.text.learner import *
from fastai.basic_data import DataBunch
from torch.utils.data import Dataset, DataLoader
import torch
from torch import nn
import numpy as np
import pickle
from collections import Counter
import math
%matplotlib inline

In [60]:
#tokenize the dataset
#export
#special tokens
UNK, PAD, BOS, EOS, TK_REP, TK_WREP, TK_UP, TK_MAJ = "xxunk xxpad xxbos xxeos xxrep xxwrep xxup xxmaj".split()

def sub_n(t):
    "Replaces the \n by space"
    re_br = re.compile(r'\n', re.IGNORECASE)
    return re_br.sub(" ", t)

def spec_add_spaces(t):
    "Add spaces around / # , . ; :"
    return re.sub(r'([/#,.;:])', r' \1 ', t)

def rm_useless_spaces(t):
    "Remove multiple spaces"
    return re.sub(' {2,}', ' ', t)

def replace_rep(t):
    "Replace repetitions at the character level: cccc -> TK_REP 4 c"
    def _replace_rep(m:str):
        c,cc = m.groups()
        return f' {TK_REP} {len(cc)+1} {c} '
    re_rep = re.compile(r'(\S)(\1{3,})')
    return re_rep.sub(_replace_rep, t)

def replace_section_number(t):
    "Replace section numbers by NEWSECTION"
    return re.sub(r'(\d+:\d+)', 'NEWSECTION', t)

def sep_special(t):
    return re.sub(r'[\.,]', '  ', t)
    
def fixup_text(x):
    "Various messy things we've seen in documents"
    re1 = re.compile(r'  +')
    x = x.replace('#39;', "'").replace('amp;', '&').replace('#146;', "'").replace(
        'nbsp;', ' ').replace('#36;', '$').replace('\\n', "\n").replace('quot;', "'").replace(
        '<br />', "\n").replace('\\"', '"').replace('<unk>',UNK).replace(' @.@ ','.').replace(
        ' @-@ ','-').replace('\\', ' \\ ')
    return re1.sub(' ', html.unescape(x))
    
default_pre_rules = [fixup_text,  replace_section_number, spec_add_spaces, replace_rep
                     , sub_n]
default_spec_tok = [UNK, PAD, BOS, EOS, TK_REP, TK_WREP, TK_UP, TK_MAJ]


In [61]:
#process the data
def read_and_tokenize():
    data = open('input_bible.txt', 'r').read()
    data = data.lower()
    for rule in default_pre_rules:
        data = rule(data)
    data = rm_useless_spaces(data)
    return data

df = read_and_tokenize().split()
vocab = Counter(df)
vocab = [k for k in vocab.keys()]
word_to_ix  = {wr:i for i,wr in enumerate(vocab)}
ix_to_word = {i:wr for i,wr in enumerate(vocab)} 

In [79]:
#create a data loader
from torch.utils.data import Dataset, DataLoader
#create a custom data dataset / dataloader
class bible_dataset(Dataset):

    def __init__(self, seq_len):
        self.df = read_and_tokenize().split()
        self.vocab = Counter(self.df)
        self.vocab = [k for k in self.vocab.keys()]
        self.word_to_ix  = {wr:i for i,wr in enumerate(self.vocab)}
        self.ix_to_word = {i:wr for i,wr in enumerate(self.vocab)}        
        self.seq_len = seq_len
        
    def __len__(self):
        return len(self.vocab)

    def __getitem__(self, idx):
        return (np.array([self.word_to_ix[wrd] for wrd in self.df[idx:idx+self.seq_len]]),
                np.array([self.word_to_ix[wrd] for wrd in self.df[idx+1:idx+self.seq_len+1]]))
seq_len=10
bible_df = DataLoader(bible_dataset(seq_len), 1)   

In [80]:
#one hot encode over the whole vocab
def one_hot_encode(sequence, dict_size, seq_len, batch_size):
    # Creating a multi-dimensional array of zeros with the desired output shape
    features = np.zeros((batch_size, seq_len, dict_size), dtype=np.float32)
    
    # Replacing the 0 at the relevant character index with a 1 to represent that character
    for i in range(batch_size):
        for u in range(seq_len):
            features[i, u, sequence[i][u]] = 1
    return features


In [81]:
from torch.optim.lr_scheduler import _LRScheduler
class CyclicLR(_LRScheduler):
    def __init__(self, optimizer, schedule, last_epoch=-1):
        assert callable(schedule)
        self.schedule = schedule
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        return [self.schedule(self.last_epoch, lr) for lr in self.base_lrs]
    
def cosine(t_max, eta_min=0):  
    def scheduler(epoch, base_lr):
        t = epoch % t_max
        return eta_min + (base_lr - eta_min)*(1 + math.cos(math.pi*t/t_max))/2
    
    return scheduler


In [82]:
from fastai.text import *
awd_lstm_clas_config = dict(emb_sz=400, n_hid=1150, n_layers=3, pad_token=1, qrnn=False, bidir=False, output_p=0.4,
                       hidden_p=0.2, input_p=0.6, embed_p=0.1, weight_p=0.5, tie_weights=None, out_bias=1)
awd_lstm_model = get_language_model(arch=AWD_LSTM,config=awd_lstm_clas_config, vocab_sz=len(bible_df))
old_wgts  = torch.load('models/pretrained/lstm_wt103.pth', map_location='cpu')
old_vocab = pickle.load(open('models/pretrained/itos_wt103.pkl', 'rb'))
print(len(old_vocab))

60002


In [83]:
#match vocab overlap 
def match_embeds(old_wgts, old_vocab, new_vocab):
    wgts = old_wgts['0.encoder.weight']
    bias = old_wgts['1.decoder.bias']
    wgts_m,bias_m = wgts.mean(dim=0),bias.mean()
    new_wgts = wgts.new_zeros(len(word_to_ix), wgts.size(1))
    new_bias = bias.new_zeros(len(word_to_ix))
    otoi = {v:k for k,v in enumerate(old_vocab)}
    for i,w in enumerate(word_to_ix): 
        if w in otoi:
            idx = otoi[w]
            new_wgts[i],new_bias[i] = wgts[idx],bias[idx]
        else: new_wgts[i],new_bias[i] = wgts_m,bias_m
    old_wgts['0.encoder.weight']    = new_wgts
    old_wgts['0.encoder_dp.emb.weight'] = new_wgts
    old_wgts['1.decoder.weight']    = new_wgts
    old_wgts['1.decoder.bias']      = new_bias
    return old_wgts
wgts = match_embeds(old_wgts, old_vocab, word_to_ix)

In [85]:
awd_lstm_model.load_state_dict(wgts)

<All keys matched successfully>

In [87]:
#load the AWD LSTM model with the trained weights
def load_model():
    awd_lstm_lm_config = dict(emb_sz=400, n_hid=1150, n_layers=3, pad_token=1, qrnn=False, bidir=False, output_p=0.1,
                       hidden_p=0.15, input_p=0.25, embed_p=0.02, weight_p=0.2, tie_weights=True, out_bias=True)
    awd_lstm_model = get_language_model(arch=AWD_LSTM,config=awd_lstm_lm_config, vocab_sz=len(bible_df))
    wgts  = torch.load('pretrained_unfrozen.pth', map_location='cpu')
    awd_lstm_model.load_state_dict(wgts)
    return awd_lstm_model
awd_lstm_model = load_model()

In [213]:
max_length = 35

# Sample from a category and starting word
def sample(start_word='NEWSECTION'):
    awd_lstm_model.reset()
    with torch.no_grad():  # no need to track history in sampling
        input = torch.from_numpy(np.array([[word_to_ix[start_word]]]))
        #make the newsection word realistic
        if start_word == 'NEWSECTION':
            start_word = str(random.choice(range(0,20))) + ':' + str(random.choice(range(0,100)))
        output_name = start_word + ' '

        for i in range(max_length):
            output, hidden1,hidden2 = awd_lstm_model(input.long())
            topv, topi = output.topk(1)
            topi = topi[0][0]
            if topi == len(vocab) - 1:
                break
            else:
                next_word = vocab[topi]
                output_name += next_word + ' '
            input = torch.from_numpy(np.array([[word_to_ix[next_word]]]))
        return output_name     

In [214]:
sample('jesus')

'jesus , who was a member of the two . the first . the new head of the new south , the new of the new league . the new new new life , which was '

In [None]:
pre_words = ['NEWSECTION', 'the']

def sample_ext():

def sentence_tensor():

    tensor = torch.zeros(len(pre_words))
    for p_word in range(len(pre_words)):
        tensor[p_word] = word_to_ix[pre_words[p_word]]

In [212]:
# torch.from_numpy(np.array([[word_to_ix['NEWSECTION']]]))
[[word_to_ix['NEWSECTION']]]
pre_words = ['NEWSECTION', 'the']
tensor = torch.zeros(len(pre_words))
for p_word in range(len(pre_words)):
    tensor[p_word] = word_to_ix[pre_words[p_word]]
tensor    

tensor([0., 2.])

In [206]:
tensor = torch.zeros(len(pre_words))
for p_word in range(len(pre_words)):
#     word = pre_words[p_word]
    tensor[p_word] = word_to_ix[pre_words[p_word]]

KeyError: 0

In [210]:
tensor = torch.zeros(len(pre_words))
# tensor[word_to_ix[p_word]] = [word_to_ix[p_word]]
pre_words[p_word], word_to_ix[pre_words[p_word]]

('NEWSECTION', 0)

In [173]:
import random
print(str(random.choice(range(0,20))) + ':' + str(random.choice(range(0,100))))

13:48
