In [61]:
from dataset import NMnistSampled
import numpy as np 
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tnn import *
import matplotlib.pyplot as plt
from tqdm import tqdm

In [15]:
state = torch.load('model/n-mnist-cv')
weight = state['weight']

In [24]:
model = ConvColumn(
    input_channel=2, output_channel=8, 
    kernel=3, stride=2,
    step=16, leak=32,
    dense=0.3, fodep=0, w_init=0.5
).cuda()
with torch.no_grad():
    model.weight.set_(weight)
model.train(False)

Building convolutional TNN layer with theta=5.4000, dense=0.3000, fodep=48


ConvColumn()

In [59]:
train_data_loader = DataLoader(NMnistSampled('data/n-mnist/TrainSP', 34, 34, 256, device='cuda:0'), batch_size=32, shuffle=True)
test_data_loader = DataLoader(NMnistSampled('data/n-mnist/TestSP', 34, 34, 256, device='cuda:0'), batch_size=32, shuffle=True)

In [63]:
result = [[], [], [], [], [], [], [], [], [], []]
for input_spikes, labels in tqdm(test_data_loader):
    output_spikes = model(input_spikes)
    for label, output in zip(labels, output_spikes):
        result[label].append(output.sum((-1, -2, -3)).cpu().numpy())

100%|██████████| 313/313 [07:08<00:00,  1.37s/it]


In [70]:
rs = [np.vstack(r) for r in result]

In [71]:
avgs = np.array([r.mean(0) for r in rs])

In [76]:
float_formatter = "{:5.2f}".format
np.set_printoptions(formatter={'float_kind':float_formatter})

In [77]:
print(avgs)

[[29.23 28.30 25.86 29.33 17.86 96.56 35.25 29.33]
 [ 3.84  3.57  3.50  3.58  1.54 19.38  5.46  3.58]
 [27.53 26.63 24.83 27.63 16.94 90.44 32.32 27.63]
 [30.36 29.93 28.00 30.58 20.06 95.17 35.50 30.58]
 [15.26 14.58 14.76 14.62  9.39 55.03 18.19 14.62]
 [29.44 28.52 27.16 29.33 19.49 91.40 33.94 29.33]
 [22.01 21.08 20.24 21.69 13.39 75.84 25.77 21.69]
 [17.55 16.92 15.81 17.65 10.27 59.93 21.24 17.65]
 [27.27 27.77 25.97 27.24 17.93 91.47 32.40 27.24]
 [19.55 19.85 19.08 20.01 12.44 68.95 23.86 20.01]]


In [58]:
class LinearModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(LinearModel, self).__init__()
        self.layer = nn.Linear(input_size, output_size)
    
    def forward(self, input_data):
        output = self.layer(input_data)
        logits = torch.log_softmax(output, dim=1)
        return logits

In [82]:
tester = LinearModel(8, 10).cuda()
tester.train()
optimizer = torch.optim.Adam(model.parameters())
error = torch.nn.CrossEntropyLoss()

In [86]:
train_data_iterator = tqdm(train_data_loader)
for data, label in train_data_iterator:
    output_spikes = model.forward(data)
    optimizer.zero_grad()
    output = tester.forward(output_spikes.sum((-1, -2, -3)))
    loss = error(output, label.cuda())
    train_data_iterator.set_description(f'loss={loss.detach().cpu().numpy():.4f}')
    loss.backward()
    optimizer.step()

In [100]:
tavg = torch.Tensor(avgs).cuda().T.unsqueeze(0)
def predictor(x):
    pass

In [112]:
features = []
labels = []
for data, label in tqdm(train_data_loader):
    output_spikes = model.forward(data)
    feature = output_spikes.sum((-1, -2, -3)).cpu().numpy()
    features.append(feature)
    labels.append(label.numpy())

100%|██████████| 1875/1875 [1:30:18<00:00,  2.89s/it]


In [117]:
X_train = np.vstack(features)
Y_train = np.hstack(labels)
X_test = np.vstack(rs)
Y_test = np.hstack([
    np.zeros(len(r), dtype=np.int64) + i
    for i, r in enumerate(rs)
])
X_train.shape, Y_train.shape, X_test.shape, Y_test.shape

((60000, 8), (60000,), (10000, 8), (10000,))

In [118]:
from sklearn.ensemble import GradientBoostingClassifier
tester = GradientBoostingClassifier()
tester.fit(X_train, Y_train)

GradientBoostingClassifier()

In [119]:
Y_pred = tester.predict(X_test)

In [120]:
from sklearn.metrics import *
accuracy_score(Y_test, Y_pred)

0.2929

In [121]:
confusion_matrix(Y_test, Y_pred)

array([[ 356,   30,  161,  106,   51,    4,  105,  104,   32,   31],
       [   0, 1056,    3,    0,   30,    0,   16,   22,    0,    8],
       [ 248,   37,  197,  115,   77,    6,  115,  138,   45,   54],
       [ 211,   33,  166,  198,   79,   18,   71,  127,   40,   67],
       [  16,  121,   26,    1,  409,    1,  101,  228,    7,   72],
       [ 132,   43,  133,  180,   87,   25,   76,  138,   23,   55],
       [ 129,   69,   96,   67,  162,    4,  159,  175,   33,   64],
       [  22,  109,   68,   24,  258,    2,  104,  376,    8,   57],
       [ 218,   37,  139,  119,   99,   13,  105,  116,   56,   72],
       [  64,  103,  108,   56,  209,    1,  107,  244,   20,   97]],
      dtype=int64)