In [158]:
from torchnlp.datasets import penn_treebank_dataset
import torch
from torchnlp.samplers import BPTTBatchSampler
from torch.utils.data import DataLoader
from rsm_samplers import LangSequenceSampler, language_pred_sequence_collate
from ptb_lstm import LSTMModel
from lang_util import Corpus
import torch.nn.functional as F
from importlib import reload 
import rsm

In [75]:
corpus = Corpus('/Users/jgordon/nta/datasets/PTB')

[slice(0, 3, None),
 slice(185918, 185921, None),
 slice(371836, 371839, None),
 slice(557754, 557757, None),
 slice(743672, 743675, None)]

In [280]:
reload(rsm)

BS = 2
SEQL = 3
VS = 10000
model = rsm.RSMLayer(d_in=30, d_out=VS, m=4, n=3, k=2, k_winner_cells=2, 
                        vocab_size=VS, embed_dim=30, debug=False)

sampler = LangSequenceSampler(corpus.train, batch_size=BS,
                                                seq_length=SEQL,
                                                parallel_seq=True)
loader = DataLoader(corpus.train,
                       batch_sampler=sampler,
                       collate_fn=language_pred_sequence_collate)
criterion = CrossEntropyLoss()

hidden = model.init_hidden(BS)

for i, (data, targets, _) in enumerate(loader):
    model.zero_grad()
    output, hidden, _ = model(data, hidden)
    print('output', output.size())
    print('targets', targets.size())
    loss = criterion(output.view(-1, VS), targets)
    loss.backward()
    print(loss.item())
    hidden = _repackage_hidden(hidden)
    if i > 3:
        break


output torch.Size([3, 2, 10000])
targets torch.Size([6])
9.043150901794434
output torch.Size([3, 2, 10000])
targets torch.Size([6])
9.00977897644043
output torch.Size([3, 2, 10000])
targets torch.Size([6])
9.214032173156738
output torch.Size([3, 2, 10000])
targets torch.Size([6])
9.105277061462402
output torch.Size([3, 2, 10000])
targets torch.Size([6])
9.204472541809082


In [268]:
a = torch.randn(2, 12).reshape(2, 4, 3)
print(a)
values, indices = torch.topk(a, 2)
print(indices)
arr = a.new_zeros(a.size())  # Zeros, conserve device
arr.scatter_(2, indices, 1)

tensor([[[-0.6471, -1.0628,  1.0940],
         [-0.3992,  2.3036, -1.0940],
         [ 0.8734, -2.1454,  2.0653],
         [ 0.7241, -1.7697, -0.2021]],

        [[ 1.1691, -0.2936, -0.5580],
         [-1.2807, -1.0672,  0.2157],
         [ 0.5309,  0.1857, -0.2516],
         [ 1.1832, -1.2698,  0.2208]]])
tensor([[[2, 0],
         [1, 0],
         [2, 0],
         [0, 2]],

        [[0, 1],
         [2, 1],
         [0, 1],
         [0, 2]]])


tensor([[[1., 0., 1.],
         [1., 1., 0.],
         [1., 0., 1.],
         [1., 0., 1.]],

        [[1., 1., 0.],
         [0., 1., 1.],
         [1., 1., 0.],
         [1., 0., 1.]]])

In [4]:
len(corpus.dictionary)

10000

In [4]:
[len(corpus.train), len(corpus.valid), len(corpus.test)]

[929589, 73760, 82430]

In [6]:
batch_size = 300
batches = len(corpus.train) / batch_size
0.25 * batches

774.6575

In [12]:
from torch.nn.functional import softmax

def topk_mask(a, k, dim=0, do_softmax=False):
    """
    Return a 1 for the top b elements in the last dim of a, 0 otherwise
    """
    if do_softmax:
        return softmax(a)
    else:
        values, indices = torch.topk(a, k)
    arr = a.new_zeros(a.size())  # Zeros, conserve device
    arr.scatter_(dim, indices, 1)
    return arr

a = torch.randn((3, 4))
print(a)
topk_mask(a, 1, dim=1, do_softmax=True)

tensor([[ 1.3621,  0.2758, -0.5286,  0.0746],
        [ 0.8212,  0.9581, -0.0496, -0.1561],
        [ 1.4873, -0.6304,  1.5969,  1.2753]])


  


tensor([[0.5668, 0.1913, 0.0856, 0.1564],
        [0.3400, 0.3898, 0.1423, 0.1279],
        [0.3284, 0.0395, 0.3664, 0.2657]])

In [65]:
from torch import nn
from copy import deepcopy
import matplotlib.pyplot as plt

class LocalLinear(nn.Module):
    """
    """
    def __init__(self, in_features, local_features, kernel_size, stride=1, bias=True):
        super(LocalLinear, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride

        fold_num = (in_features - self.kernel_size) // self.stride + 1
        self.lc = nn.ModuleList([deepcopy(nn.Linear(kernel_size, local_features, bias=bias))
                                 for _ in range(fold_num)])

    def forward(self, x):
        x = x.unfold(-1, size=self.kernel_size, step=self.stride)
        fold_num = x.shape[1]
        x = torch.cat([self.lc[i](x[:, i, :]) for i in range(fold_num)], 1)
        return x


class ActiveDendriteLayer(torch.nn.Module):
    """
    Local layer for active dendrites. Similar to a non-shared weight version of a 
    2D Conv layer.
    
    Note that dendrites are fully connected to input, local layer used only for connecting
    neurons and their dendrites
    """
    def __init__(self, input_dim, n_cells=50, n_dendrites=3):
        super(ActiveDendriteLayer, self).__init__()
        self.n_cells = n_cells
        self.n_dendrites = n_dendrites
        
        total_dendrites = n_dendrites * n_cells
        self.linear_dend = nn.Linear(input_dim, total_dendrites)
        self.linear_neuron = LocalLinear(total_dendrites, 1, n_dendrites, stride=n_dendrites)
        
    def __repr__(self):
        return "ActiveDendriteLayer neur=%d, dend per neuron=%d" % (self.n_cells, self.n_dendrites)
    
    def forward(self, x):
        x = F.relu(self.linear_dend(x))
        x = self.linear_neuron(x)
        return x
 
x = torch.randn(1, 5)
print(x)
adl = ActiveDendriteLayer(5, 4, 2)
print(adl(x))


tensor([[-0.2939,  0.5844,  0.6376, -0.2840, -0.3008]])
tensor([[ 0.3429, -0.1966, -0.4056, -0.0240]], grad_fn=<CatBackward>)


In [68]:
x = torch.randn(5, 3)
x[:, -2:] = 1
print(x)

tensor([[-0.1196,  1.0000,  1.0000],
        [-0.0664,  1.0000,  1.0000],
        [-0.3680,  1.0000,  1.0000],
        [ 1.5387,  1.0000,  1.0000],
        [ 1.1104,  1.0000,  1.0000]])


In [72]:
x.size()

torch.Size([5, 3])