In [1]:
import torch
import nltk
from functools import reduce
from torch.utils.data import Dataset
from model import GRULanguageModel
from dataset import GRULanguageModelDataset, preprocess
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch import nn

In [2]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to /home/jkfirst/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [3]:
text = 'she sells sea shells by the sea shore'

In [4]:
dataset = GRULanguageModelDataset(text)
for d in dataset:
    print(d)
    break

tensor([ 1,  4,  5,  6,  7,  8,  9,  6, 10,  2])


In [5]:
dataset.vocab

{'<pad>': 0,
 '<s>': 1,
 '</s>': 2,
 '<unk>': 3,
 'she': 4,
 'sells': 5,
 'sea': 6,
 'shells': 7,
 'by': 8,
 'the': 9,
 'shore': 10}

In [6]:
def collate_fn(batch):
    batch = pad_sequence(batch, batch_first=True)
    return batch

In [7]:
dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=16)

In [8]:
for d in dataloader:
    print(d)
    print(d.shape)
    break

tensor([[ 1,  4,  5,  6,  7,  8,  9,  6, 10,  2]])
torch.Size([1, 10])


#### GRULanguageModel를 이용한 문장 생성하기

In [9]:
import pickle
from generate import generate_sentence_from_bos

In [10]:
# define dataset and dataloader
vocab = pickle.load(open('vocab.pickle', 'rb'))

In [11]:
# define and load model
hidden_size = 30
output_size = len(vocab)
model = GRULanguageModel(hidden_size=hidden_size, output_size=output_size)
model.load_state_dict(torch.load('gru_model.bin'))
model.eval()

GRULanguageModel(
  (embedding): Embedding(21, 30)
  (gru): GRU(30, 30, batch_first=True)
  (softmax): LogSoftmax(dim=-1)
  (out): Linear(in_features=30, out_features=21, bias=True)
)

In [12]:
generated_text = generate_sentence_from_bos(model, vocab, bos=1)

In [13]:
generated_text

"<s> for if she sells sea shells by the sea shore then i'm sure she sells sea shore shells."