# Order matters 
Modifying code from http://nlp.seas.harvard.edu/2018/04/03/attention.html to implement the architechture from https://arxiv.org/pdf/1511.06391.pdf

In [20]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy, time
from torch.autograd import Variable
import matplotlib.pyplot as plt
import seaborn
seaborn.set_context(context="talk")
%matplotlib inline

import sys
sys.path.append('../scripts')
from order_matters import Read, Process, Write, ReadProcessWrite

In [21]:
hidden_dim = 16
#write_hidden_dim = 128
lstm_steps = 5
batch_size = 32
input_dim = 1

read = Read(hidden_dim, input_dim)
process = Process(hidden_dim, hidden_dim, lstm_steps, batch_size)
write = Write(hidden_dim, hidden_dim)

In [22]:
set_size = 5
x = torch.rand(batch_size, input_dim, set_size)
x

tensor([[[0.5166, 0.7594, 0.7064, 0.5585, 0.4379]],

        [[0.1764, 0.1904, 0.9376, 0.0023, 0.4835]],

        [[0.2771, 0.4536, 0.7482, 0.6885, 0.4249]],

        [[0.0790, 0.5032, 0.9270, 0.8757, 0.6437]],

        [[0.1127, 0.0720, 0.0236, 0.9603, 0.0903]],

        [[0.8892, 0.0098, 0.2194, 0.7558, 0.7159]],

        [[0.5756, 0.9800, 0.2571, 0.3288, 0.9987]],

        [[0.4710, 0.1513, 0.4093, 0.1925, 0.0585]],

        [[0.4714, 0.8635, 0.6059, 0.2714, 0.3345]],

        [[0.2199, 0.3989, 0.9071, 0.0848, 0.5998]],

        [[0.9707, 0.5569, 0.7496, 0.5609, 0.4412]],

        [[0.1231, 0.1943, 0.7264, 0.6123, 0.7722]],

        [[0.9680, 0.9313, 0.3200, 0.5120, 0.5995]],

        [[0.1742, 0.9707, 0.4961, 0.9195, 0.7598]],

        [[0.1981, 0.3434, 0.7442, 0.5061, 0.6249]],

        [[0.8896, 0.9216, 0.1101, 0.1940, 0.6588]],

        [[0.3648, 0.3553, 0.9014, 0.4082, 0.7805]],

        [[0.5187, 0.2990, 0.8797, 0.9929, 0.0542]],

        [[0.1717, 0.5700, 0.5858, 0.1278, 0.58

In [23]:
M = read(x)
M.size()

torch.Size([32, 16, 5])

In [24]:
r_t, c_t = process(M)
r_t, c_t

(tensor([[11.7570,  0.0000,  3.5078,  0.0000,  0.0000,  0.0000,  0.0000,  0.7906,
           9.8391,  5.4301,  9.1781,  6.6894,  0.0000,  0.0000,  0.0000,  4.6002],
         [10.9477,  0.0433,  2.5887,  0.0000,  0.0000,  0.1067,  0.0000,  2.9201,
           8.9445,  4.5228,  8.1555,  4.9785,  0.0000,  0.0000,  0.0000,  3.5662],
         [11.4939,  0.0000,  3.2090,  0.0000,  0.0000,  0.0000,  0.0000,  1.3631,
           9.5482,  5.1351,  8.8456,  6.1331,  0.0000,  0.0000,  0.0000,  4.2640],
         [11.7909,  0.0000,  3.5464,  0.0000,  0.0000,  0.0387,  0.0000,  1.2505,
           9.8766,  5.4682,  9.2209,  6.7612,  0.0000,  0.0000,  0.0000,  4.6436],
         [10.5861,  0.0292,  2.1780,  0.0000,  0.0000,  0.1704,  0.0000,  3.8075,
           8.5447,  4.1173,  7.6986,  4.2140,  0.0000,  0.0000,  0.0000,  3.1042],
         [11.4923,  0.0384,  3.2072,  0.0000,  0.0000,  0.1001,  0.0000,  1.7759,
           9.5465,  5.1333,  8.8436,  6.1298,  0.0000,  0.0000,  0.0000,  4.2620],
         [

In [25]:
r_t.size(), c_t.size()

(torch.Size([32, 16]), torch.Size([32, 16]))

In [26]:
decoder_input0 = nn.Parameter(torch.zeros(hidden_dim)).unsqueeze(0).expand(batch_size, -1)
#print('decoder_input0: ', decoder_input0)
decoder_hidden0 = (r_t, c_t)
outputs, pointers, hidden = write(M,
                                       decoder_input0,
                                       decoder_hidden0,
                                         M)
outputs, pointers

(tensor([[[0.2097, 0.1792, 0.1851, 0.2037, 0.2223],
          [0.2715, 0.2282, 0.2371, 0.2632, 0.0000],
          [0.0000, 0.3132, 0.3254, 0.3614, 0.0000],
          [0.0000, 0.4904, 0.5096, 0.0000, 0.0000],
          [0.0000, 1.0000, 0.0000, 0.0000, 0.0000]],
 
         [[0.2202, 0.2175, 0.1303, 0.2598, 0.1723],
          [0.3015, 0.2975, 0.1692, 0.0000, 0.2318],
          [0.0000, 0.4272, 0.2410, 0.0000, 0.3318],
          [0.0000, 0.0000, 0.4205, 0.0000, 0.5795],
          [0.0000, 0.0000, 1.0000, 0.0000, 0.0000]],
 
         [[0.2373, 0.2062, 0.1697, 0.1760, 0.2107],
          [0.0000, 0.2722, 0.2199, 0.2295, 0.2785],
          [0.0000, 0.3775, 0.3046, 0.3180, 0.0000],
          [0.0000, 0.0000, 0.4892, 0.5108, 0.0000],
          [0.0000, 0.0000, 1.0000, 0.0000, 0.0000]],
 
         [[0.2920, 0.2040, 0.1569, 0.1616, 0.1856],
          [0.0000, 0.2924, 0.2178, 0.2254, 0.2645],
          [0.0000, 0.0000, 0.3071, 0.3181, 0.3749],
          [0.0000, 0.0000, 0.4911, 0.5089, 0.0000],
   

In [27]:
rpw = ReadProcessWrite(hidden_dim, lstm_steps, batch_size, input_dim)

In [36]:
l = list(rpw.named_parameters())
for for name, param in rpw.named_parameters():
    if param.requires_grad:

24

In [45]:
import random
weights_indices = {}
l = list(rpw.named_parameters())
for name, param in l:
    if param.requires_grad:
        size = list(param.data.flatten().size())[0]
        weights_indices[name] = random.sample(range(size), 5)
weights_indices

{'decoder_input0': [9, 8, 3, 13, 11],
 'read.W': [8, 0, 7, 14, 5],
 'read.b': [3, 8, 6, 15, 1],
 'process.lstmcell.weight_ih': [479, 530, 490, 669, 362],
 'process.lstmcell.weight_hh': [372, 519, 978, 925, 398],
 'process.lstmcell.bias_ih': [12, 5, 13, 58, 53],
 'process.lstmcell.bias_hh': [53, 19, 14, 43, 41],
 'write.input_to_hidden.weight': [221, 80, 477, 526, 640],
 'write.input_to_hidden.bias': [1, 19, 44, 22, 12],
 'write.hidden_to_hidden.weight': [348, 762, 930, 496, 949],
 'write.hidden_to_hidden.bias': [33, 8, 49, 29, 32],
 'write.hidden_out.weight': [360, 205, 165, 151, 168],
 'write.hidden_out.bias': [6, 1, 3, 9, 4],
 'write.att.V': [12, 5, 2, 1, 13],
 'write.att.input_linear.weight': [162, 145, 24, 227, 61],
 'write.att.input_linear.bias': [2, 4, 9, 1, 14],
 'write.att.context_linear.weight': [247, 0, 199, 225, 160],
 'write.att.context_linear.bias': [8, 11, 14, 9, 10]}

In [None]:
def write_weights(weights_indices, parameters, writer)
    weights_data = {}
    for name, param in parameters:
        if param.requires_grad:
            indices = weights_indices[name]
            for idx in indices:
                weights_data[f'{name}.{idx}'] = params.data/flatten()[idx]
    writer.add_scalars('data/weights', weights_data)

In [63]:
outputs

tensor([[[0.2803, 0.1907, 0.1305, 0.2735, 0.1250],
         [0.0000, 0.2649, 0.1779, 0.3872, 0.1700],
         [0.0000, 0.4355, 0.2888, 0.0000, 0.2756],
         [0.0000, 0.0000, 0.5119, 0.0000, 0.4881],
         [0.0000, 0.0000, 0.0000, 0.0000, 1.0000]],

        [[0.2351, 0.1320, 0.1885, 0.0686, 0.3758],
         [0.3834, 0.2081, 0.3033, 0.1052, 0.0000],
         [0.0000, 0.3370, 0.4961, 0.1669, 0.0000],
         [0.0000, 0.6712, 0.0000, 0.3288, 0.0000],
         [0.0000, 0.0000, 0.0000, 1.0000, 0.0000]],

        [[0.2436, 0.2941, 0.0919, 0.1727, 0.1978],
         [0.3481, 0.0000, 0.1279, 0.2436, 0.2803],
         [0.0000, 0.0000, 0.1940, 0.3742, 0.4318],
         [0.0000, 0.0000, 0.3397, 0.6603, 0.0000],
         [0.0000, 0.0000, 1.0000, 0.0000, 0.0000]],

        [[0.3631, 0.1636, 0.1940, 0.1351, 0.1442],
         [0.0000, 0.2569, 0.3065, 0.2110, 0.2256],
         [0.0000, 0.3710, 0.0000, 0.3038, 0.3252],
         [0.0000, 0.0000, 0.0000, 0.4828, 0.5172],
         [0.0000, 0.0000,

In [60]:
X_train = np.random.uniform(size=(10000, 5))
Y_train = np.sort(X, axis=1)
X_val = np.random.uniform(size=(10000, 5))
Y_val = np.sort(X, axis=1)
Y_train[:5,:]

array([[0.0406089 , 0.11446918, 0.20511638, 0.46667707, 0.95021659],
       [0.04846931, 0.05260525, 0.26244789, 0.42091101, 0.77468845],
       [0.11588492, 0.23446641, 0.30200325, 0.80968508, 0.82180195],
       [0.03560383, 0.04234643, 0.38428259, 0.5542201 , 0.90921458],
       [0.03457714, 0.39351441, 0.41889419, 0.65903414, 0.95649737]])

In [62]:
dict_data = {'attributes': None, 'split':{'train': [], 'val': []}}
for i in range(X_train.shape[0]):
    dict_data['split']['train'].append((X_train[i, :], Y_train[i,:]))
    dict_data['split']['val'].append((X_val[i, :], Y_val[i,:]))
dict_data

{'attributes': None,
 'split': {'train': [(array([0.58206184, 0.47354334, 0.0901827 , 0.77592039, 0.1037447 ]),
    array([0.0406089 , 0.11446918, 0.20511638, 0.46667707, 0.95021659])),
   (array([0.33237462, 0.19105765, 0.5835769 , 0.77370521, 0.21692672]),
    array([0.04846931, 0.05260525, 0.26244789, 0.42091101, 0.77468845])),
   (array([0.17942557, 0.79543805, 0.14740483, 0.52099521, 0.84718262]),
    array([0.11588492, 0.23446641, 0.30200325, 0.80968508, 0.82180195])),
   (array([0.5292463 , 0.59817234, 0.51269798, 0.5227361 , 0.71198074]),
    array([0.03560383, 0.04234643, 0.38428259, 0.5542201 , 0.90921458])),
   (array([0.79237819, 0.73196427, 0.7093469 , 0.76444092, 0.8697165 ]),
    array([0.03457714, 0.39351441, 0.41889419, 0.65903414, 0.95649737])),
   (array([0.17933344, 0.49685717, 0.22859221, 0.33520646, 0.12634754]),
    array([0.1400745 , 0.2375888 , 0.63361123, 0.64325355, 0.7149642 ])),
   (array([0.96864187, 0.93035639, 0.45801691, 0.05330802, 0.08790261]),
    ar