In [84]:
import torch
import torch.distributions as ds
import torch.nn as nn 
from torchtext.datasets.babi import BABI20, BABI20Field
import torchtext.data

In [124]:
def E(p, v): 
    return torch.bmm(p.transpose(1, 2), v)

def KL(p, q): 
    return ds.kl_divergence(p, q)

def Cat(scores, one_hot=True): 
    if one_hot:
        return ds.OneHotCategorical(logits=scores)
    else:
        return ds.Categorical(logits=scores)

In [98]:
class Attention(nn.Module):
    def __init__(self, src_encoder, query_encoder):
        super(Attention, self).__init__()
        self.src_encoder = src_encoder
        self.query_encoder = query_encoder
        
    def forward(self, src, query):
        src_vecs = self.src_encoder(src)
        query_vecs = self.query_encoder(query)
        attention = torch.bmm(src_vecs, query_vecs.transpose(1, 2))
        return src_vecs, attention

In [110]:
class Generator(nn.Module):
    def __init__(self, out_classes, hidden_size):
        super(Generator, self).__init__()
        self.prediction_network = 
            nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                nn.ReLUL(),
                nn.Linear(hidden_size, out_classes)) 

    def forward(self, context):
        return self.prediction_network(context)


In [158]:
train, val, test = BABI20.iters(batch_size=100, task = 1,)
QUERY = train.dataset.fields["query"]
STORY = train.dataset.fields["story"]
ANSWER = train.dataset.fields["answer"]

In [153]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, hidden_size, last_only=False):
        super(Encoder, self).__init__()
        self.lut = nn.Embedding(vocab_size, hidden_size) 
        self.encoder = nn.LSTM(hidden_size=hidden_size,
                               input_size=hidden_size,
                               bidirectional=not last_only) 
        self.last_only = last_only
        
    def forward(self, text):
        x = self.lut(text)
        x, _ = self.encoder(x)
        if self.last_only:
            return x[:, -1:]
        else:
            return x
    

In [164]:
model = nn.Module()
model.alignment = Attention(Encoder(len(STORY.vocab), 50),
                            Encoder(len(QUERY.vocab), 100, last_only=True)
                           )
model.generator = Generator(len(ANSWER.vocab), 100)
#model.inference = Inferer(len(ANSWER.vocab), 20)

opt = torch.optim.Adam(model.parameters(), lr=0.001)

typ = "soft"
for epoch in range(10):
    for batch in test:
        opt.zero_grad()
        x = batch.story.view(batch.story.shape[0], -1)
        x_tilde = batch.query
        y = batch.answer.squeeze(1)

        v, theta = model.alignment.forward(x, x_tilde)
        p = Cat(theta)
        def soft():
            context = E(p.probs, v)
            logits = model.generator(context).squeeze(1)
            p_y = Cat(logits,one_hot=False)
            #print(logits[0], y[0])
            return -p_y.log_prob(y)

        def hard():
            context = E(p.sample(), v)
            logits = model.generator(context).squeeze(1)
            p_y = Cat(logits, one_hot=False)    
            return -p_y.log_prob(y) 

        def enum():
            logits = model.generator(v)
            p_y = Cat(logits)
            return E(p.probs, p_y.log_prob(y).unsqueeze(2))

        def vae():
            q = model.inference(x, x_tilde, y)
            context = E(q.sample(), v)
            p_y = Cat(model.generator(context).squeeze(1))
            return p_y.log_prob(y) + KL(p, q)

        if typ == "soft":
            loss = soft()
        elif typ == "hard":
            loss = hard()
        elif typ == "enum":
            loss = enum()
        elif typ == "vae":
            loss = vae()

        #loss = E(p, p_y.log_prob(y.unsqueeze()))
        loss = loss.mean()
        loss.backward()
        print(loss.detach())
        opt.step()

tensor(41.3210)
tensor(32.2808)
tensor(32.5659)
tensor(27.1223)
tensor(25.9738)
tensor(24.0419)
tensor(20.5931)
tensor(13.1045)
tensor(8.4104)
tensor(10.7441)
tensor(14.7080)
tensor(13.8520)
tensor(10.5503)
tensor(9.0310)
tensor(8.5378)
tensor(5.5080)
tensor(6.2666)
tensor(5.4611)
tensor(3.6705)
tensor(5.9019)
tensor(8.0330)
tensor(5.3569)
tensor(4.6592)
tensor(4.0378)
tensor(5.2231)
tensor(3.7877)
tensor(3.6607)
tensor(5.5091)
tensor(4.3207)
tensor(4.5140)
tensor(4.5242)
tensor(4.8852)
tensor(4.6404)
tensor(4.0859)
tensor(3.7199)
tensor(4.0425)
tensor(4.2135)
tensor(2.6278)
tensor(4.0896)
tensor(3.4656)
tensor(3.4424)
tensor(3.8586)
tensor(3.5965)
tensor(3.1555)
tensor(3.3140)
tensor(2.7481)
tensor(4.3643)
tensor(4.0031)
tensor(3.2112)
tensor(2.9813)
tensor(3.3521)
tensor(3.8359)
tensor(3.8734)
tensor(3.7003)
tensor(3.9962)
tensor(2.8786)
tensor(2.9738)
tensor(4.3933)
tensor(2.5862)
tensor(2.5975)
tensor(2.2397)
tensor(3.0304)
tensor(2.9350)
tensor(2.5647)
tensor(2.6183)
tensor(2.3786

In [39]:
ds.Categorical(logits=torch.autograd.Variable(torch.Tensor([0.2, 0.8]))
              ).sample()

Variable containing:
 0
[torch.LongTensor of size ()]

In [None]:
ds.

In [None]:
ds.Categorical()