In [50]:
def attention(query, key, value, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    key_dims = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) \
             / math.sqrt(key_dims)
    product_attention = F.softmax(scores, dim = -1)
    if dropout is not None:
        product_attention = dropout(product_attention)
    return torch.matmul(product_attention, value), product_attention

class MultiHeadedAttention(nn.Module):
    def __init__(self, number_of_heads, hidden_dimensions, dropout=0.1):
        super(MultiHeadedAttention, self).__init__()
        num_heads = number_of_heads
        h_dims = hidden_dimensions
        
        assert h_dims % num_heads == 0
        
        self.num_heads = num_heads
        self.h_dims = h_dims
        
        self.key_dims = h_dims // num_heads
        
        self.linears = nn.ModuleList([
            nn.Linear(h_dims, h_dims) for _ in range(num_heads)
        ])
        
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, query, key=None, value=None):
        if key is None:
            key = value = query
        batch_size = query.shape[0]
        
        for lin in self.linears:
            query, key, value, = lin(query), lin(key), lin(value)
            
        x, self.attn = attention(query, key, value, dropout=self.dropout)
        
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h_dims)
        last_lin = self.linears[-1](x)

        return last_lin

In [94]:
lstm = nn.LSTM(20, 10, 2)

out, (h, c) = lstm(torch.rand(3, 40, 20))

out.shape, h.shape, c.shape

(torch.Size([3, 40, 10]), torch.Size([2, 40, 10]), torch.Size([2, 40, 10]))

In [101]:
np.ndarray(shape=(50, 50)).shape

(50, 50)

In [95]:
b = torch.rand(10)
a = [3, 2, 2]

b[a]

tensor([ 0.9836,  0.9051,  0.9051])

In [107]:
a = [3, 2, 1]
a[]

[3, 2, 1]

In [81]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

from torch import nn
import torch
import torch.nn.functional as F
from lib.JANET import JANET

from torch import optim

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.autograd import Variable
from torch import optim
import numpy as np
import math, random
from ipdb import set_trace
from torch.nn.utils import weight_norm as wn

in_size=1
hidden_size = 100
num_layers = 4
sequence_len = 784//in_size

class LSTMClassifier(nn.Module):
    def __init__(self, seq_len=784, in_size=1, hidden_size=100, num_layers=4, num_heads=4):
        super(LSTMClassifier, self).__init__()
        seq_len //= in_size
        
        self.seq_len = seq_len
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.in_size = in_size
        
        self.attn = MultiHeadedAttention(num_heads, hidden_size)
        self.lstm = nn.LSTM(in_size, hidden_size, num_layers)
        
        for name, params in self.lstm.state_dict().items():
            if "weight" in name:
                wn(self.lstm, name)
                nn.init.xavier_uniform_(params)
            elif "bias" in name:
                init = nn.Parameter(torch.log(torch.rand(hidden_size)*(sequence_len - 1) + 1))
                params[:hidden_size] = -init.clone()
                params[hidden_size:2*hidden_size] = init

        self.lin = nn.Sequential(*[
            nn.Linear(hidden_size, hidden_size)
            , nn.Linear(hidden_size, hidden_size)
            , nn.Linear(hidden_size, 10)
            , nn.LogSoftmax(dim=-1)
        ])
                
        self.hidden = None
        
    def forward(self, sequence):
        for x in sequence:
            x = x.unsqueeze(-1)
            out, (c, h) = self.lstm(x, self.hidden)
            self.hidden = (self.attn(c), self.attn(h))
            
        self.hidden = None
        
        return self.lin(out)

In [88]:
torch.rand(100)*783 + 1

tensor([ 309.0916,  614.5851,   63.4067,  555.8377,  289.4675,  487.7041,
         633.8980,  316.6725,  742.8395,  259.7570,  774.9430,  771.0657,
         368.0129,  138.9766,  310.4167,   40.6794,  438.3559,  770.3306,
         619.8373,  473.7707,  716.7692,  379.0143,  339.0074,  259.6896,
         460.7392,  680.5891,   61.7223,  430.9417,  657.2477,  369.5740,
         149.3458,  202.6087,  231.9478,  264.8786,  691.0543,  477.7019,
         241.2799,  629.0342,  440.2131,  703.0553,  605.2755,  594.0752,
         668.6731,  124.0641,   56.1716,   65.2071,  321.8222,   29.1163,
         370.8891,  720.8765,  439.7370,  200.1614,  503.0772,   37.9705,
         205.3878,   26.3700,  718.4380,  693.7109,   73.8634,  717.4750,
          60.6411,  540.6418,   62.8727,  447.0753,  341.5144,  317.9478,
         727.2603,   96.2434,   37.6566,   97.3319,   19.6317,  236.9877,
         561.6227,  762.7272,  248.9099,  321.9046,  665.2184,  183.6462,
         529.9541,  414.9385,  608.701

In [83]:
lstm = LSTMClassifier()
optimizer = optim.Adam(lstm.parameters(), amsgrad=True)
log_interval = 1

def train(model, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        data = data.view(-1, data.shape[0], 1)
        optimizer.zero_grad()
        output = model(data).squeeze(0)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            
for i in range(10):
    train(lstm, i)



KeyboardInterrupt: 

In [47]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

from torch import nn
import torch
import torch.nn.functional as F
from lib.JANET import JANET

from torch import optim

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.autograd import Variable
from torch import optim
import numpy as np
import math, random
from ipdb import set_trace
from torch.nn.utils import weight_norm as wn

in_size=1
hidden_size = 100
num_layers = 4
sequence_len = 784//in_size

class LSTMClassifier(nn.Module):
    def __init__(self, seq_len=784, in_size=1, hidden_size=100, num_layers=4):
        super(LSTMClassifier, self).__init__()
        seq_len //= in_size
        
        self.seq_len = seq_len
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.in_size = in_size
        
        self.importance_logits = []
        self.mixing_logits = []
        self.lstm_layers = []
        for i in range(num_layers):
            lstm_in = in_size if i == 0 else hidden_size
            self.lstm_layers.append(nn.LSTM(lstm_in, hidden_size, 1))
            
            self.importance_logits.append(
                nn.Sequential(*[
                    nn.Linear(hidden_size*self.in_size, hidden_size)
                    , nn.Linear(hidden_size, hidden_size)
                    , nn.Linear(hidden_size, 1)
                ])
            
#         nn.LSTM(in_size, hidden_size, num_layers)
        
            for name, params in self.lstm_layers[-1].state_dict().items():
                if "weight" in name:
                    wn(self.lstm_layers[-1], name)
                    nn.init.xavier_uniform_(params)
                elif "bias" in name:
                    init = nn.Parameter(torch.log(torch.rand(hidden_size)*(sequence_len - 1) + 1))
                    params[:hidden_size] = -init
                    params[hidden_size:2*hidden_size] = init
                    
        self.lstm_layers = nn.ModuleList(self.lstm_layers)

        self.lin = nn.Sequential(*[
            nn.Linear(hidden_size*sequence_len, 10)
            , nn.LogSoftmax(dim=-1)
        ])
        
        #so whta do I want to do
        #I want the output of each individual layer to output how important it thinks it is
        
    def forward(self, sequence):
        batch_size = sequence.shape[1]
        hidden_importance_logits = torch.zeros(self.num_layers, self.in_size, self.seq_len)
        output_importance_logits = torch.zeros(self.num_layers, self.in_size, self.seq_len)
        
        #so lets see, hiddens should be num_layers (1), 1, hidden_size?
        #I tihnk?
        hs = torch.zeros(self.num_layers, 1, self.in_size, self.hidden_size)
        cs = torch.zeros(self.num_layers, 1, self.in_size, self.hidden_size)
        
        hidden = None
        for j, x in enumerate(sequence):
            mixing_logits = torch.zeros(self.num_layers)
            x = x.unsqueeze(0)
            for i, (layer, importance_logit) in enumerate(zip(self.lstm_layers, self.importance_logits)):
                if j > 0:
                    set_trace()
                    hidden = (hs[i, 0], cs[i, 0])
                    
                if i > 1:
                    importance_probas = F.softmax(importance_logits[:i], dim=-1)
                    
                    hidden = (hs[:i]*importance_probas, cs[:i]*importance_probas)
                    hidden = (hs[i-1]*(1-mixing_probas) + mixing_probas*hidden[0].sum(), 
                              cs[i-1]*(1-mixing_probas) + mixing_probas*hidden[1].sum())
                    
                    x = mixing_probas*hidden[0]
                    
                    #hmmm so lets see...
                    #we're updating the hidden for the previous layer kind of?
                    #because we're basically taking all of the 
                    #so lets see..
                    #like what we're doing is taking all of the importance_logits and hiddens (outs)
                    #for the previous of the layer, and we're ...
                    #I think we can simplify this alot
                    #we can probably make it lists
                    #or keeping it as a tensor is fine
                    #and we want 
                    
                    
                x, (h, c) = layer(x, hidden)
                
                hs[i] = h
                cs[i] = c
                
                #so x should be equal to h
                #so the shape will be (1, 1, self.hidden_size, self.in_size)
                #so the importance_logit will take 
                #nvm the x is the output from the inputs
                #the h is the weight matrix or whatever
                
                #so x will be (batch_size, self.hidden_size, self.in_size)
                #so the importance_logit should take hidden_size, 
                importance_logits[i, :, j] = importance_logit(x)
                mixing_logits[i] = mixing_logit(x)
                
        x = x.permute(1, 0, 2).contiguous().view(batch_size, -1)
        log_probas = self.lin(x)
        
        return log_probas
    
lstm = LSTMClassifier()
    
lstm_opt = optim.Adam(lstm.parameters(), amsgrad=True)

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

batch_size = 1
test_batch_size = 1

use_cuda = False

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=test_batch_size, shuffle=True, **kwargs)

device = torch.device("cuda" if use_cuda else "cpu")

lstm = lstm.to(device)

lstm.train()

for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    
    data = data.view(sequence_len, batch_size, in_size)
    
    lstm_opt.zero_grad()
    lstm_out = lstm(data)
    lstm_loss = F.nll_loss(lstm_out, target)
    lstm_loss.backward()
    lstm_opt.step()

    print(f"LSTM Loss: {lstm_loss.item()}")

> [0;32m<ipython-input-47-0676e3543e0b>[0m(86)[0;36mforward[0;34m()[0m
[0;32m     85 [0;31m                    [0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m---> 86 [0;31m                    [0mmixing_probas[0m [0;34m=[0m [0mF[0m[0;34m.[0m[0msigmoid[0m[0;34m([0m[0mmixing_logits[0m[0;34m[[0m[0mi[0m[0;34m-[0m[0;36m1[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m     87 [0;31m[0;34m[0m[0m
[0m
ipdb> mixing_logits[i-1]
tensor(1.00000e-02 *
       3.9722)
ipdb> mixing_logits
tensor(1.00000e-02 *
       [ 3.9722,  0.0000,  0.0000,  0.0000])
ipdb> n
> [0;32m<ipython-input-47-0676e3543e0b>[0m(88)[0;36mforward[0;34m()[0m
[0;32m     87 [0;31m[0;34m[0m[0m
[0m[0;32m---> 88 [0;31m                [0;32mif[0m [0mi[0m [0;34m>[0m [0;36m1[0m[0;34m:[0m[0;34m[0m[0m
[0m[0;32m     89 [0;31m                    [0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0m
[0m
ipdb> mixing_probas
tensor(0.5099)
ipdb> n
> [

BdbQuit: 