In [1]:
import torch
import torch.distributions as ds
import torch.nn as nn

In [None]:
from torchtext.datasets.babi import BABI20, BABI20Field
import torchtext.data
torch.device.index
train, val, test = BABI20.iters(batch_size=20, task = 1, device=torch.device(0))
QUERY = train.dataset.fields["query"]
STORY = train.dataset.fields["story"]
ANSWER = train.dataset.fields["answer"]
def wrap(data):
    return ((batch.story, 
            batch.query,
            batch.answer.squeeze(1)) for batch in  data)

In [70]:
def display_instance(x, x_tilde, y, p = None, p_y=None):
    print("Correct answer: ", ANSWER.vocab.itos[y])
    print("Story: ")
    for i, s in enumerate(x):
        for k, w in enumerate(s):
            if STORY.vocab.itos[w] != "<pad>":
                print("%10s"%(STORY.vocab.itos[w]), end=" ")
            elif k == 0:
                break
        if k > 0:
            if p is not None:
                print("%f "%p[i].data.item(), end=" ")
            print()

    print("Query:")
    for w in x_tilde:
        print("%20s "%(QUERY.vocab.itos[w]), end = " ")
    print()
    if p_y is not None:
        for i, p_y in enumerate(p_y):
            print("%20s %f"%(ANSWER.vocab.itos[i], py), end= " ")
        print()


In [61]:
x, x_tilde, y = next(wrap(val))
display_instance(x[0], x_tilde[0], y[0])

Correct answer:  bedroom
Story: 
    Daniel  journeyed         to        the    bedroom 
    Daniel  journeyed         to        the    kitchen 
Query:
               Where                    is                Daniel  


In [3]:
class Attention(nn.Module):
    "Simple implementation of attention."
    def __init__(self, value_encoder, query_encoder, 
                 evidence_encoder=None, combiner=None):
        super(Attention, self).__init__()
        self.value_encoder = value_encoder
        self.query_encoder = query_encoder
        
        # If we have acess to the evidence (inference network)
        self.evidence_encoder = evidence_encoder
        self.combiner = combiner
        
    def forward(self, src, query, answer=None):
        value_vecs = self.value_encoder(src)
        query_vecs = self.query_encoder(query)
        # Incorporate evidence.
        if self.answer_encoder is not None:
            answer_vecs = self.answer_encoder(answer)
            query_vecs = self.combiner(torch.cat([query_vecs, answer_vecs], -1))
        
        # Apply attention.
        attention_logits = torch.bmm(src_vecs, query_vecs.transpose(1, 2)).squeeze(-1)
        return value_vecs, attention_logits

In [5]:
class Encoder(nn.Module):
    "A simple LSTM text encoder."
    def __init__(self, vocab_size, hidden_size, last_only=False, pad=0):
        super(Encoder, self).__init__()
        self.lut = nn.Embedding(vocab_size, hidden_size, padding_idx=pad) 
        self.encoder = nn.LSTM(hidden_size=hidden_size,
                               input_size=hidden_size,
                               bidirectional=not last_only) 
        self.last_only = last_only
        
    def forward(self, text):
        x = self.lut(text)
        # If there is an extra dimension, sum it out.
        if x.dim() >= 4:
            x = x.sum(2)
        x, _ = self.encoder(x)
        if self.last_only:
            return x[:, -1:]
        else:
            return x

In [44]:
def make_model():
    model = nn.Module()
    model.alignment = Attention(Encoder(len(STORY.vocab), 50, pad=STORY.vocab.stoi["pad"]),
                                Encoder(len(QUERY.vocab), 100, last_only=True, 
                                        pad=STORY.vocab.stoi["pad"])
                               )
    model.generator =  nn.Sequential(
                    nn.Linear(100, 100),
                    nn.ReLU(),
                    nn.Linear(100, 100),
                    nn.ReLU(),
                    nn.Linear(100, len(ANSWER.vocab))) 

    # This will come up later. 
    model.inference = Attention(Encoder(len(STORY.vocab), 50, pad=STORY.vocab.stoi["pad"]),
                                Encoder(len(QUERY.vocab), 100, last_only=True, 
                                        pad=STORY.vocab.stoi["pad"]),
                                Encoder(len(ANSWER.vocab), 100, last_only=True, 
                                        pad=ANSWER.vocab.stoi["pad"]),
                                nn.Sequential(
                                    nn.Linear(200, 100),
                                    nn.ReLU(),
                                    nn.Linear(100, 100)
                                )
                               )
    model.cuda()
    return model

In [None]:
def E(p, v): 
    return torch.bmm(p.unsqueeze(1), v)

def KL(p, q): 
    return ds.kl_divergence(p, q)

def Cat1(logits):
    return ds.OneHotCategorical(logits=logits)

def Cat(logits): 
    return ds.Categorical(logits=logits)

def CatP(probs):
    return ds.Categorical(probs=probs)

In [46]:
def run_training(model, dataset, generator, epochs=50):
    opt = torch.optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(epochs):
        for x, x_tilde, y in wrap(dataset):
            opt.zero_grad()
            v, theta = model.alignment.forward(x, x_tilde)
            p = Cat1(theta)
            obj, p_y = generator(model.generator, x, x_tilde, y, p, v)
            obj.mean().backward()
            opt.step()
            _, y_hat = p_y.probs.max(1)
            correct = (y_hat == y).data.float().mean()
        print(-p_y.log_prob(y).mean().detach().item(), correct.item() )


In [49]:
def soft(f, x, x_tilde, y, p, v):
    "y ~ Cat(f(E_p[x]))"
    context = E(p.probs, v)
    p_y = Cat(f(context).squeeze(1))
    return -p_y.log_prob(y), p_y

In [9]:
def hard(f, x, x_tilde, y, p, v):
    choice = p.sample()
    context = E(choice, v)
    p_y = Cat(f(context).squeeze(1))
    reward = p_y.log_prob(y).detach()
    return -(p_y.log_prob(y) + p.log_prob(choice) * reward), p_y

In [10]:
def hard_soft(f, x, x_tilde, y, p, v):
    choice = p.sample()
    context = E(choice, v)
    p_y = Cat(f(context).squeeze(1))
    _, p_y_soft = soft(f, None, None, y, p, v)
    reward = (p_y.log_prob(y) - p_y_soft.log_prob(y)).detach()
    return -(p_y_soft.log_prob(y) + p_y.log_prob(y) + p.log_prob(choice) * reward), p_y

In [11]:
def enum(f, x, x_tilde, y, p, v):
    logits = model.generator(v)
    p_y = CatP(probs=E(p.probs, Cat(logits).probs).squeeze(1))
    return -p_y.log_prob(y), p_y

In [12]:
def vae(f, x, x_tilde, y, p, v):
    _, logits = model.inference(x, x_tilde, y.unsqueeze(1))
    q = Cat1(logits)
    choice = q.sample()
    context = E(choice, v)
    p_y = Cat(f(context).squeeze(1))
    _, p_y_soft = soft(f, x, x_tilde, y, p, v)
    reward = (p_y.log_prob(y) - p_y_soft.log_prob(y)).detach()
    return -(p_y_soft.log_prob(y) + p_y.log_prob(y) + q.log_prob(choice) * reward - KL(q, p)), p_y

In [63]:
model = make_model()
run_training(model, test, soft, epochs=20)

1.7560386657714844 0.15000000596046448
1.7189416885375977 0.20000000298023224
1.408884882926941 0.5
0.5281757116317749 0.9000000357627869
0.0964445099234581 1.0
0.11451363563537598 0.949999988079071
0.014591860584914684 1.0
0.010198927484452724 1.0
0.0708063393831253 0.949999988079071
0.01989584043622017 1.0
0.010172558017075062 1.0
0.0027688504196703434 1.0
0.012838221155107021 1.0
0.002007436705753207 1.0
0.1813783496618271 0.949999988079071
0.017470194026827812 1.0
0.0030329227447509766 1.0
0.006617713253945112 1.0
0.0038623332511633635 1.0
0.0015689373249188066 1.0


In [64]:
method = vae
x, x_tilde, y = next(wrap(test))
v, theta = model.alignment.forward(x, x_tilde)
p = Cat1(theta)
loss, p_y = method(model.generator, x, x_tilde, y, p, v)

In [71]:
for j in range(0,10):
    display_instance(x[j], x_tilde[j], y[j], p.probs[j], p_y.probs[j])

Correct answer:  hallway
Story: 
      Mary  journeyed         to        the   bathroom 0.006361  
      John  travelled         to        the    hallway 0.981357  
Query:
               Where                    is                  John  
               <pad> 0.000000                  the 0.000000                   to 0.000000                 went 0.000000               Sandra 0.000000                 Mary 0.000000                 John 0.000000               Daniel 0.000000              hallway 0.000000            journeyed 0.000000                moved 0.000000            travelled 0.000000              bedroom 0.000000                 back 0.000000               garden 0.000000               office 0.000000              kitchen 0.000000             bathroom 0.000000                Where 0.000000                   is 0.000000 
Correct answer:  bathroom
Story: 
      John      moved         to        the    bedroom 0.042216  
    Daniel       went       back         to        the   bat