In [1]:
import os
import numpy as np
import wget
import gdown
import librosa

os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
def create_path(path):
    if not os.path.exists(path):
        os.mkdir(path)

workspace = "./workspace"
dataset_path = os.path.join(workspace, "UrbanSound8k")
checkpoint_path = os.path.join(workspace, "ckpt")
UrbanSound8k_raw_path = os.path.join(dataset_path, 'raw')
checkpoint = "./checkpoint"

create_path(workspace)
create_path(dataset_path)
create_path(checkpoint_path)
create_path(UrbanSound8k_raw_path)
create_path(checkpoint)

if not os.path.exists(os.path.join(dataset_path, 'UrbanSound8K.tar.gz')):
    print("-------------Downloading Dataset-------------")
    wget.download('https://zenodo.org/record/1203745/files/UrbanSound8K.tar.gz', out=dataset_path)
    !tar -xzf ./workspace/UrbanSound8k/UrbanSound8K.tar.gz --directory ./workspace/UrbanSound8k/raw
    print("-------------Success-------------")

if not os.path.exists(os.path.join(checkpoint_path,'htsat_audioset_pretrain.ckpt')):
    gdown.download(id='1OK8a5XuMVLyeVKF117L8pfxeZYdfSDZv', output=os.path.join(checkpoint_path,'htsat_audioset_pretrain.ckpt'))

if not os.path.exists(os.path.join(checkpoint,'US8K-acc=0.891.ckpt')):
    gdown.download(id='1g6Bpnx6FqKut7SsGdQDnlLSOkvyX7U1D', output=os.path.join(checkpoint,'US8K-acc=0.891.ckpt'))

In [None]:
meta_path = os.path.join(UrbanSound8k_raw_path, 'UrbanSound8K', 'metadata', 'UrbanSound8K.csv')
audio_path = os.path.join(UrbanSound8k_raw_path, 'UrbanSound8K', 'audio')
resample_path = os.path.join(dataset_path, 'resample')
savedata_path = os.path.join(dataset_path, 'UrbanSound8K-data.npy')

create_path(resample_path)
for i in range(1,11):
    fold_path = os.path.join(resample_path, f'fold{i}')
    create_path(fold_path)

meta = np.loadtxt(meta_path , delimiter=',', dtype='str', skiprows=1)
audio_folds = os.listdir(audio_path)

print("-------------Resample-------------")
for f in audio_folds:
    if f.startswith('.'):
        continue
    audio_list = os.listdir(os.path.join(audio_path, f))
    for wav in audio_list:
        full_f = os.path.join(audio_path, f, wav)
        resample_f = os.path.join(resample_path, f, wav)
        if not os.path.exists(resample_f):
            os.system('sox -V1 ' + full_f +  ' -r 44100 -c 1 ' + resample_f)
print("-------------Success-------------")

print("-------------Build Dataset-------------")
output_dict = [[] for _ in range(10)]
for label in meta:
    name = label[0]
    fold = label[5]
    target = label[6]
    y, sr = librosa.load(os.path.join(resample_path, f"fold{fold}", name), sr = None)
    output_dict[int(fold) - 1].append(
        {
            "name": name,
            "target": int(target),
            "waveform": y
        }
    )
np.save(savedata_path, output_dict)
print("-------------Success-------------")

In [4]:
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

from utils import create_folder, dump_config
import us8k_config as config
from sed_model import SEDWrapper
from data_generator import UrbanSound8k_Dataset
from model.htsat import HTSAT_Swin_Transformer

In [5]:
class data_prep(pl.LightningDataModule):
    def __init__(self, train_dataset, eval_dataset, device_num):
        super().__init__()
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.device_num = device_num

    def train_dataloader(self):
        train_sampler = DistributedSampler(self.train_dataset, shuffle = False) if self.device_num > 1 else None
        train_loader = DataLoader(
            dataset = self.train_dataset,
            num_workers = config.num_workers,
            batch_size = config.batch_size // self.device_num,
            shuffle = False,
            sampler = train_sampler
        )
        return train_loader
    def val_dataloader(self):
        eval_sampler = DistributedSampler(self.eval_dataset, shuffle = False) if self.device_num > 1 else None
        eval_loader = DataLoader(
            dataset = self.eval_dataset,
            num_workers = config.num_workers,
            batch_size = config.batch_size // self.device_num,
            shuffle = False,
            sampler = eval_sampler
        )
        return eval_loader
    def test_dataloader(self):
        test_sampler = DistributedSampler(self.eval_dataset, shuffle = False) if self.device_num > 1 else None
        test_loader = DataLoader(
            dataset = self.eval_dataset,
            num_workers = config.num_workers,
            batch_size = config.batch_size // self.device_num,
            shuffle = False,
            sampler = test_sampler
        )
        return test_loader

In [None]:
device_num = torch.cuda.device_count()
print("each batch size:", config.batch_size // device_num)

full_dataset = np.load(os.path.join(config.dataset_path, "UrbanSound8K-data.npy"), allow_pickle = True)

exp_dir = os.path.join(config.workspace, "results", config.exp_name)
checkpoint_dir = os.path.join(config.workspace, "results", config.exp_name, "checkpoint")
if not config.debug:
    create_folder(os.path.join(config.workspace, "results"))
    create_folder(exp_dir)
    create_folder(checkpoint_dir)
    dump_config(config, os.path.join(exp_dir, config.exp_name), False)

print("Using UrbanSound8K")
dataset = UrbanSound8k_Dataset(
    dataset = full_dataset,
    config = config,
    eval_mode = False
)
eval_dataset = UrbanSound8k_Dataset(
    dataset = full_dataset,
    config = config,
    eval_mode = True
)

audioset_data = data_prep(dataset, eval_dataset, device_num)
checkpoint_callback = ModelCheckpoint(
    monitor = "acc",
    filename='l-{epoch:d}-{acc:.3f}',
    save_top_k = 20,
    mode = "max"
)

In [None]:
trainer = pl.Trainer(
    deterministic=False,
    default_root_dir = checkpoint_dir,
    gpus = device_num,
    val_check_interval = 1.0,
    max_epochs = config.max_epoch,
    auto_lr_find = True,
    sync_batchnorm = True,
    callbacks = [checkpoint_callback],
    accelerator = "ddp" if device_num > 1 else None,
    num_sanity_val_steps = 0,
    resume_from_checkpoint = None,
    replace_sampler_ddp = False,
    gradient_clip_val=1.0
)

sed_model = HTSAT_Swin_Transformer(
    spec_size=config.htsat_spec_size,
    patch_size=config.htsat_patch_size,
    in_chans=1,
    num_classes=config.classes_num,
    window_size=config.htsat_window_size,
    config = config,
    depths = config.htsat_depth,
    embed_dim = config.htsat_dim,
    patch_stride=config.htsat_stride,
    num_heads=config.htsat_num_head
)

model = SEDWrapper(
    sed_model = sed_model,
    config = config,
    dataset = dataset
)

if config.resume_checkpoint is not None:
    print("Load Checkpoint from ", config.resume_checkpoint)
    ckpt = torch.load(config.resume_checkpoint, map_location="cpu")
    ckpt["state_dict"].pop("sed_model.head.weight")
    ckpt["state_dict"].pop("sed_model.head.bias")
    ckpt["state_dict"].pop("sed_model.tscam_conv.weight")
    ckpt["state_dict"].pop("sed_model.tscam_conv.bias")
    model.load_state_dict(ckpt["state_dict"], strict=False)

In [None]:
trainer.fit(model, audioset_data)

In [8]:
model_path = './checkpoint/US8K-acc=0.891.ckpt'

meta = np.loadtxt(meta_path , delimiter=',', dtype='str', skiprows=1)
gd = {}
for label in meta:
    name = label[0]
    target = label[6]
    gd[name] = target

class Audio_Classification:
    def __init__(self, model_path, config):
        super().__init__()

        self.device = torch.device('cuda')
        self.sed_model = HTSAT_Swin_Transformer(
            spec_size=config.htsat_spec_size,
            patch_size=config.htsat_patch_size,
            in_chans=1,
            num_classes=config.classes_num,
            window_size=config.htsat_window_size,
            config = config,
            depths = config.htsat_depth,
            embed_dim = config.htsat_dim,
            patch_stride=config.htsat_stride,
            num_heads=config.htsat_num_head
        )
        ckpt = torch.load(model_path, map_location="cpu")
        temp_ckpt = {}
        for key in ckpt["state_dict"]:
            temp_ckpt[key[10:]] = ckpt['state_dict'][key]
        self.sed_model.load_state_dict(temp_ckpt)
        self.sed_model.to(self.device)
        self.sed_model.eval()


    def predict(self, audiofile):
        if audiofile:
            waveform, sr = librosa.load(audiofile, sr=44100)
            waveform = librosa.to_mono(waveform)
            with torch.no_grad():
                x = torch.from_numpy(waveform).float().to(self.device)
                output_dict = self.sed_model(x[None, :], None, True)
                pred = output_dict['clipwise_output']
                pred_post = pred[0].detach().cpu().numpy()
                pred_label = np.argmax(pred_post)
                pred_prob = np.max(pred_post)
            return pred_label, pred_prob

In [None]:
Audiocls = Audio_Classification(model_path, config)
pred_label, pred_prob = Audiocls.predict('./workspace/UrbanSound8k/raw/UrbanSound8K/audio/fold9/13579-2-0-15.wav')
print('Audiocls predict output: ', pred_label, pred_prob, gd["13579-2-0-15.wav"])