In [1]:
import torch
from dataset import TextDataset

In [156]:
def one_hot(batch, vocab_size):

    X = torch.zeros(batch.shape[0], batch.shape[1], vocab_size, device=batch.device)
    X.scatter_(2, batch[:,:,None], 1)

    return X

def sample_output(output, temperature=1.0):
    # helper function to sample an index from a probability array

    output = torch.softmax(output, dim=1) / temperature
    sample = torch.multinomial(output, 1)
    
    return sample

def sample_model(model, vocab_size, device=torch.device('cpu'), temp=1.0, hidden_states=None, n=30, prev_char=None):

    if n == 0:
        return prev_char

    # initialize sequence with random character
    if prev_char is None:
        # random character seed
        prev_char = torch.empty(1, 1).random_(0, vocab_size - 1).type(torch.long)
        # convert to one-hot
        prev_char = one_hot(prev_char, vocab_size).to(device)

    output, hidden_states = model(prev_char, hidden_states)
    output = output[:, -1, :] # last prediction

    # Sample next character from softmax
    next_char = sample_output(output, temperature=temp)  # [B,1]
    next_char = one_hot(next_char, vocab_size)           # [B,1,D]

    # concat the recursive predictions to currect character
    future = sample_model(model, vocab_size, temp=temp, hidden_states=hidden_states, n=n-1, prev_char=next_char)
    encoded_text = torch.cat([next_char, future], dim=1)

    return encoded_text

def string_from_one_hot(sequence, dataset, unk=0):
    char_idxs = sequence.argmax(dim=2).squeeze_(0).cpu().numpy()
    return dataset.convert_to_string([char if char < dataset.vocab_size else unk for char in char_idxs])

In [179]:
def sample_for_epoch(epoch, dataset, n=5, temp=1.0, unk=0):
    
    model = torch.load('grimm/grimm_epoch_{}.pt'.format(epoch),map_location=lambda storage, loc: storage)

    print()
    print("Unkown char:", dataset.convert_to_string([unk]))
    
    samples = []
    for i in range(n):
        chars = sample_model(model, model.out_projection.out_features, temp=temp, n=300)
        sample = string_from_one_hot(chars, dataset)
        
        print("----------")
        print(sample)

In [180]:
dataset = TextDataset('assets/book_EN_grimms_fairy_tails.txt', 30)

for ep in range(0, 21, 5):
    if ep == 0:
        ep += 1
    print()
    print()
    print("Samples for epoch {}".format(ep))
    print()
    sample_for_epoch(ep,dataset, n=5, unk=43, temp=0.00001)

Initialize dataset with 540241 characters, 87 unique.


Samples for epoch 1


Unkown char: Q
----------
arried the door
and stook a choose that morning. He heard them with its for a large prisonor journey, and tied on their dear asks for her;
but they certains is with theirn again, but the tempt of all.

The princesses were sinten he agree door, and she went with
her. 
‘’They doord not agrow there, frr
----------
ND TAR HIP  Ther Sicks there were
dear son their
new griefnes.
‘“ Soon after ask: and when
he he
spring again again. So the king came and one peas me some of the door, and
they weak to
do see now;
for I dear for shee in his peace.

‘’I am joytiged twop, found the daumfs again.
‘“ The king take his bb
----------
ther
ever seeit door. At last, asks for the girl path again, and as he came to let up, behine one of their door alron with asks for the
hes.
‘“ The king said: 
‘’Ild and the beautiful
princesss for the task-door, and had eatour on the
foucarrared their wing until it all

In [161]:
model = torch.load('grimm/grimm_epoch_{}.pt'.format(1),map_location=lambda storage, loc: storage)
model

TextGenerationModel(
  (lstm): LSTM(91, 128, num_layers=2, batch_first=True)
  (out_projection): Linear(in_features=128, out_features=91, bias=True)
)

In [113]:
tens = torch.empty(4).random_(4)
soft1 = torch.softmax(tens, dim=0)
soft2 = torch.softmax(soft1, dim=0)
soft3 = torch.softmax(soft2, dim=0)
print(soft1, soft2, soft3)

tensor([0.3995, 0.0541, 0.1470, 0.3995]) tensor([0.2870, 0.2031, 0.2229, 0.2870]) tensor([0.2592, 0.2384, 0.2432, 0.2592])
