In [1]:
!pip install timm --no-index --find-links=file:///kaggle/input/ast-materials/timm04
!pip install audiomentations --no-index --find-links=file:///kaggle/input/packages/audiomentations/pck/

In [2]:
import timm
import logging
import time
import os
import copy
import sys
import time
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import List
from typing import Optional
from sklearn import metrics
from tqdm import tqdm
import random

import gc
import librosa

import scipy
from IPython.display import Audio, display

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.utils.data as torchdata
import torchaudio as ta
import soundfile

#from torch_audiomentations import Compose, Gain, Shift, PeakNormalization, PitchShift, AddColoredNoise
from torchaudio import transforms as T


from typing import Callable
from torch.utils.data import DataLoader
from torch.utils.data import Dataset



from ast import literal_eval as LE
from IPython.display import clear_output

import albumentations as A
import albumentations.pytorch.transforms as T
import matplotlib.pyplot as plt

#import noisereduce as nr
import audiomentations
from audiomentations import Normalize as Normalize_aud


In [3]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = True  # type: ignore
    torch.backends.cudnn.benchmark = True  # type: ignore

def get_device() -> torch.device:
    if torch.cuda.is_available():
        print("cuda: {}\n".format(torch.cuda.get_device_name()))
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


#device = get_device()
set_seed(42)


In [4]:
class CFG:
    ######################
    # Globals #
    ######################
    seed = 42    
    
    num_classes = 152
    in_channels_ast = 1
    target_columns = 'afrsil1 akekee akepa1 akiapo akikik amewig aniani apapan arcter \
                      barpet bcnher belkin1 bkbplo bknsti bkwpet blkfra blknod bongul \
                      brant brnboo brnnod brnowl brtcur bubsan buffle bulpet burpar buwtea \
                      cacgoo1 calqua cangoo canvas caster1 categr chbsan chemun chukar cintea \
                      comgal1 commyn compea comsan comwax coopet crehon dunlin elepai ercfra eurwig \
                      fragul gadwal gamqua glwgul gnwtea golphe grbher3 grefri gresca gryfra gwfgoo \
                      hawama hawcoo hawcre hawgoo hawhaw hawpet1 hoomer houfin houspa hudgod iiwi incter1 \
                      jabwar japqua kalphe kauama laugul layalb lcspet leasan leater1 lessca lesyel lobdow lotjae \
                      madpet magpet1 mallar3 masboo mauala maupar merlin mitpar moudov norcar norhar2 normoc norpin \
                      norsho nutman oahama omao osprey pagplo palila parjae pecsan peflov perfal pibgre pomjae puaioh \
                      reccar redava redjun redpha1 refboo rempar rettro ribgul rinduc rinphe rocpig rorpar rudtur ruff \
                      saffin sander semplo sheowl shtsan skylar snogoo sooshe sooter1 sopsku1 sora spodov sposan \
                      towsol wantat1 warwhe1 wesmea wessan wetshe whfibi whiter whttro wiltur yebcar yefcan zebdov'.split()
    
    scored = ["akiapo", "aniani", "apapan", "barpet", "crehon", "elepai", "ercfra", "hawama", "hawcre",
              "hawgoo", "hawhaw", "hawpet1", "houfin", "iiwi", "jabwar", "maupar", 
              "omao", "puaioh", "skylar", "warwhe1", "yefcan"]
    
    rare = ['omao', 'akiapo', 'barpet','hawama', 'elepai', 'aniani', 'hawgoo', 'ercfra', 'maupar',
            'hawpet1', 'hawhaw', 'crehon', 'puaioh']

    period = 5
    n_mels = 224 # 128
    n_mels_sed = 224
    sample_rate = 32000
    
    var_mean = -6.98279
    var_std = 3.05492
    
    cuda_num = 0
    device = get_device()
    
    base_model_name = "tf_efficientnet_b0_ns"
    pooling = "max"
    pretrained = False
    in_channels = 3

In [5]:
def interpolate(x: torch.Tensor, ratio: int):
    """Interpolate data in time domain. This is used to compensate the
    resolution reduction in downsampling of a CNN.
    Args:
      x: (batch_size, time_steps, classes_num)
      ratio: int, ratio to interpolate
    Returns:
      upsampled: (batch_size, time_steps * ratio, classes_num)
    """
    (batch_size, time_steps, classes_num) = x.shape
    upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
    upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
    return upsampled


def pad_framewise_output(framewise_output: torch.Tensor, frames_num: int):
    """Pad framewise_output to the same length as input frames. The pad value
    is the same as the value of the last frame.
    Args:
      framewise_output: (batch_size, frames_num, classes_num)
      frames_num: int, number of frames to pad
    Outputs:
      output: (batch_size, frames_num, classes_num)
    """
    output = F.interpolate(
        framewise_output.unsqueeze(1),
        size=(frames_num, framewise_output.size(2)),
        align_corners=True,
        mode="bilinear").squeeze(1)

    return output

class AttBlockV2(nn.Module):
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 activation="linear"):
        super().__init__()

        self.activation = activation
        self.att = nn.Conv1d(
            in_channels=in_features,
            out_channels=out_features,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True)
        self.cla = nn.Conv1d(
            in_channels=in_features,
            out_channels=out_features,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True)

    def forward(self, x):
        # x: (n_samples, n_in, n_time)
        norm_att = torch.softmax(torch.tanh(self.att(x)), dim=-1)
        cla = self.nonlinear_transform(self.cla(x))
        x = torch.sum(norm_att * cla, dim=2)
        return x, norm_att, cla

    def nonlinear_transform(self, x):
        if self.activation == 'linear':
            return x
        elif self.activation == 'sigmoid':
            return torch.sigmoid(x)

class Model_SED(nn.Module):
    def __init__(self, base_model_name: str, pretrained=False, num_classes=24,
                 in_channels=1, dropout_rate=0.3):
        super().__init__()
        
        self._dropout_rate = dropout_rate
        
        self.bn0 = nn.BatchNorm2d(CFG.n_mels_sed)

        base_model = timm.create_model(
            base_model_name, pretrained=pretrained, in_chans=in_channels)
        layers = list(base_model.children())[:-2]
        self.encoder = nn.Sequential(*layers)

        if hasattr(base_model, "fc"):
            in_features = base_model.fc.in_features
        else:
            in_features = base_model.classifier.in_features

        self.fc1 = nn.Linear(in_features, in_features, bias=True)
        self.att_block = AttBlockV2(
            in_features, num_classes, activation="sigmoid")
        

    def forward(self, input_data):
        x = input_data # (batch_size, 3, time_steps, mel_bins)

        frames_num = x.shape[2]

        x = x.transpose(1, 3)
        x = self.bn0(x)
        x = x.transpose(1, 3)

        x = x.transpose(2, 3)

        x = self.encoder(x)
        
        # Aggregate in frequency axis
        x = torch.mean(x, dim=3)

        x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
        x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
        x = x1 + x2

        x = F.dropout(x, p=self._dropout_rate, training=self.training)
        x = x.transpose(1, 2)
        x = F.relu_(self.fc1(x))
        x = x.transpose(1, 2)
        x = F.dropout(x, p=self._dropout_rate, training=self.training)

        (clipwise_output, norm_att, segmentwise_output) = self.att_block(x)
        logit = torch.sum(norm_att * self.att_block.cla(x), dim=2)
        segmentwise_logit = self.att_block.cla(x).transpose(1, 2)
        segmentwise_output = segmentwise_output.transpose(1, 2)

        interpolate_ratio = frames_num // segmentwise_output.size(1)

        # Get framewise output
        framewise_output = interpolate(segmentwise_output,
                                       interpolate_ratio)
        framewise_output = pad_framewise_output(framewise_output, frames_num)

        framewise_logit = interpolate(segmentwise_logit, interpolate_ratio)
        framewise_logit = pad_framewise_output(framewise_logit, frames_num)

        output_dict = {
            'framewise_output': framewise_output,
            'clipwise_output': clipwise_output,
            'logit': logit,
            'framewise_logit': framewise_logit,
        }

        return output_dict
    
class Model_linear(nn.Module):
    def __init__(self, base_model_name: str, pretrained=False, num_classes=24,
                 in_channels=1, dropout_rate=0.4):
        super().__init__()

        self.encoder = timm.create_model(base_model_name, pretrained=pretrained, in_chans=in_channels)

        self.fc1 = nn.Linear(1000, num_classes)        

    def forward(self, input_data):
        x = self.encoder(input_data)
        x = self.fc1(x)
        return x

In [6]:
AUDIO_PATH = '../input/birdclef-2022/train_audio'
CLASSES = CFG.target_columns
NUM_CLASSES = len(CLASSES)
class AudioParams:
    """
    Parameters used for the audio data
    """
    sr = 32000
    duration = 5
    # Melspectrogram
    n_mels = 224
    fmin = 20
    fmax = 16000
    hop_length = 501
    n_fft=2048
    
TARGET_SR = 32000
DATADIR = Path("../input/birdclef-2022/test_soundscapes/")
datadir2 = Path("../input/sndscps10/ana_b")
datadir3 = Path("../input/sndscps10/ana_no")

if len(list(DATADIR.glob("*.ogg"))) == 1:
    all_audios = list(DATADIR.glob("*.ogg")) + list(datadir2.glob("*.wav")) + list(datadir3.glob("*.wav"))
else:
    all_audios = list(DATADIR.glob("*.ogg"))

sample_submission = pd.read_csv('../input/birdclef-2022/sample_submission.csv')
sample_submission

In [7]:
def get_mean_std(df):
    mean=[]
    std=[]
    SR_ = 32000
    for i in tqdm(range(len(df))):
        clip, sr = ta.load(df[i])
        if clip.size(0) != 1:
            clip = clip.mean(dim=0, keepdim=True)
            
        if sr != SR_:
            clip = taF.resample(clip, sr, SR_, lowpass_filter_width=64,
                                            rolloff=0.9475937167399596, resampling_method="kaiser_window",
                                            beta=14.769656459379492)
        
        audio_input = clip - clip.mean()
        audio_input = ta.compliance.kaldi.fbank(audio_input, htk_compat=True, sample_frequency=SR_, use_energy=False,
                                                  window_type='hanning', num_mel_bins=224, 
                                                  dither=0.0, frame_shift=9.7)
        
        cur_mean = torch.mean(audio_input)
        cur_std = torch.std(audio_input)
        mean.append(cur_mean)
        std.append(cur_std)
        #print(cur_mean, cur_std)
    return np.mean(mean), np.mean(std)

#mean, std = get_mean_std(train_data)

GENERATE_NEW_MEAN_STD = False
if GENERATE_NEW_MEAN_STD:
    tar_mean, tar_std = get_mean_std(all_audios)
else:
    tar_mean, tar_std = CFG.var_mean, CFG.var_std
    
print(tar_mean, tar_std)

In [8]:
val_audio_conf = {'num_mel_bins': 224, 'target_length': 512, 'freqm': 0, 'timem': 0, 
              'mixup': -1, 'skip_norm': True, 'mode': 'eval', 'dataset': 'Birds custom',
             'mean':tar_mean, 'std': tar_std, 'noise': False}


def crop_or_pad(y, sr=CFG.sample_rate, length=CFG.sample_rate*CFG.period, mode='eval'):
    leny = y.size(1)
    if leny <= length:
        y = torch.nn.functional.pad(y, (0,length - leny), "constant", 0)
    else:
        if mode != 'train':
            start_ = 0
        else:        
            start_ = np.random.randint(leny - length)
        y = y[:, start_: start_ + length]
    return y

def compute_melspec(y, params, to_db=True):
    """
    Computes a mel-spectrogram and puts it at decibel scale
    Arguments:
        y {np array} -- signal
        params {AudioParams} -- Parameters to use for the spectrogram. Expected to have the attributes sr, 
        n_mels, f_min, f_max
    Returns:
        np array -- Mel-spectrogram
    """
    melspec = librosa.feature.melspectrogram(
        y=y, sr=params.sr, n_mels=params.n_mels, fmin=params.fmin, fmax=params.fmax,
        hop_length=params.hop_length, n_fft=params.n_fft)  #**2.5 #####
    if to_db:
        melspec = librosa.power_to_db(melspec).astype(np.float32)
    return melspec

def mono_to_color(X, eps=1e-6, mean=None, std=None):
    X = np.stack([X, X, X], axis=-1)

    # Standardize
    mean = mean or X.mean()
    std = std or X.std()
    X = (X - mean) / (std + eps)

    # Normalize to [0, 255]
    _min, _max = X.min(), X.max()

    if (_max - _min) > eps:
        V = np.clip(X, _min, _max)
        V = 255 * (V - _min) / (_max - _min)
        V = V.astype(np.uint8)
    else:
        V = np.zeros_like(X, dtype=np.uint8)

    return V

mean = (0.485, 0.456, 0.406) # RGB
std = (0.229, 0.224, 0.225) # RGB

albu_transforms = {'valid' : A.Compose([A.Normalize(mean, std)])}

In [9]:

class AudiosetDataset(Dataset):
    def __init__(self, df_in: pd.DataFrame, clip_ast, audio_conf=val_audio_conf, label_csv=None):
        """
        Dataset that manages audio recordings
        :param audio_conf: Dictionary containing the audio loading and preprocessing settings
        :param dataset_json_file
        """
        self.data = df_in
        self.clip_ast = clip_ast

        self.audio_conf = audio_conf
        
        self.melbins = self.audio_conf.get('num_mel_bins')    
        
        self.dataset = self.audio_conf.get('dataset')
        # dataset spectrogram mean and std, used to normalize the input
        self.norm_mean = self.audio_conf.get('mean')
        self.norm_std = self.audio_conf.get('std')
        # if add noise for data augmentation
        self.noise = self.audio_conf.get('noise')
        self.label_num = len(CFG.target_columns)
        
        self.SR_ = 32000
        self.wt = audiomentations.Compose([Normalize_aud(p=1)])
        
    def _wav2fbank(self, waveform):
            
        #waveform, sr = ta.load(filename)
        waveform = crop_or_pad(waveform, sr=self.SR_, mode='eval')
        waveform = waveform - waveform.mean()
        waveform = torch.tensor(self.wt(samples=waveform[0].numpy(), sample_rate=self.SR_)).unsqueeze(0)
        fbank = ta.compliance.kaldi.fbank(waveform, htk_compat=True, sample_frequency=self.SR_, 
                                                  use_energy=False, window_type='hanning', 
                                                  num_mel_bins=self.melbins, dither=0.0, frame_shift=9.7)
        #fbank = fbank ** 2.5 #########
        target_length = self.audio_conf.get('target_length')
        n_frames = fbank.shape[0]
        p = target_length - n_frames
        # cut and pad
        if p > 0:
            m = torch.nn.ZeroPad2d((0, 0, 0, p))
            fbank = m(fbank)
        elif p < 0:
            fbank = fbank[0:target_length, :]
        return fbank
    
    def __len__(self):
        return len(self.data)
            

    def __getitem__(self, index):
        """
        returns: image, audio, nframes
        where image is a FloatTensor of size (3, H, W)
        audio is a FloatTensor of size (N_freq, N_frames) for spectrogram, or (N_frames) for waveform
        nframes is an integer
        """
        SR = 32000
        sample = self.data.loc[index, :]
        row_id = sample.row_id
        
        end_seconds = int(sample.seconds)
        start_seconds = int(end_seconds - 5)
        
        # ast preproc
        img_ast = self.clip_ast[:, SR*start_seconds:SR*end_seconds]         
        
        fbank = self._wav2fbank(img_ast) 
        #fbank = torch.transpose(fbank, 0, 1)
        fbank = (fbank - self.norm_mean) / (self.norm_std * 2)

        fbank = torch.stack([fbank, fbank, fbank])
        
        
        image = img_ast[0].numpy()
        image = self.wt(samples=image, sample_rate=SR)
        image = np.nan_to_num(image)
        
        image = compute_melspec(image, AudioParams, to_db=True)  
        image = mono_to_color(image)
        
        image = image.astype(np.uint8)
        image = albu_transforms['valid'](image=image)['image'].T

        # the output fbank shape is [time_frame_num, frequency_bins], e.g., [1024, 128]
        return {
            "image_ast": fbank,
            "image_sed": image,
            "row_id": row_id,
        }


In [10]:
######################################################################################################
path_image_base = '../input/checks-triple/_att_6_BEST_FOR_IMAGE.ckpt'

path_5_2_6 = '../input/checks-triple/_att_5_2_trans_new_data_811.ckpt'  

path_6_1_3 = '../input/checks-triple/_att_6_1_03secondary_0845.ckpt'

path_pack1 = [path_image_base, path_5_2_6, path_6_1_3]
######################################################################################################

models_pack1 = []
for ppath in path_pack1:
    
    model_image_base =  Model_SED(
                                base_model_name=CFG.base_model_name,
                                pretrained=CFG.pretrained,
                                num_classes=CFG.num_classes,
                                in_channels=CFG.in_channels)

    model_image_base.to(CFG.device)
    model_image_base.load_state_dict(torch.load(ppath, map_location=CFG.device))  # path_image_base
    model_image_base.eval()
    models_pack1.append(('sed', model_image_base))

print(f'Length of model pack 1: {len(models_pack1)} ')

######################################################################################################
path_9_f0 = '../input/checks-triple/app9/_9_att_1_best05_fold0.ckpt'
path_9_f1 = '../input/checks-triple/app9/_9_att_1_best05_fold1.ckpt'
path_9_f2 = '../input/checks-triple/app9/_9_att_1_best05_fold2.ckpt'
path_9_f3 = '../input/checks-triple/app9/_9_att_2_best05_fold3.ckpt'
path_9_f4 = '../input/checks-triple/app9/_9_att_2_best05_fold4.ckpt'
path_9_f5 = '../input/checks-triple/app9/_9_att_3_best05_fold5.ckpt'
path_9_f6 = '../input/checks-triple/app9/_9_att_3_best05_fold6.ckpt'

path_pack2 = [path_9_f0, path_9_f1, path_9_f2, path_9_f3, path_9_f4, path_9_f5, path_9_f6]

######################################################################################################
models_pack2 = []
for ppath in path_pack2:
    model_9_f0 =  Model_SED(
                                base_model_name=CFG.base_model_name,
                                pretrained=CFG.pretrained,
                                num_classes=CFG.num_classes,
                                in_channels=CFG.in_channels)

    model_9_f0.to(CFG.device)
    model_9_f0.load_state_dict(torch.load(ppath, map_location=CFG.device))  # path_image_base
    model_9_f0.eval()
    models_pack2.append(('sed', model_9_f0))

print(f'Length of model pack 2: {len(models_pack2)} ')


######################################################################################################
path_f0 = '../input/checks-triple/_att_6_1_03secondary_0858.ckpt'
path_f1 = '../input/checks-triple/_att_6_2_05secondary_0850.ckpt'
path_f2 = '../input/checks-triple/app8/_att_4_best05_fold0_.ckpt'
path_f3 = '../input/checks-triple/app8/_att_7_best_fold0.ckpt'
path_f4 = '../input/checks-triple/app8/_att_1_best05_fold1_.ckpt'

path_pack3 = [path_f0, path_f1, path_f2, path_f4] #path_f3, path_f4]

######################################################################################################
models_pack3 = []
for ppath in path_pack3:
    model_9_f0 =  Model_SED(
                                base_model_name=CFG.base_model_name,
                                pretrained=CFG.pretrained,
                                num_classes=CFG.num_classes,
                                in_channels=CFG.in_channels)

    model_9_f0.to(CFG.device)
    model_9_f0.load_state_dict(torch.load(ppath, map_location=CFG.device))  # path_image_base
    model_9_f0.eval()
    models_pack3.append(('sed', model_9_f0))

print(f'Length of model pack 3: {len(models_pack3)} ')

# SEDMIX
path_SEDMIX_1 = '../input/ftw-new/CHECKS_SEDMIX/SEDMIX_f02_last_epoch.bin'
path_SEDMIX_2 = '../input/ftw-new/CHECKS_SEDMIX/SEDMIX_f02_best_03_epoch_train830_eval767.bin'
path_SEDMIX_3 = '../input/ftw-new/CHECKS_SEDMIX/SEDMIX_f02_best_03_epoch_train833_eval766.bin'

path_SEDMIXFULL_pret_1 = '../input/ftw-new/CHECKS_SEDMIX_FULL/SEDMIXFULL_fromcheck_TARGET_03BEST_881_981_808.bin'
path_SEDMIXFULL_pret_2 = '../input/ftw-new/CHECKS_SEDMIX_FULL/SEDMIXFULL_fromcheck_TARGET_03BEST_881_987_809.bin'
path_SEDMIXFULL_pret_3 = '../input/ftw-new/CHECKS_SEDMIX_FULL/SEDMIXFULL_fromcheck_TRAIN_03BEST_902_985_820.bin'

path_SEDMIXFULL_1 = '../input/ftw-new/CHECKS_SEDMIX_FULL/SEDMIXFULL_VALID_03BEST_839_965_770.bin'
path_SEDMIXFULL_2 = '../input/ftw-new/CHECKS_SEDMIX_FULL/SEDMIXFULL_TRAIN_03BEST_870_968_801.bin'

path_pack4 = [path_SEDMIX_2, path_SEDMIXFULL_pret_3, path_SEDMIXFULL_2] 
models_pack4 = []
for ppath in path_pack4:     
    model_image_base =  Model_SED(
                                base_model_name=CFG.base_model_name,
                                pretrained=CFG.pretrained,
                                num_classes=CFG.num_classes,
                                in_channels=CFG.in_channels)

    model_image_base.to(CFG.device)
    model_image_base.load_state_dict(torch.load(ppath, map_location=CFG.device))  # path_image_base
    model_image_base.eval()
    models_pack4.append(('ast', model_image_base))

print(f'Length of model pack 4: {len(models_pack4)} ')


path_pseudo_1 = '../input/ftw-new/pseudo/PSEUDO_last_epoch_SAVE_974_971_893.bin'
path_pseudo_2 = '../input/ftw-new/pseudo/PSEUDO_TARGET_03BEST_GO.bin'
path_pseudo_3 = '../input/ftw-new/pseudo/PSEUDO_last_epoch.bin'


path_pack5 = [path_pseudo_1, path_pseudo_2, path_pseudo_3] 
models_pack5 = []
for ppath in path_pack5:     
    model_image_base =  Model_SED(
                                base_model_name=CFG.base_model_name,
                                pretrained=CFG.pretrained,
                                num_classes=CFG.num_classes,
                                in_channels=CFG.in_channels)

    model_image_base.to(CFG.device)
    model_image_base.load_state_dict(torch.load(ppath, map_location=CFG.device))  # path_image_base
    model_image_base.eval()
    models_pack5.append(('ast', model_image_base))

print(f'Length of model pack 5: {len(models_pack5)} ')



# NOCALL CLS

path_NOCALL_1 = '../input/ftw-new/NOCALLCLS_f00_last_epoch.bin'
path_NOCALL_2 = '../input/ftw-new/NOCALLCLS_f00_best_930.bin'
path_NOCALL_3 = '../input/ftw-new/NOCALLCLS_f00_best_936.bin'

model_cls = Model_linear(base_model_name=CFG.base_model_name,
                        pretrained=False,
                        num_classes=2,
                        in_channels=CFG.in_channels)

model_cls.to(CFG.device)
model_cls.load_state_dict(torch.load(path_NOCALL_3, map_location=CFG.device))
model_cls.eval()
print()

In [11]:
def prediction_for_clip(test_df, 
                        clip_ast,
                        device,
                        models, model_cls,
                        threshold,
                        weights,
                        wgt_models,
                        wg_full):
    BS = 12 #isinstance(model_pack[3][0], ASTModel)
    dataset = AudiosetDataset(test_df, clip_ast)
    
    loader = DataLoader(dataset, batch_size=BS, shuffle=False)
    prediction_dict = {}
    for data in loader:
        row_id = data['row_id']
        image_ast = data['image_ast'].to(device)        
        image_sed = data['image_sed'].to(device)
    
        with torch.no_grad():
            full_events_all = []
            #nocl = model_cls(image_ast.transpose(2,3)).softmax(1).cpu().numpy()
            #nocl = np.array([float(fr[0] < 0.83) for fr in nocl])
            #print(row_id[0], nocl)

            for md in range(len(models)):
                
                full_events = []
                for i in range(len(models[md])):
                    with torch.cuda.amp.autocast():
                        if models[md][i][0] == 'ast':
                            output = models[md][i][1](image_ast)
                        else:
                            output = models[md][i][1](image_sed)
                        probs = output['clipwise_output'].detach().cpu().numpy() 
                        #probs = torch.sigmoid(output['logit']).detach().cpu().numpy() 

                        probs = probs * weights[md][i] #* wg_full[md]
                        #probs = probs * nocl.reshape(-1, 1)
                    full_events.append(probs)
                    
                full_events = np.array(full_events).transpose(1, 0, 2)
                
                ### POST PROCESSING
                USE_MEAN_MEDIAN = True
                USE_MAX_ADDER = True

                if USE_MEAN_MEDIAN:
                    full_med = np.median(full_events, axis=1) 
                    full_mean = np.mean(full_events, axis=1)
                    full_events = np.mean(np.array([full_med, full_mean]), axis=0)                
                else:        
                    full_events = full_events * wg_full.reshape(-1, 1) 
                    full_events = np.sum(full_events, axis=1)  
                if USE_MAX_ADDER:
                    logits_max = full_events.max(0)
                    for jk in range(full_events.shape[1]):
                        if logits_max[jk] > threshold * 2.5:
                            full_events[:, jk] += threshold * 0.5
                            
                full_events_all.append(full_events)
                
        full_events_all = np.array(full_events_all).transpose(1, 0, 2)
        full_events_all = np.mean(full_events_all, axis=1)        
        
        full_events_all = full_events_all >= threshold
        
        for kk in range(full_events_all.shape[0]):
            labels = np.argwhere(full_events_all[kk]).reshape(-1).tolist()

            labels_str_list = list(map(lambda x: CFG.target_columns[x], labels))
            #labels_str_list = [x for x in labels_str_list if x in CFG.scored]
            label_string = " ".join(labels_str_list)
            if len(label_string) == 0:
                label_string = "nocall"
            prediction_dict[str(row_id[kk])] = label_string
    return prediction_dict

In [12]:
def prediction(test_audios,
               models, model_cls,
               threshold,
               weights,
               wgt_models, 
               wg_full):
    
    # models = [model]
    warnings.filterwarnings("ignore")
    device = get_device()
    prediction_dicts = {}
    SR_ = 32000
    sec_60_len = SR_ * 60
    for audio_path in test_audios:            
        clip_ast, sr = ta.load(audio_path)

        if clip_ast.size(0) != 1:
            clip_ast = clip_ast.mean(dim=0, keepdim=True)
            
        if clip_ast.size(1) < sec_60_len:
            clip_ast = crop_or_pad(clip_ast, SR_, sec_60_len)
        if sr != SR_:
            clip_ast = taF.resample(clip_ast, sr, SR_, lowpass_filter_width=64,
                                            rolloff=0.9475937167399596, resampling_method="kaiser_window",
                                            beta=14.769656459379492)
        
        seconds = []
        row_ids = []
        for second in range(5, 65, 5):
            row_id = "_".join(audio_path.name.split(".")[:-1]) + f"_{second}"
            seconds.append(second)
            row_ids.append(row_id)
        #print(row_ids)
        test_df = pd.DataFrame({
            "row_id": row_ids,
            "seconds": seconds
        })
        
        prediction_dict = prediction_for_clip(test_df,
                                                clip_ast=clip_ast,  
                                                device=device,
                                                models=models,model_cls=model_cls,
                                                threshold=threshold,
                                                weights=weights,
                                                wgt_models=wgt_models,
                                                wg_full=wg_full)

        prediction_dicts.update(prediction_dict)
    return prediction_dicts

In [13]:
weights_p1_base_b = np.load('../input/checks-triple/weights_pu.npy')
weights_p1_2_b = np.load('../input/checks-triple/weights_ar_250_12.npy')
weights_p1_3_b = np.load('../input/checks-triple/weights_ar_300_10.npy')
### FOR PACK 1
#####################################################################################

weights_p2 = np.load('../input/checks-triple/app8/weights_ar_WAVE_800_7.npy')
### FOR PACK 2
#####################################################################################

weights_p3_1 = np.load('../input/checks-triple/weights_ar_300_10.npy')
weights_p3_2 = np.load('../input/checks-triple/app8/weights_ar_WAVE_420_7.npy')
weights_p3_3 = np.load('../input/checks-triple/app8/weights_ar_WAVE_800_7.npy')
weights_p3_4 = np.load('../input/checks-triple/app8/weights_ar_WAVE_250_7.npy')
### FOR PACK 3
#####################################################################################

In [14]:
st_p1 = 1. # 0.5 
weights_ar_p1 = [weights_p1_base_b**st_p1, weights_p1_2_b**st_p1, weights_p1_3_b**st_p1]
#wgt_models_p1 = np.array([0.19, 0.28, 0.53])
### FOR PACK 1
#####################################################################################

st_p2 = 0.35 # 0.35
weights_ar_p2 = [weights_p2**st_p2, weights_p2**st_p2, weights_p2**st_p2, weights_p2**st_p2,
                 weights_p2**st_p2, weights_p2**st_p2, weights_p2**st_p2]
#wgt_models_p2 = np.array([0.15, 0.14, 0.14, 0.14, 0.15, 0.14, 0.14])
### FOR PACK 2
#####################################################################################

st_p3 = 1. #0.5
#weights_ar_p3 = [weights_p3_1**st_p3, weights_p3_1**st_p3, weights_p3_2**st_p3, weights_p3_2**st_p3, weights_p3_2**st_p3]
weights_ar_p3 = [weights_p3_1**st_p3, weights_p3_1**st_p3, weights_p3_2**st_p3, weights_p3_2**st_p3]
#wgt_models_p3 = np.array([0.17, 0.17, 0.2, 0.26, 0.2])
### FOR PACK 3
#####################################################################################


In [15]:
### FOR PACK 4 SEDMIX
#####################################################################################
weights_p4 = np.load('../input/ftw-new/CHECKS_SEDMIX/SEDMIX_weights_ar_ast_300_15.npy')
floor = 8
weights_p4 = np.where(weights_p4 > floor, floor, weights_p4)

weights_p4_2 = np.load('../input/ftw-new/CHECKS_SEDMIX_FULL/SEDMIX_NOFOLDS_weights_ar_ast_300_15.npy')
floor = 8
weights_p4_2 = np.where(weights_p4_2 > floor, floor, weights_p4_2)

st_p4 = 0.6  #0.6
weights_ar_p4 = [weights_p4**st_p4, weights_p4_2**st_p4, weights_p4_2**st_p4]
#wgt_models_p4 = np.array([0.32, 0.35, 0.33])
### FOR PACK 4  SEDMIX
#####################################################################################



weights_p5 = np.load('../input/ftw-new/pseudo/PSEUDO_WGTS_300_15.npy')
floor = 8
weights_p5 = np.where(weights_p5 > floor, floor, weights_p5)

st_p5 = 0.6
weights_ar_p5 = [(weights_p5**st_p5), (weights_p5**st_p5), (weights_p5**st_p5)]

In [16]:
print_it = False
if print_it:

    for kk in zip (CFG.target_columns,np.around(weights_p4**st_p4, decimals=2),
                                      np.around(weights_p5**st_p5/8*0, decimals=2)):
        if kk[0] in CFG.scored:
            print(kk)

In [17]:
#model_pack = models_pack2 + models_pack4 + models_pack5
#weights_pack = weights_ar_p2 + weights_ar_p4 + weights_ar_p5

model_pack = [models_pack2 + models_pack4, models_pack5]
weights_pack = [weights_ar_p2 + weights_ar_p4, weights_ar_p5]

print(len(model_pack), len(weights_pack))



In [18]:
threshold = 0.0795 

import time
sddt = time.time()

prediction_dicts = prediction(test_audios=all_audios,
                               models=model_pack, model_cls=model_cls,
                               threshold=threshold, 
                               weights=weights_pack,
                               wgt_models=None,
                               wg_full=None)

for i in range(len(sample_submission)):
    sample = sample_submission.row_id[i]
    key = sample.split("_")[0] + "_" + sample.split("_")[1] + "_" + sample.split("_")[3]
    target_bird = sample.split("_")[2]
    #print(key, target_bird)
    if key in prediction_dicts:
        sample_submission.iat[i, 1] = (target_bird in prediction_dicts[key])
sample_submission.to_csv("submission.csv", index=False)

print(f'Elapsed time: {time.time() - sddt:.3f} seconds')
print()

In [19]:
#prediction_dicts
if len(list(DATADIR.glob("*.ogg"))) == 1:
    cntr = 0    
    for k in prediction_dicts:
        print(f'{k}: ',[x for x in prediction_dicts[k].split(' ') if x in CFG.scored])
        cntr += 1
        if cntr % 12 == 0:
            print('\n')