In [None]:
import os
import time
import utils
import math
import argparse
import requests
import random
import datetime
import numpy as np

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.cuda import amp
import torch.distributed.optim
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary

import torchvision
from torchvision import transforms
from spikingjelly.activation_based import layer,functional,neuron,surrogate

from scipy.io import loadmat,savemat

In [None]:
torch.backends.cudnn.benchmark = True
_seed_ = 42
random.seed(42)
torch.manual_seed(_seed_)  # use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA)
torch.cuda.manual_seed_all(_seed_)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(_seed_)

train_tb_writer = None
te_tb_writer = None
device = torch.device('cuda:0')
deviceIds = [0]

In [None]:
def parse_args():
    parser = argparse.ArgumentParser(description='PyTorch Classification Training')

    parser.add_argument('--device', default='cuda:0', help='device')
    parser.add_argument('-b', '--batch-size', default=32, type=int)
    parser.add_argument('--data-path', default='./data/', help='dataset')
    parser.add_argument('--epochs', default=320, type=int, metavar='N',
                        help='number of total epochs to pre-train')
    parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
                        help='number of data loading workers (default: 16)')
    parser.add_argument('--lr', default=0.0025, type=float, help='initial learning rate')

    parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
    parser.add_argument('--output-dir', default='./logs', help='path where to save')
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--tb', action='store_true',
                        help='Use TensorBoard to record logs')
    parser.add_argument(
        "--cache-dataset",
        dest="cache_dataset",
        help="Cache the datasets for quicker initialization. It also serializes the transforms",
        action="store_true",
    )

    args = parser.parse_args(args=['--data-path','./data','--lr','0.01','-b','64','--epochs','100','--print-freq','100','--tb','--cache-dataset'])
    return args

args = parse_args()

In [None]:
class NetworkA(torch.nn.Module):
    def __init__(self):
        super(NetworkA, self).__init__()
        self.conv1 = nn.Conv2d(3, 96, kernel_size=3, padding=1) ## 32x32x3 -> 32x32x96
        self.bn1 = nn.BatchNorm2d(96)
        self.sn1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(96, 96, kernel_size=3, padding=1) ## 32x32x96 -> 32x32x96
        self.bn2 = nn.BatchNorm2d(96)
        self.sn2 = nn.ReLU(inplace=True)

        self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2) ## 32x32x96 -> 16x16x96

        self.conv3 = nn.Conv2d(96, 128, kernel_size=3, padding=1) ## 16x16x96 -> 16x16x128
        self.bn3 = nn.BatchNorm2d(128)
        self.sn3 = nn.ReLU(inplace=True)

        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1) ## 16x16x128 -> 16x16x128
        self.bn4 = nn.BatchNorm2d(128)
        self.sn4 = nn.ReLU(inplace=True)

        self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2) ## 16x16x128 -> 8x8x128

        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1) ## 8x8x128 -> 8x8x256
        self.bn5 = nn.BatchNorm2d(256)
        self.sn5 = nn.ReLU(inplace=True)

        self.conv6 = nn.Conv2d(256, 256, kernel_size=3, padding=1) ## 8x8x256 -> 8x8x256
        self.bn6 = nn.BatchNorm2d(256)
        self.sn6 = nn.ReLU(inplace=True)

        self.conv7 = nn.Conv2d(256, 256, kernel_size=3, padding=1) ## 8x8x256 -> 8x8x256
        self.bn7 = nn.BatchNorm2d(256)
        self.sn7 = nn.ReLU(inplace=True)

        self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2) ## 8x8x256 -> 4x4x256

        self.conv8 = nn.Conv2d(256, 512, kernel_size=3, padding=1) ## 4x4x256 -> 4x4x512
        self.bn8 = nn.BatchNorm2d(512)
        self.sn8 = nn.ReLU(inplace=True)

        self.conv9 = nn.Conv2d(512, 512, kernel_size=3, padding=1) ## 4x4x512 -> 4x4x512
        self.bn9 = nn.BatchNorm2d(512)
        self.sn9 = nn.ReLU(inplace=True)

        self.conv10 = nn.Conv2d(512, 512, kernel_size=3, padding=1) ## 4x4x512 -> 4x4x512
        self.bn10 = nn.BatchNorm2d(512)
        self.sn10 = nn.ReLU(inplace=True)

        self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.connect1 = nn.Conv1d(1,1,kernel_size=7,stride=4,padding=3) 
        self.sn16 = nn.ReLU(inplace=True)

        self.connect2 = nn.Linear(512, 2*2*128)
        self.sn17 = nn.ReLU(inplace=True)

        self.conv11 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.sn11 = nn.ReLU(inplace=True)

        self.conv12 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.sn12 = nn.ReLU(inplace=True)

        self.conv13 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn13 = nn.BatchNorm2d(128)
        self.sn13 = nn.ReLU(inplace=True)

        self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2)

        self.linear1 = nn.Linear(128, 512)
        self.sn14 = nn.ReLU(inplace=True)

        self.linear3 = nn.Linear(512, 10)

    def forward(self, x):
        
        x = self.sn1(self.bn1(self.conv1(x)))
        x = self.sn2(self.bn2(self.conv2(x)))
        x = self.pool1(x)

        x = self.sn3(self.bn3(self.conv3(x)))
        x = self.sn4(self.bn4(self.conv4(x)))
        x = self.pool2(x)

        x = self.sn5(self.bn5(self.conv5(x)))
        x = self.sn6(self.bn6(self.conv6(x)))
        x = self.sn7(self.bn7(self.conv7(x)))
        x = self.pool3(x)

        x = self.sn8(self.bn8(self.conv8(x)))
        x = self.sn9(self.bn9(self.conv9(x)))
        x = self.sn10(self.bn10(self.conv10(x)))
        x = self.pool4(x)

        x = x.view(x.size(0), -1).unsqueeze(1)
        x = self.sn16(self.connect1(x))
        x = self.sn17(self.connect2(x))

        x = x.reshape(-1, 128, 2, 2)

        x = self.sn11(self.conv11(x))
        x = self.sn13(self.bn13(self.conv13(x)))
        x = self.pool5(x)

        x = x.view(x.size(0), -1)

        x = self.sn14(self.linear1(x))

        x = self.linear3(x)

        return x

In [None]:
x = torch.randn(2,3,32,32)
net = NetworkA()
net(x)

In [None]:
args = parse_args()
max_test_acc1 = 0.
test_acc5_at_max_test_acc1 = 0.

utils.init_distributed_mode(args)
print(args)
output_dir = os.path.join(args.output_dir, f'b_{args.batch_size}_lr{args.lr}')

time_now = datetime.datetime.now()
output_dir += f'_{time_now.year}_{time_now.month}_{time_now.day}_{time_now.hour}__{time_now.minute}'


if output_dir:
    utils.mkdir(output_dir)


# Load data
data_loader = torch.utils.data.DataLoader(
    dataset=torchvision.datasets.CIFAR10(root= args.data_path,
    train=True,
    transform=torchvision.transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.557, 0.549, 0.5534])
    ]),
    download=True),
    batch_size=args.batch_size, shuffle=True, pin_memory=True, drop_last=True, num_workers=args.workers)

data_loader_test = torch.utils.data.DataLoader(
    dataset=torchvision.datasets.CIFAR10(root=args.data_path,
    train=False,
    transform=torchvision.transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.557, 0.549, 0.5534])
    ]),
    download=True),
    batch_size=args.batch_size, shuffle=False, pin_memory=True, drop_last=False, num_workers=args.workers)


print("Creating model")

net = NetworkA().to(device)
net.to(device)


optimizer = torch.optim.SGD(
    net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-3)


scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.4, last_epoch=-1)


criterion = nn.CrossEntropyLoss()




if args.resume:
    checkpoint = torch.load(args.resume, map_location='cpu')
    args.start_epoch = checkpoint['epoch'] + 1
    max_test_acc1 = checkpoint['max_test_acc1']
    test_acc5_at_max_test_acc1 = checkpoint['test_acc5_at_max_test_acc1']


if args.tb and utils.is_main_process():
    purge_step_train = args.start_epoch
    purge_step_te = args.start_epoch
    train_tb_writer = SummaryWriter(output_dir + '_logs/train', purge_step=purge_step_train)
    te_tb_writer = SummaryWriter(output_dir + '_logs/te', purge_step=purge_step_te)
    with open(output_dir + '_logs/args.txt', 'w', encoding='utf-8') as args_txt:
        args_txt.write(str(args))

    with open(output_dir + '_logs/resluts.txt', 'w', encoding='utf-8') as args_txt:
        args_txt.write('Results\n')


In [None]:
def train_one_epoch(net, criterion, data_loader, device, epoch, print_freq, scaler=None,lr = 1e-2):
    net.train()

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
    metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}'))

    header = 'Epoch: [{}]'.format(epoch)

    for image, target in metric_logger.log_every(data_loader, print_freq, header):
        start_time = time.time()
        image, target = image.to(device), target.to(device)

        output = net(image)
        loss = criterion(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        batch_size = image.shape[0]
        loss_s = loss.item()
        if math.isnan(loss_s):
            raise ValueError('loss is Nan')
        acc1_s = acc1.item()
        acc5_s = acc5.item()

        metric_logger.update(loss=loss_s, lr=lr)

        metric_logger.meters['acc1'].update(acc1_s, n=batch_size)
        metric_logger.meters['acc5'].update(acc5_s, n=batch_size)
        metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))
        
    scheduler.step()
    metric_logger.synchronize_between_processes()
    return metric_logger.loss.global_avg, metric_logger.acc1.global_avg, metric_logger.acc5.global_avg

In [None]:
def evaluate(net, criterion, data_loader, device, print_freq=100, header='Test:'):
    net.eval()
    net.to(device)
    metric_logger = utils.MetricLogger(delimiter="  ")
    with torch.no_grad():
        for image, target in metric_logger.log_every(data_loader, print_freq, header):
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = net(image)
            loss = criterion(output, target)
            functional.reset_net(net)

            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            batch_size = image.shape[0]
            metric_logger.update(loss=loss.item())
            metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
            metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
    metric_logger.synchronize_between_processes()

    loss, acc1, acc5 = metric_logger.loss.global_avg, metric_logger.acc1.global_avg, metric_logger.acc5.global_avg
    print(f' * Acc@1 = {acc1}, Acc@5 = {acc5}, loss = {loss}')
    return loss, acc1, acc5

In [None]:
print("Start training")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
    save_max = False

    train_loss, train_acc1, train_acc5 = train_one_epoch(net, criterion,data_loader, device, epoch, args.print_freq,lr=args.lr)

    if utils.is_main_process():
        train_tb_writer.add_scalar('train_loss', train_loss, epoch)
        train_tb_writer.add_scalar('train_acc1', train_acc1, epoch)
        train_tb_writer.add_scalar('train_acc5', train_acc5, epoch)

    test_loss, test_acc1, test_acc5 = evaluate(net, criterion, data_loader_test, device=device, header='Test:')
    if te_tb_writer is not None:
        if utils.is_main_process():
            te_tb_writer.add_scalar('test_loss', test_loss, epoch)
            te_tb_writer.add_scalar('test_acc1', test_acc1, epoch)
            te_tb_writer.add_scalar('test_acc5', test_acc5, epoch)

    if max_test_acc1 < test_acc1:
        max_test_acc1 = test_acc1
        test_acc5_at_max_test_acc1 = test_acc5
        save_max = True

    if output_dir:

        checkpoint = {
            'model': net.state_dict(),
            'pre-train_epoch': epoch,
            'args': args,
            'max_test_acc1': max_test_acc1,
            'test_acc5_at_max_test_acc1': test_acc5_at_max_test_acc1,
        }

        utils.save_on_master(
            checkpoint,
            os.path.join(output_dir, 'checkpoint_latest.pth'))
        save_flag = False

        if epoch % 64 == 0 or epoch == args.epochs - 1:
            save_flag = True


        if save_flag:
            utils.save_on_master(
                checkpoint,
                os.path.join(output_dir, f'checkpoint_{epoch}.pth'))

        if save_max:
            utils.save_on_master(
                checkpoint,
                os.path.join(output_dir,f'train_maxacc1_{max_test_acc1}_checkpoint_max_test_acc1_all_pretrain.pth'))
            torch.save(net,os.path.join(output_dir,f'train_maxacc1_{max_test_acc1}_checkpoint_max_test_acc1_all_pretrain.pth'))
            torch.save(net.state_dict(),os.path.join(output_dir,f'train_maxacc1_{max_test_acc1}_checkpoint_max_test_acc1_state_pretrain.pth'))
    print(args)
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print(output_dir)

    print('Training time {}'.format(total_time_str), 'max_test_acc1', max_test_acc1,
            'test_acc5_at_max_test_acc1', test_acc5_at_max_test_acc1,'train_acc1', train_acc1,
            'train_acc5', train_acc1)

    Train_logs= {
        'Epoch:': epoch,
        'max_test_acc1 ': max_test_acc1 ,
        'test_acc5_at_max_test_acc1 ': test_acc5_at_max_test_acc1,
        'train_acc1 ': train_acc1,
        'train_acc5 ': train_acc1,
        'args': args
    }
    with open(output_dir + '_logs/args.txt', 'a', encoding='utf-8') as args_txt:
        args_txt.write('\n')
        args_txt.write(str(Train_logs))

rep_pre_acc1 = max_test_acc1
rep_pre_acc5 = test_acc5_at_max_test_acc1
requests.get("http://www.pushplus.plus/send?token=d58b0916a344410c911ae19aeffc67e2&title=111程序通知&content={}{}\n{}{}&template=html".format(
    "pre_acc1 = ",rep_pre_acc1, "pre_acc5 = " ,rep_pre_acc5))

In [None]:
def val(net, device, data_loader, T=None):
    net.eval().to(device)
    correct = 0.0
    total = 0.0
    if T is not None:
        corrects = np.zeros(T)
    with torch.no_grad():
        for batch, (img, label) in enumerate(tqdm(data_loader)):
            img = img.to(device)
            if T is None:
                out = net(img)
                correct += (out.argmax(dim=1) == label.to(device)).float().sum().item()
            else:
                for m in net.modules():
                    if hasattr(m, 'reset'):
                        m.reset()
                for t in range(T):
                    if t == 0:
                        out = net(img)
                    else:
                        out += net(img)
                    corrects[t] += (out.argmax(dim=1) == label.to(device)).float().sum().item()
            total += out.shape[0]
    return correct / total if T is None else corrects / total

In [None]:
from tqdm import tqdm
from spikingjelly.activation_based import ann2snn
print('---------------------------------------------')
print('Converting using MaxNorm')
model_converter = ann2snn.Converter(mode='max', dataloader=data_loader)
snn_model = model_converter(net)



In [None]:
T = 500
print('Simulating...')
mode_robust_accs = val(snn_model, device, data_loader_test, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_robust_accs[-1]))

In [None]:
print(mode_robust_accs)
print(mode_robust_accs[9::10])

In [None]:
net = NetworkA().to(device)
net.to(device)
weights = torch.load('./logs/b_64_lr0.01_2024_4_10_21__11/train_maxacc1_92.27_checkpoint_max_test_acc1_state_pretrain.pth')
net.load_state_dict(weights)

In [None]:
root1 = '/home/mlw/Desktop/paper_SNNcom/mat/ann/'
class NetworkFuse(torch.nn.Module):
    def __init__(self):
        super(NetworkFuse, self).__init__()
        self.write = False

        self.conv1 = nn.Conv2d(3, 96, kernel_size=3, padding=1) 
        self.sn1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(96, 96, kernel_size=3, padding=1)
        self.sn2 = nn.ReLU(inplace=True)

        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 

        self.conv3 = nn.Conv2d(96, 128, kernel_size=3, padding=1)
        self.sn3 = nn.ReLU(inplace=True)

        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 
        self.sn4 = nn.ReLU(inplace=True)

        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 

        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1) 
        self.sn5 = nn.ReLU(inplace=True)

        self.conv6 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.sn6 = nn.ReLU(inplace=True)

        self.conv7 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 
        self.sn7 = nn.ReLU(inplace=True)

        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 

        self.conv8 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.sn8 = nn.ReLU(inplace=True)

        self.conv9 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 
        self.sn9 = nn.ReLU(inplace=True)

        self.conv10 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 
        self.sn10 = nn.ReLU(inplace=True)

        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 

        self.connect1 = nn.Conv1d(1,1,kernel_size=7,stride=4,padding=3) 
        self.sn16 = nn.ReLU(inplace=True)

        self.connect2 = nn.Linear(512, 2*2*128)
        self.sn17 = nn.ReLU(inplace=True)

        self.conv11 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        
        self.sn11 = nn.ReLU(inplace=True)

        self.conv12 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
       
        self.sn12 = nn.ReLU(inplace=True)

        self.conv13 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.sn13 = nn.ReLU(inplace=True)

        self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.linear1 = nn.Linear(128, 512)
        self.sn14 = nn.ReLU(inplace=True)

        self.linear3 = nn.Linear(512, 10)

    def forward(self, x):
        
        x_in = torch.zeros_like(x)

        sn1_out = torch.zeros([x.shape[0], 96, 32, 32])

        sn2_out = torch.zeros([x.shape[0], 96, 32, 32])

        pool1_out = torch.zeros([x.shape[0], 96, 16, 16])

        sn3_out = torch.zeros([x.shape[0], 128, 16, 16])

        sn4_out = torch.zeros([x.shape[0], 128, 16, 16])

        pool2_out = torch.zeros([x.shape[0], 128, 8, 8])

        sn5_out = torch.zeros([x.shape[0], 256, 8, 8])

        sn6_out = torch.zeros([x.shape[0], 256, 8, 8])

        sn7_out = torch.zeros([x.shape[0], 256, 8, 8])

        pool3_out = torch.zeros([x.shape[0], 256, 4, 4])

        sn8_out = torch.zeros([x.shape[0], 512, 4, 4])

        sn9_out = torch.zeros([x.shape[0], 512, 4, 4])

        sn10_out = torch.zeros([x.shape[0], 512, 4, 4])

        pool4_out = torch.zeros([x.shape[0], 512, 2, 2])

        sn16_out = torch.zeros([x.shape[0], 1, 512*2*2])

        sn17_out = torch.zeros([x.shape[0], 1024])

        sn11_out = torch.zeros([x.shape[0], 128, 2, 2])

        sn12_out = torch.zeros([x.shape[0], 128, 2, 2])

        sn13_out = torch.zeros([x.shape[0], 128, 2, 2])

        pool5_out = torch.zeros([x.shape[0], 128, 1, 1])

        sn14_out = torch.zeros([x.shape[0], 512])


        linear3_out = torch.zeros([x.shape[0], 10])


        x_in = x.detach().cpu()

        x = self.sn1(self.conv1(x))
        sn1_out = x.detach().cpu()

        x = self.sn2(self.conv2(x))
        sn2_out = x.detach().cpu()

        x = self.pool1(x)
        pool1_out = x.detach().cpu()

        x = self.sn3(self.conv3(x))
        sn3_out = x.detach().cpu()

        x = self.sn4(self.conv4(x))
        sn4_out = x.detach().cpu()

        x = self.pool2(x)
        pool2_out = x.detach().cpu()

        x = self.sn5(self.conv5(x))
        sn5_out = x.detach().cpu()

        x = self.sn6(self.conv6(x))
        sn6_out = x.detach().cpu()

        x = self.sn7(self.conv7(x))
        sn7_out = x.detach().cpu()

        x = self.pool3(x)
        pool3_out = x.detach().cpu()

        x = self.sn8(self.conv8(x))
        sn8_out = x.detach().cpu()

        x = self.sn9(self.conv9(x))
        sn9_out = x.detach().cpu()

        x = self.sn10(self.conv10(x))
        sn10_out = x.detach().cpu()

        x = self.pool4(x)
        pool4_out = x.detach().cpu()

        x = x.view(x.size(0), -1).unsqueeze(1)
        x = self.sn16(self.connect1(x))
        sn16_out = x.detach().cpu()


        x = self.sn17(self.connect2(x))
        sn17_out = x.detach().cpu()

        x = x.reshape(-1, 128, 2, 2)

        x = self.sn11(self.conv11(x))
        sn11_out = x.detach().cpu()


        x = self.sn13(self.conv13(x))
        sn13_out = x.detach().cpu()

        x = self.pool5(x)
        pool5_out = x.detach().cpu()

        x = x.view(x.size(0), -1)

        x = self.sn14(self.linear1(x))
        sn14_out = x.detach().cpu()
        x = self.linear3(x)
        linear3_out = x.detach().cpu()

        if self.write:
          savemat(root1+'x_in.mat', {'x_in': x_in})
          savemat(root1+'sn1_out.mat', {'sn1_out': sn1_out})
          savemat(root1+'sn2_out.mat', {'sn2_out': sn2_out})
          savemat(root1+'pool1_out.mat', {'pool1_out': pool1_out})
          savemat(root1+'sn3_out.mat', {'sn3_out': sn3_out})
          savemat(root1+'sn4_out.mat', {'sn4_out': sn4_out})
          savemat(root1+'pool2_out.mat', {'pool2_out': pool2_out})
          savemat(root1+'sn5_out.mat', {'sn5_out': sn5_out})
          savemat(root1+'sn6_out.mat', {'sn6_out': sn6_out})
          savemat(root1+'sn7_out.mat', {'sn7_out': sn7_out})
          savemat(root1+'pool3_out.mat', {'pool3_out': pool3_out})
          savemat(root1+'sn8_out.mat', {'sn8_out': sn8_out})
          savemat(root1+'sn9_out.mat', {'sn9_out': sn9_out})
          savemat(root1+'sn10_out.mat', {'sn10_out': sn10_out})
          savemat(root1+'pool4_out.mat', {'pool4_out': pool4_out})
          savemat(root1+'sn16_out.mat', {'sn16_out': sn16_out})
          savemat(root1+'sn17_out.mat', {'sn17_out': sn17_out})
          savemat(root1+'sn11_out.mat', {'sn11_out': sn11_out})
          savemat(root1+'sn12_out.mat', {'sn12_out': sn12_out})
          savemat(root1+'sn13_out.mat', {'sn13_out': sn13_out})
          savemat(root1+'pool5_out.mat', {'pool5_out': pool5_out})
          savemat(root1+'sn14_out.mat', {'sn14_out': sn14_out})
          savemat(root1+'linear3_out.mat', {'linear3_out': linear3_out})

        return x
    def Net_write(self, whether_write):
        self.write = whether_write

In [None]:
x = torch.randn(2,3,32,32).to(device)
net_fuse = NetworkFuse().to(device)
net_fuse.Net_write(True)
net_fuse(x)

In [None]:
import torch.nn.utils.fusion as fusion

model_name = []

for name, module in net.named_modules():
  model_name.append(name)
  # print(module)
print(model_name)

for i in range(len(model_name)):
    if 'conv' in model_name[i]:
        if 'bn' in model_name[i+1]: ##后面带bn的处理
          print(model_name[i])
          conv_name = model_name[i]
          bn_name = model_name[i+1]
          conv_module = getattr(net, conv_name)
          bn_module = getattr(net, bn_name)

          new_conv_weights, new_conv_bias = fusion.fuse_conv_bn_weights(conv_w = conv_module.weight, conv_b = conv_module.bias, bn_w = bn_module.weight, bn_b = bn_module.bias, bn_rm = bn_module.running_mean, bn_rv = bn_module.running_var, bn_eps = bn_module.eps)
          fuse_conv_module = getattr(net_fuse,conv_name)
          fuse_conv_module.weight = torch.nn.Parameter(new_conv_weights)
          fuse_conv_module.bias = torch.nn.Parameter(new_conv_bias)
        
        else: ## 后面不带bn的处理
          print(model_name[i])
          conv_name = model_name[i]
          conv_module = getattr(net, conv_name)
          fuse_conv_module = getattr(net_fuse,conv_name)
          fuse_conv_module.weight = conv_module.weight
          fuse_conv_module.bias = conv_module.bias
    elif 'linear' in model_name[i]:
        print(model_name[i])
        fc_name = model_name[i]
        fc_module = getattr(net, fc_name)
        fuse_fc_module = getattr(net_fuse,fc_name)
        fuse_fc_module.weight = fc_module.weight
        fuse_fc_module.bias = fc_module.bias
    elif 'connect' in model_name[i]:
        print(model_name[i])
        connect_name = model_name[i]
        connect_module = getattr(net, connect_name)
        fuse_connect_module = getattr(net_fuse,connect_name)
        fuse_connect_module.weight = connect_module.weight
        fuse_connect_module.bias = connect_module.bias

  

In [None]:
x = torch.randn(2,3,32,32)
x = x *0

In [None]:
test_loss, test_acc ,test_acc5 = evaluate(net_fuse, criterion, data_loader_test, device=device, header='Test:')

In [None]:
import torch
import scipy.io as sio


layer_names = []
for name, _ in net_fuse.named_parameters():
    layer_names.append(name.split('.')[0])
for name, param in net_fuse.named_parameters():
    layer_name = name.split('.')[0]
    weight_name = root1 + name + '.mat'

    if 'weight' in name:
      weight_dict = {'weight': param.data.cpu().numpy()}
    elif 'bias' in name:
      weight_dict = {'bisa': param.data.cpu().numpy()}
    else:
       print('error')
    sio.savemat(weight_name, weight_dict)


In [None]:
print(net_fuse)