In [None]:
!pip install -r requirements.txt
!pip install torchvision torchtext==0.8.1 torchaudio

In [None]:
# import basic packages
import os
import numpy as np
import wget
import sys
import gdown
import zipfile
import librosa
# in the notebook, we only can use one GP

In [None]:
# Process ESC-50 Dataset

workspace = "/content/drive/MyDrive/HTS-Audio-Transformer"
# dataset_path = os.path.join(workspace, "esc-50")
checkpoint_path = os.path.join(workspace, "ckpt")
# esc_raw_path = os.path.join(dataset_path, 'raw')

# meta_path = os.path.join(esc_raw_path, 'ESC-50-master', 'meta', 'esc50.csv')
# audio_path = os.path.join(esc_raw_path, 'ESC-50-master', 'audio')
# resample_path = os.path.join(dataset_path, 'resample')
# savedata_path = os.path.join(dataset_path, 'esc-50-data.npy')

In [None]:
# Load the model package
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import warnings

from utils import create_folder, dump_config, process_idc
import esc_config as config
from sed_model import SEDWrapper, Ensemble_SEDWrapper
from data_generator import ESC_Dataset
from model.htsat import HTSAT_Swin_Transformer



In [None]:
# infer the single data to check the result
# get a model you saved
model_path = '/content/drive/MyDrive/HTS-Audio-Transformer/ckpt/htsat_esc-50_pretrain.ckpt'

# get the groundtruth
meta = np.loadtxt(meta_path , delimiter=',', dtype='str', skiprows=1)
gd = {}
for label in meta:
    name = label[0]
    target = label[2]
    gd[target] = label[3]
# for label in meta:
#     name = label[0]
#     target = label[2]
#     gd[name] = target

class Audio_Classification:
    def __init__(self, model_path, config):
        super().__init__()

        self.device = torch.device('cuda')
        self.sed_model = HTSAT_Swin_Transformer(
            spec_size=config.htsat_spec_size,
            patch_size=config.htsat_patch_size,
            in_chans=1,
            num_classes=config.classes_num,
            window_size=config.htsat_window_size,
            config = config,
            depths = config.htsat_depth,
            embed_dim = config.htsat_dim,
            patch_stride=config.htsat_stride,
            num_heads=config.htsat_num_head
        )
        ckpt = torch.load(model_path, map_location="cpu")
        temp_ckpt = {}
        for key in ckpt["state_dict"]:
            temp_ckpt[key[10:]] = ckpt['state_dict'][key]
        self.sed_model.load_state_dict(temp_ckpt)
        self.sed_model.to(self.device)
        self.sed_model.eval()


    def predict(self, audiofile):

        if audiofile:
            waveform, sr = librosa.load(audiofile, sr=32000)

            with torch.no_grad():
                x = torch.from_numpy(waveform).float().to(self.device)
                output_dict = self.sed_model(x[None, :], None, True)
                pred = output_dict['clipwise_output']
                pred_post = pred[0].detach().cpu().numpy()
                pred_label = np.argmax(pred_post)
                pred_prob = np.max(pred_post)
            return pred_label, pred_prob


In [None]:
# Inference
Audiocls = Audio_Classification(model_path, config)

# pick any audio you like in the ESC-50 testing set (cross-validation)
pred_label, pred_prob = Audiocls.predict("/content/Recording (2).m4a")

print('Audiocls predict output: ', pred_label, pred_prob, gd[f"{pred_label}"])