In [None]:
import os
import time
import utils
import math
import argparse
import requests
import random
import datetime
import numpy as np

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.cuda import amp
import torch.distributed.optim
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary

import torchvision
from torchvision import transforms
from spikingjelly.activation_based import layer,functional,neuron,surrogate

from scipy.io import loadmat,savemat

In [None]:
torch.backends.cudnn.benchmark = True
_seed_ = 42
random.seed(42)
torch.manual_seed(_seed_)
torch.cuda.manual_seed_all(_seed_)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(_seed_)

train_tb_writer = None
te_tb_writer = None
device = torch.device('cuda:0')
deviceIds = [0]

In [None]:
def parse_args():
    parser = argparse.ArgumentParser(description='PyTorch Classification Training')

    parser.add_argument('--device', default='cuda:0', help='device')
    parser.add_argument('-b', '--batch-size', default=32, type=int)
    parser.add_argument('--data-path', default='./data/', help='dataset')
    parser.add_argument('--epochs', default=320, type=int, metavar='N',
                        help='number of total epochs to pre-train')
    parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
                        help='number of data loading workers (default: 16)')
    parser.add_argument('--lr', default=0.0025, type=float, help='initial learning rate')

    parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
    parser.add_argument('--output-dir', default='./logs', help='path where to save')
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--tb', action='store_true',
                        help='Use TensorBoard to record logs')
    parser.add_argument(
        "--cache-dataset",
        dest="cache_dataset",
        help="Cache the datasets for quicker initialization. It also serializes the transforms",
        action="store_true",
    )

    args = parser.parse_args(args=['--data-path','./data','--lr','0.01','-b','128','--epochs','100','--print-freq','100','--tb','--cache-dataset'])
    return args

args = parse_args()

In [None]:
class NetworkB(torch.nn.Module):
    def __init__(self):
        super(NetworkB, self).__init__()
        self.T = 8
        self.conv1 = nn.Conv2d(3, 96, kernel_size=3, padding=1) 
        self.bn1 = nn.BatchNorm2d(96)
        self.sn1 = neuron.IFNode(detach_reset=True)

        self.conv2 = layer.SeqToANNContainer(nn.Conv2d(96, 96, kernel_size=3, padding=1),nn.BatchNorm2d(96))
        self.sn2 = neuron.IFNode(detach_reset=True)
        self.pool1 = layer.SeqToANNContainer(nn.MaxPool2d(2))

        self.conv3 = layer.SeqToANNContainer(nn.Conv2d(96, 128, kernel_size=3, padding=1),nn.BatchNorm2d(128))
        self.sn3 = neuron.IFNode(detach_reset=True)
      
        self.conv4 = layer.SeqToANNContainer(nn.Conv2d(128, 128, kernel_size=3, padding=1),nn.BatchNorm2d(128))
        self.sn4 = neuron.IFNode(detach_reset=True)
        self.pool2 = layer.SeqToANNContainer(nn.MaxPool2d(2))

        self.conv5 = layer.SeqToANNContainer(nn.Conv2d(128, 256, kernel_size=3, padding=1),nn.BatchNorm2d(256))
        self.sn5 = neuron.IFNode(detach_reset=True)

        self.conv6 = layer.SeqToANNContainer(nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.BatchNorm2d(256))
        self.sn6 = neuron.IFNode(detach_reset=True)
  
        self.conv7 = layer.SeqToANNContainer(nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.BatchNorm2d(256))

        self.pool3 = layer.SeqToANNContainer(nn.MaxPool2d(2)) 
        self.conv8 = layer.SeqToANNContainer(nn.Conv2d(256, 512, kernel_size=3, padding=1),nn.BatchNorm2d(512))
        self.sn8 = neuron.IFNode(detach_reset=True)

        self.conv9 = layer.SeqToANNContainer(nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.BatchNorm2d(512)) 
        self.sn9 = neuron.IFNode(detach_reset=True)

        self.conv10 = layer.SeqToANNContainer(nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.BatchNorm2d(512))
        self.sn10 = neuron.IFNode(detach_reset=True)

        self.pool4 = layer.SeqToANNContainer(nn.MaxPool2d(2))
        self.conv11 = layer.SeqToANNContainer(nn.Conv2d(512, 256, kernel_size=3, padding=1),nn.BatchNorm2d(256))
        self.sn11 = neuron.IFNode(detach_reset=True)

        self.conv13 = layer.SeqToANNContainer(nn.Conv2d(256, 128, kernel_size=3, padding=1),nn.BatchNorm2d(128)) 
        self.sn13 = neuron.IFNode(detach_reset=True)

        self.pool5 = layer.SeqToANNContainer(nn.MaxPool2d(2))


        self.fc1 = layer.SeqToANNContainer(nn.Linear(128, 256),nn.BatchNorm1d(256))
        self.sn14 = neuron.IFNode(detach_reset=True)
        self.fc2 = layer.SeqToANNContainer(nn.Linear(256, 10),nn.BatchNorm1d(10))
        self.sn15 = neuron.IFNode(detach_reset=True)


    def forward(self, x):
        T = self.T
        x = self.bn1(self.conv1(x))
        x.unsqueeze_(0)
        x = x.repeat(10, 1, 1, 1, 1)
        x = self.sn1(x)

        x = self.sn2(self.conv2(x))
        x = self.pool1(x)

        x = self.sn3(self.conv3(x))
        x = self.sn4(self.conv4(x))
        x = self.pool2(x)

        x = self.sn5(self.conv5(x))
        x = self.sn6(self.conv6(x))
        x = self.sn7(self.conv7(x))
        x = self.pool3(x)

        x = self.sn8(self.conv8(x))
        x = self.sn9(self.conv9(x))
        x = self.sn10(self.conv10(x))
        x = self.pool4(x)

        x = self.sn11(self.conv11(x))

        x = x[:T, :, :, :, :]
        x = self.sn13(self.conv13(x))
        x = self.pool5(x)

        x = torch.flatten(x,2)

        x = self.sn14(self.fc1(x))
        x = self.sn15(self.fc2(x))
        return x.mean(0)
    
    def set_T(self, T):
        self.T = T

In [None]:
x = torch.randn(2,3,32,32)
net = NetworkB()
net(x)

In [None]:
args = parse_args()
max_test_acc1 = 0.
test_acc5_at_max_test_acc1 = 0.

utils.init_distributed_mode(args)
print(args)
output_dir = os.path.join(args.output_dir, f'b_{args.batch_size}_lr{args.lr}')

time_now = datetime.datetime.now()
output_dir += f'_{time_now.year}_{time_now.month}_{time_now.day}_{time_now.hour}__{time_now.minute}'


if output_dir:
    utils.mkdir(output_dir)


# Load data
data_loader = torch.utils.data.DataLoader(
    dataset=torchvision.datasets.CIFAR10(root= args.data_path,
    train=True,
    transform=torchvision.transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.557, 0.549, 0.5534])
    ]),
    download=True),
    batch_size=args.batch_size, shuffle=True, pin_memory=True, drop_last=True, num_workers=args.workers)

data_loader_test = torch.utils.data.DataLoader(
    dataset=torchvision.datasets.CIFAR10(root=args.data_path,
    train=False,
    transform=torchvision.transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.557, 0.549, 0.5534])
    ]),
    download=True),
    batch_size=args.batch_size, shuffle=False, pin_memory=True, drop_last=False, num_workers=args.workers)


print("Creating model")
net = NetworkB().to(device)
functional.set_step_mode(net,step_mode='m')
functional.set_backend(net, backend='cupy')
net.to(device)


optimizer = torch.optim.SGD(
    net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)




criterion = nn.CrossEntropyLoss()




if args.resume:
    checkpoint = torch.load(args.resume, map_location='cpu')
    args.start_epoch = checkpoint['epoch'] + 1
    max_test_acc1 = checkpoint['max_test_acc1']
    test_acc5_at_max_test_acc1 = checkpoint['test_acc5_at_max_test_acc1']


if args.tb and utils.is_main_process():
    purge_step_train = args.start_epoch
    purge_step_te = args.start_epoch
    train_tb_writer = SummaryWriter(output_dir + '_logs/train', purge_step=purge_step_train)
    te_tb_writer = SummaryWriter(output_dir + '_logs/te', purge_step=purge_step_te)
    with open(output_dir + '_logs/args.txt', 'w', encoding='utf-8') as args_txt:
        args_txt.write(str(args))

    with open(output_dir + '_logs/resluts.txt', 'w', encoding='utf-8') as args_txt:
        args_txt.write('Results\n')


In [None]:
def train_one_epoch(net, criterion, data_loader, device, epoch, print_freq, scaler=None,lr = 1e-2):
    net.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
    metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}'))

    header = 'Epoch: [{}]'.format(epoch)


    for image, target in metric_logger.log_every(data_loader, print_freq, header):
        start_time = time.time()
        image, target = image.to(device), target.to(device)

        output = net(image)
        loss = criterion(output, target)

        optimizer.zero_grad()
        
        loss.backward()

        optimizer.step()


        functional.reset_net(net)

        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        batch_size = image.shape[0]
        loss_s = loss.item()
        if math.isnan(loss_s):
            raise ValueError('loss is Nan')
        acc1_s = acc1.item()
        acc5_s = acc5.item()

        metric_logger.update(loss=loss_s, lr=lr)

        metric_logger.meters['acc1'].update(acc1_s, n=batch_size)
        metric_logger.meters['acc5'].update(acc5_s, n=batch_size)
        metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))

    metric_logger.synchronize_between_processes()
    return metric_logger.loss.global_avg, metric_logger.acc1.global_avg, metric_logger.acc5.global_avg

In [None]:
def evaluate(net, criterion, data_loader, device, print_freq=100, header='Test:'):
    net.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    with torch.no_grad():
        for image, target in metric_logger.log_every(data_loader, print_freq, header):
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = net(image)
            loss = criterion(output, target)
            functional.reset_net(net)

            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        
            batch_size = image.shape[0]
            metric_logger.update(loss=loss.item())
            metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
            metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
    
    metric_logger.synchronize_between_processes()

    loss, acc1, acc5 = metric_logger.loss.global_avg, metric_logger.acc1.global_avg, metric_logger.acc5.global_avg
    print(f' * Acc@1 = {acc1}, Acc@5 = {acc5}, loss = {loss}')
    return loss, acc1, acc5

In [None]:
print("Start training")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
    save_max = False

    train_loss, train_acc1, train_acc5 = train_one_epoch(net, criterion,data_loader, device, epoch, args.print_freq,lr=args.lr)

    if utils.is_main_process():
        train_tb_writer.add_scalar('train_loss', train_loss, epoch)
        train_tb_writer.add_scalar('train_acc1', train_acc1, epoch)
        train_tb_writer.add_scalar('train_acc5', train_acc5, epoch)

    test_loss, test_acc1, test_acc5 = evaluate(net, criterion, data_loader_test, device=device, header='Test:')
    if te_tb_writer is not None:
        if utils.is_main_process():
            te_tb_writer.add_scalar('test_loss', test_loss, epoch)
            te_tb_writer.add_scalar('test_acc1', test_acc1, epoch)
            te_tb_writer.add_scalar('test_acc5', test_acc5, epoch)


    if max_test_acc1 < test_acc1:
        max_test_acc1 = test_acc1
        test_acc5_at_max_test_acc1 = test_acc5
        save_max = True

    if output_dir:

        checkpoint = {
            'model': net.state_dict(),
            'pre-train_epoch': epoch,
            'args': args,
            'max_test_acc1': max_test_acc1,
            'test_acc5_at_max_test_acc1': test_acc5_at_max_test_acc1,
        }

        utils.save_on_master(
            checkpoint,
            os.path.join(output_dir, 'checkpoint_latest.pth'))
        save_flag = False

        if epoch % 64 == 0 or epoch == args.epochs - 1:
            save_flag = True


        if save_flag:
            utils.save_on_master(
                checkpoint,
                os.path.join(output_dir, f'checkpoint_{epoch}.pth'))

        if save_max:
            utils.save_on_master(
                checkpoint,
                os.path.join(output_dir,f'train_maxacc1_{max_test_acc1}_checkpoint_max_test_acc1_all_pretrain.pth'))
            torch.save(net,os.path.join(output_dir,f'train_maxacc1_{max_test_acc1}_checkpoint_max_test_acc1_all_pretrain.pth'))
            torch.save(net.state_dict(),os.path.join(output_dir,f'train_maxacc1_{max_test_acc1}_checkpoint_max_test_acc1_state_pretrain.pth'))
    print(args)
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print(output_dir)

    print('Training time {}'.format(total_time_str), 'max_test_acc1', max_test_acc1,
            'test_acc5_at_max_test_acc1', test_acc5_at_max_test_acc1,'train_acc1', train_acc1,
            'train_acc5', train_acc1)

    Train_logs= {
        'Epoch:': epoch,
        'max_test_acc1 ': max_test_acc1 ,
        'test_acc5_at_max_test_acc1 ': test_acc5_at_max_test_acc1,
        'train_acc1 ': train_acc1,
        'train_acc5 ': train_acc1,
        'args': args
    }
    with open(output_dir + '_logs/args.txt', 'a', encoding='utf-8') as args_txt:
        args_txt.write('\n')
        args_txt.write(str(Train_logs))

In [None]:
acc_list = []
for time in range(1,11):
    acc = 0
    net = NetworkB().to(device)
    functional.set_step_mode(net,step_mode='m')
    functional.set_backend(net, backend='cupy')
    weights = torch.load('logs/b_128_lr0.01_2024_9_27_23__48/train_maxacc1_84.04_checkpoint_max_test_acc1_state_pretrain.pth')
    net.load_state_dict(weights)
    net.set_T(time)
    test_loss, test_acc1, test_acc5 = evaluate(net, criterion, data_loader_test, device=device, header='Test:')
    acc += test_acc1
    acc_list.append(acc)

In [None]:
print(acc_list)

In [None]:
net = NetworkB().to(device)
functional.set_step_mode(net,step_mode='m')
functional.set_backend(net, backend='cupy')
net.to(device)

weights = torch.load('./models/snn/max_86.81_weights.pth')
net.load_state_dict(weights)
test_loss, test_acc1, test_acc5 = evaluate(net, criterion, data_loader_test, device=device, header='Test:')

In [None]:
root1 = 'mat/snn/'
class NetworkFuse(torch.nn.Module):
    def __init__(self):
        super(NetworkFuse, self).__init__()
        self.write = False

        self.conv1 = nn.Conv2d(3, 128, kernel_size=3, padding=0)
        self.sn1 = neuron.LIFNode(detach_reset=True)

        self.conv2 = layer.SeqToANNContainer(nn.Conv2d(128, 256, kernel_size=3, padding=0))
        self.sn2 = neuron.LIFNode(detach_reset=True)
        self.pool1 = layer.SeqToANNContainer(nn.MaxPool2d(2))

        self.conv3 = layer.SeqToANNContainer(nn.Conv2d(256, 512, kernel_size=3, padding=0))
        self.sn3 = neuron.LIFNode(detach_reset=True)
        self.pool2 = layer.SeqToANNContainer(nn.MaxPool2d(2))

        self.conv4 = layer.SeqToANNContainer(nn.Conv2d(512, 256, kernel_size=3, padding=0))
        self.sn4 = neuron.LIFNode(detach_reset=True)

        self.conv6 = layer.SeqToANNContainer(nn.Conv1d(1, 1, kernel_size=7, stride=4, padding=3))
        self.sn6 = neuron.LIFNode(detach_reset=True)

        self.fc1 = layer.SeqToANNContainer(nn.Linear(1024, 1024))
        self.sn7 = neuron.LIFNode(detach_reset=True)
        self.fc2 = layer.SeqToANNContainer(nn.Linear(1024, 512))
        self.sn8 = neuron.LIFNode(detach_reset=True)
        self.fc3 = layer.SeqToANNContainer(nn.Linear(512, 10))
        
    def forward(self, x):
        T = 8
        x_in = torch.zeros_like(x)
        sn1_out_spike = torch.zeros(x.shape[0],128,32,32)
        sn2_out_spike = torch.zeros(T,x.shape[0],256,30,30)
        pool1_out_spike = torch.zeros(T,x.shape[0],256,15,15)
        sn3_out_spike = torch.zeros(T,x.shape[0],512,13,13)
        pool2_out_spike = torch.zeros(T,x.shape[0],512,6,6)
        sn4_out_spike = torch.zeros(T,x.shape[0],256,6,6)
        sn6_out_spike = torch.zeros(T,x.shape[0],1,512)
        sn7_out_spike = torch.zeros(T,x.shape[0],1024)
        sn8_out_spike = torch.zeros(T,x.shape[0],512)
        linear_out_spike = torch.zeros(T,x.shape[0],10)

        sn1_out_v = torch.zeros(T,x.shape[0],128,32,32)
        sn2_out_v = torch.zeros(T,x.shape[0],256,30,30)
        sn3_out_v = torch.zeros(T,x.shape[0],512,13,13)
        sn4_out_v = torch.zeros(T,x.shape[0],256,6,6)
        sn6_out_v = torch.zeros(T,x.shape[0],1,512)
        sn7_out_v = torch.zeros(T,x.shape[0],1024)
        sn8_out_v = torch.zeros(T,x.shape[0],512)


        x_in = x.cpu().detach()
        x = self.conv1(x)
        x.unsqueeze_(0)
        x = x.repeat(T, 1, 1, 1, 1)
        x = self.sn1(x)
        sn1_out_spike = x.cpu().detach()
        sn1_out_v = self.sn1.v.cpu().detach()

        x = self.sn2(self.conv2(x))
        sn2_out_spike = x.cpu().detach()
        sn2_out_v = self.sn2.v.cpu().detach()  

        x = self.pool1(x)
        pool1_out_spike = x.cpu().detach()

        x = self.sn3(self.conv3(x))
        sn3_out_spike = x.cpu().detach()
        sn3_out_v = self.sn3.v.cpu().detach()

        x = self.pool2(x)
        pool2_out_spike = x.cpu().detach()

        x = self.sn4(self.conv4(x))
        sn4_out_spike = x.cpu().detach()
        sn4_out_v = self.sn4.v.cpu().detach()

        x = torch.flatten(x,2).unsqueeze(2)

        x = self.sn6(self.conv6(x))
        sn6_out_spike = x.cpu().detach()
        sn6_out_v = self.sn6.v.cpu().detach()


        x = torch.flatten(x,2)


        x = self.sn8(self.fc2(x))
        sn8_out_spike = x.cpu().detach()
        sn8_out_v = self.sn8.v.cpu().detach()

        x = self.fc3(x)
        linear_out_spike = x.cpu().detach()

        if self.write:
            savemat(root1+'x_in.mat', {'x_in': x_in})
            savemat(root1+'sn1_out_spike.mat', {'sn1_out_spike': sn1_out_spike})
            savemat(root1+'sn1_out_v.mat', {'sn1_out_v': sn1_out_v})
            savemat(root1+'sn2_out_spike.mat', {'sn2_out_spike': sn2_out_spike})
            savemat(root1+'sn2_out_v.mat', {'sn2_out_v': sn2_out_v})
            savemat(root1+'pool1_out_spike.mat', {'pool1_out_spike': pool1_out_spike})
            savemat(root1+'sn3_out_spike.mat', {'sn3_out_spike': sn3_out_spike})
            savemat(root1+'sn3_out_v.mat', {'sn3_out_v': sn3_out_v})
            savemat(root1+'pool2_out_spike.mat', {'pool2_out_spike': pool2_out_spike})
            savemat(root1+'sn4_out_spike.mat', {'sn4_out_spike': sn4_out_spike})
            savemat(root1+'sn4_out_v.mat', {'sn4_out_v': sn4_out_v})
            savemat(root1+'sn6_out_spike.mat', {'sn6_out_spike': sn6_out_spike})
            savemat(root1+'sn6_out_v.mat', {'sn6_out_v': sn6_out_v})
            savemat(root1+'sn7_out_spike.mat', {'sn7_out_spike': sn7_out_spike})
            savemat(root1+'sn7_out_v.mat', {'sn7_out_v': sn7_out_v})
            savemat(root1+'sn8_out_spike.mat', {'sn8_out_spike': sn8_out_spike})
            savemat(root1+'sn8_out_v.mat', {'sn8_out_v': sn8_out_v})
            savemat(root1+'linear_out_spike.mat', {'linear_out_spike': linear_out_spike})

        return x.mean(0)
    
    def Net_write(self, whether_write):
        self.write = whether_write

In [None]:
net_fuse = NetworkFuse().to(device)
functional.set_step_mode(net_fuse,step_mode='m')
functional.set_backend(net_fuse, backend='cupy')
net_fuse.Net_write(True)

In [None]:
import torch.nn.utils.fusion as fusion

model_name = []

for name, module in net.named_modules():
  model_name.append(name)
print(model_name)

for i in range(len(model_name)):
    if 'conv1' in model_name[i]:
        if 'bn' in model_name[i+1]:
          conv_name = model_name[i]
          bn_name = model_name[i+1]
          conv_module = getattr(net, conv_name)
          bn_module = getattr(net, bn_name)

          new_conv_weights, new_conv_bias = fusion.fuse_conv_bn_weights(conv_w = conv_module.weight, conv_b = conv_module.bias, bn_w = bn_module.weight, bn_b = bn_module.bias, bn_rm = bn_module.running_mean, bn_rv = bn_module.running_var, bn_eps = bn_module.eps)
          fuse_conv_module = getattr(net_fuse,conv_name)
          fuse_conv_module.weight = torch.nn.Parameter(new_conv_weights)
          fuse_conv_module.bias = torch.nn.Parameter(new_conv_bias)
    elif 'conv' in model_name[i] and '.' not in model_name[i]:
        conv_modules = getattr(net, model_name[i])
        if isinstance(conv_modules,layer.SeqToANNContainer):
           if isinstance(conv_modules[1],nn.BatchNorm2d) or isinstance(conv_modules[1],nn.BatchNorm1d):
              print(conv_modules[1])
              conv_module = conv_modules[0]
              bn_module = conv_modules[1]

              new_conv_weights, new_conv_bias = fusion.fuse_conv_bn_weights(conv_w = conv_module.weight, conv_b = conv_module.bias, bn_w = bn_module.weight, bn_b = bn_module.bias, bn_rm = bn_module.running_mean, bn_rv = bn_module.running_var, bn_eps = bn_module.eps)
              fuse_modules = getattr(net_fuse, model_name[i])
              fuse_conv_module = getattr(net_fuse,model_name[i])
              fuse_conv_module[0].weight = torch.nn.Parameter(new_conv_weights)
              fuse_conv_module[0].bias = torch.nn.Parameter(new_conv_bias)

    elif 'fc' in model_name[i] and '.' not in model_name[i]:
        fc_modules = getattr(net, model_name[i])
        if isinstance(fc_modules,layer.SeqToANNContainer):
           if isinstance(fc_modules[1],nn.BatchNorm2d) or isinstance(fc_modules[1],nn.BatchNorm1d):
              print(fc_modules[1])
              fc_module = fc_modules[0]
              bn_module = fc_modules[1]

              new_fc_weights, new_fc_bias = fusion.fuse_linear_bn_weights(linear_w = fc_module.weight, linear_b = fc_module.bias, bn_w = bn_module.weight, bn_b = bn_module.bias, bn_rm = bn_module.running_mean, bn_rv = bn_module.running_var, bn_eps = bn_module.eps)
              fuse_modules = getattr(net_fuse, model_name[i])
              fuse_fc_module = getattr(net_fuse,model_name[i])
              fuse_fc_module[0].weight = torch.nn.Parameter(new_fc_weights)
              fuse_fc_module[0].bias = torch.nn.Parameter(new_fc_bias)
        

        else:
          print(model_name[i])
          conv_name = model_name[i]
          conv_module = getattr(net, conv_name)
          fuse_conv_module = getattr(net_fuse,conv_name)
          fuse_conv_module.weight = conv_module.weight
          fuse_conv_module.bias = conv_module.bias
    elif 'linear' in model_name[i]:
        print(model_name[i])
        fc_name = model_name[i]
        fc_module = getattr(net, fc_name)
        fuse_fc_module = getattr(net_fuse,fc_name)
        fuse_fc_module.weight = fc_module.weight
        fuse_fc_module.bias = fc_module.bias
    elif 'connect' in model_name[i]:
        print(model_name[i])
        connect_name = model_name[i]
        connect_module = getattr(net, connect_name)
        fuse_connect_module = getattr(net_fuse,connect_name)
        fuse_connect_module.weight = connect_module.weight
        fuse_connect_module.bias = connect_module.bias


In [None]:
test_loss, test_acc1, test_acc5 = evaluate(net_fuse, criterion, data_loader_test, device=device, header='Test:')

In [None]:
print(net_fuse)

In [None]:
torch.save(net_fuse.state_dict(), root1+'SNNCom_SNN.pth')

In [None]:
import torch
import scipy.io as sio

layer_names = []
for name, _ in net_fuse.named_parameters():
    layer_names.append(name.split('.')[0])


for name, param in net_fuse.named_parameters():
    layer_name = name.split('.')[0]
    weight_name = root1 + name + '.mat'

    if 'weight' in name:
   
      weight_dict = {'weight': param.data.cpu().numpy()}
    elif 'bias' in name:

      weight_dict = {'bisa': param.data.cpu().numpy()}
    else:
       print('error')
    sio.savemat(weight_name, weight_dict)
