In [None]:
import os
import time
import utils
import math
import argparse
import requests
import random
import datetime
import numpy as np

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.cuda import amp
import torch.distributed.optim
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary

import torchvision
from torchvision import transforms
from spikingjelly.activation_based import layer,functional,neuron,surrogate

from scipy.io import loadmat,savemat
from sklearn.cluster import KMeans

In [None]:
torch.backends.cudnn.benchmark = True
_seed_ = 42
random.seed(42)
torch.manual_seed(_seed_)
torch.cuda.manual_seed_all(_seed_)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(_seed_)

train_tb_writer = None
te_tb_writer = None
device = torch.device('cuda:1')
deviceIds = [0]

In [None]:
def parse_args():
    parser = argparse.ArgumentParser(description='PyTorch Classification Training')

    parser.add_argument('--device', default='cuda:0', help='device')
    parser.add_argument('-b', '--batch-size', default=32, type=int)
    parser.add_argument('--data-path', default='./data/', help='dataset')
    parser.add_argument('--epochs', default=320, type=int, metavar='N',
                        help='number of total epochs to pre-train')
    parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
                        help='number of data loading workers (default: 16)')
    parser.add_argument('--lr', default=0.0025, type=float, help='initial learning rate')

    parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
    parser.add_argument('--output-dir', default='./logs', help='path where to save')
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--tb', action='store_true',
                        help='Use TensorBoard to record logs')
    parser.add_argument(
        "--cache-dataset",
        dest="cache_dataset",
        help="Cache the datasets for quicker initialization. It also serializes the transforms",
        action="store_true",
    )

    args = parser.parse_args(args=['--data-path','./data','--lr','0.01','-b','128','--epochs','100','--print-freq','100','--tb','--cache-dataset'])
    return args

args = parse_args()

In [None]:
class NetworkA(torch.nn.Module):
    def __init__(self):
        super(NetworkA, self).__init__()
        self.T = 8
        self.conv1 = nn.Conv2d(3, 96, kernel_size=3, padding=1) 
        self.bn1 = nn.BatchNorm2d(96)
        self.sn1 = nn.ReLU()

        self.conv2 = nn.Conv2d(96, 96, kernel_size=3, padding=1) 
        self.sn2 = nn.ReLU()

        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 

        self.conv3 = nn.Conv2d(96, 128, kernel_size=3, padding=1) 
        self.bn3 = nn.BatchNorm2d(128)
        self.sn3 = nn.ReLU()

        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.sn4 = nn.ReLU()

        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1) 
        self.bn5 = nn.BatchNorm2d(256)
        self.sn5 = nn.ReLU()

        self.conv6 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 
        self.bn6 = nn.BatchNorm2d(256)
        self.sn6 = nn.ReLU()

        self.conv7 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 
        self.bn7 = nn.BatchNorm2d(256)
        self.sn7 = nn.ReLU()

        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv8 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.bn8 = nn.BatchNorm2d(512)
        self.sn8 = nn.ReLU()

        self.conv9 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn9 = nn.BatchNorm2d(512)
        self.sn9 = nn.ReLU()

        self.conv10 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn10 = nn.BatchNorm2d(512)
        self.sn10 = nn.ReLU()

        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 
        self.conv11 = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.bn11 = nn.BatchNorm2d(256)
        self.sn11 = neuron.IFNode(detach_reset=True)

        self.conv13 = layer.SeqToANNContainer(nn.Conv2d(256, 128, kernel_size=3, padding=1))
        self.bn13 = layer.SeqToANNContainer(nn.BatchNorm2d(128))
        self.sn13 = neuron.IFNode(detach_reset=True)

        self.pool5 = layer.SeqToANNContainer(nn.MaxPool2d(kernel_size=2, stride=2))

        self.linear1 = layer.SeqToANNContainer(nn.Linear(128, 256))
        self.bn14 = layer.SeqToANNContainer(nn.BatchNorm1d(256))
        self.sn14 = neuron.IFNode(detach_reset=True)

        self.linear3 = layer.SeqToANNContainer(nn.Linear(256, 10))
        self.bn15 = layer.SeqToANNContainer(nn.BatchNorm1d(10))
        self.sn15 = neuron.IFNode(detach_reset=True)



    def forward(self, x):
        
        T = self.T
        
        x = self.sn1(self.bn1(self.conv1(x)))
        x = self.sn2(self.bn2(self.conv2(x)))
        x = self.pool1(x)

        x = self.sn3(self.bn3(self.conv3(x)))
        x = self.sn4(self.bn4(self.conv4(x)))
        x = self.pool2(x)

        x = self.sn5(self.bn5(self.conv5(x)))
        x = self.pool3(x)

        x = self.sn8(self.bn8(self.conv8(x)))
        x = self.sn10(self.bn10(self.conv10(x)))
        x = self.pool4(x)


        x = self.bn11(self.conv11(x))

        x = x.unsqueeze(0)
        x = x.repeat(T, 1, 1, 1, 1)
        x = self.sn11(x)
        print(x.shape)
        x = self.sn13(self.bn13(self.conv13(x)))
        x = self.pool5(x)

        x = torch.flatten(x,2)

        x = self.sn14(self.bn14(self.linear1(x)))

        x = self.sn15(self.bn15(self.linear3(x)))

        return x.mean(0)
    
    def set_T(self, T):
        self.T = T

In [None]:
x = torch.randn(2,3,32,32)
net = NetworkA()
functional.set_step_mode(net,step_mode='m')
net(x)

In [None]:
args = parse_args()
max_test_acc1 = 0.
test_acc5_at_max_test_acc1 = 0.

utils.init_distributed_mode(args)
print(args)
output_dir = os.path.join(args.output_dir, f'b_{args.batch_size}_lr{args.lr}')

time_now = datetime.datetime.now()
output_dir += f'_{time_now.year}_{time_now.month}_{time_now.day}_{time_now.hour}__{time_now.minute}'


if output_dir:
    utils.mkdir(output_dir)


# Load data
data_loader = torch.utils.data.DataLoader(
    dataset=torchvision.datasets.CIFAR10(root= args.data_path,
    train=True,
    transform=torchvision.transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.557, 0.549, 0.5534])
    ]),
    download=True),
    batch_size=args.batch_size, shuffle=True, pin_memory=True, drop_last=True, num_workers=args.workers)

data_loader_test = torch.utils.data.DataLoader(
    dataset=torchvision.datasets.CIFAR10(root=args.data_path,
    train=False,
    transform=torchvision.transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.557, 0.549, 0.5534])
    ]),
    download=True),
    batch_size=args.batch_size, shuffle=False, pin_memory=True, drop_last=False, num_workers=args.workers)


print("Creating model")

net = NetworkA().to(device)
functional.set_step_mode(net,step_mode='m')
functional.set_backend(net, backend='cupy')
net.to(device)


optimizer = torch.optim.SGD(
    net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)



scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.4, last_epoch=-1)


criterion = nn.CrossEntropyLoss()




if args.resume:
    checkpoint = torch.load(args.resume, map_location='cpu')
    args.start_epoch = checkpoint['epoch'] + 1
    max_test_acc1 = checkpoint['max_test_acc1']
    test_acc5_at_max_test_acc1 = checkpoint['test_acc5_at_max_test_acc1']


if args.tb and utils.is_main_process():
    purge_step_train = args.start_epoch
    purge_step_te = args.start_epoch
    train_tb_writer = SummaryWriter(output_dir + '_logs/train', purge_step=purge_step_train)
    te_tb_writer = SummaryWriter(output_dir + '_logs/te', purge_step=purge_step_te)
    with open(output_dir + '_logs/args.txt', 'w', encoding='utf-8') as args_txt:
        args_txt.write(str(args))

    with open(output_dir + '_logs/resluts.txt', 'w', encoding='utf-8') as args_txt:
        args_txt.write('Results\n')


In [None]:
def train_one_epoch(net, criterion, data_loader, device, epoch, print_freq, scaler=None,lr = 1e-2):
    net.train()

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
    metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}'))

    header = 'Epoch: [{}]'.format(epoch)

    for image, target in metric_logger.log_every(data_loader, print_freq, header):
        start_time = time.time()
        image, target = image.to(device), target.to(device)

        output = net(image)
        loss = criterion(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        functional.reset_net(net)

        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        batch_size = image.shape[0]
        loss_s = loss.item()
        if math.isnan(loss_s):
            raise ValueError('loss is Nan')
        acc1_s = acc1.item()
        acc5_s = acc5.item()

        metric_logger.update(loss=loss_s, lr=lr)

        metric_logger.meters['acc1'].update(acc1_s, n=batch_size)
        metric_logger.meters['acc5'].update(acc5_s, n=batch_size)
        metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))
        
    scheduler.step()
    metric_logger.synchronize_between_processes()
    return metric_logger.loss.global_avg, metric_logger.acc1.global_avg, metric_logger.acc5.global_avg

In [None]:
def evaluate(net, criterion, data_loader, device, print_freq=100, header='Test:'):
    net.eval()
    net.to(device)
    metric_logger = utils.MetricLogger(delimiter="  ")
    with torch.no_grad():
        for image, target in metric_logger.log_every(data_loader, print_freq, header):
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = net(image)
            loss = criterion(output, target)
            functional.reset_net(net)

            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            batch_size = image.shape[0]
            metric_logger.update(loss=loss.item())
            metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
            metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
    metric_logger.synchronize_between_processes()

    loss, acc1, acc5 = metric_logger.loss.global_avg, metric_logger.acc1.global_avg, metric_logger.acc5.global_avg
    print(f' * Acc@1 = {acc1}, Acc@5 = {acc5}, loss = {loss}')
    return loss, acc1, acc5

In [None]:
print("Start training")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
    save_max = False

    train_loss, train_acc1, train_acc5 = train_one_epoch(net, criterion,data_loader, device, epoch, args.print_freq,lr=args.lr)

    if utils.is_main_process():
        train_tb_writer.add_scalar('train_loss', train_loss, epoch)
        train_tb_writer.add_scalar('train_acc1', train_acc1, epoch)
        train_tb_writer.add_scalar('train_acc5', train_acc5, epoch)

    test_loss, test_acc1, test_acc5 = evaluate(net, criterion, data_loader_test, device=device, header='Test:')
    if te_tb_writer is not None:
        if utils.is_main_process():
            te_tb_writer.add_scalar('test_loss', test_loss, epoch)
            te_tb_writer.add_scalar('test_acc1', test_acc1, epoch)
            te_tb_writer.add_scalar('test_acc5', test_acc5, epoch)


    if max_test_acc1 < test_acc1:
        max_test_acc1 = test_acc1
        test_acc5_at_max_test_acc1 = test_acc5
        save_max = True

    if output_dir:

        checkpoint = {
            'model': net.state_dict(),
            'pre-train_epoch': epoch,
            'args': args,
            'max_test_acc1': max_test_acc1,
            'test_acc5_at_max_test_acc1': test_acc5_at_max_test_acc1,
        }

        utils.save_on_master(
            checkpoint,
            os.path.join(output_dir, 'checkpoint_latest.pth'))
        save_flag = False

        if epoch % 64 == 0 or epoch == args.epochs - 1:
            save_flag = True


        if save_flag:
            utils.save_on_master(
                checkpoint,
                os.path.join(output_dir, f'checkpoint_{epoch}.pth'))

        if save_max:
            utils.save_on_master(
                checkpoint,
                os.path.join(output_dir,f'train_maxacc1_{max_test_acc1}_checkpoint_max_test_acc1_all_pretrain.pth'))
            torch.save(net,os.path.join(output_dir,f'train_maxacc1_{max_test_acc1}_checkpoint_max_test_acc1_all_pretrain.pth'))
            torch.save(net.state_dict(),os.path.join(output_dir,f'train_maxacc1_{max_test_acc1}_checkpoint_max_test_acc1_state_pretrain.pth'))
    print(args)
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print(output_dir)

    print('Training time {}'.format(total_time_str), 'max_test_acc1', max_test_acc1,
            'test_acc5_at_max_test_acc1', test_acc5_at_max_test_acc1,'train_acc1', train_acc1,
            'train_acc5', train_acc1)

    Train_logs= {
        'Epoch:': epoch,
        'max_test_acc1 ': max_test_acc1 ,
        'test_acc5_at_max_test_acc1 ': test_acc5_at_max_test_acc1,
        'train_acc1 ': train_acc1,
        'train_acc5 ': train_acc1,
        'args': args
    }
    with open(output_dir + '_logs/args.txt', 'a', encoding='utf-8') as args_txt:
        args_txt.write('\n')
        args_txt.write(str(Train_logs))

rep_pre_acc1 = max_test_acc1
rep_pre_acc5 = test_acc5_at_max_test_acc1
requests.get("http://www.pushplus.plus/send?token=d58b0916a344410c911ae19aeffc67e2&title=111程序通知&content={}{}\n{}{}&template=html".format(
    "pre_acc1 = ",rep_pre_acc1, "pre_acc5 = " ,rep_pre_acc5))

In [None]:
net = NetworkA().to(device)
functional.set_step_mode(net,step_mode='m')
functional.set_backend(net, backend='cupy')
weights = torch.load('/home/yanggl/code/snncom/logs/b_128_lr0.01_2024_9_27_23__50/train_maxacc1_86.49_checkpoint_max_test_acc1_state_pretrain.pth')
net.load_state_dict(weights)
net.set_T(8)
test_loss, test_acc1, test_acc5 = evaluate(net, criterion, data_loader_test, device=device, header='Test:')

In [None]:
acc_list = []
for time in range(1,11):
    acc = 0
    net = NetworkA().to(device)
    functional.set_step_mode(net,step_mode='m')
    functional.set_backend(net, backend='cupy')
    weights = torch.load('/home/yanggl/code/snncom/logs/b_128_lr0.01_2024_9_27_23__50/train_maxacc1_86.49_checkpoint_max_test_acc1_state_pretrain.pth')
    net.load_state_dict(weights)
    net.set_T(time)
    test_loss, test_acc1, test_acc5 = evaluate(net, criterion, data_loader_test, device=device, header='Test:')
    acc += test_acc1
    acc_list.append(acc)

In [None]:
print(acc_list)

In [None]:
def find_scale_factor(tensor, num_bits):
    tensor = tensor.cpu().numpy()
    max_val = np.max(np.abs(tensor)) 
    scale_factor = (2**(num_bits - 1) - 1) / max_val
    return scale_factor

In [None]:
def k_means_cpu(weight, n_clusters, init='k-means++',quantization_bits = 8):
    quantization_bits = quantization_bits
    org_shape = weight.shape
    weight = weight.reshape(-1, 1)
    if n_clusters > weight.size:
        n_clusters = weight.size

    k_means = KMeans(n_clusters=n_clusters, init=init, n_init=1, max_iter=50)
    k_means.fit(weight)

    centroids = torch.from_numpy(k_means.cluster_centers_).cuda().view(1, -1)
    labels = k_means.labels_
    labels = torch.from_numpy(labels.reshape(org_shape)).int().cuda()
    weight = torch.zeros_like(labels).float().cuda()
    for i, c in enumerate(centroids.cpu().numpy().squeeze()):
        weight[labels == i] = c.item()
    scale_factor = find_scale_factor(weight, quantization_bits)
    weight = torch.round(weight * scale_factor)
    print('scale_factor:',scale_factor)
    return centroids, labels, weight

In [None]:
class QuantLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=False):
        super(QuantLinear, self).__init__(in_features, out_features, bias)
        self.num_cent = 8
        self.weight_labels = None
        self.centroids = None
        
    def kmeans_quant(self):
        
        w = self.weight.data
        self.centroids, self.weight_labels, w_q = k_means_cpu(w.cpu().numpy(), self.num_cent)
        self.weight_labels = self.weight_labels + 1
        self.weight.data = w_q

            
class QuantConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=False):
        super(QuantConv2d, self).__init__(in_channels, out_channels, 
            kernel_size, stride, padding, dilation, groups, bias)
        self.num_cent = 8
        self.weight_labels = None
        self.centroids = None
        
    def kmeans_quant(self):
        
        w = self.weight.data
        self.centroids, self.weight_labels, w_q = k_means_cpu(w.cpu().numpy(), self.num_cent)
        self.weight_labels = self.weight_labels + 1
        self.weight.data = w_q

In [None]:
root1 = '/home/mlw/Desktop/paper_SNNcom/mat/mix_nn/b8k8/'
class NetworkFuse(torch.nn.Module):
    def __init__(self):
        super(NetworkFuse, self).__init__()
        self.write = False
        self.T = 8

        self.conv1 = nn.Conv2d(3, 96, kernel_size=3, padding=1)
        self.sn1 = nn.ReLU()

        self.conv2 = nn.Conv2d(96, 96, kernel_size=3, padding=1)
        self.sn2 = nn.ReLU()

        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv3 = nn.Conv2d(96, 128, kernel_size=3, padding=1)
        self.sn3 = nn.ReLU()

        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.sn4 = nn.ReLU()

        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 

        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1) 
        self.sn5 = nn.ReLU()

        self.conv6 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.sn6 = nn.ReLU()

        self.conv7 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 
        self.sn7 = nn.ReLU()

        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 

        self.conv8 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.sn8 = nn.ReLU()

        self.conv9 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 
        self.sn9 = nn.ReLU()

        self.conv10 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 
        self.sn10 = nn.ReLU()

        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)


        self.conv11 = QuantConv2d(512, 256, kernel_size=3, padding=1) 
        self.sn11 = neuron.IFNode(detach_reset=True)


        self.conv13 = layer.SeqToANNContainer(QuantConv2d(256, 128, kernel_size=3, padding=1)) 
        self.sn13 = neuron.IFNode(detach_reset=True)

        self.pool5 = layer.SeqToANNContainer(nn.MaxPool2d(kernel_size=2, stride=2))

        self.linear1 = layer.SeqToANNContainer(QuantLinear(128, 512))
        self.sn14 = neuron.IFNode(detach_reset=True)

        self.linear3 = layer.SeqToANNContainer(QuantLinear(512, 10)) 
        self.sn15 = neuron.IFNode(detach_reset=True)

    def forward(self, x):
        
        T = self.T

        x_in = torch.zeros_like(x)

        sn1_out = torch.zeros([x.shape[0], 96, 32, 32])

        sn2_out = torch.zeros([x.shape[0], 96, 32, 32])

        pool1_out = torch.zeros([x.shape[0], 96, 16, 16])

        sn3_out = torch.zeros([x.shape[0], 128, 16, 16])

        sn4_out = torch.zeros([x.shape[0], 128, 16, 16])

        pool2_out = torch.zeros([x.shape[0], 128, 8, 8])

        sn5_out = torch.zeros([x.shape[0], 256, 8, 8])

        pool3_out = torch.zeros([x.shape[0], 256, 4, 4])

        sn8_out = torch.zeros([x.shape[0], 512, 4, 4])

        sn10_out = torch.zeros([x.shape[0], 512, 4, 4])

        pool4_out = torch.zeros([x.shape[0], 512, 2, 2])

        sn11_spike = torch.zeros([T,x.shape[0], 512, 2, 2])
        sn11_v = torch.zeros([T,x.shape[0], 512, 2, 2])

        sn13_spike = torch.zeros([T,x.shape[0], 128, 2, 2])
        sn13_v = torch.zeros([T,x.shape[0], 128, 2, 2])

        pool5_spike = torch.zeros([T,x.shape[0], 128, 1, 1])

        sn14_spike = torch.zeros([T,x.shape[0], 512])
        sn14_v = torch.zeros([T,x.shape[0], 512])

        sn15_spike = torch.zeros([T,x.shape[0], 10])
        sn15_v = torch.zeros([T,x.shape[0], 10])
        
        x_in = x.detach().cpu()

        x = self.sn1(self.conv1(x))
        sn1_out = x.detach().cpu()

        x = self.sn2(self.conv2(x))
        sn2_out = x.detach().cpu()

        x = self.pool1(x)
        pool1_out = x.detach().cpu()

        x = self.sn3(self.conv3(x))
        sn3_out = x.detach().cpu()

        x = self.sn4(self.conv4(x))
        sn4_out = x.detach().cpu()

        x = self.pool2(x)
        pool2_out = x.detach().cpu()

        x = self.sn5(self.conv5(x))
        sn5_out = x.detach().cpu()

        x = self.pool3(x)
        pool3_out = x.detach().cpu()

        x = self.sn8(self.conv8(x))
        sn8_out = x.detach().cpu()


        x = self.sn10(self.conv10(x))
        sn10_out = x.detach().cpu()

        x = self.pool4(x)
        pool4_out = x.detach().cpu()

        x = self.conv11(x)
        
        x = x.unsqueeze(0)
        x = x.repeat(T, 1, 1, 1, 1)
        x = self.sn11(x)
        sn11_spike = x.detach().cpu()
        sn11_v = self.sn11.v.detach().cpu()
    

        x = self.sn13(self.conv13(x))
        sn13_spike = x.detach().cpu()
        sn13_v = self.sn13.v.detach().cpu()

        x = self.pool5(x)
        pool5_spike = x.detach().cpu()

        x = torch.flatten(x,2)

        x = self.sn14(self.linear1(x))
        sn14_spike = x.detach().cpu()
        sn14_v = self.sn14.v.detach().cpu()

        x = self.sn15(self.linear3(x))
        sn15_spike = x.detach().cpu()
        sn15_v = self.sn15.v.detach().cpu() 

        if self.write:
            savemat(root1+'x_in.mat', {'x_in': x_in})
            savemat(root1+'sn1_out.mat', {'sn1_out': sn1_out})
            savemat(root1+'sn2_out.mat', {'sn2_out': sn2_out})
            savemat(root1+'pool1_out.mat', {'pool1_out': pool1_out})
            savemat(root1+'sn3_out.mat', {'sn3_out': sn3_out})
            savemat(root1+'sn4_out.mat', {'sn4_out': sn4_out})
            savemat(root1+'pool2_out.mat', {'pool2_out': pool2_out})
            savemat(root1+'sn5_out.mat', {'sn5_out': sn5_out})
            savemat(root1+'pool3_out.mat', {'pool3_out': pool3_out})
            savemat(root1+'sn8_out.mat', {'sn8_out': sn8_out})
            savemat(root1+'sn10_out.mat', {'sn10_out': sn10_out})
            savemat(root1+'pool4_out.mat', {'pool4_out': pool4_out})
            savemat(root1+'sn11_spike.mat', {'sn11_spike': sn11_spike})
            savemat(root1+'sn11_v.mat', {'sn11_v': sn11_v})
            savemat(root1+'sn13_spike.mat', {'sn13_spike': sn13_spike})
            savemat(root1+'sn13_v.mat', {'sn13_v': sn13_v})
            savemat(root1+'pool5_spike.mat', {'pool5_spike': pool5_spike})
            savemat(root1+'sn14_spike.mat', {'sn14_spike': sn14_spike})
            savemat(root1+'sn14_v.mat', {'sn14_v': sn14_v})
            savemat(root1+'sn15_spike.mat', {'sn15_spike': sn15_spike})
            savemat(root1+'sn15_v.mat', {'sn15_v': sn15_v})
            savemat(root1+'Conv13Weight.mat', {'conv13_q':self.conv13[0].weight.cpu().detach().numpy(), 'conv13_label':self.conv13[0].weight_labels.cpu().detach().numpy(), 'conv13_center':self.conv13[0].centroids.cpu().detach().numpy()})
            savemat(root1+'Linear1Weight.mat', {'linear1_q':self.linear1[0].weight.cpu().detach().numpy(), 'linear1_label':self.linear1[0].weight_labels.cpu().detach().numpy(), 'linear1_center':self.linear1[0].centroids.cpu().detach().numpy()})
            savemat(root1+'Linear3Weight.mat', {'linear3_q':self.linear3[0].weight.cpu().detach().numpy(), 'linear3_label':self.linear3[0].weight_labels.cpu().detach().numpy(), 'linear3_center':self.linear3[0].centroids.cpu().detach().numpy()})

        return x.mean(0)
    def Net_write(self, whether_write):
        self.write = whether_write

    def set_T(self, T):
        self.T = T
    def kmeans_quant(self):
        self.conv13[0].kmeans_quant()
        self.linear1[0].kmeans_quant()
        self.linear3[0].kmeans_quant()
        
    def set_threshold(self, threshold):
        self.sn13.v_threshold = threshold[0]
        self.sn14.v_threshold = threshold[1]
        self.sn15.v_threshold = threshold[2]

In [None]:
x = torch.randn(2,3,32,32).to(device)
net_fuse = NetworkFuse().to(device)
set_threshold = [1.,1.,1.]
net_fuse.set_threshold(set_threshold)
functional.set_step_mode(net_fuse,step_mode='m')
functional.set_backend(net_fuse, backend='cupy')
net_fuse(x)

In [None]:
import torch.nn.utils.fusion as fusion

model_name = []

for name, module in net.named_modules():
  model_name.append(name)
  # print(module)
print(model_name)

# model_name = []
# for name, module in net_fuse.named_modules():
#   model_name.append(name)
#   # print(module)
# print(model_name)

for i in range(len(model_name)):
    if 'conv' in model_name[i] and '.' not in model_name[i]:
        if 'bn' in model_name[i+1]: ##后面带bn的处理
          print('s1:',model_name[i])
          conv_name = model_name[i]
          bn_name = model_name[i+1]
          conv_module = getattr(net, conv_name)
          bn_module = getattr(net, bn_name)
          new_conv_weights, new_conv_bias = fusion.fuse_conv_bn_weights(conv_w = conv_module.weight, conv_b = conv_module.bias, bn_w = bn_module.weight, bn_b = bn_module.bias, bn_rm = bn_module.running_mean, bn_rv = bn_module.running_var, bn_eps = bn_module.eps)
          fuse_conv_module = getattr(net_fuse,conv_name)
          fuse_conv_module.weight = torch.nn.Parameter(new_conv_weights)
          fuse_conv_module.bias = torch.nn.Parameter(new_conv_bias)
    elif 'linear' in model_name[i] and '.'  in model_name[i]:
        print('s2:',model_name[i])
        fc_modules = getattr(net, model_name[i-1])
        fc_module = fc_modules[0]
        bn_module = getattr(net, model_name[i+1])[0]
        fuse_fc_modules = getattr(net_fuse,model_name[i-1])
        fuse_fc_module = fuse_fc_modules[0]
        new_fc_weights, new_fc_bias = fusion.fuse_linear_bn_weights(linear_w = fc_module.weight, linear_b = fc_module.bias, bn_w = bn_module.weight, bn_b = bn_module.bias, bn_rm = bn_module.running_mean, bn_rv = bn_module.running_var, bn_eps = bn_module.eps)
        fuse_fc_module.weight = torch.nn.Parameter(new_fc_weights)
        fuse_fc_module.bias = torch.nn.Parameter(new_fc_bias)

    elif 'conv' in model_name[i] and '.' in model_name[i]:
        print('s6:',model_name[i])
        conv_module = getattr(net, model_name[i-1])[0]
        bn_module = getattr(net, model_name[i+1])[0]
        fuse_conv_module = getattr(net_fuse, model_name[i-1])[0]

        new_conv_weights, new_conv_bias = fusion.fuse_conv_bn_weights(conv_w = conv_module.weight, conv_b = conv_module.bias, bn_w = bn_module.weight, bn_b = bn_module.bias, bn_rm = bn_module.running_mean, bn_rv = bn_module.running_var, bn_eps = bn_module.eps)

        fuse_conv_module.weight = torch.nn.Parameter(new_conv_weights)
        fuse_conv_module.bias = torch.nn.Parameter(new_conv_bias)


  

In [None]:
print(net_fuse)

In [None]:
functional.reset_net(net_fuse)
set_threshold = [1.,1.,1.,1.]
net_fuse.set_threshold(set_threshold)
net_fuse.kmeans_quant()

In [None]:
print(net_fuse)

In [None]:
net_fuse.state_dict()['conv13.0.bias'][:] = torch.zeros_like(net_fuse.state_dict()['conv13.0.bias'][:])
net_fuse.state_dict()['linear1.0.bias'][:] = torch.zeros_like(net_fuse.state_dict()['linear1.0.bias'][:])
net_fuse.state_dict()['linear3.0.bias'][:] = torch.zeros_like(net_fuse.state_dict()['linear3.0.bias'][:])
print(net_fuse.state_dict()['conv13.0.bias'][:])

In [None]:
functional.reset_net(net_fuse)
net_fuse.set_T(8)
net_fuse.Net_write(True)
set_threshold = [831,383,566]
net_fuse.set_threshold(set_threshold)
test_loss, test_acc1, test_acc5 = evaluate(net_fuse, criterion, data_loader_test, device=device, header='Test:')


In [None]:
acc_list = []
for time in range(1,11):
  acc = 0
  for i in range(10):
    functional.reset_net(net_fuse)
    net_fuse.set_T(time)
    net_fuse.Net_write(False)
    set_threshold = [831,383,566]
    net_fuse.set_threshold(set_threshold)
    test_loss, test_acc1, test_acc5 = evaluate(net_fuse, criterion, data_loader_test, device=device, header='Test:')
    acc += test_acc1
  acc_list.append(acc/10)

In [None]:
print(acc_list)

In [None]:
torch.save(net_fuse.state_dict(), root1+'SNNCom_Mix_NN.pth')

In [None]:
import torch
import scipy.io as sio


layer_names = []
for name, _ in net_fuse.named_parameters():
    layer_names.append(name.split('.')[0])

for name, param in net_fuse.named_parameters():
    layer_name = name.split('.')[0]
    weight_name = root1 + name + '.mat'

    if 'weight' in name:
      weight_dict = {'weight': param.data.cpu().numpy()}
    elif 'bias' in name:
      weight_dict = {'bisa': param.data.cpu().numpy()}
    else:
       print('error')

    sio.savemat(weight_name, weight_dict)


In [None]:
print(net_fuse)