In [1]:
import wandb
try:
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    api_key = user_secrets.get_secret("WANDB")
    # Login to wandb with the API key
    wandb.login(key=api_key)
    # Set anonymous mode to None
    anonymous = None
except:
    # If Kaggle secrets are not available, set anonymous mode to 'must'
    anonymous = 'must'
    # Login to wandb anonymously and relogin if needed
    wandb.login(anonymous=anonymous, relogin=True)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/cedric/.netrc


In [1]:
import numpy as np
import torch
import torch.nn as nn
import torchaudio
import pandas as pd
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
import IPython.display as ipd
from datetime import datetime

from torch.utils.data import DataLoader
from torchmetrics.classification import MulticlassAUROC
import audiomentations
from torch.utils.data import default_collate
from torchvision.transforms import v2

from src.audio_utils import play_audio, plot_specgram, plot_waveform
from src.data import AudioDataset, FrequencyMaskingAug, TimeMaskingAug
from src.data_utils import get_metadata, get_fold
from src.train_utils import FocalLoss, get_cosine_schedule_with_warmup, wandb_init
from src.models import BasicClassifier

import ast
import wandb
import yaml

  from .autonotebook import tqdm as notebook_tqdm


### Config

In [2]:
class Config:
    duration = 10
    sample_rate = 32000
    target_length = 384
    n_mels = 128
    n_fft = 2028
    window = 2028
    audio_len = duration*sample_rate
    hop_length = audio_len // (target_length-1)
    fmin = 20
    fmax = 16000
    top_db = 80

    n_classes = 182
    batch_size = 24
    model_name = 'efficientnet_v2_s'
    n_folds = 5
    upsample_thr = 50
    use_class_weights = True

    standardize = False
    dataset_mean = [-16.8828]
    dataset_std = [12.4019]

    data_aug = True
    cutmix_mixup = True
    loss = 'bce'
    secondary_labels_weight = 0.3
    use_focal = True
    focal_gamma = 2
    focal_lambda = 1
    label_smoothing = 0.05

    num_epochs = 10
    warmup_epochs = 0.5
    lr = 1e-3
    start_lr = 0.01 # relative to lr
    final_lr = 0.01
    weight_decay = 0.0001

    wandb = False
    competition   = 'birdclef-2024' 
    _wandb_kernel = 'cvincent13'
    date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    run_name = f"{date}_fold-{0}_dim-{n_mels}x{target_length}_model-{model_name}"
    wandb_group = 'EfficientNetB0|FSR|t=10s|128x384|up_thr=50|cv_filter'

metadata = get_metadata(Config.n_folds)

### Dataset

In [3]:
fold = 0
train_df, valid_df, class_weights = get_fold(metadata, fold, up_thr=Config.upsample_thr)

Num Train: 22045, 182 classes |Num Valid: 4892, 182 classes


In [8]:
# Data transforms and augmentations
waveform_transforms = audiomentations.Compose([
    audiomentations.Shift(min_shift=-0.5, max_shift=0.5, p=0.5),
    audiomentations.SevenBandParametricEQ(min_gain_db=-12., max_gain_db=12., p=0.5),
    audiomentations.AirAbsorption(min_temperature=10, max_temperature=20, min_humidity=30, max_humidity=90,
                                  min_distance=10, max_distance=100, p=1.), 

    audiomentations.OneOf([
        audiomentations.Gain(min_gain_db=-6., max_gain_db=6., p=1),  # How to handle waveforms out of [-1, 1] ? dont see the issue
        audiomentations.GainTransition(min_gain_db=-12., max_gain_db=3., p=1)
    ], p=1.),

    audiomentations.OneOf([
        audiomentations.AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=1.),
        audiomentations.AddGaussianSNR(min_snr_db=5., max_snr_db=40., p=1.),
        audiomentations.AddColorNoise(min_snr_db=5., max_snr_db=40., min_f_decay=-3.01, max_f_decay=-3.01, p=1.)
    ], p=1.),

    #audiomentations.AddShortNoises(sounds_path=unlabeled_dir, min_snr_db=3., max_snr_db=30., 
    #                           noise_rms='relative_to_whole_input',
    #                           min_time_between_sounds=2., max_time_between_sounds=8., 
    #                           noise_transform=audiomentations.PolarityInversion(), p=0.5),
    #audiomentations.AddBackgroundNoise(sounds_path=unlabeled_dir, min_snr_db=3., max_snr_db=30., 
    #                               noise_transform=audiomentations.PolarityInversion(), p=0.5),
                                   
    audiomentations.LowPassFilter(min_cutoff_freq=750., max_cutoff_freq=7500., min_rolloff=12, max_rolloff=24, p=0.8),
    audiomentations.PitchShift(min_semitones=-2.5, max_semitones=2.5, p=0.3)
])

spec_transforms = nn.Sequential(
    FrequencyMaskingAug(0.3, 0.1, Config.n_mels, n_masks=3, mask_mode='mean'),
    TimeMaskingAug(0.3, 0.1, Config.target_length, n_masks=3, mask_mode='mean'),
)


waveform_transforms=None if not Config.data_aug else waveform_transforms
spec_transforms=None if not Config.data_aug else spec_transforms


train_dataset = AudioDataset(
    train_df, 
    n_classes=Config.n_classes,
    duration=Config.duration,
    sample_rate=Config.sample_rate,
    target_length=Config.target_length,
    n_mels=Config.n_mels,
    n_fft=Config.n_fft,
    window=Config.window,
    hop_length=Config.hop_length,
    fmin=Config.fmin,
    fmax=Config.fmax,
    top_db=Config.top_db,
    waveform_transforms=waveform_transforms,
    spec_transforms=spec_transforms,
    standardize=Config.standardize,
    mean=Config.dataset_mean,
    std=Config.dataset_std,
    loss=Config.loss,
    secondary_labels_weight=Config.secondary_labels_weight
    )
val_dataset = AudioDataset(
    valid_df, 
    n_classes=Config.n_classes,
    duration=Config.duration,
    sample_rate=Config.sample_rate,
    target_length=Config.target_length,
    n_mels=Config.n_mels,
    n_fft=Config.n_fft,
    window=Config.window,
    hop_length=Config.hop_length,
    fmin=Config.fmin,
    fmax=Config.fmax,
    top_db=Config.top_db,
    waveform_transforms=None,
    spec_transforms=None,
    standardize=Config.standardize,
    mean=Config.dataset_mean,
    std=Config.dataset_std,
    loss=Config.loss,
    secondary_labels_weight=Config.secondary_labels_weight
    )

### Training

In [12]:
cutmix_or_mixup = v2.RandomApply([
    v2.RandomChoice([
        v2.CutMix(num_classes=Config.n_classes, alpha=0.5, one_hot_labels=Config.loss=='bce'),
        v2.MixUp(num_classes=Config.n_classes, alpha=0.5, one_hot_labels=Config.loss=='bce')
    ], p=[0.65, 0.35])
], p=0.7)


def mix_collate_fn(batch):
    return cutmix_or_mixup(*default_collate(batch))

collate_fn = mix_collate_fn if Config.cutmix_mixup else None

train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True, num_workers=6, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=Config.batch_size, shuffle=False, num_workers=6)

In [13]:
device = torch.device('cuda')

model = BasicClassifier(Config.n_classes, Config.model_name).to(device)
optimizer = torch.optim.Adam(model.parameters(), weight_decay=Config.weight_decay, lr=Config.lr)
spe = len(train_loader)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=spe*Config.warmup_epochs, num_training_steps=spe*Config.num_epochs, 
                                            start_lr=Config.start_lr, final_lr=Config.final_lr)
                                                
pos_weight = torch.tensor(class_weights).to(device) if Config.use_class_weights else None
if Config.loss == 'crossentropy':
    criterion = nn.CrossEntropyLoss(label_smoothing=Config.label_smoothing, pos_weight=pos_weight)
elif Config.loss == 'bce':
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight, weight=None)
if Config.use_focal:
    focal_criterion = FocalLoss(gamma=Config.focal_gamma, pos_weight=pos_weight)
    
metric = MulticlassAUROC(num_classes=182, average='macro', thresholds=5)

### Training loop

In [14]:
if Config.wandb:
    run = wandb_init(fold, Config)

save_dir = f"checkpoints/{Config.run_name}"
train_losses = []
val_losses = []
val_metrics = {'AUC': [], 'Accuracy': []}

for epoch in range(Config.num_epochs):
    train_loss = 0
    model.train()
    train_iter = tqdm(train_loader)
    for (batch, labels) in train_iter:
        optimizer.zero_grad()

        batch = batch.to(device)
        labels = labels.to(device)

        out = model(batch)
        loss = criterion(out, labels) + Config.focal_lambda * focal_criterion(out, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()

        train_iter.set_description(desc=f'train loss: {loss.item():.3f}')
        train_loss += loss.item() / len(train_loader)

    train_losses.append(train_loss)

    val_loss = 0
    #val_auc = 0
    val_accuracy = 0
    model.eval()
    val_iter = tqdm(val_loader)
    for (batch, labels) in val_iter:
        batch = batch.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            out = model(batch)
            loss = criterion(out, labels) + Config.focal_lambda * focal_criterion(out, labels)

        val_loss += loss.item() / len(val_loader)
        #val_auc += metric(out, labels) / len(val_loader)
        pred = out.argmax(1)
        val_accuracy += ((pred == labels.argmax(1)).sum() / len(labels)) / len(val_loader)

        val_iter.set_description(desc=f'val loss: {loss.item():.3f}')

    val_losses.append(val_loss)
    #al_metrics['AUC'].append(val_auc)
    val_metrics['Accuracy'].append(val_accuracy)

    save_dict = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict(),
        "epoch": epoch+1,
        "train_losses": train_losses,
        "val_losses": val_losses,
        "val_metrics": val_metrics
    }

    torch.save(save_dict, save_dir + "/checkpoint.pth")
    with open(save_dir + "logs.txt", "w") as f:
        f.write(f"Epoch {epoch+1}: Train Loss = {train_loss:.3f} |\
          Val Loss = {val_loss:.3f}, Val Accuracy = {val_accuracy:.3f}")
        f.write("\n")
        f.write("CONFIG:")
        for k,v in dict(vars(Config)).items():
            if '__' not in k:
                f.write("\n")
                f.write(f"{k}: {v}")


    if Config.wandb:
        wandb.log({
            "train_loss": train_loss,
            "val_loss": val_loss,
            "val_accuracy": val_accuracy,
            #"val_auc": val_auc,
            "lr": scheduler.get_last_lr()
        })


    print(f'Epoch {epoch+1}: Train Loss = {train_loss:.3f} |\
          Val Loss = {val_loss:.3f}, Val Accuracy = {val_accuracy:.3f}')
    
    
if Config.wandb:
    #print('# WandB')
    #log_wandb(valid_df)
    wandb.run.finish()
    display(ipd.IFrame(run.url, width=1080, height=720))

  0%|          | 0/919 [00:00<?, ?it/s]

train loss: 0.740:   1%|          | 6/919 [00:09<11:47,  1.29it/s]  