In [1]:
import numpy as np
import random

def make_seq_data(n_samples, seq_len):
    # Boundary tasks
    data, labels = [], []
    for _ in range(n_samples):
        input = np.random.permutation(range(seq_len)).tolist()
        target = sorted(range(len(input)), key=lambda k: input[k])
        data.append(input)
        labels.append(target)
    return data, labels

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class PointerNetwork(nn.Module):
    def __init__(self, input_size, emb_size, weight_size, answer_seq_len, hidden_size=512, is_GRU=True):
        super(PointerNetwork, self).__init__()

        self.hidden_size = hidden_size
        self.input_size = input_size
        self.answer_seq_len = answer_seq_len
        self.weight_size = weight_size
        self.emb_size = emb_size
        self.is_GRU = is_GRU

        self.emb = nn.Embedding(input_size, emb_size)  # embed inputs
        if is_GRU:
            self.enc = nn.GRU(emb_size, hidden_size, batch_first=True)
            self.dec = nn.GRUCell(emb_size, hidden_size) # GRUCell's input is always batch first
        else:
            self.enc = nn.LSTM(emb_size, hidden_size, batch_first=True)
            self.dec = nn.LSTMCell(emb_size, hidden_size) # LSTMCell's input is always batch first

        self.W1 = nn.Linear(hidden_size, weight_size, bias=False) # blending encoder
        self.W2 = nn.Linear(hidden_size, weight_size, bias=False) # blending decoder
        self.vt = nn.Linear(weight_size, 1, bias=False) # scaling sum of enc and dec by v.T

    def forward(self, input):
        batch_size = input.size(0)
        y = F.softmax(input, dim=3)  
        y_bins = y.argmax(dim=3)  
        input = y[:,:,:,1]
        #input = torch.squeeze(input,0)
        path = {}
        for k in range(input.size(1)):
            input = input[:,k,:]
            
            input = self.emb(input.long()) # (bs, L, embd_size)

            # Encoding
            encoder_states, hc = self.enc(input) # encoder_state: (bs, L, H)
            encoder_states = encoder_states.transpose(1, 0) # (L, bs, H)

            # Decoding states initialization
            decoder_input = Variable(torch.zeros(batch_size, self.emb_size)) # (bs, embd_size)
            hidden = Variable(torch.zeros([batch_size, self.hidden_size]))   # (bs, h)
            cell_state = encoder_states[-1]                                # (bs, h)

            probs = []
            # Decoding
            for i in range(self.answer_seq_len): # range(M)
                if self.is_GRU:
                    hidden = self.dec(decoder_input, hidden) # (bs, h), (bs, h)
                else:
                    hidden, cell_state = self.dec(decoder_input, (hidden, cell_state)) # (bs, h), (bs, h)

                # Compute blended representation at each decoder time step
                blend1 = self.W1(encoder_states)          # (L, bs, W)
                blend2 = self.W2(hidden)                  # (bs, W)
                blend_sum = F.tanh(blend1 + blend2)    # (L, bs, W)
                out = self.vt(blend_sum).squeeze()        # (L, bs)
                out = F.log_softmax(out.transpose(0, 1).contiguous(), -1) # (bs, L)
                probs.append(out)

            probs = torch.stack(probs, dim=1)           # (bs, M, L)
        path[k] = 
        return probs

In [531]:
data_0 = np.array([[[0.69246036, 0.59753674, 0.5689253 , 0.56384754, 0.5495899 ,
         0.61832803, 0.5803497 , 0.61829597, 0.5745164 , 0.62113696],
        [0.5894812 , 0.5551746 , 0.49092993, 0.5252931 , 0.587854  ,
         0.53320426, 0.5182203 , 0.5992091 , 0.50958365, 0.60037696],
        [0.5886919 , 0.49092993, 0.5830317 , 0.463785  , 0.59199905,
         0.46117648, 0.46119246, 0.5954041 , 0.45702708, 0.59380126],
        [0.60557234, 0.5252931 , 0.46378502, 0.5862742 , 0.5896137 ,
         0.5020659 , 0.50027835, 0.5985208 , 0.5032851 , 0.5972939 ],
        [0.6093323 , 0.5878541 , 0.48655203, 0.52164733, 0.5841894 ,
         0.4850552 , 0.58888733, 0.49546102, 0.5827878 , 0.49952856],
        [0.61832803, 0.5926625 , 0.4611765 , 0.5020659 , 0.48505518,
         0.60978544, 0.59040314, 0.47951713, 0.58687156, 0.48131478],
        [0.5908085 , 0.51822037, 0.46119246, 0.5002783 , 0.58888733,
         0.4963187 , 0.561206  , 0.5942717 , 0.49203363, 0.5936888 ],
        [0.61829597, 0.599209  , 0.47888252, 0.5985209 , 0.495461  ,
         0.4795172 , 0.5942717 , 0.6245361 , 0.509626  , 0.48861158],
        [0.5745164 , 0.50958365, 0.45702705, 0.5032851 , 0.5827878 ,
         0.4894781 , 0.49203363, 0.58865094, 0.5568124 , 0.58771056],
        [0.62113696, 0.6003769 , 0.48427382, 0.53866506, 0.49952856,
         0.48131478, 0.5936888 , 0.48861155, 0.58771056, 0.6238037 ]]])

In [537]:
from random import *

len(data_0[0][0])

10

In [535]:
shuffle(data_0[0])
print(data_0[0])

[[0.69246036 0.59753674 0.5689253  0.56384754 0.5495899  0.61832803
  0.5803497  0.61829597 0.5745164  0.62113696]
 [0.69246036 0.59753674 0.5689253  0.56384754 0.5495899  0.61832803
  0.5803497  0.61829597 0.5745164  0.62113696]
 [0.69246036 0.59753674 0.5689253  0.56384754 0.5495899  0.61832803
  0.5803497  0.61829597 0.5745164  0.62113696]
 [0.60557234 0.5252931  0.46378502 0.5862742  0.5896137  0.5020659
  0.50027835 0.5985208  0.5032851  0.5972939 ]
 [0.60557234 0.5252931  0.46378502 0.5862742  0.5896137  0.5020659
  0.50027835 0.5985208  0.5032851  0.5972939 ]
 [0.69246036 0.59753674 0.5689253  0.56384754 0.5495899  0.61832803
  0.5803497  0.61829597 0.5745164  0.62113696]
 [0.5894812  0.5551746  0.49092993 0.5252931  0.587854   0.53320426
  0.5182203  0.5992091  0.50958365 0.60037696]
 [0.69246036 0.59753674 0.5689253  0.56384754 0.5495899  0.61832803
  0.5803497  0.61829597 0.5745164  0.62113696]
 [0.5908085  0.51822037 0.46119246 0.5002783  0.58888733 0.4963187
  0.561206   0.

In [486]:
torch.Tensor(data_0).size()

torch.Size([1, 10, 10])

In [392]:
data_t = torch.Tensor(data_0)
#data_t = torch.squeeze(data_t,0)
data_t

tensor([[[0.6925, 0.5975, 0.5689, 0.5638, 0.5496, 0.6183, 0.5803, 0.6183,
          0.5745, 0.6211],
         [0.5895, 0.5552, 0.4909, 0.5253, 0.5879, 0.5332, 0.5182, 0.5992,
          0.5096, 0.6004],
         [0.5887, 0.4909, 0.5830, 0.4638, 0.5920, 0.4612, 0.4612, 0.5954,
          0.4570, 0.5938],
         [0.6056, 0.5253, 0.4638, 0.5863, 0.5896, 0.5021, 0.5003, 0.5985,
          0.5033, 0.5973],
         [0.6093, 0.5879, 0.4866, 0.5216, 0.5842, 0.4851, 0.5889, 0.4955,
          0.5828, 0.4995],
         [0.6183, 0.5927, 0.4612, 0.5021, 0.4851, 0.6098, 0.5904, 0.4795,
          0.5869, 0.4813],
         [0.5908, 0.5182, 0.4612, 0.5003, 0.5889, 0.4963, 0.5612, 0.5943,
          0.4920, 0.5937],
         [0.6183, 0.5992, 0.4789, 0.5985, 0.4955, 0.4795, 0.5943, 0.6245,
          0.5096, 0.4886],
         [0.5745, 0.5096, 0.4570, 0.5033, 0.5828, 0.4895, 0.4920, 0.5887,
          0.5568, 0.5877],
         [0.6211, 0.6004, 0.4843, 0.5387, 0.4995, 0.4813, 0.5937, 0.4886,
          0.5877,

In [487]:
data = Variable(torch.FloatTensor(data_t[:,0,:].cpu().detach().numpy()))
data

tensor([[0.6925, 0.5975, 0.5689, 0.5638, 0.5496, 0.6183, 0.5803, 0.6183, 0.5745,
         0.6211]])

In [488]:
input_size=10
emb_size=32
emb = nn.Embedding(input_size, emb_size)

In [489]:
emb(data.long()).size()

torch.Size([1, 10, 32])

In [490]:
hidden_size=512
enc = nn.GRU(emb_size, hidden_size, batch_first=True)

In [491]:
input = emb(data.long())
encoder_states, hc = enc(input)
encoder_states = encoder_states.transpose(1, 0)

In [492]:
encoder_states.size()

torch.Size([10, 1, 512])

In [493]:
dec = nn.GRUCell(emb_size, hidden_size)

In [494]:
batch_size=1
decoder_input = Variable(torch.zeros(batch_size, emb_size)) # (bs, embd_size)
hidden = Variable(torch.zeros([batch_size, hidden_size]))   # (bs, h)
cell_state = encoder_states[-1]               

In [518]:
probs = []

In [519]:
answer_seq_len = 10
W1 = nn.Linear(hidden_size, weight_size, bias=False)
W2 = nn.Linear(hidden_size, weight_size, bias=False)
vt = nn.Linear(weight_size, 1, bias=False)

for i in range(answer_seq_len): # range(M)
    hidden = dec(decoder_input, hidden) # (bs, h), (bs, h)
    # Compute blended representation at each decoder time step
    blend1 = W1(encoder_states)          # (L, bs, W)
    blend2 = W2(hidden)                  # (bs, W)
    blend_sum = F.tanh(blend1 + blend2)    # (L, bs, W)
    out = vt(blend_sum).squeeze(2)      # (L, bs)
    out = F.softmax(out.transpose(0, 1).contiguous(), -1) # (bs, L)
    probs.append(out)
probs = torch.stack(probs, dim=1)

In [520]:
_v, indices = torch.max(probs, 2)

In [523]:
indices

tensor([[9, 9, 9, 9, 9, 9, 9, 9, 9, 9]])

In [60]:
total_size = 1000
weight_size = 256
emb_size = 32
batch_size = 25
n_epochs = 1

input_seq_len = 20
inp_size = input_seq_len
input, targets = make_seq_data(total_size, input_seq_len)

In [76]:
input = Variable(torch.LongTensor(input))     # (N, L)
targets = Variable(torch.LongTensor(targets)) # (N, L)

data_split = (int)(total_size * 0.9)
train_X = input[:data_split]
train_Y = targets[:data_split]
test_X = input[data_split:]
test_Y = targets[data_split:]

In [77]:
print(input.size(), targets.size())

torch.Size([1000, 20]) torch.Size([1000, 20])


In [78]:
model = PointerNetwork(inp_size, emb_size, weight_size, input_seq_len)

In [80]:
from torch import optim

def train(model, X, Y, batch_size, n_epochs):
    model.train()
    optimizer = optim.Adam(model.parameters())
    N = X.size(0)
    L = X.size(1)
    # M = Y.size(1)
    for epoch in range(n_epochs + 1):
        # for i in range(len(train_batches))
        for i in range(0, N-batch_size, batch_size):
            x = X[i:i+batch_size] # (bs, L)
            y = Y[i:i+batch_size] # (bs, M)

            probs = model(x) # (bs, M, L)
            outputs = probs.view(-1, L) # (bs*M, L)
            # outputs = probs.view(L, -1).t().contiguous() # (bs*M, L)
            y = y.view(-1) # (bs*M)
            loss = F.nll_loss(outputs, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

In [81]:
train(model, train_X, train_Y, batch_size, n_epochs)

In [66]:
X,Y = input,targets

N = X.size(0)
L = X.size(1)

In [67]:
X[(int)(total_size * 0.9):].size()

torch.Size([100, 20])

In [68]:
probs = model(X[(int)(total_size * 0.9):])
probs.size()

torch.Size([100, 20, 20])

In [69]:
_v, indices = torch.max(probs, 2)
_v.size(), indices.size()

(torch.Size([100, 20]), torch.Size([100, 20]))

Start working with probabilities as input and return a path with Pointer Network

In [70]:
y = Y[(int)(total_size * 0.99):]

In [71]:
sum([1 if torch.equal(ind.data, y.data) else 0 for ind, y in zip(indices, y)])

0

In [72]:
indices[4]

tensor([19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19,
        19, 19])

In [73]:
_v[4]

tensor([-2.9504, -2.9504, -2.9504, -2.9504, -2.9504, -2.9504, -2.9504, -2.9504,
        -2.9504, -2.9504, -2.9504, -2.9504, -2.9504, -2.9504, -2.9504, -2.9504,
        -2.9504, -2.9504, -2.9504, -2.9504], grad_fn=<SelectBackward>)

In [82]:
def test(model, X, Y):
    probs = model(X) # (bs, M, L)
    _v, indices = torch.max(probs, 2) # (bs, M)
    # show test examples
    for i in range(len(indices)):
        print('-----')
        print('test', [v for v in X[i].data])
        print('label', [v for v in Y[i].data])
        print('pred', [v for v in indices[i].data])
        if i>4: break
    correct_count = sum([1 if torch.equal(ind.data, y.data) else 0 for ind, y in zip(indices, Y)])
    print('Acc: {:.2f}% ({}/{})'.format(correct_count/len(X)*100, correct_count, len(X)))

In [83]:
test(model,test_X,test_Y)

-----
test [tensor(12), tensor(1), tensor(19), tensor(11), tensor(15), tensor(9), tensor(5), tensor(4), tensor(0), tensor(3), tensor(17), tensor(8), tensor(16), tensor(18), tensor(13), tensor(14), tensor(7), tensor(2), tensor(10), tensor(6)]
label [tensor(8), tensor(1), tensor(17), tensor(9), tensor(7), tensor(6), tensor(19), tensor(16), tensor(11), tensor(5), tensor(18), tensor(3), tensor(0), tensor(14), tensor(15), tensor(4), tensor(12), tensor(10), tensor(13), tensor(2)]
pred [tensor(8), tensor(8), tensor(8), tensor(7), tensor(7), tensor(6), tensor(19), tensor(16), tensor(11), tensor(5), tensor(18), tensor(3), tensor(3), tensor(15), tensor(12), tensor(4), tensor(2), tensor(2), tensor(2), tensor(2)]
-----
test [tensor(19), tensor(3), tensor(7), tensor(8), tensor(16), tensor(6), tensor(15), tensor(11), tensor(0), tensor(12), tensor(4), tensor(13), tensor(1), tensor(14), tensor(17), tensor(10), tensor(2), tensor(9), tensor(18), tensor(5)]
label [tensor(8), tensor(12), tensor(16), tenso