In [None]:
import os
import time
import utils
import math
import argparse
import requests
import random
import datetime
import numpy as np

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.cuda import amp
import torch.distributed.optim
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary

import torchvision
from torchvision import transforms
from spikingjelly.activation_based import layer,functional,neuron,surrogate

from scipy.io import loadmat,savemat
from sklearn.cluster import KMeans

import torchaudio
from torchaudio import transforms

In [None]:
torch.backends.cudnn.benchmark = True
_seed_ = 42
random.seed(42)
torch.manual_seed(_seed_)  
torch.cuda.manual_seed_all(_seed_)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(_seed_)

train_tb_writer = None
te_tb_writer = None
device = torch.device('cuda:0')
deviceIds = [0]

In [None]:
def parse_args():
    parser = argparse.ArgumentParser(description='PyTorch Classification Training')

    parser.add_argument('--device', default='cuda:0', help='device')
    parser.add_argument('-b', '--batch-size', default=32, type=int)
    parser.add_argument('--data-path', default='./data/', help='dataset')
    parser.add_argument('--epochs', default=320, type=int, metavar='N',
                        help='number of total epochs to pre-train')
    parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
                        help='number of data loading workers (default: 16)')
    parser.add_argument('--lr', default=0.0025, type=float, help='initial learning rate')

    parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
    parser.add_argument('--output-dir', default='./logs', help='path where to save')
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--tb', action='store_true',
                        help='Use TensorBoard to record logs')
    parser.add_argument(
        "--cache-dataset",
        dest="cache_dataset",
        help="Cache the datasets for quicker initialization. It also serializes the transforms",
        action="store_true",
    )

    args = parser.parse_args(args=['--data-path','./data','--lr','0.01','-b','128','--epochs','50','--print-freq','100','--tb','--cache-dataset'])
    return args

args = parse_args()


In [None]:
class NetworkA(torch.nn.Module):
    def __init__(self):
        super(NetworkA, self).__init__()
        self.T = 8
        self.conv1 = nn.Conv2d(1, 96, kernel_size=3, padding=1) 
        self.bn1 = nn.BatchNorm2d(96)
        self.sn1 = nn.ReLU()

        self.conv2 = nn.Conv2d(96, 96, kernel_size=3, padding=1)    
        self.bn2 = nn.BatchNorm2d(96)
        self.sn2 = nn.ReLU()

        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 
        self.conv3 = nn.Conv2d(96, 128, kernel_size=3, padding=1) 
        self.bn3 = nn.BatchNorm2d(128)
        self.sn3 = nn.ReLU()

        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 
        self.bn4 = nn.BatchNorm2d(128)
        self.sn4 = nn.ReLU()

        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 

        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1) 
        self.bn5 = nn.BatchNorm2d(256)
        self.sn5 = nn.ReLU()

        self.conv6 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 
        self.bn6 = nn.BatchNorm2d(256)
        self.sn6 = nn.ReLU()

        self.conv7 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 
        self.bn7 = nn.BatchNorm2d(256)
        self.sn7 = nn.ReLU()

        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 

        self.conv8 = nn.Conv2d(256, 512, kernel_size=3, padding=1) 
        self.bn8 = nn.BatchNorm2d(512)
        self.sn8 = nn.ReLU()

        self.conv9 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 
        self.bn9 = nn.BatchNorm2d(512)
        self.sn9 = nn.ReLU()

        self.conv10 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 
        self.bn10 = nn.BatchNorm2d(512)
        self.sn10 = nn.ReLU()

        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 

        self.conv11 = nn.Conv2d(512, 256, kernel_size=3, padding=1) 
        self.bn11 = nn.BatchNorm2d(256)
        self.sn11 = neuron.IFNode(detach_reset=True)

        self.conv13 = layer.SeqToANNContainer(nn.Conv2d(256, 128, kernel_size=3, padding=1)) 
        self.bn13 = layer.SeqToANNContainer(nn.BatchNorm2d(128))
        self.sn13 = neuron.IFNode(detach_reset=True)

        self.pool5 = layer.SeqToANNContainer(nn.MaxPool2d(kernel_size=2, stride=2)) 

        self.linear1 = layer.SeqToANNContainer(nn.Linear(256, 256)) 
        self.bn14 = layer.SeqToANNContainer(nn.BatchNorm1d(256))
        self.sn14 = neuron.IFNode(detach_reset=True)

        self.linear3 = layer.SeqToANNContainer(nn.Linear(256, 35)) 
        self.bn15 = layer.SeqToANNContainer(nn.BatchNorm1d(35))
        self.sn15 = neuron.IFNode(detach_reset=True)



    def forward(self, x):
        
        T = self.T
        
        x = self.sn1(self.bn1(self.conv1(x)))
        x = self.sn2(self.bn2(self.conv2(x)))
        x = self.pool1(x)
        print(x.shape)

        x = self.sn3(self.bn3(self.conv3(x)))
        x = self.sn4(self.bn4(self.conv4(x)))
        x = self.pool2(x)
        print(x.shape)

        x = self.sn5(self.bn5(self.conv5(x)))
        x = self.pool3(x)
        print(x.shape)

        x = self.sn8(self.bn8(self.conv8(x)))
        x = self.sn10(self.bn10(self.conv10(x)))
        x = self.pool4(x)
        print(x.shape)


        x = self.bn11(self.conv11(x))
        print(x.shape)
        x = x.unsqueeze(0)
        x = x.repeat(T, 1, 1, 1, 1)
        print(x.shape)
        x = self.sn11(x)
        
        print(x.shape)

        
        x = self.sn13(self.bn13(self.conv13(x)))
        x = self.pool5(x)
        print(x.shape)

        x = torch.flatten(x,2)

        x = self.sn14(self.bn14(self.linear1(x)))

        x = self.sn15(self.bn15(self.linear3(x)))

        return x.mean(0)
    
    def set_T(self, T):
        self.T = T

In [None]:
x = torch.randn(2,1,40,81)
net = NetworkA()
functional.set_step_mode(net,step_mode='m')
net(x)

In [None]:
from torch.utils.data import DataLoader,random_split,Dataset
import matplotlib.pyplot as plt

args = parse_args()
max_test_acc1 = 0.
test_acc5_at_max_test_acc1 = 0.

utils.init_distributed_mode(args)
print(args)
output_dir = os.path.join(args.output_dir, f'b_{args.batch_size}_lr{args.lr}')

time_now = datetime.datetime.now()
output_dir += f'_{time_now.year}_{time_now.month}_{time_now.day}_{time_now.hour}__{time_now.minute}'


if output_dir:
    utils.mkdir(output_dir)

class SpeechDataLoader(Dataset):
    
    def __init__(self,data,labels,list_dir,transform=None):
        self.data = data
        self.labels = labels
        self.label_dict = list_dir
        self.transform = transform
            
    def __len__(self):
        return len(self.data)    
    
    def __getitem__(self,idx):
        

        if self.labels[idx] in self.label_dict:
            out_labels = self.label_dict.index(self.labels[idx])
            waveform = self.data[idx]
            waveform = self.transform(waveform)
            

        return waveform, out_labels
train_audio_transforms = nn.Sequential(
            torchaudio.transforms.MFCC(log_mels=False)
            )

train_audio_path = './data1/SpeechCommands/speech_commands_v0.02/'
labels_dict=os.listdir(train_audio_path)
labels_dict=os.listdir(train_audio_path)

labels_dict.remove('_background_noise_')

labels_dict.remove('LICENSE')

labels_dict.remove('README.md')

labels_dict.remove('validation_list.txt')

labels_dict.remove('testing_list.txt')

labels_dict.remove('.DS_Store')
print(labels_dict)
print("Number of labels: ",len(labels_dict))

a = torchaudio.datasets.SPEECHCOMMANDS('./data1/' , url = 'speech_commands_v0.02', 
                                       folder_in_archive= 'SpeechCommands', download = True)

filename = "./data1/SpeechCommands/speech_commands_v0.02/backward/0165e0e8_nohash_0.wav"
waveform, sample_rate = torchaudio.load(filename)

print("Shape of waveform: {}".format(waveform.size()))
print("Sample rate of waveform: {}".format(sample_rate))

plt.figure()
plt.plot(waveform.t().numpy())
plt.plot(a[0][0].t())
plt.show()
count=0
wave = []
labels = []
for i in range(0,105829):
    if a[i][0].shape == (1,16000):
        wave.append(a[i][0])
        labels.append(a[i][2])
specgram = torchaudio.transforms.MFCC()(wave[0])

print("Shape of spectrogram: {}".format(specgram.size()))

plt.figure(figsize=(10,5))
plt.imshow(specgram[0,:,:].numpy())
plt.colorbar()
plt.show()

dataset= SpeechDataLoader(wave,labels,labels_dict, train_audio_transforms)

traindata, testdata = random_split(dataset, [round(len(dataset)*.8), round(len(dataset)*.2)])

data_loader = torch.utils.data.DataLoader(traindata, batch_size=128, shuffle=True)

data_loader_test = torch.utils.data.DataLoader(testdata, batch_size=128, shuffle=True)

print("Creating model")

net = NetworkA().to(device)
functional.set_step_mode(net,step_mode='m')
functional.set_backend(net, backend='cupy')
net.to(device)
optimizer = torch.optim.SGD(
    net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)


scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.4, last_epoch=-1)
criterion = nn.CrossEntropyLoss()
if args.resume:
    checkpoint = torch.load(args.resume, map_location='cpu')
    args.start_epoch = checkpoint['epoch'] + 1
    max_test_acc1 = checkpoint['max_test_acc1']
    test_acc5_at_max_test_acc1 = checkpoint['test_acc5_at_max_test_acc1']


if args.tb and utils.is_main_process():
    purge_step_train = args.start_epoch
    purge_step_te = args.start_epoch
    train_tb_writer = SummaryWriter(output_dir + '_logs/train', purge_step=purge_step_train)
    te_tb_writer = SummaryWriter(output_dir + '_logs/te', purge_step=purge_step_te)
    with open(output_dir + '_logs/args.txt', 'w', encoding='utf-8') as args_txt:
        args_txt.write(str(args))

    with open(output_dir + '_logs/resluts.txt', 'w', encoding='utf-8') as args_txt:
        args_txt.write('Results\n')

In [None]:
def train_one_epoch(net, criterion, data_loader, device, epoch, print_freq, scaler=None,lr = 1e-2):
    net.train()

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
    metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}'))

    header = 'Epoch: [{}]'.format(epoch)

    for image, target in metric_logger.log_every(data_loader, print_freq, header):
        start_time = time.time()
        image, target = image.to(device), target.to(device)

        output = net(image)
        loss = criterion(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        functional.reset_net(net)

        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        batch_size = image.shape[0]
        loss_s = loss.item()
        if math.isnan(loss_s):
            raise ValueError('loss is Nan')
        acc1_s = acc1.item()
        acc5_s = acc5.item()

        metric_logger.update(loss=loss_s, lr=lr)

        metric_logger.meters['acc1'].update(acc1_s, n=batch_size)
        metric_logger.meters['acc5'].update(acc5_s, n=batch_size)
        metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))
        
    scheduler.step()
    metric_logger.synchronize_between_processes()
    return metric_logger.loss.global_avg, metric_logger.acc1.global_avg, metric_logger.acc5.global_avg

In [None]:
def evaluate(net, criterion, data_loader, device, print_freq=100, header='Test:'):
    net.eval()
    net.to(device)
    metric_logger = utils.MetricLogger(delimiter="  ")
    with torch.no_grad():
        for image, target in metric_logger.log_every(data_loader, print_freq, header):
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = net(image)
            loss = criterion(output, target)
            functional.reset_net(net)

            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            batch_size = image.shape[0]
            metric_logger.update(loss=loss.item())
            metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
            metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
    metric_logger.synchronize_between_processes()

    loss, acc1, acc5 = metric_logger.loss.global_avg, metric_logger.acc1.global_avg, metric_logger.acc5.global_avg
    print(f' * Acc@1 = {acc1}, Acc@5 = {acc5}, loss = {loss}')
    return loss, acc1, acc5

In [None]:
import numpy as np
np.int = int
print("Start training")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
    save_max = False

    train_loss, train_acc1, train_acc5 = train_one_epoch(net, criterion,data_loader, device, epoch, args.print_freq,lr=args.lr)

    if utils.is_main_process():
        train_tb_writer.add_scalar('train_loss', train_loss, epoch)
        train_tb_writer.add_scalar('train_acc1', train_acc1, epoch)
        train_tb_writer.add_scalar('train_acc5', train_acc5, epoch)

    test_loss, test_acc1, test_acc5 = evaluate(net, criterion, data_loader_test, device=device, header='Test:')
    if te_tb_writer is not None:
        if utils.is_main_process():
            te_tb_writer.add_scalar('test_loss', test_loss, epoch)
            te_tb_writer.add_scalar('test_acc1', test_acc1, epoch)
            te_tb_writer.add_scalar('test_acc5', test_acc5, epoch)


    if max_test_acc1 < test_acc1:
        max_test_acc1 = test_acc1
        test_acc5_at_max_test_acc1 = test_acc5
        save_max = True

    if output_dir:

        checkpoint = {
            'model': net.state_dict(),
            'pre-train_epoch': epoch,
            'args': args,
            'max_test_acc1': max_test_acc1,
            'test_acc5_at_max_test_acc1': test_acc5_at_max_test_acc1,
        }

        utils.save_on_master(
            checkpoint,
            os.path.join(output_dir, 'checkpoint_latest.pth'))
        save_flag = False

        if epoch % 64 == 0 or epoch == args.epochs - 1:
            save_flag = True


        if save_flag:
            utils.save_on_master(
                checkpoint,
                os.path.join(output_dir, f'checkpoint_{epoch}.pth'))

        if save_max:
            utils.save_on_master(
                checkpoint,
                os.path.join(output_dir,f'train_maxacc1_{max_test_acc1}_checkpoint_max_test_acc1_all_pretrain.pth'))
            torch.save(net,os.path.join(output_dir,f'train_maxacc1_{max_test_acc1}_checkpoint_max_test_acc1_all_pretrain.pth'))
            torch.save(net.state_dict(),os.path.join(output_dir,f'train_maxacc1_{max_test_acc1}_checkpoint_max_test_acc1_state_pretrain.pth'))
    print(args)
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print(output_dir)

    print('Training time {}'.format(total_time_str), 'max_test_acc1', max_test_acc1,
            'test_acc5_at_max_test_acc1', test_acc5_at_max_test_acc1,'train_acc1', train_acc1,
            'train_acc5', train_acc1)

    Train_logs= {
        'Epoch:': epoch,
        'max_test_acc1 ': max_test_acc1 ,
        'test_acc5_at_max_test_acc1 ': test_acc5_at_max_test_acc1,
        'train_acc1 ': train_acc1,
        'train_acc5 ': train_acc1,
        'args': args
    }
    with open(output_dir + '_logs/args.txt', 'a', encoding='utf-8') as args_txt:
        args_txt.write('\n')
        args_txt.write(str(Train_logs))

rep_pre_acc1 = max_test_acc1
rep_pre_acc5 = test_acc5_at_max_test_acc1
requests.get("http://www.pushplus.plus/send?token=d58b0916a344410c911ae19aeffc67e2&title=111程序通知&content={}{}\n{}{}&template=html".format(
    "pre_acc1 = ",rep_pre_acc1, "pre_acc5 = " ,rep_pre_acc5))

In [None]:
acc_list = []
for time in range(1,11):
    acc = 0
    net = NetworkA().to(device)
    functional.set_step_mode(net,step_mode='m')
    functional.set_backend(net, backend='cupy')
    weights = torch.load('/home/curry/snncom/logs/b_128_lr0.01_2024_9_27_23__39/train_maxacc1_94.25022275800619_checkpoint_max_test_acc1_state_pretrain.pth')
    net.load_state_dict(weights)
    net.set_T(time)
    test_loss, test_acc1, test_acc5 = evaluate(net, criterion, data_loader_test, device=device, header='Test:')
    acc += test_acc1
    acc_list.append(acc)

In [None]:
print(acc_list)