In [1]:
import argparse
import os
import ast
import pickle
import sys
import time
import torch
from torch.utils.data import WeightedRandomSampler
basepath = os.path.dirname(os.path.dirname(sys.path[0]))
sys.path.append(basepath)
import dataloader
from models import ASTModel
import numpy as np
from traintest import train, validate

In [2]:
args = argparse.Namespace()

args.data_train = '../../data/final_train_data.json'
args.data_val = '../../data/final_test_data.json'
args.data_eval = '.json'
args.n_class = 6
args.model = 'ast'
args.dataset = 'speechcommands'
args.exp_dir = 'tmp/out'
args.lr = 0.001
args.optim = 'adam'
args.batch_size = 12
args.num_workers =32
args.n_epochs = 3
args.lr_patience = 2
args.n_print_steps = 100
args.save_model = None # 
args.freqm = 0
args.timem = 0
args.mixup = 0
args.bal = False
args.fstride = 10
args.tstride = 10
args.imagenet_pretrain = True
args.audioset_pretrain = False


In [3]:
# dataset spectrogram mean and std, used to normalize the input
norm_stats = {'audioset':[-4.2677393, 4.5689974], 'esc50':[-6.6268077, 5.358466], 'speechcommands':[-6.845978, 5.5654526]}
target_length = {'audioset':1024, 'esc50':512, 'speechcommands':128}
# if add noise for data augmentation, only use for speech commands
noise = {'audioset': False, 'esc50': False, 'speechcommands':True}

audio_conf = {'num_mel_bins': 128, 'target_length': target_length[args.dataset], 'freqm': args.freqm, 'timem': args.timem, 'mixup': args.mixup, 'dataset': args.dataset, 'mode':'train', 'mean':norm_stats[args.dataset][0], 'std':norm_stats[args.dataset][1],
                'noise':noise[args.dataset]}
val_audio_conf = {'num_mel_bins': 128, 'target_length': target_length[args.dataset], 'freqm': 0, 'timem': 0, 'mixup': 0, 'dataset': args.dataset, 'mode':'evaluation', 'mean':norm_stats[args.dataset][0], 'std':norm_stats[args.dataset][1], 'noise':False}


In [4]:
if args.bal == 'bal':
    print('balanced sampler is being used')
    samples_weight = np.loadtxt(args.data_train[:-5]+'_weight.csv', delimiter=',')
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight), replacement=True)

    train_loader = torch.utils.data.DataLoader(
        dataloader.AudiosetDataset(args.data_train, audio_conf=audio_conf),
        batch_size=args.batch_size, sampler=sampler, num_workers=args.num_workers, pin_memory=True)
else:
    print('balanced sampler is not used')
    train_loader = torch.utils.data.DataLoader(
        dataloader.AudiosetDataset(args.data_train, audio_conf=audio_conf),
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

val_loader = torch.utils.data.DataLoader(
    dataloader.AudiosetDataset(args.data_val, audio_conf=val_audio_conf),
    batch_size=args.batch_size*2, shuffle=False, num_workers=args.num_workers, pin_memory=True)

balanced sampler is not used
---------------the train dataloader---------------
now using following mask: 0 freq, 0 time
now using mix-up with rate 0.000000
now process speechcommands
use dataset mean -6.846 and std 5.565 to normalize the input.
now use noise augmentation
number of classes is 6
---------------the evaluation dataloader---------------
now using following mask: 0 freq, 0 time
now using mix-up with rate 0.000000
now process speechcommands
use dataset mean -6.846 and std 5.565 to normalize the input.
number of classes is 6




In [5]:
audio_model = ASTModel(label_dim=args.n_class, fstride=args.fstride, tstride=args.tstride, input_fdim=128,
                                input_tdim=target_length[args.dataset], imagenet_pretrain=args.imagenet_pretrain,
                                audioset_pretrain=args.audioset_pretrain, model_size='base384')


---------------AST Model Summary---------------
ImageNet pretraining: True, AudioSet pretraining: False


Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth" to ../../pretrained_models\hub\checkpoints\deit_base_distilled_patch16_384-d0272ac0.pth


frequncey stride=10, time stride=10
number of patches=144


In [6]:
print('Now starting training for {:d} epochs'.format(args.n_epochs))
train(audio_model, train_loader, val_loader, args)

Now starting training for 3 epochs
running on cuda
Total parameter number is : 86.911 million
Total trainable parameter number is : 86.911 million
scheduler for speech commands is used
now training with speechcommands, main metrics: acc, loss function: BCEWithLogitsLoss(), learning rate scheduler: <torch.optim.lr_scheduler.MultiStepLR object at 0x000002387BE0A340>
current #steps=0, #epochs=1
start training...
---------------
2022-03-05 15:01:00.046424
current #epochs=1, #steps=0
