In [1]:
!which python

/home/ec2-user/miniconda3/bin/python


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
sys.path.append("/home/ec2-user/nta/nupic.research")
sys.path.append("/home/ec2-user/nta/nupic.torch")

In [4]:
import models

In [5]:
import sys
import torch

from models import MultiHeadedSparseMLP, MultiHeadedDendriticMLP, DendriticMLP, SparseMLP

from nupic.research.frameworks.pytorch.model_utils import count_nonzero_params

from nupic.torch.modules import rezero_weights

from nupic.research.frameworks.pytorch.models.common_models import StandardMLP

from nupic.research.frameworks.dendrites.modules.dendritic_layers import BiasingDendriticLayer

from copy import deepcopy
from torch import nn

In [6]:
test_data_with_context = torch.rand((8,21))


In [7]:
# start a regular MLP
sparse_net = SparseMLP(
    hidden_sizes=(2048,2048,2048),
    input_size=21,
    output_dim=7,
    linear_weight_percent_on=(0.5, 0.5, 0.5),
    linear_activity_percent_on=(0.1, 0.1, 0.1),
    use_batch_norm=False,
)

In [8]:
%time
output = sparse_net(test_data_with_context)
print(output)

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 7.15 µs
tensor([[-7.6795e-03,  3.7848e-03,  6.9775e-03, -9.3998e-03, -3.0156e-03,
          3.1120e-03, -8.1030e-03],
        [-7.4035e-03,  1.1602e-02,  7.0175e-03, -6.0808e-03, -5.8343e-03,
         -2.1367e-03, -1.7176e-02],
        [-3.5647e-03,  4.4708e-03,  1.0158e-02, -2.0132e-03, -1.0010e-02,
         -3.2798e-03, -1.5986e-02],
        [-7.7780e-03,  5.2119e-03,  4.1889e-03, -1.0919e-02, -6.2680e-03,
          3.6875e-03, -1.1003e-02],
        [-1.4235e-02,  1.8707e-02,  1.0515e-02, -6.6521e-03, -1.3724e-02,
          6.5169e-06, -1.1407e-02],
        [ 4.1190e-03,  1.6794e-02,  6.1752e-03, -1.4180e-03, -7.3284e-03,
          4.6146e-04, -1.1511e-02],
        [-1.0167e-02,  1.5454e-02,  5.4558e-03,  6.2702e-04, -7.3032e-03,
         -5.3850e-03, -1.5711e-02],
        [-1.2256e-02,  1.8944e-02,  8.3157e-03, -2.9518e-03, -2.3077e-03,
         -2.5419e-04, -2.1103e-02]], grad_fn=<AddmmBackward>)


In [9]:
# start a dendritic MLP
dendrite_net = DendriticMLP(
    hidden_sizes=(2048,2048,2048),
    input_size=11,
    output_dim=7,
    kw=False,
    relu=True,
    kw_percent_on=1.,
    dim_context=10,
    num_segments=(30, 30, 30),
    sparsity=0.5,
)

In [10]:
test_data_without_context = torch.rand((8,11))
test_context = torch.rand((8,10))

In [11]:
%time
output = dendrite_net(test_data_without_context, test_context)
print(output)

CPU times: user 1 µs, sys: 0 ns, total: 1 µs
Wall time: 6.2 µs
tensor([[ 0.0021,  0.0039,  0.0177, -0.0183, -0.0049, -0.0017,  0.0040],
        [ 0.0029,  0.0058,  0.0173, -0.0170, -0.0067, -0.0014,  0.0042],
        [ 0.0038,  0.0043,  0.0166, -0.0185, -0.0061, -0.0023,  0.0030],
        [ 0.0045,  0.0052,  0.0177, -0.0180, -0.0056, -0.0002,  0.0041],
        [ 0.0048,  0.0048,  0.0191, -0.0194, -0.0071, -0.0029,  0.0046],
        [ 0.0026,  0.0054,  0.0159, -0.0190, -0.0079, -0.0015,  0.0039],
        [ 0.0029,  0.0042,  0.0184, -0.0175, -0.0064, -0.0010,  0.0042],
        [ 0.0031,  0.0040,  0.0177, -0.0184, -0.0066, -0.0005,  0.0051]],
       grad_fn=<AddmmBackward>)


In [12]:
# start a dendritic MLP
dendrite_net2 = DendriticMLP(
    hidden_sizes=(2048,2048,2048),
    input_size=11,
    output_dim=7,
    kw=False,
    relu=True,
    kw_percent_on=1.,
    dim_context=10,
    num_segments=(30, 30, 30),
    sparsity=0.5,
    dendritic_layer_class=BiasingDendriticLayer
)

In [13]:
%time
output = dendrite_net2(test_data_without_context, test_context)
print(output)

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 6.44 µs
tensor([[ 0.2199,  0.0905,  0.1041,  0.1568,  0.1492, -0.5456, -0.2105],
        [ 0.1484,  0.0395,  0.0559,  0.0640,  0.1299, -0.4030, -0.1294],
        [ 0.1744,  0.0723,  0.1382,  0.0609,  0.0445, -0.3637, -0.2328],
        [ 0.2002,  0.1502,  0.0076,  0.1395,  0.0607, -0.5386, -0.2957],
        [ 0.1378,  0.0193,  0.1687,  0.1601,  0.1357, -0.3876, -0.1388],
        [ 0.2569,  0.0463,  0.0882,  0.1333,  0.1584, -0.4526, -0.2106],
        [ 0.1801,  0.1057,  0.1073,  0.0925,  0.1084, -0.4615, -0.2016],
        [ 0.1095,  0.0235,  0.0601,  0.0100,  0.0649, -0.3919, -0.1522]],
       grad_fn=<AddmmBackward>)


In [14]:
dense_net = StandardMLP(
    input_size=21, 
    num_classes=7,
    hidden_sizes=(2048, 2048, 2048)
)

In [15]:
%time
output = dense_net(test_data_with_context)
print(output)

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 6.91 µs
tensor([[ 0.0233,  0.0151, -0.0164,  0.0210, -0.0011,  0.0606, -0.0041],
        [ 0.0246,  0.0303, -0.0120,  0.0199,  0.0279,  0.0658,  0.0231],
        [ 0.0072,  0.0339, -0.0156,  0.0378,  0.0111,  0.0486,  0.0422],
        [ 0.0227,  0.0251, -0.0251,  0.0223,  0.0169,  0.0477,  0.0414],
        [ 0.0225,  0.0143, -0.0296,  0.0042,  0.0103,  0.0580,  0.0302],
        [ 0.0086,  0.0306, -0.0218,  0.0060,  0.0151,  0.0698,  0.0412],
        [-0.0063,  0.0374, -0.0150,  0.0115, -0.0043,  0.0467,  0.0028],
        [ 0.0122,  0.0219, -0.0309,  0.0065, -0.0067,  0.0507,  0.0020]],
       grad_fn=<AddmmBackward>)


In [16]:
# have a small forward and backward pass. time

In [17]:
sparse_net2 = deepcopy(sparse_net)
sparse_net2._hidden_base.linear1_kwinners = nn.ReLU()
sparse_net2._hidden_base.linear2_kwinners = nn.ReLU()
sparse_net2._hidden_base.linear3_kwinners = nn.ReLU()

In [18]:
%time
output = sparse_net2(test_data_with_context)
print(output)

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 6.91 µs
tensor([[-0.0147,  0.0132, -0.0067, -0.0027, -0.0110,  0.0012, -0.0065],
        [-0.0120,  0.0159, -0.0032,  0.0031, -0.0192,  0.0010, -0.0133],
        [-0.0158,  0.0137, -0.0030,  0.0055, -0.0234,  0.0052, -0.0161],
        [-0.0184,  0.0147, -0.0010,  0.0041, -0.0177,  0.0026, -0.0091],
        [-0.0155,  0.0175, -0.0022, -0.0006, -0.0155,  0.0019, -0.0113],
        [-0.0149,  0.0190, -0.0019,  0.0033, -0.0158,  0.0053, -0.0092],
        [-0.0163,  0.0217, -0.0012,  0.0052, -0.0145,  0.0042, -0.0154],
        [-0.0188,  0.0232, -0.0073,  0.0059, -0.0090,  0.0011, -0.0090]],
       grad_fn=<AddmmBackward>)


In [19]:
device = torch.device("cuda")
for model in [dense_net, sparse_net, dendrite_net, sparse_net2, dendrite_net2]:
    model.to(device)

In [20]:
# !cat ~/nta/nupic.research/nupic/research/frameworks/dendrites/functional/apply_dendrites.py

In [21]:
def train(model, device, epochs=1, batch_size=64, dendrites=False, backward_pass=False):
    
    input_dim = 11
    context_dim = 10
    output_dim = 7

    optim = torch.optim.SGD(lr=0.01, params=model.parameters())
    loss_fn = torch.nn.MSELoss()

    for _ in range(epochs):
        target = torch.randn(batch_size, output_dim, device=device)

        if dendrites:
            data = torch.rand(batch_size, input_dim, device=device)
            context = torch.rand(batch_size, context_dim, device=device)
            output = model(data, context)
        else:
            data = torch.rand(batch_size, input_dim+context_dim, device=device)
            output = model(data)

        if backward_pass:
            loss = loss_fn(output, target)
            loss.backward()

            optim.step()

#             model.apply(rezero_weights)
            optim.zero_grad()


In [22]:
epochs = 1000

In [23]:
%time train(dense_net, device, epochs=epochs, batch_size=1024)   # dense

CPU times: user 1.33 s, sys: 440 ms, total: 1.77 s
Wall time: 1.77 s


In [24]:
%time train(sparse_net, device, epochs=epochs, batch_size=1024)  # kwinners - 50% slower

CPU times: user 2.14 s, sys: 280 ms, total: 2.42 s
Wall time: 2.41 s


In [25]:
%time train(sparse_net2, device, epochs=epochs, batch_size=1024)  # no kwinners

CPU times: user 1.07 s, sys: 420 ms, total: 1.49 s
Wall time: 1.49 s


In [26]:
%time train(dendrite_net, device, epochs=epochs, batch_size=1024, dendrites=True)  # gating

CPU times: user 15.8 s, sys: 8.78 s, total: 24.6 s
Wall time: 24.6 s


In [None]:
%time train(dendrite_net2, device, epochs=epochs, batch_size=1024, dendrites=True)  # bias

In [None]:
count_nonzero_params(dendrite_net)

In [None]:
count_nonzero_params(sparse_net)

In [None]:
sparse_net2

In [None]:
dense_net