In [2]:
import numpy as np
import torch
import torch.nn as nn
import torchaudio
import pandas as pd
import os
import gc
from tqdm import tqdm
from glob import glob
import time
import scipy
from functools import partial
from scipy import signal
from scipy.ndimage import gaussian_filter1d

import torch.jit as jit
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import v2
from torchvision.models import get_model
import timm

import openvino as ov
import openvino.properties as props
import openvino.properties.hint as hints
from concurrent import futures

  from .autonotebook import tqdm as notebook_tqdm


In [14]:
class Config:
    use_1_peak = True
    peak_filter = 'none'
    use_peaks = False
    duration = 5
    sample_rate = 32000
    target_length = 500 #!!!!!!!!!!!!
    n_mels = 128 #!!!!!!!!!!!!!!!
    n_fft = 1024
    window = 160 #!!!!!!!!!!!!!
    audio_len = duration*sample_rate
    hop_length = 64 #!!!!!!!!!!
    fmin = 50
    fmax = 16000
    top_db = 80

    n_classes = 182
    n_channels = 1 #!!!!!!!!!!!
    
    use_openvino = True
    multithreading = False
    checkpoint_dir = 'checkpoints/2024-06-03_11-50-16_128x500_mn20_as_fold-2'
    loss = 'crossentropy'
    ensemble_checkpoints = ['checkpoints/2024-05-23_00-55-30_256x256_convnextv2_tiny.fcmae_ft_in22k_in1k_fold-0',
                            'checkpoints/2024-05-23_00-55-30_256x256_convnextv2_tiny.fcmae_ft_in22k_in1k_fold-2',
                            'checkpoints/2024-05-23_00-55-30_256x256_convnextv2_tiny.fcmae_ft_in22k_in1k_fold-3'
                            ]
    ensemble_losses = ['crossentropy', 'crossentropy', 'crossentropy']

    standardize = False
    dataset_mean = [-16.8828]
    dataset_std = [12.4019]

In [46]:
def create_frames(waveform, duration=5, sr=32000):
    frame_size = int(duration * sr)
    surplus = waveform.size(-1)%frame_size
    if surplus > 0:
        waveform = waveform[:, :-surplus]
    frames = waveform.view(-1, 1, frame_size)
    return frames

def find_peaks_max(x, filter='savgol'):
    if filter == 'savgol':
        smooth_x = signal.savgol_filter(x, window_length=100, polyorder=2)
    elif filter == 'gaussian':
        smooth_x = gaussian_filter1d(x, sigma=25)
    else:
        smooth_x = x
    return smooth_x.argmax(axis=-1)

def window_around_peak(len_x, peak, window_size):
    half_window = window_size // 2
    start_index = max(0, peak - half_window)
    end_index = min(len_x, peak + half_window)

    # Adjust the window if it's too close to the borders
    if end_index - start_index < window_size:
        if start_index == 0:
            end_index = min(len_x, start_index + window_size)
        elif end_index == len_x:
            start_index = max(0, end_index - window_size)
    return start_index, end_index

class AudioDatasetInference(Dataset):
    def __init__(
            self, 
            files,
            cfg,
            targets = None
            ):
        super(AudioDatasetInference, self).__init__()
        self.files = files
        self.targets = targets
        self.n_classes = cfg.n_classes
        self.duration = cfg.duration
        self.sample_rate = cfg.sample_rate
        self.audio_len = self.duration*self.sample_rate
        self.target_length = cfg.target_length
        self.n_mels = cfg.n_mels
        self.n_fft = cfg.n_fft
        self.window = cfg.window
        self.hop_length = cfg.hop_length
        self.fmin = cfg.fmin
        self.fmax = cfg.fmax
        self.top_db = cfg.top_db
        self.standardize = cfg.standardize
        self.mean = cfg.dataset_mean
        self.std = cfg.dataset_std
        self.n_channels = cfg.n_channels
        self.use_1_peak = cfg.use_1_peak
        self.use_peaks = cfg.use_peaks
        self.peak_filter = cfg.peak_filter

        self.to_mel_spectrogramn = torchaudio.transforms.MelSpectrogram(self.sample_rate, n_fft=self.n_fft, win_length=self.window,  
                                                 hop_length=self.hop_length, n_mels=self.n_mels, 
                                                 f_min=self.fmin, f_max=self.fmax)

        self.mel_to_db = nn.Sequential(torchaudio.transforms.AmplitudeToDB(top_db=self.top_db))

        if self.mean is not None and self.std is not None:
            self.mel_to_db.append(v2.Normalize(mean=self.mean, std=self.std))

    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        if self.targets is not None:
            label = torch.tensor(self.targets[idx])

        file = self.files[idx]
        waveform, sr = torchaudio.load(file)
        frames = create_frames(waveform)
        spec = self.to_mel_spectrogramn(frames)

        if self.use_1_peak:
            per_frame_energy = spec.sum(dim=-2).squeeze().numpy()
            peaks = find_peaks_max(per_frame_energy, filter=self.peak_filter)
            new_spec = torch.empty((spec.size(0), self.n_channels, self.n_mels, self.target_length))
            for p in range(len(peaks)):
                start_index, end_index = window_around_peak(per_frame_energy.shape[-1], peaks[p], self.target_length)
                new_spec[p] = spec[p,:,:,start_index:end_index]
        
        elif self.use_peaks:
            per_frame_energy = spec.sum(dim=1).squeeze().numpy()
            peak1 = find_peak_max(per_frame_energy, filter=self.peak_filter)
            start_index, end_index = window_around_peak(len(per_frame_energy), peak, self.target_length)
            spec1 = spec[:,:,start_index:end_index]

        spec = self.mel_to_db(new_spec)

        # Standardize
        if self.standardize:
            spec = (spec - spec.mean()) / spec.std()

        # expand to 3 channels for imagenet trained models
        spec = spec.expand(-1, self.n_channels,-1,-1)

        if self.targets is not None:
            return spec, label
        else:
            return spec, file

In [47]:
base_dir = 'data'
train_dir = base_dir + '/train_audio/'
test_dir = base_dir + '/test_soundscapes/'
unlabeled_dir = base_dir + '/unlabeled_soundscapes/'

class_names = sorted(os.listdir(train_dir))
n_classes = len(class_names)
class_labels = list(range(n_classes))
label2name = dict(zip(class_labels, class_names))
name2label = {v:k for k,v in label2name.items()}

In [48]:
test_paths = glob(base_dir + '/test_soundscapes/*ogg')
if len(test_paths)==0:
    test_paths = glob(base_dir + '/unlabeled_soundscapes/*ogg')[:10]
test_df = pd.DataFrame(test_paths, columns=['filepath'])
test_df.head()

Unnamed: 0,filepath
0,data/unlabeled_soundscapes/646255149.ogg
1,data/unlabeled_soundscapes/1171835482.ogg
2,data/unlabeled_soundscapes/1590789246.ogg
3,data/unlabeled_soundscapes/115033522.ogg
4,data/unlabeled_soundscapes/1971688290.ogg


In [49]:
test_dataset = AudioDatasetInference(
    test_df['filepath'].values, 
    targets=None, 
    cfg=Config
    )

In [53]:
if Config.multithreading:
    def predict(specs, infer_request, final_activation):
        sample_preds = np.empty(shape=(0, n_classes), dtype='float32')
        start_time = time.time()
        outs = infer_request.infer([specs])[0]
        outs = final_activation(outs)
        model_time = time.time()-start_time
        sample_preds = np.concatenate([sample_preds, outs], axis=0)
        return sample_preds, model_time

    def helper(inputs):
        return predict(inputs[0], inputs[1], inputs[2])

    
    def get_model(model_id):
        core = ov.Core()
        checkpoint_ov = Config.ensemble_checkpoints[model_id] + '/checkpoint.xml'
        loss = Config.ensemble_losses[model_id]
        config = {hints.performance_mode: hints.PerformanceMode.THROUGHPUT}
        model = core.compile_model(checkpoint_ov, "CPU", config)
        infer_request = model.create_infer_request()
        return infer_request
    
    def get_final_activation(model_id):
        loss = Config.ensemble_losses[model_id]
        if loss == 'crossentropy':
            final_activation = partial(scipy.special.softmax, axis=1)
        elif loss == 'bce':
            final_activation = scipy.special.expit
        return final_activation
        

    start=time.time()

    models = [get_model(model_id) for model_id in range(len(Config.ensemble_checkpoints))]
    f_activations = [get_final_activation(model_id) for model_id in range(len(Config.ensemble_checkpoints))]
    
    preds = np.empty(shape=(0, n_classes), dtype='float32')
    ids = []
    ensemble_preds = np.empty(shape=(0, n_classes), dtype='float32')
    for i in range(len(test_dataset)):
        specs, file = test_dataset[i]
        filename = file.split('/')[-1][:-4]
        frame_ids = [f'{filename}_{(frame_id+1)*5}' for frame_id in range(len(specs))]
        ids += frame_ids
        
        ensemble_preds = []
        list_inputs = [(specs, models[k], f_activations[k]) for k in range(len(models))]
        with futures.ThreadPoolExecutor(max_workers=len(Config.ensemble_checkpoints)) as executor:
            for sample_preds, model_time in executor.map(helper, list_inputs):
                ensemble_preds.append(sample_preds)
                #print('model', model_time)
        ensemble_preds = np.array(ensemble_preds)
        #ensemble_preds = ensemble_preds.mean(axis=0)
        ensemble_preds = (ensemble_preds**2).mean(axis=0) ** 0.5
        preds = np.concatenate([preds, ensemble_preds], axis=0)

    print(time.time()-start)

In [55]:
if Config.use_openvino:
    start=time.time()
    
    checkpoint_ov = Config.checkpoint_dir + '/checkpoint.xml'
    config = {hints.performance_mode: hints.PerformanceMode.THROUGHPUT}
    core = ov.Core()
    model = core.compile_model(checkpoint_ov, "AUTO", config)


    ids = []
    preds = np.empty(shape=(0, n_classes), dtype='float32')
    output_layer = model.output(0)
    if Config.loss == 'crossentropy':
        final_activation = partial(scipy.special.softmax, axis=1)
    elif Config.loss == 'bce':
        final_activation = scipy.special.expit

    test_iter = tqdm(range(len(test_dataset)))
    for i in test_iter:
        #start_sample_time = time.time()
        specs, file = test_dataset[i]
        filename = file.split('/')[-1][:-4]
        #data_time = time.time()
        #print("data", data_time-start_sample_time)
        
        outs = model([specs])[output_layer]
        outs = final_activation(outs)
        #model_time = time.time()
        #print("model", model_time-data_time)

        frame_ids = [f'{filename}_{(frame_id+1)*5}' for frame_id in range(len(specs))]
        ids += frame_ids

        preds = np.concatenate([preds, outs], axis=0)
        #end_time = time.time()
        #print("end", end_time-model_time)

    print(time.time()-start)

100%|██████████| 10/10 [00:19<00:00,  1.96s/it]

20.404430866241455





In [54]:
if not Config.use_openvino and not Config.multithreading:
    device = torch.device('cpu')
    checkpoint_name = Config.checkpoint_dir + '/checkpoint.pth'
    model = src.models.BasicClassifier(Config.n_classes, pretrained=False, model_name=Config.model_name).to(device)
    checkpoint = torch.load(checkpoint_name, map_location='cpu')
    model.load_state_dict(checkpoint['model'])
    model = torch.jit.optimize_for_inference(torch.jit.script(model.eval()))

    ids = []
    preds = np.empty(shape=(0, n_classes), dtype='float32')

    test_iter = tqdm(range(len(test_dataset)))
    for i in test_iter:
        specs, file = test_dataset[i]
        filename = file.split('/')[-1][:-4]
        specs = specs.to(device)

        with torch.no_grad():
            outs = model(specs)
            if Config.loss == 'crossentropy':
                outs = nn.functional.softmax(outs, dim=1).detach().cpu().numpy()
            elif Config.loss == 'bce':
                outs = outs.sigmoid().detach().cpu().numpy()

        frame_ids = [f'{filename}_{(frame_id+1)*5}' for frame_id in range(len(specs))]
        ids += frame_ids

        preds = np.concatenate([preds, outs], axis=0)

In [18]:
# Submit prediction
pred_df = pd.DataFrame(ids, columns=['row_id'])
pred_df.loc[:, class_names] = preds
pred_df.to_csv('submission.csv',index=False)

Unnamed: 0,row_id,asbfly,ashdro1,ashpri1,ashwoo2,asikoe2,asiope1,aspfly1,aspswi1,barfly1,...,whbwoo2,whcbar1,whiter2,whrmun,whtkin2,woosan,wynlau1,yebbab1,yebbul3,zitcis1
0,XC756601_5,0.005847,0.011639,0.002512,0.00011,0.01907,1.5e-05,0.000609,0.000143,6.9e-05,...,0.000811,0.000281,0.001206,0.000926,0.005042,0.010253,5.2e-05,0.000306,3.2e-05,0.004853
1,XC756601_10,0.008243,0.007165,0.001763,9.2e-05,0.018658,1.1e-05,0.000512,0.000136,0.000104,...,0.000528,0.000333,0.002916,0.001131,0.00492,0.010568,6.1e-05,0.000224,2.8e-05,0.005552
2,XC756601_15,0.010177,0.009778,0.003066,0.000166,0.017798,2e-05,0.000657,0.000119,0.000143,...,0.000895,0.000277,0.001253,0.001104,0.007026,0.00824,7.4e-05,0.000242,3.9e-05,0.006269
3,XC756601_20,0.006208,0.006755,0.002717,0.00015,0.014288,2.4e-05,0.000561,0.000175,0.000171,...,0.000782,0.000253,0.001888,0.001559,0.007132,0.009607,5.6e-05,0.000403,2.3e-05,0.012148
4,XC756601_25,0.006516,0.006837,0.001713,0.000164,0.020643,1.6e-05,0.000494,0.000309,8.7e-05,...,0.000877,0.000263,0.001813,0.001479,0.005192,0.010191,6e-05,0.000307,1.8e-05,0.007364
