In [1]:
import argparse
import math
import os
import random
import time

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.autograd import Variable
import pandas as pd

from dataset import TimeMachineData, TimeMachineDataset
from model import GRU, RNN, LSTM

In [2]:
def batchify(data, bsz):
    # Divide the dataset into bsz parts.
    nbatch = data.size(0) // bsz
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the bsz batches.
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)

In [3]:
def get_batch(source, i):
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i : i + seq_len]
    target = source[i + 1 : i + 1 + seq_len].reshape(-1)
    return data, target

In [4]:
# device = f"cuda:0" if torch.cuda.is_available() else "cpu"
device = "cpu"
temperature = 1.0
bptt = 35

In [5]:
timemachine = TimeMachineData()

In [6]:
ntokens = timemachine.vocab_size

In [7]:
args_model = "lstm"

In [8]:
if args_model == "rnn":
    model = RNN(ntokens).to(device)
elif args_model == "gru":
    model = GRU(ntokens).to(device)
elif args_model == "lstm":
    model = LSTM(ntokens).to(device)
else:
    raise ValueError("Invalid model argument: {}".format(args_model))
model.load_state_dict(torch.load(f"{args_model}_thetimemachine.pth"))
model = model.to(device)
model.eval()

LSTM(
  (embed): Embedding(94, 128)
  (dropout1): Dropout(p=0.5, inplace=False)
  (lstm): LSTM(128, 128, num_layers=2, dropout=0.3)
  (dropout2): Dropout(p=0.5, inplace=False)
  (linear): Linear(in_features=128, out_features=94, bias=True)
)

In [9]:
text_path="../data/TheTimeMachine/35-0.txt"
dict_path="../data/TheTimeMachine/char.csv"

In [10]:
word_df = pd.read_csv(dict_path)
char2idx = dict(zip(word_df["word"], word_df.index))
idx2char = dict(zip(word_df.index, word_df["word"]))

# multi char input

In [11]:
text = "A queer thing I soon"
text = [char2idx["<BOS>"]] + [char2idx[char] for char in text.strip()]
data = [text]
data = torch.from_numpy(np.array(sum(data, []))).long()
data = torch.unsqueeze(data, 0)

In [12]:
data

tensor([[ 0, 31,  3, 51, 15,  4,  4, 11,  3,  5, 12,  9,  8, 19,  3, 26,  3, 10,
          7,  7,  8]])

In [13]:
data = data.to(device)

In [14]:
data

tensor([[ 0, 31,  3, 51, 15,  4,  4, 11,  3,  5, 12,  9,  8, 19,  3, 26,  3, 10,
          7,  7,  8]])

In [15]:
hidden = model.init_hidden(data.shape[1])
if args_model == "lstm":
    cell = model.init_hidden(data.shape[1])
result = data.cpu().numpy().tolist()[0]
for i in range(100):
    with torch.no_grad():
        if args_model == "lstm":
            output, hidden, cell = model(data, hidden, cell)
        else:
            output, hidden = model(data, hidden)
        
        word_weights = output.squeeze().data.div(1.0).exp().cpu()
        word_idx = torch.multinomial(word_weights, 1).squeeze()
        
        result += word_idx.cpu().numpy().tolist()
        data = word_idx.reshape(1, -1)
        
        if 0 in word_idx.cpu().numpy().tolist():
            break

In [16]:
"".join([idx2char[x] for x in result])

'<BOS>A queer thing I soonPts(:8O_r”esitç z.lcd.rh;t”s h<EOS> islotu kt   o h ata tTlivh iœmdhtuIawbhltr lyeehb!aoern ieiedhaey  ras lnfedFv n  etf gmytimuk,n een<EOS>ats ewiy <EOS>sato ea  u<BOS>ze w—ag,a<BOS>iren'

# single char input

In [17]:
for j in range(10):
    text = [0]
    data = [text]
    data = torch.from_numpy(np.array(sum(data, []))).long()
    data = torch.unsqueeze(data, 0)
    data = data.to(device)
    
    hidden = model.init_hidden(1)
    if args_model == "lstm":
        cell = model.init_hidden(1)
    result = [0]
    for i in range(100):
        if args_model == "lstm":
            output, hidden, cell = model(data, hidden, cell)
        else:
            output, hidden = model(data, hidden)
        
        # word_idx = torch.argmax(output, dim=2).item()
        word_weights = output.squeeze().data.div(1.0).exp().cpu()
        word_idx = torch.multinomial(word_weights, 1)[0]
        
        data.data.fill_(word_idx)
        result.append(word_idx.item())
        if word_idx == 1:
            break
    print("".join([idx2char[x] for x in result]))

<BOS>xond ay ay Wime son ferer this a<EOS>
<BOS><EOS>
<BOS>qlimin was a lait-et the wreow in the<EOS>
<BOS>Xcone thauth to tituin,<EOS>
<BOS>çenlofedWal<EOS>
<BOS>flamkeadet clareudimly stust and sumile in speens ever thit had at<EOS>
<BOS>conm burdy vemmess, and down ald in the breled the qapuinny To lent at all hamk<EOS>
<BOS>-hopomot upot read me avibs in not I wharet that evind tremen I wand almy the Tinging, werled some I
<BOS>heo<EOS>
<BOS>bHemAcs Urfurdfised muce so, whe me<EOS>
