In [1]:
import torch
import torch.distributions as ds
import torch.nn as nn

In [None]:
from torchtext.datasets.babi import BABI20, BABI20Field
import torchtext.data
torch.device.index
train, val, test = BABI20.iters(batch_size=20, task = 1, device=torch.device(0))
QUERY = train.dataset.fields["query"]
STORY = train.dataset.fields["story"]
ANSWER = train.dataset.fields["answer"]
def wrap(data):
    return ((batch.story, 
            batch.query,
            batch.answer.squeeze(1)) for batch in  data)

In [21]:
def display_instance(x, x_tilde, y):
    print("Correct answer: ", ANSWER.vocab.itos[y])
    for w in x_tilde:
        print("%20s "%(QUERY.vocab.itos[w]), end = " ")
    print()
    #for i, py in enumerate(p_y.probs):
    #    print("%20s %f"%(ANSWER.vocab.itos[i], py), end= " ")
    #print()
    for i, s in enumerate(x):
        for k, w in enumerate(s):
            if STORY.vocab.itos[w] != "<pad>":
                print("%10s"%(STORY.vocab.itos[w]), end=" ")
            elif k == 0:
                break
        if k > 0:
            print("%f "%p.probs[i].data.item())

In [22]:
x, x_tilde, y = next(wrap(val))
display_instance(x[0], x_tilde[0], y[0])

Correct answer:  bedroom
               Where                    is                Daniel  
    Daniel  journeyed         to        the    bedroom 

ValueError: only one element tensors can be converted to Python scalars

In [3]:
class Attention(nn.Module):
    "Simple implementation of attention."
    def __init__(self, value_encoder, query_encoder, 
                 evidence_encoder=None, combiner=None):
        super(Attention, self).__init__()
        self.value_encoder = value_encoder
        self.query_encoder = query_encoder
        
        # If we have acess to the evidence (inference network)
        self.evidence_encoder = evidence_encoder
        self.combiner = combiner
        
    def forward(self, src, query, answer=None):
        value_vecs = self.value_encoder(src)
        query_vecs = self.query_encoder(query)
        # Incorporate evidence.
        if self.answer_encoder is not None:
            answer_vecs = self.answer_encoder(answer)
            query_vecs = self.combiner(torch.cat([query_vecs, answer_vecs], -1))
        
        # Apply attention.
        attention_logits = torch.bmm(src_vecs, query_vecs.transpose(1, 2)).squeeze(-1)
        return value_vecs, attention_logits

In [5]:
class Encoder(nn.Module):
    "A simple LSTM text encoder."
    def __init__(self, vocab_size, hidden_size, last_only=False, pad=0):
        super(Encoder, self).__init__()
        self.lut = nn.Embedding(vocab_size, hidden_size, padding_idx=pad) 
        self.encoder = nn.LSTM(hidden_size=hidden_size,
                               input_size=hidden_size,
                               bidirectional=not last_only) 
        self.last_only = last_only
        
    def forward(self, text):
        x = self.lut(text)
        # If there is an extra dimension, sum it out.
        if x.dim() >= 4:
            x = x.sum(2)
        x, _ = self.encoder(x)
        if self.last_only:
            return x[:, -1:]
        else:
            return x

In [7]:
model = nn.Module()
model.alignment = Attention(Encoder(len(STORY.vocab), 50, pad=STORY.vocab.stoi["pad"]),
                            Encoder(len(QUERY.vocab), 100, last_only=True, 
                                    pad=STORY.vocab.stoi["pad"])
                           )
model.generator =  nn.Sequential(
                nn.Linear(100, 100),
                nn.ReLU(),
                nn.Linear(100, 100),
                nn.ReLU(),
                nn.Linear(100, len(ANSWER.vocab))) 

model.inference = Attention(Encoder(len(STORY.vocab), 50, pad=STORY.vocab.stoi["pad"]),
                            Encoder(len(QUERY.vocab), 100, last_only=True, 
                                    pad=STORY.vocab.stoi["pad"]),
                            Encoder(len(ANSWER.vocab), 100, last_only=True, 
                                    pad=ANSWER.vocab.stoi["pad"]),
                            nn.Sequential(
                                nn.Linear(200, 100),
                                nn.ReLU(),
                                nn.Linear(100, 100)
                            )
                           )
model.cuda()


Module(
  (alignment): Attention(
    (src_encoder): Encoder(
      (lut): Embedding(20, 50, padding_idx=0)
      (encoder): LSTM(50, 50, bidirectional=True)
    )
    (query_encoder): Encoder(
      (lut): Embedding(20, 100, padding_idx=0)
      (encoder): LSTM(100, 100)
    )
  )
  (generator): Generator(
    (prediction_network): Sequential(
      (0): Linear(in_features=100, out_features=100, bias=True)
      (1): ReLU()
      (2): Linear(in_features=100, out_features=100, bias=True)
      (3): ReLU()
      (4): Linear(in_features=100, out_features=20, bias=True)
    )
  )
  (inference): Attention(
    (src_encoder): Encoder(
      (lut): Embedding(20, 50, padding_idx=0)
      (encoder): LSTM(50, 50, bidirectional=True)
    )
    (query_encoder): Encoder(
      (lut): Embedding(20, 100, padding_idx=0)
      (encoder): LSTM(100, 100)
    )
    (answer_encoder): Encoder(
      (lut): Embedding(20, 100, padding_idx=0)
      (encoder): LSTM(100, 100)
    )
    (combiner): Sequential(
 

In [None]:
def E(p, v): 
    return torch.bmm(p.unsqueeze(1), v)

def KL(p, q): 
    return ds.kl_divergence(p, q)

def Cat1(logits):
    return ds.OneHotCategorical(logits=logits)

def Cat(logits): 
    return ds.Categorical(logits=logits)

def CatP(probs):
    return ds.Categorical(probs=probs)

In [8]:
def soft(gen, x, x_tilde, y, p, v):
    context = E(p.probs, v)
    logits = gen(context).squeeze(1)
    p_y = Cat(logits)
    return -p_y.log_prob(y), p_y

In [9]:
def hard(gen, x, x_tilde, y, p, v):
    choice = p.sample()
    context = E(choice, v)
    p_y = Cat(gen(context).squeeze(1))
    reward = p_y.log_prob(y).detach()
    return -(p_y.log_prob(y) + p.log_prob(choice) * reward), p_y

In [10]:
def hard_soft(gen, x, x_tilde, y, p, v):
    choice = p.sample()
    context = E(choice, v)
    p_y = Cat(gen(context).squeeze(1))
    nbaseline, p_y_soft = soft(gen, x, x_tilde, y, p, v)
    reward = (p_y.log_prob(y) - (-nbaseline)).detach()
    return -(p_y_soft.log_prob(y) + p_y.log_prob(y) + p.log_prob(choice) * reward), p_y

In [11]:
def enum(gen, x, x_tilde, y, p, v):
    logits = model.generator(v)
    p_y = CatP(probs=E(p.probs, Cat(logits).probs).squeeze(1))
    return -p_y.log_prob(y), p_y

In [12]:
def vae(gen, x, x_tilde, y, p, v):
    _, logits = model.inference(x, x_tilde, y.unsqueeze(1))
    q = Cat1(logits)
    choice = q.sample()
    context = E(choice, v)
    p_y = Cat(gen(context).squeeze(1))
    nbaseline, p_y_soft = soft(gen, x, x_tilde, y, p, v)
    reward = (p_y.log_prob(y) - (-nbaseline)).detach()
    return -(p_y_soft.log_prob(y) + p_y.log_prob(y) + q.log_prob(choice) * reward - KL(q, p)), p_y

In [13]:
# opt = torch.optim.SGD(model.parameters(), lr=5)
opt = torch.optim.Adam(model.parameters(), lr=0.001)

method = vae
for epoch in range(100):
    for x, x_tilde, y in wrap(test):
        opt.zero_grad()
        v, theta = model.alignment.forward(x, x_tilde)
        p = Cat1(theta)
        obj, p_y = method(model.generator, x, x_tilde, y, p, v)
        obj = obj.mean()
        obj.backward()
        opt.step()
        _, y_hat = p_y.probs.max(1)
        correct = (y_hat == y).data.float().mean()
    print(-p_y.log_prob(y).mean().detach().item(), correct.item() )


1.8999223709106445 0.15000000596046448
1.7574795484542847 0.15000000596046448
1.7703651189804077 0.15000000596046448
1.764735221862793 0.15000000596046448
1.7493127584457397 0.15000000596046448
1.7248371839523315 0.20000000298023224
1.7008472681045532 0.3499999940395355
0.6060455441474915 0.8500000238418579
0.5408384203910828 0.9000000357627869
0.01800537109375 1.0
0.007454657461494207 1.0
0.0031062127090990543 1.0
0.002271890640258789 1.0
0.001722145127132535 1.0
0.0015566826332360506 1.0
0.0031032562255859375 1.0
0.0009783267742022872 1.0
0.004041910171508789 1.0
0.0007668972248211503 1.0
0.0005860328674316406 1.0
0.0004953384632244706 1.0
0.0008023261907510459 1.0
0.0007991790771484375 1.0
0.0009588241809979081 1.0
0.0009348392486572266 1.0
0.0007482051732949913 1.0
0.0023152113426476717 1.0
0.0025069713592529297 1.0
0.0010557174682617188 1.0
0.0006608963012695312 1.0
0.0005763531080447137 1.0
0.0005115509266033769 1.0
0.0003913402615580708 1.0
0.0002566814364399761 1.0
0.0002030372

In [14]:
method = vae
x, x_tilde, y = next(wrap(test))
v, theta = model.alignment.forward(x, x_tilde)
p = Cat1(theta)
loss, p_y = method(model.generator, x, x_tilde, y, p, v)

In [15]:
for j in range(0,10):
    print("Correct answer: ", ANSWER.vocab.itos[y[j] ])
    for w in x_tilde[j]:
        print("%20s "%(QUERY.vocab.itos[w]), end = " ")
    print()
    for i, py in enumerate(p_y.probs[j]):
        print("%20s %f"%(ANSWER.vocab.itos[i], py), end= " ")
    print()
    for i, s in enumerate(x[j]):
        for k, w in enumerate(s):
            if STORY.vocab.itos[w] != "<pad>":
                print("%10s"%(STORY.vocab.itos[w]), end=" ")
            elif k == 0:
                break
        if k > 0:
            print("%f "%p.probs[j, i].data.item())
        
        
    #p.probs[0]

Correct answer:  hallway
               Where                    is                  John  
               <pad> 0.000000                  the 0.000000                   to 0.000000                 went 0.000000               Sandra 0.000000                 Mary 0.000000                 John 0.000000               Daniel 0.000000              hallway 1.000000            journeyed 0.000000                moved 0.000000            travelled 0.000000              bedroom 0.000000                 back 0.000000               garden 0.000000               office 0.000000              kitchen 0.000000             bathroom 0.000000                Where 0.000000                   is 0.000000 
      Mary  journeyed         to        the   bathroom 0.000002 
      John  travelled         to        the    hallway 0.999973 
Correct answer:  bathroom
               Where                    is                  Mary  
               <pad> 0.000000                  the 0.000000                   to 0.0

In [16]:
ds.Categorical(logits=torch.autograd.Variable(torch.Tensor([0.2, 0.8]))
              ).sample()

tensor(1)

In [17]:
ds.

SyntaxError: invalid syntax (<ipython-input-17-5fc19e5a66ca>, line 1)

In [None]:
ds.Categorical()