In [1]:
from torch import nn
from lib.utils import *
from lib.model import *
import torch
import numpy as np
DTYPE = torch.double
class LSTMTextGeneration(Module):
    def __init__(self, input_dim, output_dim, hid_dim):
        super().__init__()  
        self.rnn = nn.LSTM(input_size=input_dim, hidden_size=hid_dim)
        self.dense = nn.Linear(hid_dim, output_dim)
    def forward(self, x):
        
        #src = [batch size, input len, input dim]
        y = self.rnn(x)[0][:,-1,:]
        y = self.dense(y)
        # output = nn.Softmax(dim=-1)(y)
        output = y

        return output
        
import torchtext
d = torchtext.datasets.WikiText2()
def filt(string):
    if string.strip() == '':
        return False
    filt = ['=', '<', '>', '[']
    for s in string:
        if s in filt:
            return False
    return True


l = [string for string in list(iter(d[0]))+list(iter(d[1]))+list(iter(d[2])) if filt(string)]
# with open('quotes.txt', 'w', encoding='utf-8') as f:
#     f.write(''.join(l))
# text = open('quotes.txt', 'r', encoding='utf-8').read()
import string
with open('shakespeare.txt', encoding='utf-8') as f:
    text = f.read()


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# text = ''.join([c for c in text.lower() if c in string.printable])
chars = sorted(list(set(text)))
print('total chars:', len(chars))
char_int = {c: i for i, c in enumerate(chars)}
int_char = {i: c for i, c in enumerate(chars)}
chars = sorted(list(set(text)))
print('total chars:', len(chars))
char_indices = {c: i for i, c in enumerate(chars)}
indices_char = {i: c for i, c in enumerate(chars)}

maxlen = 40
step = 3
sentences = []
next_chars = []
for i in range(0, len(text) - maxlen, step):
    sentences.append(text[i: i + maxlen])
    next_chars.append(text[i + maxlen])
print(f'num sequences: {len(sentences)}')
print('corpus length:', len(text))

total chars: 65
total chars: 65
num sequences: 371785
corpus length: 1115394


In [3]:
x = np.zeros((len(sentences), maxlen, len(chars)))
y = np.zeros((len(sentences), len(chars)))
for i, sentence in enumerate(sentences):
    for t, char in enumerate(sentence):
        x[i, t, char_indices[char]] = 1
    y[i, char_indices[next_chars[i]]] = 1

In [4]:
model = LSTMTextGeneration(input_dim=65,hid_dim=128, output_dim=len(chars),).double().to('cuda:3')
model.count_parameters()
dataset = Dataset(x,y , dtype=DTYPE, device='cuda:3')
# dataset.Y =dataset.Y.softmax(dim=1)
train_data = torch.utils.data.DataLoader(dataset, batch_size=128,drop_last=True)
test_data = torch.utils.data.DataLoader(dataset[:500], batch_size=128,drop_last=True)

The model has 108,225 trainable parameters


In [5]:
model.load_state_dict(torch.load('saved_model/text_generation/saved_model.pt'))

<All keys matched successfully>

In [6]:
# train_model(name='text_generation',model=model,train_data=train_data, test_data=test_data, criterion=nn.CrossEntropyLoss())

In [14]:
def sample(preds, temperature=1.0):
    """
    Helper function to sample an index from a probability array
    
    Parameters:
        preds: numpy array of predicted probabilities
        temperature: controls the diversity when picking from preds
    """
    preds = np.asarray(preds).astype('float64')
    preds = np.log(preds) / temperature
    exp_preds = np.exp(preds)
    preds = exp_preds / np.sum(exp_preds)
    probas = np.random.multinomial(1, preds, 1)
    return np.argmax(probas)
import sys
from scipy.special import softmax
def generate_text(diversity=0.2, sentence=None, start_index=None):
    """
    Generate text using trained model
    
    Parameters:
        diversity: controls randomness of texts, higher = more variety
        sentence: starting sentence as seed
        start_index: starting index in text as seed
    """
    print(f'----- diversity: {diversity}')

    generated = ''
    if start_index is None:
        start_index = np.random.randint(0, len(text)-maxlen)
    if sentence is None:
        sentence = text[start_index: start_index + maxlen]
    else:
        assert len(sentence) > maxlen, f'Need at least {maxlen} characters to start'
        sentence = sentence[:maxlen]
    generated += sentence
    print(f'----- Generating with seed: \n  "{sentence}" \n')
    sys.stdout.write(generated)

    for i in range(400):
        x_pred = np.zeros((1, maxlen, len(chars)))
        for t, char in enumerate(sentence):
            x_pred[0, t, char_indices[char]] = 1.

        preds = model(torch.tensor(x_pred, dtype=DTYPE, device='cuda:3'))[0].detach().cpu().numpy()
        preds = softmax(preds)
        next_index = sample(preds, diversity)

        # next_index = np.argmax(preds)

        next_char = indices_char[next_index]

        sentence = sentence[1:] + next_char

        sys.stdout.write(next_char)
        sys.stdout.flush()
    print()
    return sentence

In [16]:
start_index = 321
generate_text(diversity=1, sentence='This is not a good one to learn arbr ary.')

----- diversity: 1
----- Generating with seed: 
  "This is not a good one to learn arbr ary" 

This is not a good one to learn arbr ary.
Fithe omenfors, mesir d mppar,

P n.--
Fhad ow , gombu ime, thoupool akiverngowldlle Pe,


Whe LIatinchet t gin cku hunlore t ha ime whyo,
Fos angerd pin cemy'ed at ofo t s?
LOSAShext, chy ln's hendent, sur athe, athr d, t at yot'sools akes'BAS oversped hery,
Hor, akintrth mblor,

ANThinthedyonorvivo wowher, wll ng,
SATo orge ar:
ONThis y shinf hurdouwarevititreechod
ESSPatckequpe mupod ONLONTan


'ititreechod\nESSPatckequpe mupod ONLONTan'

In [None]:
preds