In [1]:
import torch
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from torch.utils.data import DataLoader
import numpy as np

import warnings
warnings.filterwarnings("ignore")

from utils.datasets import AlphabetSortingDataset, NumberSortingDataset
from models.pointer_net import PointerNet
from torch.utils.data import DataLoader

In [2]:
params = {
    # Data
    'magnitude': 5,
    'batch_size': 128,
    'shuffle': True,
    'nof_workers': 0, # must stay at 0
    #Train
    'nof_epoch': 3,
    'lr': 0.001,
    # GPU
    'gpu': True,
    # Network
    'input_size': 300,
    'embedding_size': 300,
    'hiddens': 256,
    'nof_lstms': 2,
    'dropout': 0,
    'bidir': True
}

In [3]:
LEN = 10
#dataset = NumberSortingDataset(10**params['magnitude'], min_len=LEN, max_len=LEN)
#dataset = AlphabetSortingDataset(10**params['magnitude'], min_len=LEN, max_len=LEN, alphabet='0123456789')
dataset = AlphabetSortingDataset(10**params['magnitude'], min_len=LEN, max_len=LEN)
dataloader = DataLoader(dataset,
                        batch_size=params['batch_size'],
                        shuffle=params['shuffle'])

In [4]:
model = PointerNet(params['input_size'],
                   params['embedding_size'],
                   params['hiddens'],
                   params['nof_lstms'],
                   params['dropout'],
                   params['bidir'])

if params['gpu'] and torch.cuda.is_available():
    model.cuda()
    net = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    cudnn.benchmark = True

CCE = torch.nn.CrossEntropyLoss()
model_optim = optim.Adam(filter(lambda p: p.requires_grad,
                                model.parameters()),
                                 lr=params['lr'])

In [5]:
from tqdm import tqdm
losses = []

model.train()
for i_epoch, epoch in enumerate(range(params['nof_epoch'])):
    batch_loss = []
    iterator = tqdm(dataloader, unit='Batch')
    
    for i_batch, sample_batched in enumerate(iterator):
        iterator.set_description('Epoch %i/%i' % (epoch+1, params['nof_epoch']))

        x, y, _ = sample_batched
        train_batch = Variable(x).float()
        target_batch = Variable(y)

        if torch.cuda.is_available():
            train_batch = train_batch.cuda()
            target_batch = target_batch.cuda()

        o, p = model(train_batch)
        o = o.contiguous().view(-1, o.size()[-1])
        target_batch = target_batch.view(-1)
        
        loss = CCE(o, target_batch) #/ target_batch.shape[1] # need to take the length of the table into account
        #acc = get_accuracy(p, target_batch)
        
        losses.append(loss.data)
        batch_loss.append(loss.data)

        model_optim.zero_grad()
        loss.backward()
        model_optim.step()
        
        iterator.set_postfix(loss='{}'.format(loss.data))
        
    # each epoch, reduce the learning rate
    for param in model_optim.param_groups:
            param['lr'] *= 0.95
            
    batch_loss = torch.Tensor(batch_loss)
    iterator.set_postfix(loss=np.average(batch_loss))

Epoch 1/3: 100%|██████████████████████████████████████████| 782/782 [01:49<00:00,  7.14Batch/s, loss=1.461153268814087]
Epoch 2/3: 100%|██████████████████████████████████████████| 782/782 [01:49<00:00,  7.16Batch/s, loss=1.461151123046875]
Epoch 3/3:  26%|███████████                               | 207/782 [00:28<01:21,  7.10Batch/s, loss=1.461151361465454]

KeyboardInterrupt: 

In [6]:
model.eval()

num_samples = 100
x, y, z = dataset[:num_samples]
x = x.cuda().float()
y = y.cuda()

o, p = model(x)

y_pred = []
for pointers, values in zip(p, z):
    goal = []
    for point in pointers.cpu().numpy():
        goal.append(values[point])
    y_pred.append(goal)
    
y_true = []
for pointers, values in zip(y, z):
    goal = []
    for point in pointers.cpu().numpy():
        goal.append(values[point])
    y_true.append(goal)

In [7]:
correct = 0
for seq_pred, seq_true in zip(y_pred, y_true):
    correct += all([pred == true for (pred, true) in zip(seq_pred, seq_true)])
    print(seq_pred)
    print(seq_true)
    print("-" * 60)
acc = correct / num_samples
acc

['a', 'b', 'c', 'g', 'm', 'q', 'r', 't', 'x', 'y']
['a', 'b', 'c', 'g', 'm', 'q', 'r', 't', 'x', 'y']
------------------------------------------------------------
['e', 'f', 'h', 'i', 'j', 'p', 'r', 's', 'x', 'y']
['e', 'f', 'h', 'i', 'j', 'p', 'r', 's', 'x', 'y']
------------------------------------------------------------
['c', 'e', 'f', 'i', 'k', 'o', 's', 'x', 'y', 'z']
['c', 'e', 'f', 'i', 'k', 'o', 's', 'x', 'y', 'z']
------------------------------------------------------------
['g', 'j', 'm', 'n', 'q', 'r', 's', 'u', 'w', 'z']
['g', 'j', 'm', 'n', 'q', 'r', 's', 'u', 'w', 'z']
------------------------------------------------------------
['d', 'f', 'g', 'l', 'o', 'q', 'r', 's', 't', 'u']
['d', 'f', 'g', 'l', 'o', 'q', 'r', 's', 't', 'u']
------------------------------------------------------------
['b', 'f', 'h', 'm', 'o', 'q', 's', 'v', 'w', 'x']
['b', 'f', 'h', 'm', 'o', 'q', 's', 'v', 'w', 'x']
------------------------------------------------------------
['c', 'f', 'g', 'h', '

1.0