In [1]:
import sys, os
CURRENT_TEST_DIR = os.getcwd()
sys.path.append(CURRENT_TEST_DIR + "/../../VT_SNN/")

In [2]:
import torch
import numpy as np
import slayerSNN as snn
from pathlib import Path
import logging
from snn_models.baseline_snn import SlayerMLP
from snn_models.multimodal_snn import EncoderVis
from torch.utils.data import DataLoader
from dataset import ViTacDataset, ViTacMMDataset, ViTacVisDataset

In [3]:
fname='rests/no_delay_no_dropout_new_data_vision_new_model'

In [4]:
class FLAGS():
    def __init__(self):
        self.data_dir = '/home/tasbolat/some_python_examples/data_VT_SNN_new/'
        self.batch_size = 8
        self.sample_file = 2
        self.lr = 0.001
        self.epochs = 500
        self.output_size = 20
        self.theta = 10
        self.tauRho = 1
        self.tsample= 325
        self.tsr_stop = 325
        self.sc_true = 150
        self.sc_false = 5
        self.hidden_size = 32
args = FLAGS()

In [5]:
params = {
    "neuron": {
        "type": "SRMALPHA",
        "theta": args.theta, # activation threshold
        "tauSr": 10.0,
        "tauRef": 1.0,
        "scaleRef": 2,
        "tauRho": 1,
        "scaleRho": 1,
    },
    "simulation": {"Ts": 1.0, "tSample": args.tsample, "nSample": 1},
    "training": {
        "error": {
            "type": "NumSpikes",  # "NumSpikes" or "ProbSpikes"
            "tgtSpikeRegion": {  # valid for NumSpikes and ProbSpikes
                "start": 0,
                "stop": 325,
            },
            "tgtSpikeCount": {True: args.sc_true, False: args.sc_false},
        }
    },
}

In [6]:
output_size = args.output_size # 20

train_dataset = ViTacVisDataset(
    path=args.data_dir, sample_file=f"train_80_20_{args.sample_file}.txt", output_size=args.output_size
)
train_loader = DataLoader(
    dataset=train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4
)
test_dataset = ViTacVisDataset(
    path=args.data_dir, sample_file=f"test_80_20_{args.sample_file}.txt", output_size=args.output_size
)
test_loader = DataLoader(
    dataset=test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4
)

In [7]:
len(train_dataset)

320

In [8]:
a,b,c = test_dataset[0]

In [9]:
a.shape, b.shape, c

(torch.Size([2, 63, 50, 325]), torch.Size([20, 1, 1, 1]), 0)

In [10]:
torch.unique(a)

tensor([0., 1.])

In [11]:
device = torch.device("cuda:0")
net = EncoderVis(params, args.output_size).to(device)

error = snn.loss(params).to(device)
optimizer = torch.optim.RMSprop(
    net.parameters(), lr=args.lr, weight_decay=0.5
)

In [12]:
net

EncoderVis(
  (slayer): spikeLayer()
  (fc1): _denseLayer(2, 1024, kernel_size=(63, 50, 1), stride=(1, 1, 1), bias=False)
  (fc2): _denseLayer(1024, 20, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
)

In [None]:
train_loss = []
test_loss = []
train_acc = []
test_acc = []

for epoch in range(1, args.epochs+1):
    tr_correct = 0
    btr_loss = 0
    net.train()
    for i, (vis, target, label) in enumerate(train_loader):
        vis = vis.to(device)
        target = target.to(device)
        output = net.forward( vis)
        tr_correct += torch.sum(snn.predict.getClass(output) == label).data.item()
        loss = error.numSpikes(output, target)
        btr_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    # Evaluate
    te_correct = 0
    bte_loss = 0
    net.eval()
    with torch.no_grad():
        for i, (vis, target, label) in enumerate(test_loader):
            vis = vis.to(device)
            target = target.to(device)
            output = net.forward( vis)
            te_correct += torch.sum(snn.predict.getClass(output) == label).data.item()
            loss = error.numSpikes(output, target)
            bte_loss += loss.item()

    if epoch%10 == 0:
        print('Epoch:', epoch)    
        print('Accs (train, test):', tr_correct/len(train_dataset), te_correct/len(test_dataset))
        print('Loss (train, test):', btr_loss/len(train_dataset), bte_loss/len(test_dataset))
    
    train_loss.append( btr_loss/len(train_dataset) )
    test_loss.append( bte_loss/len(test_dataset) )
    train_acc.append( tr_correct/len(train_dataset) )
    test_acc.append( te_correct/len(test_dataset) )

Epoch: 10
Accs (train, test): 0.434375 0.2625
Loss (train, test): 25.53571379184723 30.54906005859375
Epoch: 20
Accs (train, test): 0.690625 0.45
Loss (train, test): 14.908182382583618 21.475576305389403
Epoch: 30
Accs (train, test): 0.884375 0.575
Loss (train, test): 8.323894077539444 19.676980113983156
Epoch: 40
Accs (train, test): 0.9375 0.5625
Loss (train, test): 5.016384628415108 19.893998622894287
Epoch: 50
Accs (train, test): 0.990625 0.6875
Loss (train, test): 3.018523934483528 15.983518695831298
Epoch: 60
Accs (train, test): 0.9875 0.6125
Loss (train, test): 3.2015577137470244 17.516480255126954
Epoch: 70
Accs (train, test): 0.975 0.6375
Loss (train, test): 2.9072739705443382 18.493595314025878
Epoch: 80
Accs (train, test): 0.978125 0.65
Loss (train, test): 1.841048077493906 18.275404453277588
Epoch: 90
Accs (train, test): 0.978125 0.6125
Loss (train, test): 2.148072024434805 22.276480865478515
Epoch: 100
Accs (train, test): 0.975 0.7125
Loss (train, test): 2.0226731207221746 