In [8]:
# Reference: https://github.com/guacomolia/ptr_net
import torch
from torch.autograd import Variable
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import numpy as np
import ptr
import generate_data

In [9]:
def to_var(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x)

In [10]:
total_size = 10000
dataset, starts, ends = generate_data.generate_set_seq(total_size)
targets = np.vstack((starts, ends)).T                                   # [total_size, M]
dataset = np.array(dataset)                                             # [total_size, L]

weight_size = 256                           # W
emb_size = 32
batch_size = 250                            # B
n_batches = total_size // batch_size        # NB
max_len = 2                                 # M = 2
n_epochs = 100                              # NE
seq_len = dataset.shape[1]
inp_size = 11                               # I

# Convert to torch tensors
input = to_var(torch.LongTensor(dataset))     # [total_size, L]
targets = to_var(torch.LongTensor(targets))   # [total_size, 2]

train_batches = input.view(n_batches, batch_size, seq_len)              # [NB, B, L]
targets = targets.view(n_batches, batch_size, 2)                        # [NB, B, 2]

In [63]:
def train(n_epochs, model, train_batches, targets):
    model.train()
    optimizer = optim.Adam(model.parameters())
    for epoch in range(n_epochs + 1):
        for i in range(len(train_batches)):
            input = train_batches[i] # [B, L, 2]
            target = targets[i] # [B, M]

            optimizer.zero_grad()

            L = input.data.shape[1]
            probs = model(input)
            probs = probs.view(L, -1).t().contiguous()

            target = target.view(-1)
            loss = F.nll_loss(probs, target)
            loss.backward()
            optimizer.step()

            pick = np.random.randint(0, batch_size)
        if epoch % 2 == 0:
            print('epoch: {}\t\t -- loss: {:.5f}'.format(epoch, loss.data[0]))

            print("trained ", probs.max(1)[1].data[pick], probs.max(1)[1].data[2*pick],
                  "target : ", target.data[pick], target.data[2*pick])
    return 0

In [64]:
class PointerNetwork(nn.Module):
    def __init__(self, input_size, emb_size, weight_size, batch_size, seq_len, max_len, hidden_size=512):
        super(PointerNetwork, self).__init__()
        """ Pointer network implementation. Refer to http://papers.nips.cc/paper/5866-pointer-networks.pdf
        https://github.com/ikostrikov/TensorFlow-Pointer-Networks
        Outputs links to the elements of an input sequence"""

        # Initialization of sizes
        self.batch_size = batch_size    # B
        self.hidden_size = hidden_size  # H
        self.input_size = input_size    # I
        self.max_len = max_len          # M (for decoder)
        self.weight_size = weight_size  # W (size of dec = size of enc)
        self.seq_len = seq_len          # L (length of input sequence)
        self.emb_size = emb_size        # E


        # Initialization of layers (paper's notations preserved)
        self.enc = nn.LSTM(emb_size, hidden_size, batch_first=True)           # encoding
        self.dec = nn.LSTMCell(emb_size, hidden_size)       # decoding
        self.W1 = nn.Linear(hidden_size, weight_size)       # blending encoder
        self.W2 = nn.Linear(hidden_size, weight_size)       # blending decoder
        self.vt = nn.Linear(weight_size, 1)                # scaling sum of enc and dec by v.T
        self.emb = nn.Embedding(input_size, emb_size)       # embed inputs

        # Functions to be used
        self.tanh = nn.Tanh()

    def forward(self, input):
        input = self.emb(input) # (N, L, embd_size)
        # Encoding
        encoder_states, hc = self.enc(input) # encoder_state: (B, L, H)
        encoder_states = encoder_states.transpose(1, 0) # (L, B, H)

        # Decoding states initialization
        decoder_input = to_var(torch.Tensor(self.batch_size, self.emb_size).zero_())      # [B,E]
        hidden = to_var(torch.randn([self.batch_size, self.hidden_size]))                 # [B,H]
        cell_state = encoder_states[-1]                                                   # [B,H]

        probs = []
        # Decoding
        for i in range(self.max_len): # range(M)
            hidden, cell_state = self.dec(decoder_input, (hidden, cell_state)) # (B, H), (B, H)

            # Compute blended representation at each decoder time step
            blend1 = self.W1(encoder_states)        # [L,B,W]
            blend2 = self.W2(hidden)                # [B,W]
            blend_sum = self.tanh(blend1 + blend2)  # [L,B,W]
            out = self.vt(blend_sum)                # [L,B, 1]
            out = torch.squeeze(out)                # [L,B]
            probs.append(out)
        probs = torch.stack(probs, dim=2)           # [L,B,M]
        return F.log_softmax(probs)
    
model = PointerNetwork(inp_size, emb_size, weight_size, batch_size, seq_len, max_len)
model.cuda()
train(10, model, train_batches, targets)

epoch: 0		 -- loss: 0.30443
trained  16 8 target :  16 8
epoch: 2		 -- loss: 0.00182
trained  12 10 target :  12 10


KeyboardInterrupt: 