In [1]:
# !aria2c -x8 https://zenodo.org/record/1203745/files/UrbanSound8K.tar.gz
# tar -xvzf UrbanSound8K.tar.gz

In [2]:

from torch.utils.data import Dataset
from tqdm import tqdm
from torchaudio.functional import resample
import torch
from torch.utils.data import DataLoader
import torchaudio.transforms as T
from torch import nn
import torch.nn.functional as F
import os
import pandas as pd
import torchaudio

In [3]:
class UrbanSound8KDataset(Dataset):
    def __init__(self, 
                 root='/workspace/postdoc/sound_datasets/urbansound8k/UrbanSound8K', # folder where the UrbanSound8K dataset is stored after extraction
                 folds=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], # this will load all folds
                 resample_freq=16000 # resample frequency
        ):
        self.root = root
        meta = pd.read_csv(os.path.join(root, 'metadata/UrbanSound8K.csv'))
        self.folded = meta[meta.fold.isin(folds)]
        self.resample_freq = resample_freq

    def __len__(self):
        return len(self.folded)

    def __getitem__(self, idx):
        row = self.folded.iloc[idx]
        file_path = os.path.join(self.root, 'audio', f'fold{row.fold}', row.slice_file_name)
        waveform, sample_rate = torchaudio.load(file_path)
        
        # single channel
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)
        # resample
        waveform = resample(waveform, orig_freq=sample_rate, new_freq=16000)
        # normalize
        waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-7)
        
        return {'audio': waveform.squeeze(),  # move it to collate_fn to batch processing and speedup
                'class_id': row["classID"], 
                'class_label': row["class"], 
                'fold': row["fold"]}

In [4]:
def collate_fn(batch):
    waveforms = [item['audio'] for item in batch]
    batched_waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True, padding_value=0)
    return {'audio': batched_waveforms,
            'class_id': torch.tensor([item['class_id'] for item in batch]), 
            'class_label': [item['class_label'] for item in batch], 
            'fold': torch.tensor([item['fold'] for item in batch])}

In [5]:
train_dataset = UrbanSound8KDataset(folds=[1, 2, 3, 4, 5, 6, 7, 8, 9])
val_dataset = UrbanSound8KDataset(folds=[10])

In [7]:

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn, pin_memory=True, num_workers=1)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn, pin_memory=True, num_workers=1)

In [8]:
class MelSpectrogram(torch.nn.Module):
    def __init__(
        self,
        n_fft: int = 400,
        win_length: int = 400,  # 25ms
        hop_length: int = 160,  # 10ms
        n_mels: int = 80,
        sample_rate: int = 16000,
    ):
        super().__init__()
        self.spectogram = T.Spectrogram(n_fft, win_length, hop_length, power=2)
        self.mel_scale = T.MelScale(n_mels, sample_rate, n_stft=n_fft // 2 + 1, mel_scale="slaney", norm="slaney")

    def forward(self, waveforms):
        spec = self.spectogram(waveforms)
        mel_spec = self.mel_scale(spec.float())        
        # the following is doing AmplitudeToDB conversion, but it does not match torchaudios implementation
        log_spec = torch.log10(torch.clamp(mel_spec, min=1e-10))
        log_spec = torch.max(log_spec, log_spec.max() - 8.0)
        log_spec = (log_spec + 4.0) / 4.0
        return log_spec
    
mel_spectrogram_fn = MelSpectrogram()

In [9]:
# a simple audio classifier with 92K params, we can use more sophisticated models later

class AudioClassifier(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.convnet = nn.Sequential(
            # 1-D conv across the time dimension
            nn.Conv2d(1, 32, kernel_size=(1, 5)),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(kernel_size=(1, 3)),
            nn.Dropout(0.1),

            nn.Conv2d(32, 32, kernel_size=(5, 5)),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
            nn.Dropout(0.1),

            nn.Conv2d(32, 32, kernel_size=(5, 5)),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
            nn.Dropout(0.1),

            nn.Conv2d(32, 32, kernel_size=(3, 3)),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
            nn.Dropout(0.1),

            nn.Conv2d(32, 32, kernel_size=(3, 3)),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
            nn.Dropout(0.1),

            nn.Flatten(),

            nn.Linear(160, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.1),

            nn.Linear(128, num_classes),
            nn.LogSoftmax(dim=1)
        )
        
    def forward(self, x):
        x = x.view(x.shape[0], 1, x.shape[1], x.shape[2])
        x = self.convnet(x)
        return x

def num_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

clf = AudioClassifier(num_classes=10)
print(f"Total number of parameters: {num_parameters(clf) / 1e3}k")

Total number of parameters: 92.426k


In [10]:
optim = torch.optim.AdamW(clf.parameters(), lr=1e-4, weight_decay=1e-3)

# select gpu 0
device = torch.device("cuda:0")
clf.to(device)
mel_spectrogram_fn.to(device)

for epoch in range(40):
    clf.train()
    ys = []
    y_hats = []
    progress_bar = tqdm(train_loader, desc=f"Training epoch  {epoch}")
    for batch in progress_bar:
        optim.zero_grad()
        x = batch['audio'].to(device)
        y = batch['class_id'].to(device)
        ys.append(y)
        with torch.no_grad():
            spec = mel_spectrogram_fn(x)
        y_hat = clf(spec)
        y_hats.append(y_hat)
        loss = F.nll_loss(y_hat, y)
        progress_bar.update()
        progress_bar.set_postfix(loss=loss.item())
        loss.backward()
        optim.step()
    ys = torch.cat(ys)
    y_hats = torch.cat(y_hats)
    acc = (ys == y_hats.argmax(dim=1)).float().mean()
    print(f"Training accuracy:  {acc.item() * 100:.2f}%")
    
    # evaluation
    clf.eval()
    ys = []
    y_hats = []
    progress_bar = tqdm(val_loader, desc=f"Validating epoch {epoch}")
    for batch in progress_bar:
        x = batch['audio'].to(device)
        y = batch['class_id'].to(device)
        ys.append(y)
        with torch.no_grad():
            spec = mel_spectrogram_fn(x)
            y_hat = clf(spec)
        y_hats.append(y_hat)
        progress_bar.update()
    ys = torch.cat(ys)
    y_hats = torch.cat(y_hats)
    acc = (ys == y_hats.argmax(dim=1)).float().mean()
    print(f"Validation accuracy: {acc.item() * 100:.2f}%")


Training epoch  0:   0%|          | 0/247 [00:00<?, ?it/s]

Training epoch  0: 100%|██████████| 247/247 [00:47<00:00,  5.18it/s, loss=1.7] 


Training accuracy: 26.79%


Validating epoch 0: 100%|██████████| 27/27 [00:04<00:00,  5.41it/s]


Validation accuracy: 30.82%


Training epoch  1: 100%|██████████| 247/247 [00:46<00:00,  5.29it/s, loss=1.19]


Training accuracy: 43.52%


Validating epoch 1: 100%|██████████| 27/27 [00:04<00:00,  5.43it/s]


Validation accuracy: 41.70%


Training epoch  2: 100%|██████████| 247/247 [00:46<00:00,  5.28it/s, loss=1.47] 


Training accuracy: 51.44%


Validating epoch 2: 100%|██████████| 27/27 [00:04<00:00,  5.41it/s]


Validation accuracy: 41.94%


Training epoch  3: 100%|██████████| 247/247 [00:46<00:00,  5.29it/s, loss=1.47] 


Training accuracy: 56.34%


Validating epoch 3: 100%|██████████| 27/27 [00:04<00:00,  5.42it/s]


Validation accuracy: 48.98%


Training epoch  4: 100%|██████████| 247/247 [00:46<00:00,  5.28it/s, loss=1.11] 


Training accuracy: 60.34%


Validating epoch 4: 100%|██████████| 27/27 [00:05<00:00,  5.38it/s]


Validation accuracy: 50.66%


Training epoch  5: 100%|██████████| 247/247 [00:46<00:00,  5.29it/s, loss=0.899]


Training accuracy: 64.14%


Validating epoch 5: 100%|██████████| 27/27 [00:04<00:00,  5.45it/s]


Validation accuracy: 51.37%


Training epoch  6: 100%|██████████| 247/247 [00:46<00:00,  5.29it/s, loss=1.22] 


Training accuracy: 65.99%


Validating epoch 6: 100%|██████████| 27/27 [00:04<00:00,  5.44it/s]


Validation accuracy: 48.63%


Training epoch  7: 100%|██████████| 247/247 [00:46<00:00,  5.29it/s, loss=1.11] 


Training accuracy: 68.23%


Validating epoch 7: 100%|██████████| 27/27 [00:04<00:00,  5.42it/s]


Validation accuracy: 56.99%


Training epoch  8: 100%|██████████| 247/247 [00:46<00:00,  5.28it/s, loss=0.962]


Training accuracy: 70.07%


Validating epoch 8: 100%|██████████| 27/27 [00:04<00:00,  5.41it/s]


Validation accuracy: 61.05%


Training epoch  9: 100%|██████████| 247/247 [00:46<00:00,  5.27it/s, loss=1.13] 


Training accuracy: 72.05%


Validating epoch 9: 100%|██████████| 27/27 [00:04<00:00,  5.42it/s]


Validation accuracy: 56.03%


Training epoch  10: 100%|██████████| 247/247 [00:46<00:00,  5.30it/s, loss=0.63] 


Training accuracy: 73.25%


Validating epoch 10: 100%|██████████| 27/27 [00:04<00:00,  5.43it/s]


Validation accuracy: 59.02%


Training epoch  11: 100%|██████████| 247/247 [00:46<00:00,  5.30it/s, loss=0.913]


Training accuracy: 74.60%


Validating epoch 11: 100%|██████████| 27/27 [00:04<00:00,  5.40it/s]


Validation accuracy: 65.23%


Training epoch  12: 100%|██████████| 247/247 [00:46<00:00,  5.30it/s, loss=0.673]


Training accuracy: 75.06%


Validating epoch 12: 100%|██████████| 27/27 [00:04<00:00,  5.44it/s]


Validation accuracy: 61.53%


Training epoch  13: 100%|██████████| 247/247 [00:46<00:00,  5.29it/s, loss=0.38] 


Training accuracy: 76.80%


Validating epoch 13: 100%|██████████| 27/27 [00:04<00:00,  5.41it/s]


Validation accuracy: 61.05%


Training epoch  14: 100%|██████████| 247/247 [00:46<00:00,  5.28it/s, loss=0.57] 


Training accuracy: 77.15%


Validating epoch 14: 100%|██████████| 27/27 [00:04<00:00,  5.43it/s]


Validation accuracy: 57.83%


Training epoch  15: 100%|██████████| 247/247 [00:46<00:00,  5.29it/s, loss=0.83] 


Training accuracy: 78.92%


Validating epoch 15: 100%|██████████| 27/27 [00:04<00:00,  5.44it/s]


Validation accuracy: 61.17%


Training epoch  16: 100%|██████████| 247/247 [00:46<00:00,  5.29it/s, loss=0.402]


Training accuracy: 79.71%


Validating epoch 16: 100%|██████████| 27/27 [00:04<00:00,  5.43it/s]


Validation accuracy: 67.14%


Training epoch  17: 100%|██████████| 247/247 [00:47<00:00,  5.17it/s, loss=0.755]


Training accuracy: 80.27%


Validating epoch 17: 100%|██████████| 27/27 [00:05<00:00,  5.33it/s]


Validation accuracy: 65.95%


Training epoch  18: 100%|██████████| 247/247 [00:46<00:00,  5.29it/s, loss=0.479]


Training accuracy: 81.28%


Validating epoch 18: 100%|██████████| 27/27 [00:04<00:00,  5.41it/s]


Validation accuracy: 64.64%


Training epoch  19: 100%|██████████| 247/247 [00:46<00:00,  5.29it/s, loss=0.577]


Training accuracy: 81.68%


Validating epoch 19: 100%|██████████| 27/27 [00:04<00:00,  5.41it/s]


Validation accuracy: 68.46%


Training epoch  20: 100%|██████████| 247/247 [00:46<00:00,  5.28it/s, loss=0.566]


Training accuracy: 82.27%


Validating epoch 20: 100%|██████████| 27/27 [00:04<00:00,  5.43it/s]


Validation accuracy: 62.49%


Training epoch  21: 100%|██████████| 247/247 [00:46<00:00,  5.30it/s, loss=0.341]


Training accuracy: 83.12%


Validating epoch 21: 100%|██████████| 27/27 [00:04<00:00,  5.42it/s]


Validation accuracy: 69.41%


Training epoch  22: 100%|██████████| 247/247 [00:46<00:00,  5.31it/s, loss=0.501]


Training accuracy: 83.27%


Validating epoch 22: 100%|██████████| 27/27 [00:04<00:00,  5.45it/s]


Validation accuracy: 69.18%


Training epoch  23: 100%|██████████| 247/247 [00:46<00:00,  5.31it/s, loss=0.662]


Training accuracy: 83.98%


Validating epoch 23: 100%|██████████| 27/27 [00:04<00:00,  5.40it/s]


Validation accuracy: 73.00%


Training epoch  24: 100%|██████████| 247/247 [00:46<00:00,  5.30it/s, loss=0.465]


Training accuracy: 85.04%


Validating epoch 24: 100%|██████████| 27/27 [00:04<00:00,  5.44it/s]


Validation accuracy: 71.57%


Training epoch  25: 100%|██████████| 247/247 [00:46<00:00,  5.29it/s, loss=0.457]


Training accuracy: 84.85%


Validating epoch 25: 100%|██████████| 27/27 [00:04<00:00,  5.40it/s]


Validation accuracy: 76.70%


Training epoch  26: 100%|██████████| 247/247 [00:46<00:00,  5.30it/s, loss=0.58] 


Training accuracy: 85.16%


Validating epoch 26: 100%|██████████| 27/27 [00:04<00:00,  5.41it/s]


Validation accuracy: 69.65%


Training epoch  27: 100%|██████████| 247/247 [00:46<00:00,  5.30it/s, loss=0.727]


Training accuracy: 86.02%


Validating epoch 27: 100%|██████████| 27/27 [00:04<00:00,  5.45it/s]


Validation accuracy: 77.18%


Training epoch  28: 100%|██████████| 247/247 [00:46<00:00,  5.31it/s, loss=0.438]


Training accuracy: 86.32%


Validating epoch 28: 100%|██████████| 27/27 [00:04<00:00,  5.43it/s]


Validation accuracy: 74.43%


Training epoch  29: 100%|██████████| 247/247 [00:46<00:00,  5.31it/s, loss=0.208]


Training accuracy: 86.05%


Validating epoch 29: 100%|██████████| 27/27 [00:04<00:00,  5.47it/s]


Validation accuracy: 75.51%


Training epoch  30: 100%|██████████| 247/247 [00:46<00:00,  5.30it/s, loss=0.296]


Training accuracy: 87.02%


Validating epoch 30: 100%|██████████| 27/27 [00:04<00:00,  5.45it/s]


Validation accuracy: 79.09%


Training epoch  31: 100%|██████████| 247/247 [00:46<00:00,  5.31it/s, loss=0.363]


Training accuracy: 87.47%


Validating epoch 31: 100%|██████████| 27/27 [00:04<00:00,  5.45it/s]


Validation accuracy: 70.49%


Training epoch  32: 100%|██████████| 247/247 [00:46<00:00,  5.31it/s, loss=0.162]


Training accuracy: 88.55%


Validating epoch 32: 100%|██████████| 27/27 [00:04<00:00,  5.41it/s]


Validation accuracy: 73.48%


Training epoch  33: 100%|██████████| 247/247 [00:46<00:00,  5.29it/s, loss=0.295]


Training accuracy: 88.17%


Validating epoch 33: 100%|██████████| 27/27 [00:04<00:00,  5.42it/s]


Validation accuracy: 70.13%


Training epoch  34: 100%|██████████| 247/247 [00:46<00:00,  5.28it/s, loss=0.2]   


Training accuracy: 88.46%


Validating epoch 34: 100%|██████████| 27/27 [00:04<00:00,  5.42it/s]


Validation accuracy: 71.92%


Training epoch  35: 100%|██████████| 247/247 [00:46<00:00,  5.29it/s, loss=0.408]


Training accuracy: 88.49%


Validating epoch 35: 100%|██████████| 27/27 [00:04<00:00,  5.45it/s]


Validation accuracy: 74.31%


Training epoch  36: 100%|██████████| 247/247 [00:46<00:00,  5.29it/s, loss=0.567] 


Training accuracy: 89.15%


Validating epoch 36: 100%|██████████| 27/27 [00:04<00:00,  5.44it/s]


Validation accuracy: 69.41%


Training epoch  37: 100%|██████████| 247/247 [00:46<00:00,  5.29it/s, loss=0.245]


Training accuracy: 89.01%


Validating epoch 37: 100%|██████████| 27/27 [00:05<00:00,  5.37it/s]


Validation accuracy: 77.18%


Training epoch  38: 100%|██████████| 247/247 [00:46<00:00,  5.28it/s, loss=0.231]


Training accuracy: 89.82%


Validating epoch 38: 100%|██████████| 27/27 [00:04<00:00,  5.42it/s]


Validation accuracy: 71.45%


Training epoch  39: 100%|██████████| 247/247 [00:46<00:00,  5.29it/s, loss=0.196]


Training accuracy: 89.68%


Validating epoch 39: 100%|██████████| 27/27 [00:04<00:00,  5.43it/s]

Validation accuracy: 75.15%



