In [1]:
from utils.datasets import AlphabetSortingDataset
from torch.utils.data import DataLoader

In [2]:
import torch
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from torch.utils.data import DataLoader
import numpy as np

In [3]:
import warnings
warnings.filterwarnings("ignore")

params = {
    # Data
    'batch_size': 512,
    'shuffle': True,
    'nof_workers': 0, # must stay at 0
    #Train
    'nof_epoch': 1000,
    'lr': 0.0001,
    # GPU
    'gpu': True,
    # Network
    'embedding_size': 300,
    'hiddens': 512,
    'nof_lstms': 8,
    'dropout': 0.3,
    'bidir': False # True not working right now
}

dataset = AlphabetSortingDataset(100000, min_len=4, max_len=4)
dataloader = DataLoader(dataset,
                        batch_size=params['batch_size'],
                        shuffle=params['shuffle'])

In [4]:
if params['gpu'] and torch.cuda.is_available():
    USE_CUDA = True
    print('Using GPU, %i devices.' % torch.cuda.device_count())
else:
    USE_CUDA = False

Using GPU, 1 devices.


In [5]:
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F


class Encoder(nn.Module):
    """
    Encoder class for Pointer-Net
    """

    def __init__(self, embedding_dim,
                 hidden_dim,
                 n_layers,
                 dropout,
                 bidir):
        """
        Initiate Encoder
        :param Tensor embedding_dim: Number of embbeding channels
        :param int hidden_dim: Number of hidden units for the LSTM
        :param int n_layers: Number of layers for LSTMs
        :param float dropout: Float between 0-1
        :param bool bidir: Bidirectional
        """

        super(Encoder, self).__init__()
        self.hidden_dim = hidden_dim//2 if bidir else hidden_dim
        self.n_layers = n_layers*2 if bidir else n_layers
        self.bidir = bidir
        self.lstm = nn.LSTM(embedding_dim,
                            self.hidden_dim,
                            n_layers,
                            dropout=dropout,
                            bidirectional=bidir)

        # Used for propagating .cuda() command
        self.h0 = Parameter(torch.zeros(1), requires_grad=False)
        self.c0 = Parameter(torch.zeros(1), requires_grad=False)

    def forward(self, embedded_inputs,
                hidden):
        """
        Encoder - Forward-pass
        :param Tensor embedded_inputs: Embedded inputs of Pointer-Net
        :param Tensor hidden: Initiated hidden units for the LSTMs (h, c)
        :return: LSTMs outputs and hidden units (h, c)
        """

        embedded_inputs = embedded_inputs.permute(1, 0, 2)

        outputs, hidden = self.lstm(embedded_inputs, hidden)

        return outputs.permute(1, 0, 2), hidden

    def init_hidden(self, embedded_inputs):
        """
        Initiate hidden units
        :param Tensor embedded_inputs: The embedded input of Pointer-NEt
        :return: Initiated hidden units for the LSTMs (h, c)
        """

        batch_size = embedded_inputs.size(0)

        # Reshaping (Expanding)
        h0 = self.h0.unsqueeze(0).unsqueeze(0).repeat(self.n_layers,
                                                      batch_size,
                                                      self.hidden_dim)
        c0 = self.h0.unsqueeze(0).unsqueeze(0).repeat(self.n_layers,
                                                      batch_size,
                                                      self.hidden_dim)

        return h0, c0


class Attention(nn.Module):
    """
    Attention model for Pointer-Net
    """

    def __init__(self, input_dim,
                 hidden_dim):
        """
        Initiate Attention
        :param int input_dim: Input's diamention
        :param int hidden_dim: Number of hidden units in the attention
        """

        super(Attention, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.input_linear = nn.Linear(input_dim, hidden_dim)
        self.context_linear = nn.Conv1d(input_dim, hidden_dim, 1, 1)
        self.V = Parameter(torch.FloatTensor(hidden_dim), requires_grad=True)
        self._inf = Parameter(torch.FloatTensor([float('-inf')]), requires_grad=False)
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax()

        # Initialize vector V
        nn.init.uniform(self.V, -1, 1)

    def forward(self, input,
                context,
                mask):
        """
        Attention - Forward-pass
        :param Tensor input: Hidden state h
        :param Tensor context: Attention context
        :param ByteTensor mask: Selection mask
        :return: tuple of - (Attentioned hidden state, Alphas)
        """

        # (batch, hidden_dim, seq_len)
        inp = self.input_linear(input).unsqueeze(2).expand(-1, -1, context.size(1))

        # (batch, hidden_dim, seq_len)
        context = context.permute(0, 2, 1)
        ctx = self.context_linear(context)

        # (batch, 1, hidden_dim)
        V = self.V.unsqueeze(0).expand(context.size(0), -1).unsqueeze(1)

        # (batch, seq_len)
        att = torch.bmm(V, self.tanh(inp + ctx)).squeeze(1)
        if len(att[mask]) > 0:
            att[mask] = self.inf[mask]
        alpha = self.softmax(att)

        hidden_state = torch.bmm(ctx, alpha.unsqueeze(2)).squeeze(2)

        return hidden_state, alpha

    def init_inf(self, mask_size):
        self.inf = self._inf.unsqueeze(1).expand(*mask_size)


class Decoder(nn.Module):
    """
    Decoder model for Pointer-Net
    """

    def __init__(self, embedding_dim,
                 hidden_dim):
        """
        Initiate Decoder
        :param int embedding_dim: Number of embeddings in Pointer-Net
        :param int hidden_dim: Number of hidden units for the decoder's RNN
        """

        super(Decoder, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim

        self.input_to_hidden = nn.Linear(embedding_dim, 4 * hidden_dim)
        self.hidden_to_hidden = nn.Linear(hidden_dim, 4 * hidden_dim)
        self.hidden_out = nn.Linear(hidden_dim * 2, hidden_dim)
        self.att = Attention(hidden_dim, hidden_dim)

        # Used for propagating .cuda() command
        self.mask = Parameter(torch.ones(1), requires_grad=False)
        self.runner = Parameter(torch.zeros(1), requires_grad=False)

    def forward(self, embedded_inputs,
                decoder_input,
                hidden,
                context):
        """
        Decoder - Forward-pass
        :param Tensor embedded_inputs: Embedded inputs of Pointer-Net
        :param Tensor decoder_input: First decoder's input
        :param Tensor hidden: First decoder's hidden states
        :param Tensor context: Encoder's outputs
        :return: (Output probabilities, Pointers indices), last hidden state
        """

        batch_size = embedded_inputs.size(0)
        input_length = embedded_inputs.size(1)

        # (batch, seq_len)
        mask = self.mask.repeat(input_length).unsqueeze(0).repeat(batch_size, 1)
        self.att.init_inf(mask.size())

        # Generating arang(input_length), broadcasted across batch_size
        runner = self.runner.repeat(input_length)
        for i in range(input_length):
            runner.data[i] = i
        runner = runner.unsqueeze(0).expand(batch_size, -1).long()

        outputs = []
        pointers = []

        def step(x, hidden):
            """
            Recurrence step function
            :param Tensor x: Input at time t
            :param tuple(Tensor, Tensor) hidden: Hidden states at time t-1
            :return: Hidden states at time t (h, c), Attention probabilities (Alpha)
            """

            # Regular LSTM
            h, c = hidden

            gates = self.input_to_hidden(x) + self.hidden_to_hidden(h)
            input, forget, cell, out = gates.chunk(4, 1)

            input = F.sigmoid(input)
            forget = F.sigmoid(forget)
            cell = F.tanh(cell)
            out = F.sigmoid(out)

            c_t = (forget * c) + (input * cell)
            h_t = out * F.tanh(c_t)

            # Attention section
            hidden_t, output = self.att(h_t, context, torch.eq(mask, 0))
            hidden_t = F.tanh(self.hidden_out(torch.cat((hidden_t, h_t), 1)))

            return hidden_t, c_t, output

        # Recurrence loop
        for _ in range(input_length):
            h_t, c_t, outs = step(decoder_input, hidden)
            hidden = (h_t, c_t)

            # Masking selected inputs
            masked_outs = outs * mask

            # Get maximum probabilities and indices
            max_probs, indices = masked_outs.max(1)
            one_hot_pointers = (runner == indices.unsqueeze(1).expand(-1, outs.size()[1])).float()

            # Update mask to ignore seen indices
            mask  = mask * (1 - one_hot_pointers)

            # Get embedded inputs by max indices
            embedding_mask = one_hot_pointers.unsqueeze(2).expand(-1, -1, self.embedding_dim).byte()
            decoder_input = embedded_inputs[embedding_mask.data].view(batch_size, self.embedding_dim)

            outputs.append(outs.unsqueeze(0))
            pointers.append(indices.unsqueeze(1))

        outputs = torch.cat(outputs).permute(1, 0, 2)
        pointers = torch.cat(pointers, 1)

        return (outputs, pointers), hidden


class PointerNet(nn.Module):
    """
    Pointer-Net
    """

    def __init__(self, embedding_dim,
                 hidden_dim,
                 lstm_layers,
                 dropout,
                 bidir=False):
        """
        Initiate Pointer-Net
        :param int embedding_dim: Number of embbeding channels
        :param int hidden_dim: Encoders hidden units
        :param int lstm_layers: Number of layers for LSTMs
        :param float dropout: Float between 0-1
        :param bool bidir: Bidirectional
        """

        super(PointerNet, self).__init__()
        self.embedding_dim = embedding_dim
        self.bidir = bidir
        self.embedding = nn.Linear(embedding_dim, embedding_dim) #nn.Linear(2, embedding_dim)
        self.encoder = Encoder(embedding_dim,
                               hidden_dim,
                               lstm_layers,
                               dropout,
                               bidir)
        self.decoder = Decoder(embedding_dim, hidden_dim)
        self.decoder_input0 = Parameter(torch.FloatTensor(embedding_dim), requires_grad=False)

        # Initialize decoder_input0
        nn.init.uniform(self.decoder_input0, -1, 1)

    def forward(self, inputs):
        """
        PointerNet - Forward-pass
        :param Tensor inputs: Input sequence
        :return: Pointers probabilities and indices
        """

        batch_size = inputs.size(0)
        input_length = inputs.size(1)

        decoder_input0 = self.decoder_input0.unsqueeze(0).expand(batch_size, -1)

        inputs = inputs.view(batch_size * input_length, -1)
        embedded_inputs = self.embedding(inputs).view(batch_size, input_length, -1)

        encoder_hidden0 = self.encoder.init_hidden(embedded_inputs)
        encoder_outputs, encoder_hidden = self.encoder(embedded_inputs,
                                                       encoder_hidden0)
        if self.bidir:
            decoder_hidden0 = (torch.cat(encoder_hidden[0][-2:], dim=-1),
                               torch.cat(encoder_hidden[1][-2:], dim=-1))
        else:
            decoder_hidden0 = (encoder_hidden[0][-1],
                               encoder_hidden[1][-1])
        (outputs, pointers), decoder_hidden = self.decoder(embedded_inputs,
                                                           decoder_input0,
                                                           decoder_hidden0,
                                                           encoder_outputs)

        return  outputs, pointers

In [6]:
model = PointerNet(params['embedding_size'],
                   params['hiddens'],
                   params['nof_lstms'],
                   params['dropout'],
                   params['bidir'])

if USE_CUDA:
    model.cuda()
    net = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    cudnn.benchmark = True

CCE = torch.nn.CrossEntropyLoss()
model_optim = optim.Adam(filter(lambda p: p.requires_grad,
                                model.parameters()),
                                 lr=params['lr'])


In [7]:
from tqdm import tqdm
losses = []

for i_epoch, epoch in enumerate(range(params['nof_epoch'])):
    batch_loss = []
    iterator = tqdm(dataloader, unit='Batch')
    
    for i_batch, sample_batched in enumerate(iterator):
        iterator.set_description('Epoch %i/%i' % (epoch+1, params['nof_epoch']))

        x, y, chars = sample_batched
        train_batch = Variable(x)
        target_batch = Variable(y)

        if USE_CUDA:
            train_batch = train_batch.cuda()
            target_batch = target_batch.cuda()

        o, p = model(train_batch)
        o = o.contiguous().view(-1, o.size()[-1])
        target_batch = target_batch.view(-1)
        
        loss = CCE(o, target_batch) #/ target_batch.shape[1] # need to take the length of the table into account
        #acc = get_accuracy(p, target_batch)
        
        losses.append(loss.data)
        batch_loss.append(loss.data)

        model_optim.zero_grad()
        loss.backward()
        model_optim.step()
        
        iterator.set_postfix(loss='{}'.format(loss.data))
    batch_loss = torch.Tensor(batch_loss)
    iterator.set_postfix(loss=np.average(batch_loss))

Epoch 1/1000: 100%|██████████████████████████████████████| 196/196 [00:22<00:00,  8.86Batch/s, loss=1.3025537729263306]
Epoch 2/1000: 100%|██████████████████████████████████████| 196/196 [00:21<00:00,  9.15Batch/s, loss=1.1566112041473389]
Epoch 3/1000: 100%|██████████████████████████████████████| 196/196 [00:21<00:00,  9.19Batch/s, loss=1.1035975217819214]
Epoch 4/1000: 100%|██████████████████████████████████████| 196/196 [00:21<00:00,  9.27Batch/s, loss=1.0654791593551636]
Epoch 5/1000: 100%|██████████████████████████████████████| 196/196 [00:21<00:00,  9.29Batch/s, loss=1.0273122787475586]
Epoch 6/1000: 100%|██████████████████████████████████████| 196/196 [00:21<00:00,  9.06Batch/s, loss=0.8592767715454102]
Epoch 7/1000: 100%|██████████████████████████████████████| 196/196 [00:21<00:00,  9.17Batch/s, loss=0.8141610026359558]
Epoch 8/1000: 100%|██████████████████████████████████████| 196/196 [00:21<00:00,  9.29Batch/s, loss=0.7907483577728271]
Epoch 9/1000: 100%|█████████████████████

Epoch 69/1000: 100%|█████████████████████████████████████| 196/196 [00:21<00:00,  9.04Batch/s, loss=0.7468768954277039]
Epoch 70/1000: 100%|█████████████████████████████████████| 196/196 [00:21<00:00,  9.31Batch/s, loss=0.7723820209503174]
Epoch 71/1000: 100%|██████████████████████████████████████| 196/196 [00:21<00:00,  9.19Batch/s, loss=0.762597918510437]
Epoch 72/1000: 100%|█████████████████████████████████████| 196/196 [00:21<00:00,  9.08Batch/s, loss=0.7655274868011475]
Epoch 73/1000: 100%|█████████████████████████████████████| 196/196 [00:21<00:00,  8.98Batch/s, loss=0.7778753042221069]
Epoch 74/1000: 100%|█████████████████████████████████████| 196/196 [00:21<00:00,  8.97Batch/s, loss=0.7775134444236755]
Epoch 75/1000: 100%|███████████████████████████████████████| 196/196 [00:21<00:00,  8.91Batch/s, loss=0.78975909948349]
Epoch 76/1000: 100%|███████████████████████████████████████| 196/196 [00:21<00:00,  9.13Batch/s, loss=0.77491295337677]
Epoch 77/1000: 100%|████████████████████

Epoch 137/1000: 100%|████████████████████████████████████| 196/196 [00:21<00:00,  9.07Batch/s, loss=0.7796066403388977]
Epoch 138/1000: 100%|█████████████████████████████████████| 196/196 [00:21<00:00,  9.12Batch/s, loss=0.762412428855896]
Epoch 139/1000: 100%|████████████████████████████████████| 196/196 [00:21<00:00,  9.09Batch/s, loss=0.7749530076980591]
Epoch 140/1000: 100%|████████████████████████████████████| 196/196 [00:21<00:00,  9.03Batch/s, loss=0.7608559131622314]
Epoch 141/1000: 100%|████████████████████████████████████| 196/196 [00:21<00:00,  9.05Batch/s, loss=0.7877832651138306]
Epoch 142/1000: 100%|████████████████████████████████████| 196/196 [00:21<00:00,  9.15Batch/s, loss=0.7655674815177917]
Epoch 143/1000: 100%|████████████████████████████████████| 196/196 [00:21<00:00,  9.04Batch/s, loss=0.7533819675445557]
Epoch 144/1000: 100%|████████████████████████████████████| 196/196 [00:21<00:00,  9.15Batch/s, loss=0.7733631730079651]
Epoch 145/1000: 100%|███████████████████

Epoch 205/1000: 100%|████████████████████████████████████| 196/196 [00:21<00:00,  9.11Batch/s, loss=0.7592939734458923]
Epoch 206/1000: 100%|████████████████████████████████████| 196/196 [00:21<00:00,  9.04Batch/s, loss=0.7752271890640259]
Epoch 207/1000: 100%|████████████████████████████████████| 196/196 [00:21<00:00,  9.16Batch/s, loss=0.7764835953712463]
Epoch 208/1000: 100%|████████████████████████████████████| 196/196 [00:21<00:00,  9.11Batch/s, loss=0.7782772183418274]
Epoch 209/1000: 100%|████████████████████████████████████| 196/196 [00:21<00:00,  9.15Batch/s, loss=0.7592937350273132]
Epoch 210/1000: 100%|█████████████████████████████████████| 196/196 [00:20<00:00,  9.34Batch/s, loss=0.759919285774231]
Epoch 211/1000: 100%|████████████████████████████████████| 196/196 [00:21<00:00,  9.05Batch/s, loss=0.7686721682548523]
Epoch 212/1000: 100%|████████████████████████████████████| 196/196 [00:21<00:00,  9.09Batch/s, loss=0.7561687231063843]
Epoch 213/1000: 100%|███████████████████

KeyboardInterrupt: 

In [8]:
model.eval()

PointerNet(
  (embedding): Linear(in_features=300, out_features=300, bias=True)
  (encoder): Encoder(
    (lstm): LSTM(300, 512, num_layers=8, dropout=0.3)
  )
  (decoder): Decoder(
    (input_to_hidden): Linear(in_features=300, out_features=2048, bias=True)
    (hidden_to_hidden): Linear(in_features=512, out_features=2048, bias=True)
    (hidden_out): Linear(in_features=1024, out_features=512, bias=True)
    (att): Attention(
      (input_linear): Linear(in_features=512, out_features=512, bias=True)
      (context_linear): Conv1d(512, 512, kernel_size=(1,), stride=(1,))
      (tanh): Tanh()
      (softmax): Softmax(dim=None)
    )
  )
)

In [14]:
x, y, z = dataset[:100]
x = x.cuda()
y = y.cuda()

In [15]:
o, p = model(x)

In [16]:
sorting = list(zip(z, p.data.cpu().tolist()))

In [17]:
sorting

[(['n', 't', 'y', 'q'], [0, 3, 1, 2]),
 (['o', 'q', 's', 'j'], [3, 0, 1, 2]),
 (['q', 'l', 'r', 'x'], [1, 0, 2, 3]),
 (['j', 'e', 'p', 'y'], [1, 0, 2, 3]),
 (['u', 'k', 'o', 'h'], [3, 1, 2, 0]),
 (['y', 'g', 'u', 'x'], [1, 2, 3, 0]),
 (['e', 'n', 'l', 'f'], [0, 3, 2, 1]),
 (['s', 'k', 't', 'v'], [1, 0, 2, 3]),
 (['r', 'r', 'e', 'd'], [3, 2, 0, 1]),
 (['b', 'x', 'a', 'z'], [0, 2, 1, 3]),
 (['g', 't', 'y', 'e'], [3, 0, 1, 2]),
 (['k', 'h', 'p', 'o'], [1, 0, 3, 2]),
 (['v', 'b', 'h', 'w'], [1, 2, 0, 3]),
 (['d', 'f', 'n', 'u'], [0, 1, 2, 3]),
 (['g', 'b', 'j', 'o'], [1, 0, 2, 3]),
 (['e', 'b', 'b', 't'], [1, 2, 0, 3]),
 (['h', 'c', 'i', 'z'], [1, 0, 2, 3]),
 (['y', 's', 'l', 'e'], [3, 2, 1, 0]),
 (['h', 'w', 'x', 'u'], [0, 3, 1, 2]),
 (['h', 's', 'm', 'v'], [0, 2, 1, 3]),
 (['x', 'r', 'k', 'p'], [2, 3, 1, 0]),
 (['d', 'x', 'q', 's'], [0, 2, 3, 1]),
 (['j', 'n', 'r', 'c'], [3, 0, 1, 2]),
 (['l', 'd', 'o', 'l'], [1, 0, 3, 2]),
 (['d', 'o', 'f', 'm'], [0, 2, 3, 1]),
 (['m', 'f', 'c', 'a'], [

In [18]:
for pair in sorting:
    sequence, pointers = pair
    tmp = {}
    for seq, p in zip(sequence, pointers):
        tmp[seq] = p
    tmp =  [k for k, v in sorted(tmp.items(), key=lambda item: item[1])]
    print(tmp, " - ", sorted(sequence))

['n', 'y', 'q', 't']  -  ['n', 'q', 't', 'y']
['q', 's', 'j', 'o']  -  ['j', 'o', 'q', 's']
['l', 'q', 'r', 'x']  -  ['l', 'q', 'r', 'x']
['e', 'j', 'p', 'y']  -  ['e', 'j', 'p', 'y']
['h', 'k', 'o', 'u']  -  ['h', 'k', 'o', 'u']
['x', 'y', 'g', 'u']  -  ['g', 'u', 'x', 'y']
['e', 'f', 'l', 'n']  -  ['e', 'f', 'l', 'n']
['k', 's', 't', 'v']  -  ['k', 's', 't', 'v']
['e', 'd', 'r']  -  ['d', 'e', 'r', 'r']
['b', 'a', 'x', 'z']  -  ['a', 'b', 'x', 'z']
['t', 'y', 'e', 'g']  -  ['e', 'g', 't', 'y']
['h', 'k', 'o', 'p']  -  ['h', 'k', 'o', 'p']
['h', 'v', 'b', 'w']  -  ['b', 'h', 'v', 'w']
['d', 'f', 'n', 'u']  -  ['d', 'f', 'n', 'u']
['b', 'g', 'j', 'o']  -  ['b', 'g', 'j', 'o']
['b', 'e', 't']  -  ['b', 'b', 'e', 't']
['c', 'h', 'i', 'z']  -  ['c', 'h', 'i', 'z']
['e', 'l', 's', 'y']  -  ['e', 'l', 's', 'y']
['h', 'x', 'u', 'w']  -  ['h', 'u', 'w', 'x']
['h', 'm', 's', 'v']  -  ['h', 'm', 's', 'v']
['p', 'k', 'x', 'r']  -  ['k', 'p', 'r', 'x']
['d', 's', 'x', 'q']  -  ['d', 'q', 's', 'x'