## Tutorial on training a HTS-AT model for audio classification on the ESC-50 Dataset

Referece: 

[HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection, ICASSP 2022](https://arxiv.org/abs/2202.00874)

Following the HTS-AT's paper, in this tutorial, we would show how to use the HST-AT in the training of the ESC-50 Dataset.

The [ESC-50 dataset](https://github.com/karolpiczak/ESC-50) is a labeled collection of 2000 environmental audio recordings suitable for benchmarking methods of environmental sound classification. The dataset consists of 5-second-long recordings organized into 50 semantical classes (with 40 examples per class) loosely arranged into 5 major categories

Before running this tutorial, please make sure that you install the below packages by following steps:

1. download [the codebase](https://github.com/RetroCirce/HTS-Audio-Transformer), and put this tutorial notebook inside the codebase folder.

2. In the github code folder:

    > pip install -r requirements.txt

3. We do not include the installation of PyTorch in the requirment, since different machines require different vereions of CUDA and Toolkits. So make sure you install the PyTorch from [the official guidance](https://pytorch.org/).

4. Install the 'SOX' and the 'ffmpeg', we recommend that you run this code in Linux inside the Conda environment. In that, you can install them by:

    > sudo apt install sox
    
    > conda install -c conda-forge ffmpeg


In [1]:
# import basic packages
import os
import numpy as np
import wget
import sys
import gdown
import zipfile
import librosa
# in the notebook, we only can use one GPU
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
# Build the workspace and download the needed files

def create_path(path):
    if not os.path.exists(path):
        os.mkdir(path)

workspace = "./workspace"
dataset_path = os.path.join(workspace, "esc-50")
checkpoint_path = os.path.join(workspace, "ckpt")
esc_raw_path = os.path.join(dataset_path, 'raw')
checkpoint = "./checkpoint"

create_path(workspace)
create_path(dataset_path)
create_path(checkpoint_path)
create_path(esc_raw_path)
create_path(checkpoint)


# download the esc-50 dataset

if not os.path.exists(os.path.join(dataset_path, 'ESC-50-master.zip')):
    print("-------------Downloading ESC-50 Dataset-------------")
    wget.download('https://github.com/karoldvl/ESC-50/archive/master.zip', out=dataset_path)
    with zipfile.ZipFile(os.path.join(dataset_path, 'ESC-50-master.zip'), 'r') as zip_ref:
        zip_ref.extractall(esc_raw_path)
    print("-------------Success-------------")

if not os.path.exists(os.path.join(checkpoint_path,'htsat_audioset_pretrain.ckpt')):
    gdown.download(id='1OK8a5XuMVLyeVKF117L8pfxeZYdfSDZv', output=os.path.join(checkpoint_path,'htsat_audioset_pretrain.ckpt'))

if not os.path.exists(os.path.join(checkpoint,'ESC-50-acc=0.970.ckpt')):
    gdown.download(id='1MAc06i_6cw9bG6NCj3lVa0Sy78Tzd8qG', output=os.path.join(checkpoint,'ESC-50-acc=0.970.ckpt'))

In [None]:
# Process ESC-50 Dataset
meta_path = os.path.join(esc_raw_path, 'ESC-50-master', 'meta', 'esc50.csv')
audio_path = os.path.join(esc_raw_path, 'ESC-50-master', 'audio')
resample_path = os.path.join(dataset_path, 'resample')
savedata_path = os.path.join(dataset_path, 'esc-50-data.npy')
create_path(resample_path)

meta = np.loadtxt(meta_path , delimiter=',', dtype='str', skiprows=1)
audio_list = os.listdir(audio_path)

# resample
print("-------------Resample ESC-50-------------")
for f in audio_list:
    full_f = os.path.join(audio_path, f)
    resample_f = os.path.join(resample_path, f)
    if not os.path.exists(resample_f):
        os.system('sox -V1 ' + full_f + ' -r 32000 ' + resample_f)
print("-------------Success-------------")

print("-------------Build Dataset-------------")
output_dict = [[] for _ in range(5)]
for label in meta:
    name = label[0]
    fold = label[1]
    target = label[2]
    y, sr = librosa.load(os.path.join(resample_path, name), sr = None)
    output_dict[int(fold) - 1].append(
        {
            "name": name,
            "target": int(target),
            "waveform": y
        }
    )
np.save(savedata_path, output_dict)
print("-------------Success-------------")
    

In [4]:
# Load the model package
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import warnings

from utils import create_folder, dump_config, process_idc
import esc_config as config
from sed_model import SEDWrapper, Ensemble_SEDWrapper
from data_generator import ESC_Dataset
from model.htsat import HTSAT_Swin_Transformer

In [5]:
# Data Preparation
class data_prep(pl.LightningDataModule):
    def __init__(self, train_dataset, eval_dataset, device_num):
        super().__init__()
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.device_num = device_num

    def train_dataloader(self):
        train_sampler = DistributedSampler(self.train_dataset, shuffle = False) if self.device_num > 1 else None
        train_loader = DataLoader(
            dataset = self.train_dataset,
            num_workers = config.num_workers,
            batch_size = config.batch_size // self.device_num,
            shuffle = False,
            sampler = train_sampler
        )
        return train_loader
    def val_dataloader(self):
        eval_sampler = DistributedSampler(self.eval_dataset, shuffle = False) if self.device_num > 1 else None
        eval_loader = DataLoader(
            dataset = self.eval_dataset,
            num_workers = config.num_workers,
            batch_size = config.batch_size // self.device_num,
            shuffle = False,
            sampler = eval_sampler
        )
        return eval_loader
    def test_dataloader(self):
        test_sampler = DistributedSampler(self.eval_dataset, shuffle = False) if self.device_num > 1 else None
        test_loader = DataLoader(
            dataset = self.eval_dataset,
            num_workers = config.num_workers,
            batch_size = config.batch_size // self.device_num,
            shuffle = False,
            sampler = test_sampler
        )
        return test_loader
    

In [None]:
# Set the workspace
device_num = torch.cuda.device_count()
print("each batch size:", config.batch_size // device_num)

full_dataset = np.load(os.path.join(config.dataset_path, "esc-50-data.npy"), allow_pickle = True)

# set exp folder
exp_dir = os.path.join(config.workspace, "results", config.exp_name)
checkpoint_dir = os.path.join(config.workspace, "results", config.exp_name, "checkpoint")
if not config.debug:
    create_folder(os.path.join(config.workspace, "results"))
    create_folder(exp_dir)
    create_folder(checkpoint_dir)
    dump_config(config, os.path.join(exp_dir, config.exp_name), False)

print("Using ESC")
dataset = ESC_Dataset(
    dataset = full_dataset,
    config = config,
    eval_mode = False
)
eval_dataset = ESC_Dataset(
    dataset = full_dataset,
    config = config,
    eval_mode = True
)

audioset_data = data_prep(dataset, eval_dataset, device_num)
checkpoint_callback = ModelCheckpoint(
    monitor = "acc",
    filename='l-{epoch:d}-{acc:.3f}',
    save_top_k = 20,
    mode = "max"
)

In [None]:
# Set the Trainer
trainer = pl.Trainer(
    deterministic=False,
    default_root_dir = checkpoint_dir,
    gpus = device_num, 
    val_check_interval = 1.0,
    max_epochs = config.max_epoch,
    auto_lr_find = True,    
    sync_batchnorm = True,
    callbacks = [checkpoint_callback],
    accelerator = "ddp" if device_num > 1 else None,
    num_sanity_val_steps = 0,
    resume_from_checkpoint = None, 
    replace_sampler_ddp = False,
    gradient_clip_val=1.0
)

sed_model = HTSAT_Swin_Transformer(
    spec_size=config.htsat_spec_size,
    patch_size=config.htsat_patch_size,
    in_chans=1,
    num_classes=config.classes_num,
    window_size=config.htsat_window_size,
    config = config,
    depths = config.htsat_depth,
    embed_dim = config.htsat_dim,
    patch_stride=config.htsat_stride,
    num_heads=config.htsat_num_head
)

model = SEDWrapper(
    sed_model = sed_model, 
    config = config,
    dataset = dataset
)

if config.resume_checkpoint is not None:
    print("Load Checkpoint from ", config.resume_checkpoint)
    ckpt = torch.load(config.resume_checkpoint, map_location="cpu")
    ckpt["state_dict"].pop("sed_model.head.weight")
    ckpt["state_dict"].pop("sed_model.head.bias")
    # finetune on the esc and spv2 dataset
    ckpt["state_dict"].pop("sed_model.tscam_conv.weight")
    ckpt["state_dict"].pop("sed_model.tscam_conv.bias")
    model.load_state_dict(ckpt["state_dict"], strict=False)



In [None]:
# Training the model
# You can set different fold index by setting 'esc_fold' to any number from 0-4 in esc_config.py
trainer.fit(model, audioset_data)

## Now Let us Check the Result

Find the path of your saved checkpoint and paste it in the below variable.
Then you are able to follow the below code for checking the prediction result of any sample you like.

In [12]:
# infer the single data to check the result
# get a model you saved
model_path = './workspace/results/exp_htsat_esc_50/checkpoint/lightning_logs/version_5/checkpoints/l-epoch=18-acc=0.970.ckpt'

# get the groundtruth
meta = np.loadtxt(meta_path , delimiter=',', dtype='str', skiprows=1)
gd = {}
for label in meta:
    name = label[0]
    target = label[2]
    gd[name] = target

class Audio_Classification:
    def __init__(self, model_path, config):
        super().__init__()

        self.device = torch.device('cuda')
        self.sed_model = HTSAT_Swin_Transformer(
            spec_size=config.htsat_spec_size,
            patch_size=config.htsat_patch_size,
            in_chans=1,
            num_classes=config.classes_num,
            window_size=config.htsat_window_size,
            config = config,
            depths = config.htsat_depth,
            embed_dim = config.htsat_dim,
            patch_stride=config.htsat_stride,
            num_heads=config.htsat_num_head
        )
        ckpt = torch.load(model_path, map_location="cpu")
        temp_ckpt = {}
        for key in ckpt["state_dict"]:
            temp_ckpt[key[10:]] = ckpt['state_dict'][key]
        self.sed_model.load_state_dict(temp_ckpt)
        self.sed_model.to(self.device)
        self.sed_model.eval()


    def predict(self, audiofile):

        if audiofile:
            waveform, sr = librosa.load(audiofile, sr=32000)

            with torch.no_grad():
                x = torch.from_numpy(waveform).float().to(self.device)
                output_dict = self.sed_model(x[None, :], None, True)
                pred = output_dict['clipwise_output']
                pred_post = pred[0].detach().cpu().numpy()
                pred_label = np.argmax(pred_post)
                pred_prob = np.max(pred_post)
            return pred_label, pred_prob


In [None]:
# Inference
Audiocls = Audio_Classification(model_path, config)

# pick any audio you like in the ESC-50 testing set (cross-validation)
pred_label, pred_prob = Audiocls.predict("./workspace/esc-50/raw/ESC-50-master/audio/1-4211-A-12.wav")

print('Audiocls predict output: ', pred_label, pred_prob, gd["1-4211-A-12.wav"])