In [1]:
import os
import time
import utils
import math
import argparse
import requests
import random
import datetime
import numpy as np

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.cuda import amp
import torch.distributed.optim
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary

import torchvision
from torchvision import transforms
from spikingjelly.activation_based import layer,functional,neuron,surrogate

from scipy.io import loadmat,savemat
from sklearn.cluster import KMeans

import torchaudio
from torchaudio import transforms

In [2]:
torch.backends.cudnn.benchmark = True

train_tb_writer = None
te_tb_writer = None
device = torch.device('cuda:1')
deviceIds = [1]

In [3]:
def parse_args():
    parser = argparse.ArgumentParser(description='PyTorch Classification Training')

    parser.add_argument('--device', default='cuda:0', help='device')
    parser.add_argument('-b', '--batch-size', default=32, type=int)
    parser.add_argument('--data-path', default='./data/', help='dataset')
    parser.add_argument('--epochs', default=320, type=int, metavar='N',
                        help='number of total epochs to pre-train')
    parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
                        help='number of data loading workers (default: 16)')
    parser.add_argument('--lr', default=0.0025, type=float, help='initial learning rate')

    parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
    parser.add_argument('--output-dir', default='./logs', help='path where to save')
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--tb', action='store_true',
                        help='Use TensorBoard to record logs')
    parser.add_argument(
        "--cache-dataset",
        dest="cache_dataset",
        help="Cache the datasets for quicker initialization. It also serializes the transforms",
        action="store_true",
    )

    args = parser.parse_args(args=['--data-path','../data','--lr','0.01','-b','24','--epochs','100','--print-freq','100','--tb','--cache-dataset'])
    return args

args = parse_args()

In [None]:
class NetworkB(torch.nn.Module):
    def __init__(self):
        super(NetworkB, self).__init__()
        self.T = 10
        self.conv1 = nn.Conv2d(1, 96, kernel_size=3, padding=1) 
        self.bn1 = nn.BatchNorm2d(96)
        self.sn1 = neuron.IFNode(detach_reset=True)

        self.conv2 = layer.SeqToANNContainer(nn.Conv2d(96, 96, kernel_size=3, padding=1),nn.BatchNorm2d(96)) 
        self.sn2 = neuron.IFNode(detach_reset=True)
        self.pool1 = layer.SeqToANNContainer(nn.MaxPool2d(2)) 

        self.conv3 = layer.SeqToANNContainer(nn.Conv2d(96, 128, kernel_size=3, padding=1),nn.BatchNorm2d(128))
        self.sn3 = neuron.IFNode(detach_reset=True)
      
        self.conv4 = layer.SeqToANNContainer(nn.Conv2d(128, 128, kernel_size=3, padding=1),nn.BatchNorm2d(128))
        self.sn4 = neuron.IFNode(detach_reset=True)
        self.pool2 = layer.SeqToANNContainer(nn.MaxPool2d(2))

        self.conv5 = layer.SeqToANNContainer(nn.Conv2d(128, 256, kernel_size=3, padding=1),nn.BatchNorm2d(256))
        self.sn5 = neuron.IFNode(detach_reset=True)

        self.pool3 = layer.SeqToANNContainer(nn.MaxPool2d(2))

        self.conv8 = layer.SeqToANNContainer(nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.BatchNorm2d(256)) 
        self.sn8 = neuron.IFNode(detach_reset=True)
        
        self.conv10 = layer.SeqToANNContainer(nn.Conv2d(256, 512, kernel_size=3, padding=1),nn.BatchNorm2d(512)) 
        self.sn10 = neuron.IFNode(detach_reset=True)

        self.pool4 = layer.SeqToANNContainer(nn.MaxPool2d(2)) 

        self.conv11 = layer.SeqToANNContainer(nn.Conv2d(512, 256, kernel_size=3, padding=1),nn.BatchNorm2d(256))
        self.sn11 = neuron.IFNode(detach_reset=True)

        self.conv13 = layer.SeqToANNContainer(nn.Conv2d(256, 128, kernel_size=3, padding=1),nn.BatchNorm2d(128))
        self.sn13 = neuron.IFNode(detach_reset=True)

        self.pool5 = layer.SeqToANNContainer(nn.MaxPool2d(2))

        self.fc1 = layer.SeqToANNContainer(nn.Linear(2048, 256),nn.BatchNorm1d(256))
        self.sn14 = neuron.IFNode(detach_reset=True)
        self.fc2 = layer.SeqToANNContainer(nn.Linear(256, 8),nn.BatchNorm1d(8))
        self.sn15 = neuron.IFNode(detach_reset=True)


    def forward(self, x):
        T = self.T
        x = self.bn1(self.conv1(x))
        x.unsqueeze_(0)
        x = x.repeat(10, 1, 1, 1, 1)
        x = self.sn1(x)

        x = self.sn2(self.conv2(x))
        x = self.pool1(x)

        x = self.sn3(self.conv3(x))
        x = self.sn4(self.conv4(x))
        x = self.pool2(x)

        x = self.sn5(self.conv5(x))
        x = self.pool3(x)

        x = self.sn8(self.conv8(x))
        x = self.sn10(self.conv10(x))
        x = self.pool4(x)

        x = self.sn11(self.conv11(x))

        x = x[:T, :, :, :, :]
        x = self.sn13(self.conv13(x))
        x = self.pool5(x)

        x = torch.flatten(x,2)

        x = self.sn14(self.fc1(x))
        x = self.sn15(self.fc2(x))

        return x.mean(0)
    
    def set_T(self, T):
        self.T = T

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader,random_split,Dataset
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import os
import pandas as pd
import numpy as np
import pickle

args = parse_args()
max_test_acc1 = 0.
test_acc5_at_max_test_acc1 = 0.

utils.init_distributed_mode(args)
print(args)
output_dir = os.path.join(args.output_dir, f'b_{args.batch_size}_lr{args.lr}')

time_now = datetime.datetime.now()
output_dir += f'_{time_now.year}_{time_now.month}_{time_now.day}_{time_now.hour}__{time_now.minute}'


if output_dir:
    utils.mkdir(output_dir)

class MyDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform
        self.jpg_list = os.listdir(self.data_path)
        self.label_list = []
        self.data_set = []
        
        label_0 = 0
        for i in range(len(self.jpg_list)):
            if 'c' in self.jpg_list[i]:
                label = int(self.jpg_list[i].split('c')[1].split('.')[0])
                if label in [0, 1, 2, 4, 5, 7, 11, 25]:
                    if label == 0 and label_0 < 6000:
                        label = 0
                        label_0 += 1
                    elif label == 1:
                        label = 1
                    elif label == 2:
                        label = 2
                    elif label == 4:
                        label = 3
                    elif label == 5:
                        label = 4
                    elif label == 7:
                        label = 5
                    elif label == 11:
                        label = 6
                    elif label == 25:
                        label = 7
                    
                    self.label_list.append(label)
                    self.data_set.append(self.jpg_list[i])
        
        self.label_list = np.array(self.label_list)
        self.data_set = [Image.open(os.path.join(self.data_path, i)) for i in self.data_set]
        self.data_set = [self.transform(i) for i in self.data_set]
        
    def __len__(self):
        return len(self.label_list)
    
    def __getitem__(self, idx):
        return self.data_set[idx], self.label_list[idx]

# Parameters
data_path = '/home/mlw/from_108/paper_SNNcom/Neuromorphic-Hybrid-Information-Processing-Architecture/Software/snncom/ecg/data/mitbih_database/JPEG/'
pickle_file = '/home/mlw/from_108/paper_SNNcom/Neuromorphic-Hybrid-Information-Processing-Architecture/Software/snncom/ecg/data/mitbih_database/dataset.pkl'

# Check if the dataset is already saved
if os.path.exists(pickle_file):
    with open(pickle_file, 'rb') as f:
        dataset = pickle.load(f)
else:
    # Create the dataset
    dataset = MyDataset(data_path=data_path, transform=transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor()
    ]))
    
    # Save the dataset to a pickle file
    with open(pickle_file, 'wb') as f:
        pickle.dump(dataset, f)

# Split the dataset into training and testing
traindata, testdata = random_split(dataset, [round(len(dataset) * .8), round(len(dataset) * .2)])

# Create data loaders
data_loader = DataLoader(traindata, batch_size=args.batch_size, shuffle=True)
data_loader_test = DataLoader(testdata, batch_size=args.batch_size, shuffle=True)
print("Creating model")

net = NetworkB().to(device)
functional.set_step_mode(net,step_mode='m')
functional.set_backend(net, backend='cupy')
net.to(device)
optimizer = torch.optim.SGD(
    net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.4, last_epoch=-1)
criterion = nn.CrossEntropyLoss()
if args.resume:
    checkpoint = torch.load(args.resume, map_location='cpu')
    args.start_epoch = checkpoint['epoch'] + 1
    max_test_acc1 = checkpoint['max_test_acc1']
    test_acc5_at_max_test_acc1 = checkpoint['test_acc5_at_max_test_acc1']


if args.tb and utils.is_main_process():
    purge_step_train = args.start_epoch
    purge_step_te = args.start_epoch
    train_tb_writer = SummaryWriter(output_dir + '_logs/train', purge_step=purge_step_train)
    te_tb_writer = SummaryWriter(output_dir + '_logs/te', purge_step=purge_step_te)
    with open(output_dir + '_logs/args.txt', 'w', encoding='utf-8') as args_txt:
        args_txt.write(str(args))

    with open(output_dir + '_logs/resluts.txt', 'w', encoding='utf-8') as args_txt:
        args_txt.write('Results\n')


Not using distributed mode
Namespace(device='cuda:0', batch_size=24, data_path='../data', epochs=100, workers=16, lr=0.01, print_freq=100, output_dir='./logs', resume='', start_epoch=0, tb=True, cache_dataset=True, distributed=False)
Creating model


In [7]:
def train_one_epoch(net, criterion, data_loader, device, epoch, print_freq, scaler=None,lr = 1e-2):
    net.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
    metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}'))

    header = 'Epoch: [{}]'.format(epoch)


    for image, target in metric_logger.log_every(data_loader, print_freq, header):
        start_time = time.time()
        image, target = image.to(device), target.to(device)

        output = net(image)
        loss = criterion(output, target)

        optimizer.zero_grad()
        
        loss.backward()

        optimizer.step()


        functional.reset_net(net)

        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        batch_size = image.shape[0]
        loss_s = loss.item()
        if math.isnan(loss_s):
            raise ValueError('loss is Nan')
        acc1_s = acc1.item()
        acc5_s = acc5.item()

        metric_logger.update(loss=loss_s, lr=lr)

        metric_logger.meters['acc1'].update(acc1_s, n=batch_size)
        metric_logger.meters['acc5'].update(acc5_s, n=batch_size)
        metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))
    metric_logger.synchronize_between_processes()
    return metric_logger.loss.global_avg, metric_logger.acc1.global_avg, metric_logger.acc5.global_avg

In [8]:
def evaluate(net, criterion, data_loader, device, print_freq=100, header='Test:'):
    net.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    with torch.no_grad():
        for image, target in metric_logger.log_every(data_loader, print_freq, header):
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = net(image)
            loss = criterion(output, target)
            functional.reset_net(net)

            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            batch_size = image.shape[0]
            metric_logger.update(loss=loss.item())
            metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
            metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
    metric_logger.synchronize_between_processes()

    loss, acc1, acc5 = metric_logger.loss.global_avg, metric_logger.acc1.global_avg, metric_logger.acc5.global_avg
    print(f' * Acc@1 = {acc1}, Acc@5 = {acc5}, loss = {loss}')
    return loss, acc1, acc5

In [None]:
print("Start training")
print(device)
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
    save_max = False

    train_loss, train_acc1, train_acc5 = train_one_epoch(net, criterion,data_loader, device, epoch, args.print_freq,lr=args.lr)

    if utils.is_main_process():
        train_tb_writer.add_scalar('train_loss', train_loss, epoch)
        train_tb_writer.add_scalar('train_acc1', train_acc1, epoch)
        train_tb_writer.add_scalar('train_acc5', train_acc5, epoch)
    # lr_scheduler.step()

    test_loss, test_acc1, test_acc5 = evaluate(net, criterion, data_loader_test, device=device, header='Test:')
    if te_tb_writer is not None:
        if utils.is_main_process():
            te_tb_writer.add_scalar('test_loss', test_loss, epoch)
            te_tb_writer.add_scalar('test_acc1', test_acc1, epoch)
            te_tb_writer.add_scalar('test_acc5', test_acc5, epoch)


    if max_test_acc1 < test_acc1:
        max_test_acc1 = test_acc1
        test_acc5_at_max_test_acc1 = test_acc5
        save_max = True

    if output_dir:

        checkpoint = {
            'model': net.state_dict(),
            'pre-train_epoch': epoch,
            'args': args,
            'max_test_acc1': max_test_acc1,
            'test_acc5_at_max_test_acc1': test_acc5_at_max_test_acc1,
        }

        utils.save_on_master(
            checkpoint,
            os.path.join(output_dir, 'checkpoint_latest.pth'))
        save_flag = False

        if epoch % 64 == 0 or epoch == args.epochs - 1:
            save_flag = True


        if save_flag:
            utils.save_on_master(
                checkpoint,
                os.path.join(output_dir, f'checkpoint_{epoch}.pth'))

        if save_max:
            utils.save_on_master(
                checkpoint,
                os.path.join(output_dir,f'train_maxacc1_{max_test_acc1}_checkpoint_max_test_acc1_all_pretrain.pth'))
            torch.save(net,os.path.join(output_dir,f'train_maxacc1_{max_test_acc1}_checkpoint_max_test_acc1_all_pretrain.pth'))
            torch.save(net.state_dict(),os.path.join(output_dir,f'train_maxacc1_{max_test_acc1}_checkpoint_max_test_acc1_state_pretrain.pth'))
    print(args)
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print(output_dir)

    print('Training time {}'.format(total_time_str), 'max_test_acc1', max_test_acc1,
            'test_acc5_at_max_test_acc1', test_acc5_at_max_test_acc1,'train_acc1', train_acc1,
            'train_acc5', train_acc1)

    Train_logs= {
        'Epoch:': epoch,
        'max_test_acc1 ': max_test_acc1 ,
        'test_acc5_at_max_test_acc1 ': test_acc5_at_max_test_acc1,
        'train_acc1 ': train_acc1,
        'train_acc5 ': train_acc1,
        'args': args
    }
    with open(output_dir + '_logs/args.txt', 'a', encoding='utf-8') as args_txt:
        args_txt.write('\n')
        args_txt.write(str(Train_logs))

Start training
cuda:1
Epoch: [0]  [   0/2986]  eta: 3:09:26  lr: 0.01  img/s: 6.319760627912262  loss: 2.0629 (2.0629)  acc1: 29.1667 (29.1667)  acc5: 75.0000 (75.0000)  time: 3.8067  data: 0.0091  max mem: 0
Epoch: [0]  [ 100/2986]  eta: 0:13:00  lr: 0.01  img/s: 102.76032597211291  loss: 1.4160 (1.5621)  acc1: 87.5000 (74.1337)  acc5: 100.0000 (98.8036)  time: 0.2339  data: 0.0003  max mem: 0
Epoch: [0]  [ 200/2986]  eta: 0:11:42  lr: 0.01  img/s: 102.74522244234933  loss: 1.3512 (1.4651)  acc1: 91.6667 (82.8358)  acc5: 100.0000 (99.3988)  time: 0.2340  data: 0.0003  max mem: 0
Epoch: [0]  [ 300/2986]  eta: 0:11:01  lr: 0.01  img/s: 102.60363330951627  loss: 1.3458 (1.4273)  acc1: 91.6667 (86.1573)  acc5: 100.0000 (99.5986)  time: 0.2342  data: 0.0003  max mem: 0
Epoch: [0]  [ 400/2986]  eta: 0:10:28  lr: 0.01  img/s: 102.61084993170374  loss: 1.3309 (1.4091)  acc1: 95.8333 (87.8637)  acc5: 100.0000 (99.6987)  time: 0.2342  data: 0.0003  max mem: 0
Epoch: [0]  [ 500/2986]  eta: 0:10:

<Response [200]>

In [10]:
model_dir = os.path.join(output_dir,f'train_maxacc1_{max_test_acc1}_checkpoint_max_test_acc1_state_pretrain.pth')
# print(model_dir)

In [None]:
acc_list = []
for time in range(1,11):
  acc = 0
  for i in range(10):
      net = NetworkB().to(device)
      functional.set_step_mode(net,step_mode='m')
      functional.set_backend(net, backend='cupy')
      weights = torch.load(model_dir)
      net.load_state_dict(weights)
      net.set_T(time)
      test_loss, test_acc1, test_acc5 = evaluate(net, criterion, data_loader_test, device=device, header='Test:')
      acc += test_acc1
      # acc_list.append(acc)
  acc_list.append(acc/10)

Test:  [  0/747]  eta: 0:04:57  loss: 1.3157 (1.3157)  acc1: 95.8333 (95.8333)  acc5: 100.0000 (100.0000)  time: 0.3977  data: 0.0012  max mem: 0
Test:  [100/747]  eta: 0:01:00  loss: 1.2808 (1.2938)  acc1: 100.0000 (97.5248)  acc5: 100.0000 (100.0000)  time: 0.0910  data: 0.0003  max mem: 0
Test:  [200/747]  eta: 0:00:50  loss: 1.2740 (1.2922)  acc1: 100.0000 (97.7405)  acc5: 100.0000 (100.0000)  time: 0.0909  data: 0.0003  max mem: 0
Test:  [300/747]  eta: 0:00:41  loss: 1.2808 (1.2915)  acc1: 100.0000 (97.8544)  acc5: 100.0000 (100.0000)  time: 0.0909  data: 0.0003  max mem: 0
Test:  [400/747]  eta: 0:00:31  loss: 1.2808 (1.2911)  acc1: 100.0000 (97.9530)  acc5: 100.0000 (100.0000)  time: 0.0909  data: 0.0003  max mem: 0
Test:  [500/747]  eta: 0:00:22  loss: 1.2808 (1.2917)  acc1: 95.8333 (97.8959)  acc5: 100.0000 (100.0000)  time: 0.0909  data: 0.0003  max mem: 0
Test:  [600/747]  eta: 0:00:13  loss: 1.2808 (1.2911)  acc1: 100.0000 (97.9964)  acc5: 100.0000 (100.0000)  time: 0.0909

In [None]:
with open('output_snn.txt', 'a') as file:
    line = ' '.join(map(str, acc_list))
    file.write(line + '\n')