In [1]:
from torchnlp.datasets import penn_treebank_dataset
import torch
from torchnlp.samplers import BPTTBatchSampler
from torch.utils.data import DataLoader
from rsm_samplers import MNISTSequenceSampler, ptb_pred_sequence_collate
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from importlib import reload 
from torch.utils.data import Sampler, BatchSampler
import rsm
import numpy as np
import torchvision.utils as vutils
from functools import reduce, partial
import matplotlib.pyplot as plt

In [2]:
def _repackage_hidden(h):
    """Wraps hidden states in new Tensors, to detach them from their history."""
    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(_repackage_hidden(v) for v in h)

def activity_square(vector):
    n = len(vector)
    side = int(np.sqrt(n))
    if side ** 2 < n:
        side += 1
    square = torch.zeros(side ** 2)
    square[:n] = vector
    return square.view(side, side)
    
def plot_col_distrs(distrs, n_labels=10):
    col_act_avgs = []
    fig, axs = plt.subplots(n_labels, n_labels, dpi=200, 
                            sharex=True, sharey=True,
                            gridspec_kw={'wspace': 0, 'hspace': 0})
    for i in range(n_labels):
        for j in range(n_labels):
            ax = axs[i][j]
            ax.axis('off')
            key = '%d-%d' % (i, j)
            if key in distrs:
                activity_arr = distrs[key]
                ax.set_title("%d -> %d" % (i, j), fontsize=7)
                dist = torch.stack(activity_arr)
                mean_act = dist.mean(dim=0)
                ax.imshow(activity_square(mean_act))
            else:
                ax.set_axis_off()
    plt.show()
    return {}

In [7]:
import rsm_samplers
import rsm
reload(rsm)
reload(rsm_samplers)

from torch.utils.data import DataLoader, BatchSampler

dataset = datasets.MNIST("~/nta/datasets", download=True,
                                               transform=transforms.Compose([
                                                   transforms.ToTensor(),
                                                   transforms.Normalize((0.1307,), (0.3081,))
                                               ]),)

bs=2
m=5
k=2
n=3
SEQ = [[0,1,2,3],[0,3,2,1]]
sl = 8
d_in = d_out = 28 ** 2
sampler = rsm_samplers.MNISTSequenceSampler(dataset, batch_size=bs, sequences=SEQ, randomize_sequences=True)
batch_sampler = BatchSampler(sampler, batch_size=sl * bs + 1, drop_last=True)

collate_fn = partial(rsm_samplers.pred_sequence_collate, 
                     bsz=bs,
                     seq_length=sl,
                    return_inputs=True)
loader = DataLoader(dataset,
                    batch_sampler=batch_sampler,
                    collate_fn=collate_fn)
model = rsm.RSMLayer(d_in=d_in, d_out=d_out, m=m, n=n, k=k, visual_debug=False, debug=True)

criterion = MSELoss()

LR = .0005
LR = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

In [8]:
hidden = model.init_hidden(bs)

MAX_BATCHES = 1
CLASSES = 10
PLOT_EACH = 10
EPOCHS = 1

condtl_column_dists = {}  # 'digit-digit' -> list of distribution arrays

for epoch in range(EPOCHS):
    total_loss = 0.0
    for i, (data, targets, pred_targets, input_labels) in enumerate(loader):
        optimizer.zero_grad()
        output, hidden, x_bs = model(data, hidden)
        x_b, phi, psi = hidden
        for x_b_batch, label_batch, target_batch in zip(x_bs, input_labels, pred_targets):
            for _x_b, label, target in zip(x_b_batch, label_batch, target_batch):
                digit = label.item()
                next_digit = target.item()
                col_activity = _x_b.detach().view(m, -1).max(dim=1).values
                key = "%d-%d" % (digit, next_digit)
                if key not in condtl_column_dists:
                    condtl_column_dists[key] = []
                condtl_column_dists[key].append(col_activity)

        loss = criterion(output, targets)

        total_loss += loss.item()

        loss.backward()
        optimizer.step()
        
        hidden = _repackage_hidden(hidden)

        if i >= MAX_BATCHES - 1:
            break

    print(epoch, 'loss', total_loss / (i+1))
#     condtl_column_dists = plot_col_distrs(condtl_column_dists, n_labels=4)

['seqi', 0, '-', <class 'int'>]
['x_a_row', tensor([[-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
        [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242]]), torch.Size([2, 784]), torch.float32]
['z_a', tensor([[-0.2635, -0.2635, -0.2635, -0.2424, -0.2424, -0.2424, -0.8727, -0.8727,
         -0.8727,  0.4224,  0.4224,  0.4224, -0.1345, -0.1345, -0.1345],
        [ 0.3499,  0.3499,  0.3499, -0.0136, -0.0136, -0.0136,  0.5048,  0.5048,
          0.5048,  0.5542,  0.5542,  0.5542,  0.0808,  0.0808,  0.0808]],
       grad_fn=<IndexSelectBackward>), torch.Size([2, 15]), torch.float32]
['sigma', tensor([[-0.0600, -0.2421, -0.1107, -0.2644, -0.3350, -0.1488, -0.9187, -0.6598,
         -0.8055,  0.3809,  0.4653,  0.3975, -0.1697, -0.3887, -0.3510],
        [ 0.5534,  0.3713,  0.5027, -0.0357, -0.1062,  0.0800,  0.4587,  0.7177,
          0.5719,  0.5127,  0.5970,  0.5293,  0.0455, -0.1735, -0.1358]],
       grad_fn=<AddBackward0>), torch.Size([2, 15]), torch.float32

In [None]:
a = torch.randn(2, 12).reshape(2, 4, 3)
print(a)
values, indices = torch.topk(a, 2)
print(indices)
arr = a.new_zeros(a.size())  # Zeros, conserve device
arr.scatter_(2, indices, 1)

In [None]:
len(corpus.dictionary)

In [None]:
[len(corpus.train), len(corpus.valid), len(corpus.test)]

In [None]:
batch_size = 300
batches = len(corpus.train) / batch_size
0.25 * batches

In [None]:
from torch.nn.functional import softmax

def topk_mask(a, k, dim=0, do_softmax=False):
    """
    Return a 1 for the top b elements in the last dim of a, 0 otherwise
    """
    if do_softmax:
        return softmax(a)
    else:
        values, indices = torch.topk(a, k)
    arr = a.new_zeros(a.size())  # Zeros, conserve device
    arr.scatter_(dim, indices, 1)
    return arr

a = torch.randn((3, 4))
print(a)
topk_mask(a, 1, dim=1, do_softmax=True)

In [None]:
from torch import nn
from copy import deepcopy
import matplotlib.pyplot as plt

class LocalLinear(nn.Module):
    """
    """
    def __init__(self, in_features, local_features, kernel_size, stride=1, bias=True):
        super(LocalLinear, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride

        fold_num = (in_features - self.kernel_size) // self.stride + 1
        self.lc = nn.ModuleList([deepcopy(nn.Linear(kernel_size, local_features, bias=bias))
                                 for _ in range(fold_num)])

    def forward(self, x):
        x = x.unfold(-1, size=self.kernel_size, step=self.stride)
        fold_num = x.shape[1]
        x = torch.cat([self.lc[i](x[:, i, :]) for i in range(fold_num)], 1)
        return x


class ActiveDendriteLayer(torch.nn.Module):
    """
    Local layer for active dendrites. Similar to a non-shared weight version of a 
    2D Conv layer.
    
    Note that dendrites are fully connected to input, local layer used only for connecting
    neurons and their dendrites
    """
    def __init__(self, input_dim, n_cells=50, n_dendrites=3):
        super(ActiveDendriteLayer, self).__init__()
        self.n_cells = n_cells
        self.n_dendrites = n_dendrites
        
        total_dendrites = n_dendrites * n_cells
        self.linear_dend = nn.Linear(input_dim, total_dendrites)
        self.linear_neuron = LocalLinear(total_dendrites, 1, n_dendrites, stride=n_dendrites)
        
    def __repr__(self):
        return "ActiveDendriteLayer neur=%d, dend per neuron=%d" % (self.n_cells, self.n_dendrites)
    
    def forward(self, x):
        x = F.relu(self.linear_dend(x))
        x = self.linear_neuron(x)
        return x
 
x = torch.randn(1, 5)
print(x)
adl = ActiveDendriteLayer(5, 4, 2)
print(adl(x))


In [None]:
x = torch.randn(5, 3)
x[:, -2:] = 1
print(x)

In [33]:
class BitwiseWordEmbedding(object):

    def __init__(self, vocab_size=10000, dim=28):
        self.vocab_size = vocab_size
        self.embedding_dict = {}
        self.dim = dim

    def generate_embeddings(self):
        for i in range(self.vocab_size):
            self.embedding_dict[i] = self.embed(i)

    def embed(self, i):
        first = "{0:b}".format(i).zfill(self.dim // 2)
        return first + self.inverse(first)

    def inverse(self, binstr):
        return ''.join('1' if x == '0' else '0' for x in binstr)

bwe = BitwiseWordEmbedding()

bwe.embed(9381)

'1001001010010101101101011010'

In [8]:
from importlib import reload 
import viz_util
reload(viz_util)

from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from io import BytesIO
import io
import numpy as np
import matplotlib.pyplot as plt


ax, fig = viz_util.plot_confusion_matrix(np.array([1,2,3]), np.array([1,2,0]), ['0', '1', '2', '3'])

img = viz_util.fig2img(fig)

print(img.shape)
plt.imsave('test.png', img, format='png')


(288, 432, 3)


In [55]:
torch.range(0, 5).expand((3, 6)).t()

  """Entry point for launching an IPython kernel.


tensor([[0., 0., 0.],
        [1., 1., 1.],
        [2., 2., 2.],
        [3., 3., 3.],
        [4., 4., 4.],
        [5., 5., 5.]])

In [27]:
from nupic.torch.functions import KWinners

kw = KWinners.apply

bsz = 2
m = 5
n = 3

k = 2

a = torch.rand(bsz, m, n)
print(a)
kw(a.view(bsz * m, n), 0, k, 0).view(bsz, m, n) > 0


tensor([[[0.2177, 0.5582, 0.1910],
         [0.0793, 0.9190, 0.8308],
         [0.4444, 0.2490, 0.6382],
         [0.4885, 0.0823, 0.4100],
         [0.4710, 0.3245, 0.6746]],

        [[0.0987, 0.9523, 0.1822],
         [0.0910, 0.5141, 0.2099],
         [0.9616, 0.5538, 0.5057],
         [0.1261, 0.0208, 0.4588],
         [0.0423, 0.1273, 0.3052]]])


tensor([[[1, 1, 0],
         [0, 1, 1],
         [1, 0, 1],
         [1, 0, 1],
         [1, 0, 1]],

        [[0, 1, 1],
         [0, 1, 1],
         [1, 1, 0],
         [1, 0, 1],
         [0, 1, 1]]], dtype=torch.uint8)

In [97]:
sl = 2
bsz = 3
m = 5
n = 4
d_in = 10

x = torch.rand((sl, bsz, d_in))

x_a = x[0, :]  # first item

A = torch.nn.Linear(d_in, m)

u = A(x_a)
print(u, u.size())
z = u.repeat_interleave(n, 1)

print(z)

first_z_batch = z[1]
print(first_z_batch.view(m, n))

tensor([[ 0.2309,  0.2413,  0.0531,  0.1007, -0.1583],
        [ 0.1270,  0.2542, -0.0928,  0.0849, -0.4191],
        [ 0.2453,  0.2738,  0.1485,  0.0858, -0.2552]],
       grad_fn=<AddmmBackward>) torch.Size([3, 5])
tensor([[ 0.2309,  0.2309,  0.2309,  0.2309,  0.2413,  0.2413,  0.2413,  0.2413,
          0.0531,  0.0531,  0.0531,  0.0531,  0.1007,  0.1007,  0.1007,  0.1007,
         -0.1583, -0.1583, -0.1583, -0.1583],
        [ 0.1270,  0.1270,  0.1270,  0.1270,  0.2542,  0.2542,  0.2542,  0.2542,
         -0.0928, -0.0928, -0.0928, -0.0928,  0.0849,  0.0849,  0.0849,  0.0849,
         -0.4191, -0.4191, -0.4191, -0.4191],
        [ 0.2453,  0.2453,  0.2453,  0.2453,  0.2738,  0.2738,  0.2738,  0.2738,
          0.1485,  0.1485,  0.1485,  0.1485,  0.0858,  0.0858,  0.0858,  0.0858,
         -0.2552, -0.2552, -0.2552, -0.2552]], grad_fn=<IndexSelectBackward>)
tensor([[ 0.1270,  0.1270,  0.1270,  0.1270],
        [ 0.2542,  0.2542,  0.2542,  0.2542],
        [-0.0928, -0.0928, -0.0928,

In [155]:
import timeit

import rsm_k_winners
reload(rsm_k_winners)

def run_kwin(size=50, scatter=True):
    return rsm_k_winners.KWinners.apply(torch.rand(size, size), 0, 10, 0, scatter)


In [127]:
run_kwin()

tensor([[0.4548, 0.8208, 0.3500,  ..., 0.0000, 0.0000, 0.2274],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.8647, 0.0963,  ..., 0.7562, 0.0000, 0.6054],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.3596, 0.8424, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.4383, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.6986]])

In [156]:
from timeit import Timer

t = Timer(lambda: run_kwin(scatter=True))
print('scatter', t.timeit(number=10000))

t = Timer(lambda: run_kwin(scatter=False))
print('no_scatter', t.timeit(number=10000))

scatter 0.9094837090015062
no_scatter 10.072980703000212


In [154]:
import rsm_k_winners
reload(rsm_k_winners)

a = torch.rand(3, 3)

print(a)

res_scatter = rsm_k_winners.KWinners.apply(a.clone(), 0, 2, 0, True)
res_no_scatter = rsm_k_winners.KWinners.apply(a.clone(), 0, 2, 0, False)

print(res_scatter)
print(res_no_scatter)
bool(torch.all(torch.eq(res_scatter, res_no_scatter)).item())

tensor([[0.1703, 0.4110, 0.3670],
        [0.3609, 0.5750, 0.4785],
        [0.5245, 0.6524, 0.4454]])
tensor([[2, 1],
        [2, 1],
        [0, 1]])
tensor([[0.0000, 0.4110, 0.3670],
        [0.0000, 0.5750, 0.4785],
        [0.5245, 0.6524, 0.0000]])
tensor([[0.0000, 0.4110, 0.3670],
        [0.0000, 0.5750, 0.4785],
        [0.5245, 0.6524, 0.0000]])


True