In [4]:
import numpy as np
import torch
import torch.nn as nn
import torchaudio
import pandas as pd
import os
from tqdm import tqdm
from glob import glob

import torch.jit as jit
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import v2
from torchvision.models import get_model

In [57]:
class Config:
    duration = 10
    sample_rate = 32000
    target_length = 384
    n_mels = 128
    n_fft = 2028
    window = 2028
    audio_len = duration*sample_rate
    hop_length = audio_len // (target_length-1)
    fmin = 20
    fmax = 16000
    top_db = 80

    n_classes = 182
    model_name = 'efficientnet_v2_s'
    checkpoint = 'checkpoints/2024-05-05_17-31-54_fold-0_dim-128x384_model-efficientnet_v2_s/checkpoint.pth'

    standardize = False
    dataset_mean = [-16.8828]
    dataset_std = [12.4019]

    loss = 'crossentropy'

In [58]:
def create_frames(waveform, duration=5, sr=32000):
    frame_size = int(duration * sr)
    surplus = waveform.size(-1)%frame_size
    if surplus > 0:
        waveform = waveform[:, :-surplus]
    frames = waveform.view(-1, 1, frame_size)
    return frames

class AudioDatasetInference(Dataset):
    def __init__(
            self, 
            files,
            targets = None, 
            n_classes = 182,
            duration = 5,
            sample_rate = 32000,
            target_length = 384,
            n_mels = 128,
            n_fft = 2028,
            window = 2028,
            hop_length = None,
            fmin = 20,
            fmax = 16000,
            top_db = 80,
            standardize=True,
            mean=None,
            std=None
            ):
        super(AudioDatasetInference, self).__init__()
        self.files = files
        self.targets = targets
        self.n_classes = n_classes
        self.duration = duration
        self.sample_rate = sample_rate
        self.audio_len = duration*sample_rate
        self.target_length = target_length
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.window = window
        self.hop_length = self.audio_len // (target_length-1) if not hop_length else hop_length
        self.fmin = fmin
        self.fmax = fmax
        self.top_db = top_db
        self.standardize = standardize

        self.to_mel_spectrogramn = nn.Sequential(
            torchaudio.transforms.MelSpectrogram(self.sample_rate, n_fft=self.n_fft, win_length=self.window,  
                                                 hop_length=self.hop_length, n_mels=self.n_mels, 
                                                 f_min=self.fmin, f_max=self.fmax),
            torchaudio.transforms.AmplitudeToDB(top_db=self.top_db)
        )
        if mean is not None:
            self.to_mel_spectrogramn.append(v2.Normalize(mean=mean, std=std))

    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        if self.targets is not None:
            label = torch.tensor(self.targets[idx])

        file = self.files[idx]
        waveform, sr = torchaudio.load(file)
        frames = create_frames(waveform)
        spec = self.to_mel_spectrogramn(frames)
        # Standardize
        if self.standardize:
            spec = (spec - spec.mean()) / spec.std()

        # expand to 3 channels for imagenet trained models
        spec = spec.expand(-1,3,-1,-1)

        if self.targets is not None:
            return spec, label
        else:
            return spec, file

In [59]:
class BasicClassifier(jit.ScriptModule):
    def __init__(self, n_classes, model_name, pretrained=True):
        super(BasicClassifier, self).__init__()
        weights = 'DEFAULT' if pretrained else None
        self.backbone = get_model(model_name, weights=weights).features
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Dropout(0.2, inplace=True),
            nn.Linear(1280, n_classes)
            )
        
    @jit.script_method
    def forward(self, x):
        x = self.backbone(x)
        x = self.pool(x).squeeze(dim=(-1,-2))
        x = self.classifier(x)
        return x

In [60]:
base_dir = 'data'
train_dir = base_dir + '/train_audio/'
test_dir = base_dir + '/test_soundscapes/'
unlabeled_dir = base_dir + '/unlabeled_soundscapes/'

class_names = sorted(os.listdir(train_dir))
n_classes = len(class_names)
class_labels = list(range(n_classes))
label2name = dict(zip(class_labels, class_names))
name2label = {v:k for k,v in label2name.items()}

In [61]:
test_paths = glob(base_dir + '/test_soundscapes/*ogg')
if len(test_paths)==0:
    test_paths = glob(base_dir + '/unlabeled_soundscapes/*ogg')[:10]
test_df = pd.DataFrame(test_paths, columns=['filepath'])
test_df.head()

Unnamed: 0,filepath
0,data/unlabeled_soundscapes/2003908684.ogg
1,data/unlabeled_soundscapes/514918048.ogg
2,data/unlabeled_soundscapes/62932451.ogg
3,data/unlabeled_soundscapes/735779963.ogg
4,data/unlabeled_soundscapes/1097804416.ogg


In [62]:
test_dataset = AudioDatasetInference(
    test_df['filepath'].values, 
    targets=None, 
    n_classes=Config.n_classes,
    duration=5,
    sample_rate=Config.sample_rate,
    target_length=Config.target_length,
    n_mels=Config.n_mels,
    n_fft=Config.n_fft,
    window=Config.window,
    hop_length=Config.hop_length,
    fmin=Config.fmin,
    fmax=Config.fmax,
    top_db=Config.top_db,
    standardize=Config.standardize,
    mean=Config.dataset_mean,
    std=Config.dataset_std
    )

In [63]:
device = torch.device('cpu')

model = BasicClassifier(n_classes, pretrained=False, model_name=Config.model_name).to(device)
checkpoint_name = Config.checkpoint
checkpoint = torch.load(checkpoint_name, map_location='cpu')
model.load_state_dict(checkpoint['model'])
model = torch.jit.optimize_for_inference(torch.jit.script(model.eval()))

In [64]:
checkpoint['model']

OrderedDict([('backbone.0.0.weight',
              tensor([[[[-3.4210e-01, -2.2651e-02,  3.9878e-01],
                        [-6.6069e-01, -8.1597e-02,  7.9850e-01],
                        [-6.8869e-01, -1.1586e-01,  7.2336e-01]],
              
                       [[-5.4256e-01, -7.1270e-02,  5.4674e-01],
                        [-9.2076e-01, -1.2037e-01,  1.0793e+00],
                        [-9.0337e-01, -9.9570e-02,  1.0339e+00]],
              
                       [[-1.8140e-01, -3.6956e-02,  2.1198e-01],
                        [-4.0819e-01, -1.9971e-02,  5.9507e-01],
                        [-4.5933e-01, -1.3819e-01,  4.7829e-01]]],
              
              
                      [[[ 4.3161e-02,  3.2609e-02, -3.3290e-02],
                        [ 5.9996e-02, -1.0217e-02, -3.8054e-02],
                        [ 5.2515e-02, -1.1411e-02, -6.8200e-02]],
              
                       [[-1.0465e-02, -1.3579e-01, -1.2469e-01],
                        [-1.2292e-01, 

In [66]:
checkpoint['model']['backbone.6.14.block.3.1.running_var'].dtype

torch.float32

In [28]:
ids = []
preds = np.empty(shape=(0, n_classes), dtype='float32')

test_iter = tqdm(range(len(test_dataset)))
for i in test_iter:
    specs, file = test_dataset[i]
    filename = file.split('/')[-1][:-4]
    specs = specs.to(device)
    
    with torch.no_grad():
        outs = model(specs)
        if Config.loss == 'crossentropy':
            outs = nn.functional.softmax(outs, dim=1).detach().cpu().numpy()
        elif Config.loss == 'bce':
            outs = outs.sigmoid().detach().cpu().numpy()

    frame_ids = [f'{filename}_{(frame_id+1)*5}' for frame_id in range(len(specs))]
    ids += frame_ids

    preds = np.concatenate([preds, outs], axis=0)

 50%|█████     | 5/10 [00:08<00:08,  1.61s/it]


KeyboardInterrupt: 

In [18]:
# Submit prediction
pred_df = pd.DataFrame(ids, columns=['row_id'])
pred_df.loc[:, class_names] = preds
pred_df.to_csv('submission.csv',index=False)

Unnamed: 0,row_id,asbfly,ashdro1,ashpri1,ashwoo2,asikoe2,asiope1,aspfly1,aspswi1,barfly1,...,whbwoo2,whcbar1,whiter2,whrmun,whtkin2,woosan,wynlau1,yebbab1,yebbul3,zitcis1
0,XC756601_5,0.005847,0.011639,0.002512,0.00011,0.01907,1.5e-05,0.000609,0.000143,6.9e-05,...,0.000811,0.000281,0.001206,0.000926,0.005042,0.010253,5.2e-05,0.000306,3.2e-05,0.004853
1,XC756601_10,0.008243,0.007165,0.001763,9.2e-05,0.018658,1.1e-05,0.000512,0.000136,0.000104,...,0.000528,0.000333,0.002916,0.001131,0.00492,0.010568,6.1e-05,0.000224,2.8e-05,0.005552
2,XC756601_15,0.010177,0.009778,0.003066,0.000166,0.017798,2e-05,0.000657,0.000119,0.000143,...,0.000895,0.000277,0.001253,0.001104,0.007026,0.00824,7.4e-05,0.000242,3.9e-05,0.006269
3,XC756601_20,0.006208,0.006755,0.002717,0.00015,0.014288,2.4e-05,0.000561,0.000175,0.000171,...,0.000782,0.000253,0.001888,0.001559,0.007132,0.009607,5.6e-05,0.000403,2.3e-05,0.012148
4,XC756601_25,0.006516,0.006837,0.001713,0.000164,0.020643,1.6e-05,0.000494,0.000309,8.7e-05,...,0.000877,0.000263,0.001813,0.001479,0.005192,0.010191,6e-05,0.000307,1.8e-05,0.007364
