In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Reference: https://github.com/guacomolia/ptr_net
import torch
from torch.autograd import Variable
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import numpy as np
import generate_data
from utils import to_var
import random

total_size = 10000
weight_size = 256                           # W
emb_size = 32
batch_size = 250                            # B
n_batches = total_size // batch_size        # NB
answer_seq_len = 2                          # M = 2
n_epochs = 100                              # NE

In [None]:
# dataset, starts, ends = generate_data.generate_set_seq(total_size)
# targets = np.vstack((starts, ends)).T                                   # [total_size, M]
# dataset = np.array(dataset)                                             # [total_size, L]

# input_seq_len = dataset.shape[1]
# inp_size = 11                               # I

# # Convert to torch tensors
# input = to_var(torch.LongTensor(dataset))     # [total_size, L]
# targets = to_var(torch.LongTensor(targets))   # [total_size, 2]

# train_batches = input.view(n_batches, batch_size, input_seq_len)              # [NB, B, L]
# targets = targets.view(n_batches, batch_size, answer_seq_len)                        # [NB, B, 2]

In [28]:
def make_seq_data(n_samples, seq_len=3, max_val=5):
    data, labels = [], []
    for _ in range(n_samples):
        input = [random.randint(0, max_val-1) for _ in range(seq_len)]
        target = sorted(range(len(input)), key=lambda k: input[k])
        data.append(input)
        labels.append(target)
    return data, labels

input_seq_len = 3
max_val = 4
input, targets = make_seq_data(total_size, input_seq_len, max_val)
input = to_var(torch.LongTensor(input))
targets = to_var(torch.LongTensor(targets))
print(len(targets))
train_batches = input.view(n_batches, batch_size, input_seq_len)
targets = targets.view(n_batches, batch_size, input_seq_len)

10000


In [29]:
# input, targets = make_seq_data(total_size, input_seq_len, max_val)
# print(input[0])
# print(targets[0])

In [30]:
# from pointer_network import PointerNetwork
def train(n_epochs, model, train_batches, targets):
    model.train()
    optimizer = optim.Adam(model.parameters())
    for epoch in range(n_epochs + 1):
        for i in range(len(train_batches)):
            input = train_batches[i] # [B, L, 2]
            target = targets[i] # [B, M]

            optimizer.zero_grad()

            L = input.data.shape[1]
            probs = model(input) # (L, M, N)
            probs = probs.view(L, -1).t().contiguous() # (N*M, L)
            target = target.view(-1) # (N*M)
            loss = F.nll_loss(probs, target)
            loss.backward()
            optimizer.step()

        pick = np.random.randint(0, batch_size)
        if epoch % 2 == 0:
            print('epoch: {}\t\t -- loss: {:.5f}'.format(epoch, loss.data[0]))
            print("trained ", probs.max(1)[1].data[pick], probs.max(1)[1].data[2*pick],
                  "target : ", target.data[pick], target.data[2*pick])

# model = PointerNetwork(inp_size, emb_size, weight_size, batch_size, input_seq_len, answer_seq_len)
model = PointerNetwork(max_val, emb_size, weight_size, batch_size, input_seq_len, input_seq_len)
if torch.cuda.is_available():
    model.cuda()
train(20, model, train_batches, targets)

epoch: 0		 -- loss: 0.52519
trained  1 1 target :  1 1
epoch: 2		 -- loss: 0.01987
trained  1 0 target :  1 0
epoch: 4		 -- loss: 0.00183
trained  0 2 target :  0 2
epoch: 6		 -- loss: 0.00099
trained  2 1 target :  2 1
epoch: 8		 -- loss: 0.00043
trained  0 0 target :  0 0
epoch: 10		 -- loss: 0.00997
trained  2 0 target :  2 0
epoch: 12		 -- loss: 0.00047
trained  0 2 target :  0 2
epoch: 14		 -- loss: 0.00015
trained  1 1 target :  1 1
epoch: 16		 -- loss: 0.00011
trained  2 1 target :  2 1
epoch: 18		 -- loss: 0.00008
trained  2 1 target :  2 1
epoch: 20		 -- loss: 0.00006
trained  2 0 target :  2 0


In [38]:
def predict(model, data):
    outputs = model(data)
    outputs = outputs.view(data.data.shape[1], -1).t().contiguous()
    indices = outputs.max(1)[1].data
    return indices

# Predictions
test_id = random.randint(0, batch_size-1)
test_data = train_batches[0]
test_targets = targets[0]
outputs = predict(model, test_data)
for i in range(len(outputs)):
    print('-----')
    print('test', [v for v in test_data[i].data])
    print('label', [v for v in test_targets[i].data])
    print('pred', [v for v in outputs[i:i+input_seq_len]])
    if i>20: break

-----
test [3, 3, 1]
label [2, 0, 1]
pred [2, 0, 1]
-----
test [3, 3, 0]
label [2, 0, 1]
pred [0, 1, 2]
-----
test [0, 1, 0]
label [0, 2, 1]
pred [1, 2, 0]
-----
test [3, 3, 3]
label [0, 1, 2]
pred [2, 0, 1]
-----
test [0, 3, 0]
label [0, 2, 1]
pred [0, 1, 0]
-----
test [1, 2, 0]
label [2, 0, 1]
pred [1, 0, 2]
-----
test [2, 2, 3]
label [0, 1, 2]
pred [0, 2, 1]
-----
test [2, 1, 2]
label [1, 0, 2]
pred [2, 1, 0]
-----
test [2, 1, 3]
label [1, 0, 2]
pred [1, 0, 1]
-----
test [0, 3, 2]
label [0, 2, 1]
pred [0, 1, 2]
-----
test [0, 2, 0]
label [0, 2, 1]
pred [1, 2, 0]
-----
test [2, 0, 0]
label [1, 2, 0]
pred [2, 0, 2]
-----
test [1, 2, 1]
label [0, 2, 1]
pred [0, 2, 1]
-----
test [0, 2, 0]
label [0, 2, 1]
pred [2, 1, 2]
-----
test [0, 1, 0]
label [0, 2, 1]
pred [1, 2, 0]
-----
test [2, 1, 2]
label [1, 0, 2]
pred [2, 0, 1]
-----
test [0, 0, 1]
label [0, 1, 2]
pred [0, 1, 0]
-----
test [2, 2, 0]
label [2, 0, 1]
pred [1, 0, 1]
-----
test [0, 0, 1]
label [0, 1, 2]
pred [0, 1, 2]
-----
test [