In [1]:
import torch
import sencoder as sen
import numpy as np
import os
import tokenizer as tok
from torch.distributions.normal import Normal
from torch.distributions.kl import kl_divergence

Using TensorFlow backend.


In [2]:
sample_strs = {"greetings":["how do you do kind man?", 
                            "how are you doing today?",
                            "what is going on?",
                            "good morning dear, It is a pleasure to see you",
                            "how is it going?"],
                "commands":["do not go into the forest.",
                            "you must cut down this tree.",
                            "build a house from the wood.",
                            "go into that cave."
                            ],
                "ramblings":["it is god's will, to be killed or to be lived.",
                            "the subject of good an evil cannot be contained.",
                            "the herd will not see the truths that stare them in the face.",                    
                            "death will come to those who wait.",                  
                            "life is tedious and brief."                           
                            ]
                }

tokens = {k:[tok.tokenize(s) for s in v] for k,v in sample_strs.items()}           

for k,v in tokens.items():
    print(k)
    for vv in v:
        print(vv)
    print()

greetings
['how', 'do', 'you', 'do', 'kind', 'man', '?']
['how', 'are', 'you', 'doing', 'today', '?']
['what', 'is', 'going', 'on', '?']
['good', 'morning', 'dear', ',', 'It', 'is', 'a', 'pleasure', 'to', 'see', 'you']
['how', 'is', 'it', 'going', '?']

commands
['do', 'not', 'go', 'into', 'the', 'forest', '.']
['you', 'must', 'cut', 'down', 'this', 'tree', '.']
['build', 'a', 'house', 'from', 'the', 'wood', '.']
['go', 'into', 'that', 'cave', '.']

ramblings
['it', 'is', 'god', "'", 's', 'will', ',', 'to', 'be', 'killed', 'or', 'to', 'be', 'lived', '.']
['the', 'subject', 'of', 'good', 'an', 'evil', 'cannot', 'be', 'contained', '.']
['the', 'herd', 'will', 'not', 'see', 'the', 'truths', 'that', 'stare', 'them', 'in', 'the', 'face', '.']
['death', 'will', 'come', 'to', 'those', 'who', 'wait', '.']
['life', 'is', 'tedious', 'and', 'brief', '.']



In [5]:
prepath = "../training_scripts/"
s = "attention/attention_0_lr0.005"
s = os.path.join(prepath,s)
model, chkpt = sen.io.load_model(s, ret_chkpt=True)
model.cuda()
model.eval()

SeqAutoencoder(
  (word_encoder): WordEncoder(
    (encoder): Linear(in_features=100, out_features=600, bias=True)
  )
  (encoder): Encoder(
    (rssm): RSSM(
      h_size=300, s_size=300, emb_size=100, min_sigma=0.0001
      (rnn): GRUCell(
        (gru): GRUCell(400, 300)
      )
      (state_layer): Linear(in_features=300, out_features=600, bias=True)
    )
    (attention): Linear(in_features=600, out_features=1, bias=True)
  )
  (decoder): Decoder(
    (rssm): RSSM(
      h_size=300, s_size=300, emb_size=400, min_sigma=0.0001
      (rnn): GRUCell(
        (gru): GRUCell(700, 300)
      )
      (state_layer): Linear(in_features=300, out_features=600, bias=True)
    )
  )
  (classifier): SimpleClassifier(
    (classifier): Sequential(
      (0): Linear(in_features=300, out_features=5760, bias=True)
      (1): ReLU()
      (2): Linear(in_features=5760, out_features=11521, bias=True)
    )
  )
)

In [6]:
word2idx = chkpt['word2idx']
idx2word = chkpt['idx2word']
wordidxs = {k:[[word2idx[w] for w in t] for t in v]\
                            for k,v in tokens.items()}
for k,v in wordidxs.items():
    print(k)
    for idxs in v:
        print(idxs)
    print()

greetings
[10188, 2334, 6166, 2334, 11028, 8296, 9917]
[10188, 2392, 6166, 9674, 7995, 9917]
[10310, 1144, 3098, 8355, 9917]
[7532, 3726, 4878, 10810, 4750, 1144, 676, 9795, 7640, 7112, 6166]
[10188, 1144, 5711, 3098, 9917]

commands
[2334, 9351, 1003, 1853, 1116, 3711, 9082]
[6166, 3123, 4804, 915, 8294, 4764, 9082]
[3308, 676, 4631, 8051, 1116, 5422, 9082]
[1003, 1853, 11406, 1273, 9082]

ramblings
[5711, 1144, 8704, 1938, 1413, 4595, 10810, 7640, 4017, 4758, 6814, 7640, 4017, 6908, 9082]
[1116, 5159, 5589, 7532, 4779, 10008, 6299, 4017, 3590, 9082]
[1116, 2122, 4595, 9351, 7112, 1116, 3068, 11406, 3524, 8345, 7280, 1116, 11514, 9082]
[4500, 4595, 10377, 7640, 390, 3421, 59, 9082]
[7542, 1144, 7531, 9894, 10814, 9082]



In [9]:
states = {}
for k,v in wordidxs.items():
    states[k] = []
    for idxs in v:
        words = [idx2word[idx] for idx in idxs]
        print("Real:", " ".join(words))
        idxs = torch.LongTensor(idxs).cuda()
        embs = model.embed(idxs)
        embs = embs.reshape(1,len(idxs),-1)
        enc_hs, enc_mus, enc_sigmas, enc_states = model.encode(embs)
        state = Normal(enc_mus[-1], enc_sigmas[-1])
        states[k].append(state)
        mus = enc_mus.reshape(-1,enc_mus.shape[-1])
        sigmas = enc_sigmas.reshape(-1,enc_sigmas.shape[-1])
        enc_preds = model.classify(mus, sigmas)
        argmaxes = torch.argmax(enc_preds, dim=-1).long()
        words = reversed([idx2word[arg.item()] for arg in argmaxes])
        print("Enc:", " ".join(words))
        
        state = enc_states[-1]
        h = (enc_hs[-1],enc_mus[:,-1],enc_sigmas[:,-1])
        hs,mus,sigmas = model.decode(state, h, seq_len=len(idxs),
                                classifier=model.classifier,
                                embeddings=model.embeddings)
        mus = torch.stack(mus, dim=1)
        mus = mus.reshape(-1,mus.shape[-1])
        sigmas = torch.stack(sigmas, dim=1)
        sigmas = sigmas.reshape(-1,sigmas.shape[-1])
        dec_preds = model.classify(mus, sigmas)
        argmaxes = torch.argmax(dec_preds, dim=-1).long()
        words = reversed([idx2word[arg.item()] for arg in argmaxes])
        print("Dec:", " ".join(words))
        print("\n")


Real: how do you do kind man ?
Enc: 
 for of 
 to not to
Dec: you but you do matter man ?


Real: how are you doing today ?
Enc: Or , , , the to
Dec: how with ask doing thereby ?


Real: what is going on ?
Enc: . the 
 the is
Dec: if is last on ?


Real: good morning dear , It is a pleasure to see you
Enc: 
 
 the in 
 perhaps is and 
 , 

Dec: called curiosity knowledge , it is a love to see you


Real: how is it going ?
Enc: - to 
 , to
Dec: how is it going ?


Real: do not go into the forest .
Enc: 
 and 
 churches that be not
Dec: do not go into the heart .


Real: you must cut down this tree .
Enc: 
 , 
 . , be are
Dec: must you here than this tree .


Real: build a house from the wood .
Enc: 
 to 
 them is 
 .
Dec: or a healthier on the inclination .


Real: go into that cave .
Enc: 
 : 
 churches to
Dec: go into that cave .


Real: it is god ' s will , to be killed or to be lived .
Enc: 
 , 
 
 it ) able be fools not 
 s , not is
Dec: it 
 will to never will , , be get or to be 