In [1]:
import os
import time
import utils
import math
import argparse
import datetime
import numpy as np

import torch
import torch.optim as optim
import torch.nn as nn
import torch.distributed.optim
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary

from spikingjelly.activation_based import functional

import torchaudio
import torchaudio.transforms as T
from torch.utils.data import Dataset, random_split

import pickle

In [2]:
torch.backends.cudnn.benchmark = True

train_tb_writer = None
te_tb_writer = None
device = torch.device('cuda:3')
deviceIds = [3]

In [None]:
def parse_args():
    parser = argparse.ArgumentParser(description='PyTorch Classification Training')

    parser.add_argument('--device', default='cuda:0', help='device')
    parser.add_argument('-b', '--batch-size', default=32, type=int)
    parser.add_argument('--data-path', default='./data/', help='dataset')
    parser.add_argument('--epochs', default=320, type=int, metavar='N',
                        help='number of total epochs to pre-train')
    parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
                        help='number of data loading workers (default: 16)')
    parser.add_argument('--lr', default=0.0025, type=float, help='initial learning rate')

    parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
    parser.add_argument('--output-dir', default='./logs', help='path where to save')
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--tb', action='store_true',
                        help='Use TensorBoard to record logs')
    parser.add_argument(
        "--cache-dataset",
        dest="cache_dataset",
        help="Cache the datasets for quicker initialization. It also serializes the transforms",
        action="store_true",
    )

    args = parser.parse_args(args=['--data-path','../data','--lr','0.1','-b','128','--epochs','100','--print-freq','100','--tb','--cache-dataset'])
    return args

args = parse_args()

In [None]:
class NetworkA(torch.nn.Module):
    def __init__(self):
        super(NetworkA, self).__init__()
        self.conv1 = nn.Conv2d(1, 96, kernel_size=3, padding=1) 
        self.bn1 = nn.BatchNorm2d(96)
        self.sn1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(96, 96, kernel_size=3, padding=1) 
        self.bn2 = nn.BatchNorm2d(96)
        self.sn2 = nn.ReLU(inplace=True)

        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 

        self.conv3 = nn.Conv2d(96, 128, kernel_size=3, padding=1) 
        self.bn3 = nn.BatchNorm2d(128)
        self.sn3 = nn.ReLU(inplace=True)

        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 
        self.bn4 = nn.BatchNorm2d(128)
        self.sn4 = nn.ReLU(inplace=True)

        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 

        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1) 
        self.bn5 = nn.BatchNorm2d(256)
        self.sn5 = nn.ReLU(inplace=True)

        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 

        self.conv8 = nn.Conv2d(256, 512, kernel_size=3, padding=1) 
        self.bn8 = nn.BatchNorm2d(512)
        self.sn8 = nn.ReLU(inplace=True)

        self.conv10 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 
        self.bn10 = nn.BatchNorm2d(512)
        self.sn10 = nn.ReLU(inplace=True)

        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 

        self.conv11 = nn.Conv2d(512, 256, kernel_size=3, padding=1) 
        self.sn11 = nn.ReLU(inplace=True)
        
        self.conv13 = nn.Conv2d(256, 128, kernel_size=3, padding=1) 
        self.bn13 = nn.BatchNorm2d(128)
        self.sn13 = nn.ReLU(inplace=True)

        self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.linear1 = nn.Linear(256, 35) 
        self.bn14 = nn.BatchNorm1d(35)
        self.sn14 = nn.ReLU(inplace=True)

    def forward(self, x):
        
        x = self.sn1(self.bn1(self.conv1(x)))
        x = self.sn2(self.bn2(self.conv2(x)))
        x = self.pool1(x)

        x = self.sn3(self.bn3(self.conv3(x)))
        x = self.sn4(self.bn4(self.conv4(x)))
        x = self.pool2(x)

        x = self.sn5(self.bn5(self.conv5(x)))
        x = self.pool3(x)

        x = self.sn8(self.bn8(self.conv8(x)))
        x = self.sn10(self.bn10(self.conv10(x)))
        x = self.pool4(x)

        x = self.sn11(self.conv11(x))

        x = self.sn13(self.bn13(self.conv13(x)))
        x = self.pool5(x)

        x = x.view(x.size(0), -1)
        
        x = self.sn14(self.bn14(self.linear1(x)))

        return x

In [5]:
args = parse_args()
max_test_acc1 = 0.
test_acc5_at_max_test_acc1 = 0.

utils.init_distributed_mode(args)
print(args)
output_dir = os.path.join(args.output_dir, f'b_{args.batch_size}_lr{args.lr}')

time_now = datetime.datetime.now()
output_dir += f'_{time_now.year}_{time_now.month}_{time_now.day}_{time_now.hour}__{time_now.minute}'


if output_dir:
    utils.mkdir(output_dir)
    
class SpeechDataLoader(Dataset):
    
    def __init__(self, mfccs, labels, label_dict):
        self.mfccs = mfccs
        self.labels = labels
        self.label_dict = label_dict
            
    def __len__(self):
        return len(self.mfccs)    
    
    def __getitem__(self, idx):
        if self.labels[idx] in self.label_dict:
            out_labels = self.label_dict.index(self.labels[idx])
            return self.mfccs[idx], out_labels
        else:
            raise ValueError("Label not found in label_dict.")

def save_dataset(mfccs, labels, label_dict, filename):
    with open(filename, 'wb') as f:
        pickle.dump((mfccs, labels, label_dict), f)

def load_dataset(filename):
    with open(filename, 'rb') as f:
        return pickle.load(f)

# 预处理和加载数据
def preprocess_and_load_data():
    train_audio_path = './data/SpeechCommands/speech_commands_v0.02/'
    labels_dict = [label for label in os.listdir(train_audio_path) if label not in ['_background_noise_', 'LICENSE', 'README.md', 'validation_list.txt', 'testing_list.txt', '.DS_Store']]

    a = torchaudio.datasets.SPEECHCOMMANDS('./data/', url='speech_commands_v0.02', 
                                            folder_in_archive='SpeechCommands', download=True)
    
    mfccs = []
    labels = []
    transform = nn.Sequential(
        torchaudio.transforms.MFCC(log_mels=False)
    )
    
    for i in range(len(a)):
        if a[i][0].shape == (1, 16000):
            waveform = a[i][0]
            mfcc = transform(waveform)
            mfccs.append(mfcc)
            labels.append(a[i][2])

    return mfccs, labels, labels_dict

dataset_filename = 'processed_dataset.pkl'

if os.path.exists(dataset_filename):
    mfccs, labels, labels_dict = load_dataset(dataset_filename)
else:
    mfccs, labels, labels_dict = preprocess_and_load_data()
    save_dataset(mfccs, labels, labels_dict, dataset_filename)

dataset = SpeechDataLoader(mfccs, labels, labels_dict)
traindata, testdata = random_split(dataset, [round(len(dataset) * 0.8), round(len(dataset) * 0.2)])

Not using distributed mode
Namespace(batch_size=128, cache_dataset=True, data_path='../data', device='cuda:0', distributed=False, epochs=100, lr=0.1, output_dir='./logs', print_freq=100, resume='', start_epoch=0, tb=True, workers=16)


In [6]:
data_loader = torch.utils.data.DataLoader(traindata, batch_size=args.batch_size, shuffle=True)

data_loader_test = torch.utils.data.DataLoader(testdata, batch_size=args.batch_size, shuffle=True)

print("Creating model")

net = NetworkA().to(device)
functional.set_step_mode(net,step_mode='m')
functional.set_backend(net, backend='cupy')
net.to(device)
optimizer = torch.optim.SGD(
    net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.4, last_epoch=-1)
criterion = nn.CrossEntropyLoss()
if args.resume:
    checkpoint = torch.load(args.resume, map_location='cpu')
    args.start_epoch = checkpoint['epoch'] + 1
    max_test_acc1 = checkpoint['max_test_acc1']
    test_acc5_at_max_test_acc1 = checkpoint['test_acc5_at_max_test_acc1']

if args.tb and utils.is_main_process():
    purge_step_train = args.start_epoch
    purge_step_te = args.start_epoch
    train_tb_writer = SummaryWriter(output_dir + '_logs/train', purge_step=purge_step_train)
    te_tb_writer = SummaryWriter(output_dir + '_logs/te', purge_step=purge_step_te)
    with open(output_dir + '_logs/args.txt', 'w', encoding='utf-8') as args_txt:
        args_txt.write(str(args))

    with open(output_dir + '_logs/resluts.txt', 'w', encoding='utf-8') as args_txt:
        args_txt.write('Results\n')

Creating model


In [7]:
def train_one_epoch(net, criterion, data_loader, device, epoch, print_freq, scaler=None,lr = 1e-2):
    net.train()

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
    metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}'))

    header = 'Epoch: [{}]'.format(epoch)

    for image, target in metric_logger.log_every(data_loader, print_freq, header):
        start_time = time.time()
        image, target = image.to(device), target.to(device)

        output = net(image)
        loss = criterion(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        batch_size = image.shape[0]
        loss_s = loss.item()
        if math.isnan(loss_s):
            raise ValueError('loss is Nan')
        acc1_s = acc1.item()
        acc5_s = acc5.item()

        metric_logger.update(loss=loss_s, lr=lr)

        metric_logger.meters['acc1'].update(acc1_s, n=batch_size)
        metric_logger.meters['acc5'].update(acc5_s, n=batch_size)
        metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))
        
    scheduler.step()
    metric_logger.synchronize_between_processes()
    return metric_logger.loss.global_avg, metric_logger.acc1.global_avg, metric_logger.acc5.global_avg

In [8]:
def evaluate(net, criterion, data_loader, device, print_freq=100, header='Test:'):
    net.eval()
    net.to(device)
    metric_logger = utils.MetricLogger(delimiter="  ")
    with torch.no_grad():
        for image, target in metric_logger.log_every(data_loader, print_freq, header):
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = net(image)
            loss = criterion(output, target)
            functional.reset_net(net)

            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            batch_size = image.shape[0]
            metric_logger.update(loss=loss.item())
            metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
            metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
    metric_logger.synchronize_between_processes()

    loss, acc1, acc5 = metric_logger.loss.global_avg, metric_logger.acc1.global_avg, metric_logger.acc5.global_avg
    print(f' * Acc@1 = {acc1}, Acc@5 = {acc5}, loss = {loss}')
    return loss, acc1, acc5

In [None]:
print("Start training")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
    save_max = False

    train_loss, train_acc1, train_acc5 = train_one_epoch(net, criterion,data_loader, device, epoch, args.print_freq,lr=args.lr)

    if utils.is_main_process():
        train_tb_writer.add_scalar('train_loss', train_loss, epoch)
        train_tb_writer.add_scalar('train_acc1', train_acc1, epoch)
        train_tb_writer.add_scalar('train_acc5', train_acc5, epoch)

    test_loss, test_acc1, test_acc5 = evaluate(net, criterion, data_loader_test, device=device, header='Test:')
    if te_tb_writer is not None:
        if utils.is_main_process():
            te_tb_writer.add_scalar('test_loss', test_loss, epoch)
            te_tb_writer.add_scalar('test_acc1', test_acc1, epoch)
            te_tb_writer.add_scalar('test_acc5', test_acc5, epoch)

    if max_test_acc1 < test_acc1:
        max_test_acc1 = test_acc1
        test_acc5_at_max_test_acc1 = test_acc5
        save_max = True

    if output_dir:

        checkpoint = {
            'model': net.state_dict(),
            'pre-train_epoch': epoch,
            'args': args,
            'max_test_acc1': max_test_acc1,
            'test_acc5_at_max_test_acc1': test_acc5_at_max_test_acc1,
        }

        utils.save_on_master(
            checkpoint,
            os.path.join(output_dir, 'checkpoint_latest.pth'))
        save_flag = False

        if epoch % 64 == 0 or epoch == args.epochs - 1:
            save_flag = True


        if save_flag:
            utils.save_on_master(
                checkpoint,
                os.path.join(output_dir, f'checkpoint_{epoch}.pth'))

        if save_max:
            utils.save_on_master(
                checkpoint,
                os.path.join(output_dir,f'train_maxacc1_{max_test_acc1}_checkpoint_max_test_acc1_all_pretrain.pth'))
            torch.save(net,os.path.join(output_dir,f'train_maxacc1_{max_test_acc1}_checkpoint_max_test_acc1_all_pretrain.pth'))
            torch.save(net.state_dict(),os.path.join(output_dir,f'train_maxacc1_{max_test_acc1}_checkpoint_max_test_acc1_state_pretrain.pth'))
    print(args)
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print(output_dir)

    print('Training time {}'.format(total_time_str), 'max_test_acc1', max_test_acc1,
            'test_acc5_at_max_test_acc1', test_acc5_at_max_test_acc1,'train_acc1', train_acc1,
            'train_acc5', train_acc1)

    Train_logs= {
        'Epoch:': epoch,
        'max_test_acc1 ': max_test_acc1 ,
        'test_acc5_at_max_test_acc1 ': test_acc5_at_max_test_acc1,
        'train_acc1 ': train_acc1,
        'train_acc5 ': train_acc1,
        'args': args
    }
    with open(output_dir + '_logs/args.txt', 'a', encoding='utf-8') as args_txt:
        args_txt.write('\n')
        args_txt.write(str(Train_logs))

Start training
Epoch: [0]  [  0/597]  eta: 1:27:59  lr: 0.1  img/s: 14.494585955877806  loss: 3.7066 (3.7066)  acc1: 4.6875 (4.6875)  acc5: 14.8438 (14.8438)  time: 8.8430  data: 0.0121  max mem: 0
Epoch: [0]  [100/597]  eta: 0:01:59  lr: 0.1  img/s: 697.963733560757  loss: 3.1883 (3.4043)  acc1: 11.7188 (8.5396)  acc5: 42.1875 (32.4644)  time: 0.1651  data: 0.0018  max mem: 0
Epoch: [0]  [200/597]  eta: 0:01:25  lr: 0.1  img/s: 646.6384325666155  loss: 2.9671 (3.2477)  acc1: 17.9688 (11.8120)  acc5: 50.7812 (38.8798)  time: 0.1994  data: 0.0017  max mem: 0


In [None]:
with open('output_ann.txt', 'a') as file:
    line = str(max_test_acc1)
    file.write(line + '\n')