In [None]:
# Please install dcll pkgs from below
# https://github.com/nmi-lab/dcll
# and then enjoy yourself.
# If there is any question please mail me.
# CUDA_VISIBLE_DEVICES=6 python3 example_gesture_scnn.py

import sys
sys.path.append("..")

import dcll
from dcll.load_dvsgestures_sparse import *
import argparse, pickle, torch, time, os
from importlib import import_module
import torch.nn as nn
import LIAF
import numpy as np
import pandas as pd
import util

import matplotlib.animation as animation
import matplotlib.pyplot as plt


LIAF.using_syn_batchnorm = False
torch.backends.cudnn.deterministic = True   #保证每次结果一样

#TODO 1: enter ur path (for result)
#TODO 2: put your dataset(unziped) in dcll-maseter/data
#TODO 3: if any error plz mail me(dcll pakage has some bugs..)
resultPath = './result_LIAF.csv'
modelPath =  './dvs_gestrue_model'

#################################
#Arg for network
batch_size = 64

learning_rate =3e-4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#################################
#Arg for dataset
# For how many ms do we present a sample during classification
# How many epochs to run before testing
n_test_interval = 20

time_window = 8
n_iters =  time_window
n_iters_test =  time_window
dt = 100*1000  # us, time of event accumulation for 1 frame
ds = 4      # size scale (1/ds)
target_size = 11 # num_classes
n_epochs = 4000  # in fact number of batches(no classical epoch)
in_channels = 2  # Green and Red
im_dims = im_width, im_height = (128 // ds, 128 // ds)
names = 'dvsGesture_stbp_cnn_test1'

#################################

parser = argparse.ArgumentParser(description='STBP for DVS gestures')
loss_train_list = []
loss_test_list = []
acc_train_list = []
acc_test_list = []
train_correct = 0
test_epoch = 20

# Load data
gen_train, _ = create_data(
    batch_size=batch_size,
    chunk_size=n_iters,
    size=[in_channels, im_width, im_height],
    ds=ds,
    dt=dt)

_, gen_test = create_data(
    batch_size=batch_size,
    chunk_size=n_iters_test,
    size=[in_channels, im_width, im_height],
    ds=ds,
    dt=dt)

def generate_data(gen_test, n_test: int, offset=0):
    input_test, labels_test = gen_test.next(offset=offset)
    input_tests = []
    labels1h_tests = []
    n_test = min(n_test, int(np.ceil(input_test.shape[0] / batch_size)))
    for i in range(n_test):
        input_tests.append(
            torch.Tensor(input_test.swapaxes(0, 1))[:, i * batch_size:(i + 1) * batch_size].reshape(n_iters_test, -1,
                                                                                                    in_channels,
                                                                                                    im_width,
                                                                                                    im_height))
        labels1h_tests.append(torch.Tensor(labels_test[:, i * batch_size:(i + 1) * batch_size]))
    return n_test, input_tests, labels1h_tests

n_test, input_tests, labels1h_tests = generate_data(gen_test, n_test=300, offset=0)
print('test_data_samples:',n_test*batch_size)

modules = import_module('LIAFnet.LIAFCNN')
config = modules.Config()
config.cfgCnn = [(16, 64, 3, 1, 1, False),(64, 128, 3, 2, 1, True),(128, 128, 3, 2, 1, True)]
config.cfgFc = [256, 11]
config.decay = 0.3
config.dropOut= 0
config.dataSize=[im_width,im_height]
config.padding=1
config.timeWindows=time_window
config.dropOut=0
config.useBatchNorm=True
config.useThreshFiring=True
config.actFun=torch.selu
snn = modules.LIAFCNN(config).to(device)

best_acc = 0
acc = 0
criterion = nn.CrossEntropyLoss()
######################################################################################
#note:
#CorssEntrophyLoss适用于分类问题（其为Max函数的连续近似）
#它的输入是output（每一类别的概率）和label（第几个类别）
######################################################################################
optimizer = torch.optim.Adam(snn.parameters(), lr=learning_rate)

running_loss = 0

for epoch in range(n_epochs):

    #training
    snn.train(mode=True)
    correct = 0
    total = 0
    running_loss = 0
    for i in range(1064//batch_size):
        
        
        snn.zero_grad()
        optimizer.zero_grad()
        start_time = time.time()
        images, labels = gen_train.next()
        images = torch.Tensor(images.swapaxes(1,2)).float()
        labels = torch.from_numpy(labels).float()
        _ , labels = labels[1, :, :].max(dim=1)
        
        
        images = images.reshape(batch_size,16,1,32,32)
        
        outputs = snn(images)
        loss = criterion(outputs.cpu(), labels)
        loss.backward()
        
        _ , predicted = torch.max(outputs.data, 1)
        
        total += batch_size
        correct +=  (predicted.cpu() == labels.cpu()).sum()
        running_loss += loss.item()
        
        optimizer.step()
        
        if (i+1)%4 == 0:
            print('Epoch [%d,%d/%d], Loss:%.5f' % (i, epoch + 1, n_epochs, running_loss/10))
            print('Epoch [%d,%d/%d], Acc :%.5f' % (i, epoch + 1, n_epochs, 100 * correct.float()/ total ))
            correct = 0
            total = 0
            running_loss = 0
            
    torch.cuda.empty_cache()