In [1]:
##########################
#### standard library ####
##########################
import os
import copy
import time
import gc
from itertools import  filterfalse as ifilterfalse
import warnings
import random
from typing import List, Tuple, Dict, Any, Union, Optional, Sequence, Iterable, Callable
import shutil
# warnings.filterwarnings("ignore")

###################
#### 3rd party ####
###################
import torch
import torchaudio
import torchvision
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable
from torch.optim import AdamW
from torch.utils.data.sampler import WeightedRandomSampler
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.cuda.amp import GradScaler, autocast #amp = automatic mixed precision
import lightning as L
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.model_selection import StratifiedKFold
import matplotlib.pyplot as plt

######################
#### my own files ####
######################
pass

# Config

In [2]:
class Config:
    data_dir = "/home/nikita/hdd/Data/BirdCLEF2024" # '/kaggle/input/birdclef-2024/'
    output_dir = "/home/nikita/Code/kaggle/BirdCLEF2024/output/"
    
    # Device and random seed
    device = 'cuda' # if torch.cuda.is_available() else 'cpu'
    seed = 42
    
    # Input image size and batch size
    batch_size = 32
    upsample_thr = 50 # min sample of each class (upsample)
    downsample_thr = 500 # max sample of each class (downsample)
    
    secondary_coef = 1.0 # If there are multiple bird labels, the target is also set to this coeff.

    # Audio duration, sample rate, and length
    duration = 5 # second
    sample_rate = 4000 #32000
    audio_len = duration*sample_rate
    
    # Number of epochs, model name, and number of folds
    epochs = 10
    n_folds = 5

    # Learning rate, optimizer, and cosine scheduler
    lr = 1e-3
    lr_min = 1e-7
    weight_decay = 1e-6
    gradient_clip_val = 1.0
    optimizer = torch.optim.AdamW # AdamW, Adam
    
    # # Loss function and label smoothing
    # loss = 'CCE' # BCE, CCE
    # label_smoothing = 0.05 # label smoothing
    
    # # Data augmentation parameters
    # augment=True

    # # Audio Augmentation Settings
    # audio_augment_prob = 0.5
    
    # mixup_prob = 0.65
    # mixup_alpha = 0.5
    
    # cutmix_prob = 0.65
    # cutmix_alpha = 2.5
    
    # timeshift_prob = 0.0
    
    # gn_prob = 0.35

    # # Class Labels for BirdCLEF 23
    # class_names = sorted(os.listdir('/kaggle/input/birdclef-2024/train_audio/'))
    # num_classes = len(class_names)
    # class_labels = list(range(num_classes))
    # label2name = dict(zip(class_labels, class_names))
    # name2label = {v:k for k,v in label2name.items()}
    
    # # Class Labels for BirdCLEF 21 & 22
    # class_names2 = sorted(set(os.listdir('/kaggle/input/birdclef-2021/train_short_audio/')
    #                    +os.listdir('/kaggle/input/birdclef-2022/train_audio/') 
    #                           +os.listdir('/kaggle/input/birdclef-2023/train_audio/')
    #                    +os.listdir('/kaggle/input/birdsong-recognition/train_audio/')))
    # num_classes2 = len(class_names2)
    # class_labels2 = list(range(num_classes2))
    # label2name2 = dict(zip(class_labels2, class_names2))
    # name2label2 = {v:k for k,v in label2name2.items()}

## 🌱 Seed Everything

In [3]:
def set_seed(seed=42):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
set_seed(Config.seed)

## Data

In [4]:
df = pd.read_csv(f"{Config.data_dir}/train_metadata.csv")
df["path"] = Config.data_dir + "/train_audio/" + df["filename"]
df["rating"] = np.clip(df["rating"] / df["rating"].max(), 0.1, 1.0)

skf = StratifiedKFold(n_splits=Config.n_folds, random_state=Config.seed, shuffle=True)
df['fold'] = -1
for ifold, (train_idx, val_idx) in enumerate(skf.split(X=df, y=df["primary_label"].values)):
    df.loc[val_idx, 'fold'] = ifold

sub = pd.read_csv(f"{Config.data_dir}/sample_submission.csv")
target_columns = sub.columns.tolist()[1:]
num_classes = len(target_columns)
bird2id = {b: i for i, b in enumerate(target_columns)}

In [5]:
df

Unnamed: 0,primary_label,secondary_labels,type,latitude,longitude,scientific_name,common_name,author,license,rating,url,filename,path,fold
0,asbfly,[],['call'],39.2297,118.1987,Muscicapa dauurica,Asian Brown Flycatcher,Matt Slaymaker,Creative Commons Attribution-NonCommercial-Sha...,1.0,https://www.xeno-canto.org/134896,asbfly/XC134896.ogg,/home/nikita/hdd/Data/BirdCLEF2024/train_audio...,1
1,asbfly,[],['song'],51.4030,104.6401,Muscicapa dauurica,Asian Brown Flycatcher,Magnus Hellström,Creative Commons Attribution-NonCommercial-Sha...,0.5,https://www.xeno-canto.org/164848,asbfly/XC164848.ogg,/home/nikita/hdd/Data/BirdCLEF2024/train_audio...,3
2,asbfly,[],['song'],36.3319,127.3555,Muscicapa dauurica,Asian Brown Flycatcher,Stuart Fisher,Creative Commons Attribution-NonCommercial-Sha...,0.5,https://www.xeno-canto.org/175797,asbfly/XC175797.ogg,/home/nikita/hdd/Data/BirdCLEF2024/train_audio...,3
3,asbfly,[],['call'],21.1697,70.6005,Muscicapa dauurica,Asian Brown Flycatcher,vir joshi,Creative Commons Attribution-NonCommercial-Sha...,0.8,https://www.xeno-canto.org/207738,asbfly/XC207738.ogg,/home/nikita/hdd/Data/BirdCLEF2024/train_audio...,2
4,asbfly,[],['call'],15.5442,73.7733,Muscicapa dauurica,Asian Brown Flycatcher,Albert Lastukhin & Sergei Karpeev,Creative Commons Attribution-NonCommercial-Sha...,0.8,https://www.xeno-canto.org/209218,asbfly/XC209218.ogg,/home/nikita/hdd/Data/BirdCLEF2024/train_audio...,2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
24454,zitcis1,[],[''],43.5925,4.5434,Cisticola juncidis,Zitting Cisticola,Chèvremont Fabian,Creative Commons Attribution-NonCommercial-Sha...,1.0,https://xeno-canto.org/845747,zitcis1/XC845747.ogg,/home/nikita/hdd/Data/BirdCLEF2024/train_audio...,3
24455,zitcis1,[],[''],43.5925,4.5434,Cisticola juncidis,Zitting Cisticola,Chèvremont Fabian,Creative Commons Attribution-NonCommercial-Sha...,0.8,https://xeno-canto.org/845817,zitcis1/XC845817.ogg,/home/nikita/hdd/Data/BirdCLEF2024/train_audio...,4
24456,zitcis1,[],[''],51.1207,4.5607,Cisticola juncidis,Zitting Cisticola,Wim Jacobs,Creative Commons Attribution-NonCommercial-Sha...,0.8,https://xeno-canto.org/856176,zitcis1/XC856176.ogg,/home/nikita/hdd/Data/BirdCLEF2024/train_audio...,2
24457,zitcis1,[],[''],41.5607,-8.4236,Cisticola juncidis,Zitting Cisticola,Jorge Leitão,Creative Commons Attribution-NonCommercial-Sha...,0.9,https://xeno-canto.org/856723,zitcis1/XC856723.ogg,/home/nikita/hdd/Data/BirdCLEF2024/train_audio...,1


In [6]:
sub

Unnamed: 0,row_id,asbfly,ashdro1,ashpri1,ashwoo2,asikoe2,asiope1,aspfly1,aspswi1,barfly1,...,whbwoo2,whcbar1,whiter2,whrmun,whtkin2,woosan,wynlau1,yebbab1,yebbul3,zitcis1
0,soundscape_1446779_5,0.005495,0.005495,0.005495,0.005495,0.005495,0.005495,0.005495,0.005495,0.005495,...,0.005495,0.005495,0.005495,0.005495,0.005495,0.005495,0.005495,0.005495,0.005495,0.005495
1,soundscape_1446779_10,0.005495,0.005495,0.005495,0.005495,0.005495,0.005495,0.005495,0.005495,0.005495,...,0.005495,0.005495,0.005495,0.005495,0.005495,0.005495,0.005495,0.005495,0.005495,0.005495
2,soundscape_1446779_15,0.005495,0.005495,0.005495,0.005495,0.005495,0.005495,0.005495,0.005495,0.005495,...,0.005495,0.005495,0.005495,0.005495,0.005495,0.005495,0.005495,0.005495,0.005495,0.005495


In [7]:
target_columns

['asbfly',
 'ashdro1',
 'ashpri1',
 'ashwoo2',
 'asikoe2',
 'asiope1',
 'aspfly1',
 'aspswi1',
 'barfly1',
 'barswa',
 'bcnher',
 'bkcbul1',
 'bkrfla1',
 'bkskit1',
 'bkwsti',
 'bladro1',
 'blaeag1',
 'blakit1',
 'blhori1',
 'blnmon1',
 'blrwar1',
 'bncwoo3',
 'brakit1',
 'brasta1',
 'brcful1',
 'brfowl1',
 'brnhao1',
 'brnshr',
 'brodro1',
 'brwjac1',
 'brwowl1',
 'btbeat1',
 'bwfshr1',
 'categr',
 'chbeat1',
 'cohcuc1',
 'comfla1',
 'comgre',
 'comior1',
 'comkin1',
 'commoo3',
 'commyn',
 'compea',
 'comros',
 'comsan',
 'comtai1',
 'copbar1',
 'crbsun2',
 'cregos1',
 'crfbar1',
 'crseag1',
 'dafbab1',
 'darter2',
 'eaywag1',
 'emedov2',
 'eucdov',
 'eurbla2',
 'eurcoo',
 'forwag1',
 'gargan',
 'gloibi',
 'goflea1',
 'graher1',
 'grbeat1',
 'grecou1',
 'greegr',
 'grefla1',
 'grehor1',
 'grejun2',
 'grenig1',
 'grewar3',
 'grnsan',
 'grnwar1',
 'grtdro1',
 'gryfra',
 'grynig2',
 'grywag',
 'gybpri1',
 'gyhcaf1',
 'heswoo1',
 'hoopoe',
 'houcro1',
 'houspa',
 'inbrob1',
 'indpit1',
 

In [8]:
num_classes

182

In [9]:
bird2id

{'asbfly': 0,
 'ashdro1': 1,
 'ashpri1': 2,
 'ashwoo2': 3,
 'asikoe2': 4,
 'asiope1': 5,
 'aspfly1': 6,
 'aspswi1': 7,
 'barfly1': 8,
 'barswa': 9,
 'bcnher': 10,
 'bkcbul1': 11,
 'bkrfla1': 12,
 'bkskit1': 13,
 'bkwsti': 14,
 'bladro1': 15,
 'blaeag1': 16,
 'blakit1': 17,
 'blhori1': 18,
 'blnmon1': 19,
 'blrwar1': 20,
 'bncwoo3': 21,
 'brakit1': 22,
 'brasta1': 23,
 'brcful1': 24,
 'brfowl1': 25,
 'brnhao1': 26,
 'brnshr': 27,
 'brodro1': 28,
 'brwjac1': 29,
 'brwowl1': 30,
 'btbeat1': 31,
 'bwfshr1': 32,
 'categr': 33,
 'chbeat1': 34,
 'cohcuc1': 35,
 'comfla1': 36,
 'comgre': 37,
 'comior1': 38,
 'comkin1': 39,
 'commoo3': 40,
 'commyn': 41,
 'compea': 42,
 'comros': 43,
 'comsan': 44,
 'comtai1': 45,
 'copbar1': 46,
 'crbsun2': 47,
 'cregos1': 48,
 'crfbar1': 49,
 'crseag1': 50,
 'dafbab1': 51,
 'darter2': 52,
 'eaywag1': 53,
 'emedov2': 54,
 'eucdov': 55,
 'eurbla2': 56,
 'eurcoo': 57,
 'forwag1': 58,
 'gargan': 59,
 'gloibi': 60,
 'goflea1': 61,
 'graher1': 62,
 'grbeat1': 63,
 

# metrics code

In [10]:
# from sklearn.metrics import average_precision_score

# def padded_cmap(solution, submission, padding_factor=5):
#     solution = solution.drop(["row_id"], axis=1, errors="ignore")
#     submission = submission.drop(["row_id"], axis=1, errors="ignore")
#     new_rows = []
#     for i in range(padding_factor):
#         new_rows.append([1 for j in range(len(solution.columns))])
#     new_rows = pd.DataFrame(new_rows)
#     new_rows.columns = solution.columns
#     padded_solution = (
#         pd.concat([solution, new_rows]).reset_index(drop=True).copy()
#     )
#     padded_submission = (
#         pd.concat([submission, new_rows]).reset_index(drop=True).copy()
#     )
#     score = average_precision_score(
#         padded_solution.values.astype(int),
#         padded_submission.values,
#         average="macro",
#     )
#     return score


# def padded_cmap_numpy(y_true, y_pred, padding_factor=5):
#     y_true = np.pad(y_true, ((0, padding_factor), (0, 0)), constant_values=1)
#     y_pred = np.pad(y_pred, ((0, padding_factor), (0, 0)), constant_values=1)
#     return average_precision_score(
#         y_true.astype(int),
#         y_pred,
#         average="macro",
#     )

# def calculate_competition_metrics(gt, preds, target_columns, one_hot=True):
#     if not one_hot:
#         ground_truth = np.argmax(gt, axis=1)
#         gt = np.zeros((ground_truth.size, len(target_columns)))
#         gt[np.arange(ground_truth.size), ground_truth] = 1
#     val_df = pd.DataFrame(gt, columns=target_columns)
#     pred_df = pd.DataFrame(preds, columns=target_columns)
#     cmAP_1 = padded_cmap(val_df, pred_df, padding_factor=1)
#     cmAP_5 = padded_cmap(val_df, pred_df, padding_factor=5)
#     mAP = map_score(val_df, pred_df)
#     val_df['id'] = [f'id_{i}' for i in range(len(val_df))]
#     pred_df['id'] = [f'id_{i}' for i in range(len(pred_df))]
#     train_score = score(val_df, pred_df, row_id_column_name='id')
#     return {
#         "cmAP_1": cmAP_1,
#         "cmAP_5": cmAP_5,
#         "mAP": mAP,
#         "ROC": train_score,
#     }

# def metrics_to_string(scores, key_word):
#   log_info = ""
#   for key in scores.keys():
#       log_info = log_info + f"{key_word} {key} : {scores[key]:.4f}, "
#   return log_info

# def calculate_competition_metrics_no_map(gt, preds, target_columns, one_hot=True):
#     if not one_hot:
#       ground_truth = np.argmax(gt, axis=1)
#       gt = np.zeros((ground_truth.size, len(target_columns)))
#       gt[np.arange(ground_truth.size), ground_truth] = 1
#     val_df = pd.DataFrame(gt, columns=target_columns)
#   pred_df = pd.DataFrame(preds, columns=target_columns)
#    cmAP_1 = padded_cmap(val_df, pred_df, padding_factor=1)
#   cmAP_5 = padded_cmap(val_df, pred_df, padding_factor=5)
#    val_df['id'] = [f'id_{i}' for i in range(len(val_df))]
#   pred_df['id'] = [f'id_{i}' for i in range(len(pred_df))]
#   train_score = score(val_df, pred_df, row_id_column_name='id')
#    return {
#       "cmAP_1": cmAP_1,
#       "cmAP_5": cmAP_5,
#       "ROC": train_score,
#   }

In [11]:

# exp_name = 'exp1'
# backbone = 'eca_nfnet_l0'
# seed = 42
# batch_size = 64
# num_workers = 0

# n_epochs = 100
# warmup_epo = 5
# cosine_epo = n_epochs - warmup_epo

# image_size = 256

# lr_max = 1e-5
# lr_min = 1e-7
# weight_decay = 1e-6

# mel_spec_params = {
#     "sample_rate": 32000,
#     "n_mels": 128,
#     "f_min": 20,
#     "f_max": 16000,
#     "n_fft": 2048,
#     "hop_length": 512,
#     "normalized": True,
#     "center" : True,
#     "pad_mode" : "constant",
#     "norm" : "slaney",
#     "onesided" : True,
#     "mel_scale" : "slaney"
# }

# top_db = 80
# train_period = 5
# val_period = 5

# secondary_coef = 1.0

# train_duration = train_period * mel_spec_params["sample_rate"]
# val_duration = val_period * mel_spec_params["sample_rate"]

# N_FOLD = 5
# fold = 2

# use_amp = True
# max_grad_norm = 10
# early_stopping = 7

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# output_folder = "outputs"
# os.makedirs(output_folder, exist_ok=True)
# os.makedirs(os.path.join(output_folder, exp_name), exist_ok=True)



In [12]:
# #testing
# path = Config.data_dir + "/train_audio/asbfly/XC134896.ogg"
# wav1 = read_wav(path)
# first_5_seconds = crop_start_wav(wav1, Config.audio_len)
# print("wav1.shape:", wav1.shape)
# print("first_5_seconds.shape:", first_5_seconds.shape)

## Dataset

In [13]:
def read_wav(path): # pip install sox       and      sudo apt-get install libsox-dev
    wav, org_sr = torchaudio.load(path, normalize=True)
    wav = torchaudio.functional.resample(wav, orig_freq=org_sr, new_freq=Config.sample_rate)
    return wav


def crop_start_wav(wav, duration_):
    while wav.size(-1) < duration_:
        wav = torch.cat([wav, wav], dim=1)
    wav = wav[:, :duration_]
    return wav


class BirdWavDataset(torch.utils.data.Dataset):
    def __init__(self, df, transform=None):
        self.df = df # i.e. df["fold"!=0]
        self.bird2id = bird2id
        self.num_classes = num_classes
        self.secondary_coef = Config.secondary_coef
        self.transform = transform


    def __len__(self):
        return len(self.df)


    def prepare_target(self, primary_label, secondary_labels):
        secondary_labels = eval(secondary_labels)
        target = np.zeros(self.num_classes, dtype=np.float32)
        if primary_label != 'nocall':
            primary_label = self.bird2id[primary_label]
            target[primary_label] = 1.0
            for s in secondary_labels:
                if s != "" and s in self.bird2id.keys():
                    target[self.bird2id[s]] = self.secondary_coef
        target = torch.from_numpy(target).float()
        return target


    def __getitem__(self, idx):
        path = self.df["path"].iloc[idx]
        primary_label = self.df["primary_label"].iloc[idx]
        secondary_labels = self.df["secondary_labels"].iloc[idx]
        rating = self.df["rating"].iloc[idx]

        wav = read_wav(path)
        wav = crop_start_wav(wav, Config.audio_len)
        target = self.prepare_target(primary_label, secondary_labels)

        if self.transform is not None:
            wav = self.transform(wav)

        return {"wav": wav, "target": target, 'rating': rating}

## Model

In [14]:
class HydraTransform(nn.Module):

    def __init__(self, input_length, k = 8, g = 64):
        super().__init__()

        self.k = k # num kernels per group
        self.g = g # num groups

        max_exponent = np.log2((input_length - 1) / (9 - 1)) # kernel length = 9

        self.dilations = 2 ** torch.arange(int(max_exponent) + 1)
        self.num_dilations = len(self.dilations)

        self.paddings = torch.div((9 - 1) * self.dilations, 2, rounding_mode = "floor").int()

        self.divisor = min(2, self.g)
        self.h = self.g // self.divisor

        self.W = torch.randn(self.num_dilations, self.divisor, self.k * self.h, 1, 9).to(Config.device)
        self.W = self.W - self.W.mean(-1, keepdims = True)
        self.W = self.W / self.W.abs().sum(-1, keepdims = True)
        

    # transform in batches of *batch_size*
    def batch(self, X, batch_size = 8):
        num_examples = X.shape[0]
        if num_examples <= batch_size:
            return self(X)
        else:
            Z = []
            batches = torch.split(X, batch_size)
            for batch in batches:
                Z.append(self(X[batch]))
            return torch.cat(Z)


    def forward(self, X):
        num_examples = X.shape[0]

        if self.divisor > 1:
            diff_X = torch.diff(X)

        Z = []

        for dilation_index in range(self.num_dilations):

            d = self.dilations[dilation_index].item()
            p = self.paddings[dilation_index].item()

            for diff_index in range(self.divisor):

                _Z = F.conv1d(X if diff_index == 0 else diff_X, self.W[dilation_index, diff_index], dilation = d, padding = p) \
                      .view(num_examples, self.h, self.k, -1)

                max_values, max_indices = _Z.max(2)
                count_max = torch.zeros_like(_Z[:,:,:, 0])

                min_values, min_indices = _Z.min(2)
                count_min = torch.zeros_like(_Z[:,:,:, 0])

                count_max.scatter_add_(-1, max_indices, max_values)
                count_min.scatter_add_(-1, min_indices, torch.ones_like(min_values))

                Z.append(count_max)
                Z.append(count_min)

        Z = torch.cat(Z, 1).view(num_examples, -1)

        return Z
    

class Hydra(nn.Module):
    def __init__(self, wav_input_length, num_classes, k = 8, g = 64):
        super().__init__()
        self.hydra_trans = HydraTransform(wav_input_length, k, g).to(Config.device)

        n_dilations, n_divisor, n_k_h = self.hydra_trans.W.shape[0:3]
        self.fc = nn.Linear(2 * n_dilations * n_divisor * n_k_h, num_classes).to(Config.device)


    def forward(self, X):
        X = self.hydra_trans(X)
        return self.fc(X)



# dataset = BirdWavDataset(df[df["fold"] != 0])
# wav = dataset[0]["wav"]
# print("wav", wav.shape)
# hydra = Hydra(Config.audio_len, num_classes, k=8, g=64)
# output = hydra(wav.unsqueeze(0))
# print("output", output.shape)

## Loss

In [15]:
class FocalLossBCE(torch.nn.Module):
    def __init__(
            self,
            alpha: float = 0.25,
            gamma: float = 2,
            reduction: str = "mean",
            bce_weight: float = 1.0,
            focal_weight: float = 1.0,
    ):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.bce = torch.nn.BCEWithLogitsLoss(reduction=reduction)
        self.bce_weight = bce_weight
        self.focal_weight = focal_weight

    def forward(self, logits, targets):
        focall_loss = torchvision.ops.focal_loss.sigmoid_focal_loss(
            inputs=logits,
            targets=targets,
            alpha=self.alpha,
            gamma=self.gamma,
            reduction=self.reduction,
        )
        bce_loss = self.bce(logits, targets)
        return self.bce_weight * bce_loss + self.focal_weight * focall_loss


criterion = FocalLossBCE()

## ⚡ PyTorch Lightning

In [16]:
class LightningHydra(L.LightningModule):
    def __init__(self, wav_input_length, num_classes, k = 8, g = 64):
        super().__init__()
        self.hydra = Hydra(wav_input_length, num_classes, k, g)
        self.focal_loss = FocalLossBCE()
    

    def training_step(self, batch, batch_idx):
        X = batch["wav"]
        targets = batch["target"]
        
        logits = self.hydra(X)
        loss = self.focal_loss(logits, targets)
        self.log("train_loss", loss)
        return loss


    def validation_step(self, batch, batch_idx):
        X = batch["wav"]
        targets = batch["target"]

        logits = self.hydra(X)
        val_loss = self.focal_loss(logits, targets)
        self.log("val_loss", val_loss)


    def configure_optimizers(self):
        optimizer = Config.optimizer(
                self.hydra.fc.parameters(), 
                lr=Config.lr,
                weight_decay=Config.weight_decay
            )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, 
                T_max=Config.epochs, 
                eta_min=Config.lr_min
            )
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

In [19]:
def train_fold(fold_idx:int = 0):
    #create DataLoaders for train and val
    train_df = df[df["fold"] != fold_idx]
    val_df = df[df["fold"] == fold_idx]
    train_dataset = BirdWavDataset(train_df)
    val_dataset = BirdWavDataset(val_df)
    train_loader = DataLoader(train_dataset, Config.batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, Config.batch_size, shuffle=False, num_workers=4)

    model = LightningHydra(Config.audio_len, num_classes, k=4, g=32).to(Config.device)
    trainer = L.Trainer(
            accelerator="gpu",
            max_epochs=Config.epochs,
            gradient_clip_val=Config.gradient_clip_val,
        )
    trainer.fit(model, train_loader, val_loader)

In [20]:
train_fold(0)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type         | Params
--------------------------------------------
0 | hydra      | Hydra        | 559 K 
1 | focal_loss | FocalLossBCE | 0     
--------------------------------------------
559 K     Trainable params
0         Non-trainable params
559 K     Total params
2.237     Total estimated model params size (MB)


Epoch 0:   6%|▌         | 34/612 [00:40<11:21,  0.85it/s, v_num=11]        