In [None]:
# %%
# !pip install lightning timm librosa albumentations torchaudio==2.2.2 audiomentations


In [1]:
# %%
import gc
import logging
import os
import random
import warnings
from datetime import datetime
from pprint import pformat

import albumentations
import librosa
import lightning as L
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import timm
import torch
import torch_audiomentations as audiomentations
import torchaudio
import torchmetrics
import torchvision
from joblib import Parallel, delayed
from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from sklearn.model_selection import StratifiedGroupKFold, train_test_split
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torchvision import transforms
from tqdm.notebook import tqdm

sns.set(
    rc={
        "figure.figsize": (8, 4),
        "figure.dpi": 240,
    }
)
sns.set_style("darkgrid", {"axes.grid": False})
sns.set_context("paper", font_scale=0.6)

torch.set_float32_matmul_precision("high")
warnings.simplefilter("ignore")


In [2]:
# %%
class cfg:
    experiment_name = "foundation"
    data_path = "../data"

    debug_run = False  # run on a small sample
    experiment_run = False  # run on a stratified data sample
    production_run = True  # run on all data

    normalize_waveform = True
    sample_rate = 32000
    n_fft = 2048
    hop_length = 512
    window_length = None
    melspec_hres = 128
    melspec_wres = 312
    freq_min = 40
    freq_max = 15000
    log_scale_power = 2
    max_decibels = 100
    frame_duration = 5
    frame_rate = sample_rate / hop_length

    # vit_b0 = "efficientvit_b0.r224_in1k"
    vit_b1 = "efficientvit_b1.r224_in1k"
    # vit_b1 = "efficientvit_b1.r288_in1k"
    # vit_b2 = "efficientvit_b2.r224_in1k"
    # effnet_b0 = "tf_efficientnetv2_b0.in1k"
    # effnet_b1 = "tf_efficientnetv2_b1.in1k"
    # vit_m0 = "efficientvit_m0.r224_in1k"
    # vit_m1 = "efficientvit_m1.r224_in1k"
    backbone = vit_b1

    num_classes = 182
    add_secondary_labels = False
    label_smoothing = 0.05
    weighted_sampling = False
    sample_weight_factor = 0.25
    waveform_augmentation = False
    melspec_augmentation = True
    melspec_mixup_prob = 0.5
    melspec_mixup = False
    melspec_supermixup = False

    accelerator = "gpu"
    precision = "bf16-mixed"
    n_workers = os.cpu_count() - 2

    n_epochs = 50
    batch_size = 128
    val_ratio = 0.25
    patience = 10

    lr_min = 1e-6
    lr_max = 5e-4
    weight_decay = 1e-6
    # fused_adamw = True

    focal_alpha = 0.25
    focal_gamma = 2.0
    focal_weight = 1.0
    bce_weight = 1.0

    timestamp = datetime.now().replace(microsecond=0)
    run_tag = f"{timestamp}_{backbone}_{experiment_name}_val_{val_ratio}_lr_{lr_max}_decay_{weight_decay}"

    if debug_run:
        run_tag = f"{timestamp}_{backbone}_debug"
        # accelerator = "cpu"
        # n_epochs = 1
        # batch_size = 32
        # fused_adamw = False
        # num_classes = 10


In [3]:
# %%
def define_logger():
    handlers = [
        logging.StreamHandler(),
        logging.FileHandler(f"../logs/{cfg.run_tag}.log"),
    ]

    if cfg.debug_run:
        handlers = [logging.StreamHandler()]

    logger = logging.getLogger(__name__)
    logging.basicConfig(
        level=logging.INFO,
        format=" %(asctime)s [%(threadName)s] 🐦‍🔥 %(message)s",
        handlers=handlers,
        force=True,  # reconfigure root logger, in case of rerunning -> ensures new file
    )

    return logger


def get_config(cfg) -> None:
    cfg_dictionary = {
        key: value
        for key, value in cfg.__dict__.items()
        if not key.startswith("__") and not callable(key)
    }
    logger.info(f"{'—' * 80}")
    logger.info(f"Config: \n{pformat(cfg_dictionary, indent=1)}")
    return cfg_dictionary


def load_metadata(data_path: str) -> pd.DataFrame:
    logger.info(f"Loading prepared dataframes from {data_path}")
    model_input_df = pd.read_csv(f"{data_path}/model_input_df.csv")
    sample_submission = pd.read_csv(f"{data_path}/sample_submission.csv")

    if cfg.debug_run:
        logger.info("Running debug: sampling data to 10 species and 250 samples")
        top_10_labels = model_input_df["primary_label"].value_counts()[0:10].index
        model_input_df = model_input_df[
            model_input_df["primary_label"].isin(top_10_labels)
        ]
        model_input_df = model_input_df.sample(1000).reset_index(drop=True)

    elif cfg.experiment_run:
        logger.info("Running experiment: sampling data")
        model_input_df = model_input_df.sample(frac=0.1).reset_index(drop=True)

    elif cfg.production_run:
        logger.info("Running production: full data")
        model_input_df = model_input_df.sample(frac=1).reset_index(drop=True)

    logger.info(f"Dataframe shape: {model_input_df.shape}")

    return model_input_df, sample_submission


def read_waveform(filename: str) -> np.ndarray:
    filepath = f"{cfg.data_path}/train_windows_nv_c10/{filename}"
    waveform, _ = librosa.load(filepath, sr=cfg.sample_rate)
    return waveform


def read_waveforms_parallel(model_input_df: pd.DataFrame):
    logger.info("Parallel Loading waveforms")
    waveforms = Parallel(n_jobs=cfg.n_workers, prefer="threads")(
        delayed(read_waveform)(filename)
        for filename in tqdm(model_input_df["window_filename"], desc="Loading waves")
    )
    logger.info("Finished loadeding waveforms")
    return waveforms


def create_label_map(submission_df: pd.DataFrame) -> dict:
    logging.info("Creating label mappings")
    cfg.labels = submission_df.columns[1:]
    cfg.num_classes = len(cfg.labels)
    class_to_label_map = dict(zip(cfg.labels, np.arange(cfg.num_classes)))

    return class_to_label_map


def pad_or_crop_waveforms(waveforms: list, pad_method: str = "repeat") -> list:
    logging.info("Padding or cropping waveforms to desired duration")
    desired_length = cfg.sample_rate * cfg.frame_duration

    def _pad_or_crop(waveform: np.ndarray) -> np.ndarray:
        length = len(waveform)

        while length < desired_length:  # repeat if waveform too small
            repeat_length = desired_length - length
            padding_array = waveform[:repeat_length]
            if pad_method != "repeat":
                padding_array = np.zeros(shape=waveform[:repeat_length].shape)
            waveform = np.concatenate([waveform, padding_array])
            length = len(waveform)

        if length > desired_length:  # crop if waveform is too big
            offset = np.random.randint(0, length - desired_length)
            waveform = waveform[offset : offset + desired_length]

        return waveform

    waveforms = [_pad_or_crop(wave) for wave in tqdm(waveforms, desc="Padding waves")]

    return waveforms


def add_sample_weights(
    model_input_df: pd.DataFrame, weight_factor: float = cfg.sample_weight_factor
) -> pd.DataFrame:
    sample_weights = round(
        (
            model_input_df["primary_label"].value_counts()
            / model_input_df["primary_label"].value_counts().sum()
        )
        ** (-weight_factor)
    )
    sample_weights = pd.DataFrame(
        {
            "primary_label": sample_weights.index,
            "sample_weight": sample_weights.values.astype(int),
        }
    )
    model_input_df = model_input_df.merge(
        sample_weights, on="primary_label", how="left"
    )
    return model_input_df


In [4]:
# %%
class BirdDataset(Dataset):
    def __init__(
        self,
        df: pd.DataFrame,
        waveforms: list,
        add_secondary_labels: bool = cfg.add_secondary_labels,
        augmentation: list = None,
    ):
        self.df = df
        self.waveforms = waveforms
        self.num_classes = cfg.num_classes
        self.class_to_label_map = class_to_label_map
        self.add_secondary_labels = add_secondary_labels
        self.waveform_augmentation = cfg.waveform_augmentation
        self.melspec_augmentation = cfg.melspec_augmentation
        self.augmentation = augmentation

        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=cfg.sample_rate,
            n_mels=cfg.melspec_hres,
            f_min=cfg.freq_min,
            f_max=cfg.freq_max,
            n_fft=cfg.n_fft,
            hop_length=cfg.hop_length,
            normalized=cfg.normalize_waveform,
            center=True,
            pad_mode="reflect",
            norm="slaney",
            mel_scale="slaney",
        )
        self.db_transform = torchaudio.transforms.AmplitudeToDB(
            stype="power", top_db=cfg.max_decibels
        )

    def create_target(
        self,
        primary_label: str,
        secondary_labels: list,
        secondary_label_weight: float = 1,
    ) -> torch.tensor:
        target = torch.zeros(self.num_classes, dtype=torch.float32)
        # primary_target = torch.tensor(0, dtype=torch.int64)

        if primary_label != "nocall":
            primary_label = self.class_to_label_map[primary_label]
            target[primary_label] = 1
            primary_target = torch.tensor(primary_label, dtype=torch.int64)

            if self.add_secondary_labels:
                secondary_labels = eval(secondary_labels)
                for label in secondary_labels:
                    if label != "" and label in self.class_to_label_map.keys():
                        target[self.class_to_label_map[label]] = secondary_label_weight

        binary_target = target.to(torch.int64)

        return target, binary_target, primary_target

    def __len__(self):
        return len(self.waveforms)

    def __getitem__(self, idx):
        waveform = self.waveforms[idx]
        primary_label = self.df.iloc[idx]["primary_label"]
        secondary_labels = self.df.iloc[idx]["secondary_labels"]

        waveform = torch.tensor(waveform, dtype=torch.float32)

        target, binary_target, primary_target = self.create_target(
            primary_label=primary_label, secondary_labels=secondary_labels
        )

        melspec = self.db_transform(self.mel_transform(waveform)).to(torch.uint8)
        melspec = melspec.expand(3, -1, -1).permute(1, 2, 0).numpy()

        if self.melspec_augmentation is not None:
            melspec = self.augmentation(image=melspec)["image"]

        return melspec, target, binary_target, primary_target


In [5]:
# %%
class GeM(torch.nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = torch.nn.Parameter(torch.ones(1) * p)
        self.eps = eps

    def forward(self, x):
        out = torch.nn.functional.avg_pool2d(
            x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))
        ).pow(1.0 / self.p)
        return out


class FocalLoss(torch.nn.Module):
    def __init__(
        self,
        alpha: float = cfg.focal_alpha,
        gamma: float = cfg.focal_gamma,
        reduction: str = "mean",
    ):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, x, y):
        loss = torchvision.ops.focal_loss.sigmoid_focal_loss(
            inputs=x,
            targets=y,
            alpha=self.alpha,
            gamma=self.gamma,
            reduction=self.reduction,
        )
        return loss


class FocalLossBCE(torch.nn.Module):
    def __init__(
        self,
        alpha: float = cfg.focal_alpha,
        gamma: float = cfg.focal_gamma,
        reduction: str = "mean",
        bce_weight: float = cfg.bce_weight,
        focal_weight: float = cfg.focal_weight,
    ):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.bce = torch.nn.BCEWithLogitsLoss(reduction=reduction)
        self.bce_weight = bce_weight
        self.focal_weight = focal_weight

    def forward(self, inputs, targets):
        focal_loss = torchvision.ops.focal_loss.sigmoid_focal_loss(
            inputs=inputs,
            targets=targets,
            alpha=self.alpha,
            gamma=self.gamma,
            reduction=self.reduction,
        )
        bce_loss = self.bce(inputs, targets)
        combined_loss = self.bce_weight * bce_loss + self.focal_weight * focal_loss
        return combined_loss


class EfficientViT(L.LightningModule):
    def __init__(self):
        super().__init__()

        self.vit = timm.create_model(
            cfg.backbone,
            pretrained=True,
            num_classes=cfg.num_classes,
        )

        self.loss_function = FocalLossBCE()
        # self.loss_function = FocalLoss()
        # self.loss_function = torch.nn.BCEWithLogitsLoss(reduction="mean")
        # self.loss_function = torch.nn.CrossEntropyLoss(
        #     label_smoothing=cfg.label_smoothing
        # )

        self.accuracy = torchmetrics.Accuracy(
            task="multiclass", num_classes=cfg.num_classes, top_k=1
        )
        self.auroc = torchmetrics.AUROC(
            task="multilabel",
            num_labels=cfg.num_classes,
            average="macro",
        )
        self.f1_macro = torchmetrics.F1Score(
            task="multilabel",
            num_labels=cfg.num_classes,
            average="macro",
            threshold=0.5,
        )
        self.f1_weighted = torchmetrics.F1Score(
            task="multilabel",
            num_labels=cfg.num_classes,
            average="weighted",
            threshold=0.5,
        )
        self.lrap = torchmetrics.classification.MultilabelRankingAveragePrecision(
            num_labels=cfg.num_classes,
        )


    def mixup_data(self, x, y):
        """
        Returns mixed inputs, pairs of targets, and lambda
        reference: mixup: Beyond Empirical Risk Minimization
        """
        lam = np.random.choice([0.2, 0.3, 0.4, 0.5])

        batch_size = x.size()[0]
        index = torch.randperm(batch_size)
        mixed_x = lam * x + (1 - lam) * x[index, :]
        y_a, y_b = y, y[index]
        return mixed_x, y_a, y_b, lam

    def mixup_loss_function(self, pred, y_a, y_b, lam):
        loss = lam * self.loss_function(pred, y_a) + (
            1 - lam
        ) * self.loss_function(pred, y_b)
        return loss


    def forward(self, x):
        # x = x.float() / 255
        # x = x.expand(-1, 3, 128, -1)  # go from HxW → 3xHxW
        x = x.permute(0, 3, 1, 2)
        out = self.vit(x)
        return out

    def training_step(self, batch, batch_idx):
        x, y, y_binary, y_primary = batch

        y_pred = self(x)
        loss = self.loss_function(y_pred, y)
        y_pred = y_pred.sigmoid()

        # with prob 0.5 do mel spectogram mixup
        if np.random.choice([False, True]):
            y_pred = self(x)
            loss = self.loss_function(y_pred, y)
            y_pred = y_pred.sigmoid()
        else:
            mixed_x, y_a, y_b, lam = self.mixup_data(x, y)
            y_pred = self(mixed_x)
            loss = self.mixup_loss_function(y_pred, y_a, y_b, lam)
            y_pred = y_pred.sigmoid()

        train_accuracy = self.accuracy(y_pred, y_primary)
        train_f1_macro = self.f1_macro(y_pred, y_binary)
        train_lrap = self.lrap(y_pred, y_binary)

        self.log(
            "train_loss",
            loss,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        self.log(
            "train_acc",
            train_accuracy_top1,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        self.log(
            "train_f1_macro",
            train_f1_macro,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        self.log(
            "train_lrap",
            train_lrap,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )

        return loss

    def validation_step(self, batch, batch_idx):
        x_val, y_val, y_binary_val, y_primary_val = batch
        y_pred = self(x_val)
        val_loss = self.loss_function(y_pred, y_val)

        y_pred = y_pred.sigmoid()
        val_accuracy_top1 = self.accuracy(y_pred, y_primary_val)
        val_auroc = self.auroc(y_pred, y_binary_val)
        # val_f1_weighted = self.f1_weighted(y_pred, y_binary_val)
        val_f1_macro = self.f1_macro(y_pred, y_binary_val)
        val_lrap = self.lrap(y_pred, y_binary_val)

        self.log(
            "val_loss",
            val_loss,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        self.log(
            "val_acc",
            val_accuracy,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        self.log(
            "val_auroc",
            val_auroc,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        # self.log(
        #     "val_f1_weighted",
        #     val_f1_weighted,
        #     on_step=False,
        #     on_epoch=True,
        #     prog_bar=True,
        #     logger=True,
        # )
        self.log(
            "val_f1_macro",
            val_f1_macro,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        self.log(
            "val_lrap",
            val_lrap,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )

        return val_loss

    def on_train_epoch_end(self):
        metrics = self.trainer.progress_bar_callback.get_metrics(trainer, model)
        metrics.pop("v_num", None)
        metrics.pop("train_loss_step", None)
        for key, value in metrics.items():
            metrics[key] = round(value, 5)
        logger.info(f"Epoch {self.trainer.current_epoch}: {metrics}")

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            params=self.parameters(),
            lr=cfg.lr_max,
            weight_decay=cfg.weight_decay,
            # fused=cfg.fused_adamw,
        )
        lr_scheduler = CosineAnnealingWarmRestarts(
            optimizer, T_0=cfg.n_epochs, T_mult=1, eta_min=cfg.lr_min, last_epoch=-1
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": lr_scheduler,
                "interval": "epoch",
                "monitor": "val_loss",
                "frequency": 1,
            },
        }


In [6]:
# %%
if __name__ == "__main__":
    logger = define_logger()
    config_dictionary = get_config(cfg)

    csv_logger = None
    if not cfg.debug_run:
        csv_logger = L.pytorch.loggers.CSVLogger(save_dir="../logs/")
        csv_logger.log_hyperparams(config_dictionary)

    model_input_df, sample_submission = load_metadata(data_path=cfg.data_path)
    model_input_df = add_sample_weights(model_input_df)
    class_to_label_map = create_label_map(submission_df=sample_submission)


 2024-05-27 12:09:02,566 [MainThread] 🐦‍🔥 ————————————————————————————————————————————————————————————————————————————————
 2024-05-27 12:09:02,568 [MainThread] 🐦‍🔥 Config: 
{'accelerator': 'gpu',
 'add_secondary_labels': False,
 'backbone': 'efficientvit_b1.r224_in1k',
 'batch_size': 128,
 'bce_weight': 1.0,
 'data_path': '../data',
 'debug_run': False,
 'experiment_name': 'foundation',
 'experiment_run': False,
 'focal_alpha': 0.25,
 'focal_gamma': 2.0,
 'focal_weight': 1.0,
 'frame_duration': 5,
 'frame_rate': 62.5,
 'freq_max': 15000,
 'freq_min': 40,
 'hop_length': 512,
 'label_smoothing': 0.05,
 'log_scale_power': 2,
 'lr_max': 0.0005,
 'lr_min': 1e-06,
 'max_decibels': 100,
 'melspec_augmentation': True,
 'melspec_hres': 128,
 'melspec_mixup': False,
 'melspec_mixup_prob': 0.5,
 'melspec_supermixup': False,
 'melspec_wres': 312,
 'n_epochs': 50,
 'n_fft': 2048,
 'n_workers': 30,
 'normalize_waveform': True,
 'num_classes': 182,
 'patience': 10,
 'precision': 'bf16-mixed',
 'prod

In [None]:
    # %%
    waveforms = read_waveforms_parallel(model_input_df=model_input_df)

 2024-05-27 12:09:06,270 [MainThread] 🐦‍🔥 Parallel Loading waveforms


Loading waves:   0%|          | 0/88247 [00:00<?, ?it/s]

 2024-05-27 12:23:58,439 [MainThread] 🐦‍🔥 Finished loadeding waveforms


In [8]:
    # %%
    waveforms = pad_or_crop_waveforms(waveforms=waveforms)

 2024-05-27 12:24:05,957 [MainThread] 🐦‍🔥 Padding or cropping waveforms to desired duration


Padding waves:   0%|          | 0/88247 [00:00<?, ?it/s]

In [None]:
    # %%
    train_augmentation = albumentations.Compose(
        [
            albumentations.AdvancedBlur(p=0.25),
            albumentations.GaussNoise(p=0.25),
            albumentations.ImageCompression(
                quality_lower=75, quality_upper=100, p=0.25
            ),
            albumentations.CoarseDropout(
                max_holes=1, max_height=64, max_width=64, p=0.25
            ),
            albumentations.XYMasking(
                p=0.25,
                num_masks_x=(1, 2),
                num_masks_y=(1, 2),
                mask_x_length=(5, 25),
                mask_y_length=(5, 25),
            ),
            albumentations.CLAHE(p=0.15),
            albumentations.Downscale(
                scale_min=0.5, scale_max=0.9, interpolation=4, p=0.15
            ),
            albumentations.Morphological(p=0.15, scale=(1, 3), operation="erosion"),
            albumentations.Normalize(p=1),
        ]
    )
    val_augmentation = albumentations.Compose([albumentations.Normalize(p=1)])

    # grouped split on sample index to keep different windows from the same sample
    # together if splitting randomly this can be considered as a form of leakage
    # validating on a windowed waveform while windows of the same waveform were used for
    # training is easier than classifying a waveform from a different sample, which is
    # the case in practice
    logger.info(f"Splitting {len(waveforms)} waveforms into train/val: {cfg.val_ratio}")
    n_splits = int(round(1 / cfg.val_ratio))
    kfold = StratifiedGroupKFold(n_splits=n_splits, shuffle=True)
    for fold_index, (train_index, val_index) in enumerate(
        kfold.split(
            X=model_input_df,
            y=model_input_df["primary_label"],
            groups=model_input_df["sample_index"],
        )
    ):

        train_df = model_input_df.iloc[train_index]
        val_df = model_input_df.iloc[val_index]

        train_waveforms = [waveforms[i] for i in train_index]
        val_waveforms = [waveforms[i] for i in val_index]

        train_dataset = BirdDataset(
            df=train_df, waveforms=train_waveforms, augmentation=train_augmentation
        )
        val_dataset = BirdDataset(
            df=val_df, waveforms=val_waveforms, augmentation=val_augmentation
        )

        train_dataloader = DataLoader(
            train_dataset,
            batch_size=cfg.batch_size,
            drop_last=True,
            num_workers=cfg.n_workers,
            persistent_workers=True,
            pin_memory=True,
        )

        if cfg.weighted_sampling:
            logger.info(
                f"Defining weighted sampling with  factor: {cfg.sample_weight_factor}"
            )
            sample_weight = train_df["sample_weight"].values
            sample_weight = torch.from_numpy(sample_weight)

            weighted_sampler = WeightedRandomSampler(
                sample_weight.type("torch.DoubleTensor"),
                len(sample_weight),
                replacement=True,
            )

            train_dataloader = DataLoader(
                train_dataset,
                batch_size=cfg.batch_size,
                sampler=weighted_sampler,
                drop_last=True,
                num_workers=cfg.n_workers,
                persistent_workers=True,
                pin_memory=True,
            )

        val_dataloader = DataLoader(
            val_dataset,
            batch_size=cfg.batch_size,
            shuffle=False,
            drop_last=True,
            num_workers=cfg.n_workers,
            persistent_workers=True,
            pin_memory=True,
        )

        logger.info("Dataloaders ready to go brrr")

        progress_bar = TQDMProgressBar(process_position=0)
        early_stopping = EarlyStopping(
            monitor="val_f1_macro",
            min_delta=0.005,
            patience=cfg.patience,
            verbose=True,
            mode="max",
        )
        model_checkpoint = ModelCheckpoint(
            monitor="val_f1_macro",
            every_n_epochs=1,
            mode="max",
            save_on_train_epoch_end=True,
            auto_insert_metric_name=True,
            filename=f"{cfg.run_tag}"
            + f"_fold_{fold_index}_"
            + "{epoch}-{val_lrap:.3f}-{val_acc2:.3f}-{val_acc2:.3f}-{val_f1_macro:.3}",
        )

        os.environ["PJRT_DEVICE"] = "GPU"  # fix for G Cloud to avoid XLA/autocast clash
        model = EfficientViT()
        trainer = L.Trainer(
            fast_dev_run=False,
            enable_model_summary=False,
            max_epochs=cfg.n_epochs,
            accelerator=cfg.accelerator,
            precision=cfg.precision,
            callbacks=[progress_bar, early_stopping, model_checkpoint],
            logger=csv_logger,3
            log_every_n_steps=10,
        )

        logger.info(f"\nStart training fold {fold_index}")
        trainer.fit(
            model,
            train_dataloaders=train_dataloader,
            val_dataloaders=val_dataloader,
            ckpt_path=None,
        )

        logger.info(f"Finished training fold {fold_index}")
        if not cfg.debug_run and trainer.current_epoch > 10:
            logger.info("Saving model")
            filename = (
                f"{cfg.run_tag}_fold_{fold_index}_epochs_{trainer.current_epoch}.ckpt"
            )
            trainer.save_checkpoint(f"../model_objects/{filename}")
            logger.info(f"Saved model to filename: {filename}")


 2024-05-27 12:24:46,902 [MainThread] 🐦‍🔥 Splitting 88247 waveforms into train/val: 0.25
 2024-05-27 12:24:55,167 [MainThread] 🐦‍🔥 Dataloaders ready to go brrr
 2024-05-27 12:24:55,251 [MainThread] 🐦‍🔥 Loading pretrained weights from Hugging Face hub (timm/efficientvit_b1.r224_in1k)
 2024-05-27 12:24:55,395 [MainThread] 🐦‍🔥 [timm/efficientvit_b1.r224_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
 2024-05-27 12:24:55,416 [MainThread] 🐦‍🔥 Missing keys (head.classifier.4.weight, head.classifier.4.bias) discovered while loading pretrained weights. This is expected if model is being adapted.
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
 2024-05-27 12:25:09,588 [MainThread] 🐦‍🔥 
Start training fold 0
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 12:28:26,010 [MainThread] 🐦‍🔥 Epoch 0: {'val_loss': 0.0275, 'val_acc_1': 0.34047, 'val_auroc': 0.34372, 'val_f1_macro': 0.01601, 'train_loss_epoch': 0.03728, 'train_acc_1': 0.15271, 'train_f1_macro': 0.01517}
Metric val_f1_macro improved. New best score: 0.016


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 12:31:06,279 [MainThread] 🐦‍🔥 Epoch 1: {'val_loss': 0.01837, 'val_acc_1': 0.56892, 'val_auroc': 0.37135, 'val_f1_macro': 0.09809, 'train_loss_epoch': 0.02177, 'train_acc_1': 0.37558, 'train_f1_macro': 0.07749}
Metric val_f1_macro improved by 0.082 >= min_delta = 0.005. New best score: 0.098


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 12:33:48,214 [MainThread] 🐦‍🔥 Epoch 2: {'val_loss': 0.01687, 'val_acc_1': 0.60277, 'val_auroc': 0.37412, 'val_f1_macro': 0.1288, 'train_loss_epoch': 0.01895, 'train_acc_1': 0.41383, 'train_f1_macro': 0.10176}
Metric val_f1_macro improved by 0.031 >= min_delta = 0.005. New best score: 0.129


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 12:36:29,918 [MainThread] 🐦‍🔥 Epoch 3: {'val_loss': 0.0151, 'val_acc_1': 0.65252, 'val_auroc': 0.37531, 'val_f1_macro': 0.15602, 'train_loss_epoch': 0.01686, 'train_acc_1': 0.4746, 'train_f1_macro': 0.13057}
Metric val_f1_macro improved by 0.027 >= min_delta = 0.005. New best score: 0.156


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 12:39:13,824 [MainThread] 🐦‍🔥 Epoch 4: {'val_loss': 0.01768, 'val_acc_1': 0.59748, 'val_auroc': 0.37458, 'val_f1_macro': 0.13197, 'train_loss_epoch': 0.01522, 'train_acc_1': 0.50758, 'train_f1_macro': 0.14801}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 12:41:58,758 [MainThread] 🐦‍🔥 Epoch 5: {'val_loss': 0.01579, 'val_acc_1': 0.64413, 'val_auroc': 0.37518, 'val_f1_macro': 0.15519, 'train_loss_epoch': 0.01485, 'train_acc_1': 0.51222, 'train_f1_macro': 0.1498}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 12:44:44,075 [MainThread] 🐦‍🔥 Epoch 6: {'val_loss': 0.01463, 'val_acc_1': 0.67888, 'val_auroc': 0.37622, 'val_f1_macro': 0.19087, 'train_loss_epoch': 0.01356, 'train_acc_1': 0.53386, 'train_f1_macro': 0.16436}
Metric val_f1_macro improved by 0.035 >= min_delta = 0.005. New best score: 0.191


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 12:47:32,025 [MainThread] 🐦‍🔥 Epoch 7: {'val_loss': 0.01369, 'val_acc_1': 0.68934, 'val_auroc': 0.37629, 'val_f1_macro': 0.20475, 'train_loss_epoch': 0.01352, 'train_acc_1': 0.51822, 'train_f1_macro': 0.15946}
Metric val_f1_macro improved by 0.014 >= min_delta = 0.005. New best score: 0.205


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 12:50:21,218 [MainThread] 🐦‍🔥 Epoch 8: {'val_loss': 0.01388, 'val_acc_1': 0.69913, 'val_auroc': 0.37684, 'val_f1_macro': 0.21165, 'train_loss_epoch': 0.01337, 'train_acc_1': 0.53102, 'train_f1_macro': 0.16263}
Metric val_f1_macro improved by 0.007 >= min_delta = 0.005. New best score: 0.212


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 12:53:10,740 [MainThread] 🐦‍🔥 Epoch 9: {'val_loss': 0.01296, 'val_acc_1': 0.70438, 'val_auroc': 0.37469, 'val_f1_macro': 0.21976, 'train_loss_epoch': 0.01246, 'train_acc_1': 0.55303, 'train_f1_macro': 0.17741}
Metric val_f1_macro improved by 0.008 >= min_delta = 0.005. New best score: 0.220


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 12:56:02,569 [MainThread] 🐦‍🔥 Epoch 10: {'val_loss': 0.01374, 'val_acc_1': 0.70151, 'val_auroc': 0.37649, 'val_f1_macro': 0.21805, 'train_loss_epoch': 0.01167, 'train_acc_1': 0.57589, 'train_f1_macro': 0.18553}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 12:58:53,768 [MainThread] 🐦‍🔥 Epoch 11: {'val_loss': 0.01325, 'val_acc_1': 0.71731, 'val_auroc': 0.37677, 'val_f1_macro': 0.22912, 'train_loss_epoch': 0.01288, 'train_acc_1': 0.54341, 'train_f1_macro': 0.17}
Metric val_f1_macro improved by 0.009 >= min_delta = 0.005. New best score: 0.229


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 13:01:46,996 [MainThread] 🐦‍🔥 Epoch 12: {'val_loss': 0.01345, 'val_acc_1': 0.72661, 'val_auroc': 0.37727, 'val_f1_macro': 0.23624, 'train_loss_epoch': 0.01114, 'train_acc_1': 0.59453, 'train_f1_macro': 0.19342}
Metric val_f1_macro improved by 0.007 >= min_delta = 0.005. New best score: 0.236


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 13:04:40,281 [MainThread] 🐦‍🔥 Epoch 13: {'val_loss': 0.01367, 'val_acc_1': 0.72082, 'val_auroc': 0.37708, 'val_f1_macro': 0.23593, 'train_loss_epoch': 0.01111, 'train_acc_1': 0.57683, 'train_f1_macro': 0.18737}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 13:07:35,469 [MainThread] 🐦‍🔥 Epoch 14: {'val_loss': 0.01366, 'val_acc_1': 0.73141, 'val_auroc': 0.37701, 'val_f1_macro': 0.24208, 'train_loss_epoch': 0.01096, 'train_acc_1': 0.57048, 'train_f1_macro': 0.18606}
Metric val_f1_macro improved by 0.006 >= min_delta = 0.005. New best score: 0.242


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 13:10:31,795 [MainThread] 🐦‍🔥 Epoch 15: {'val_loss': 0.01492, 'val_acc_1': 0.71026, 'val_auroc': 0.37667, 'val_f1_macro': 0.23088, 'train_loss_epoch': 0.01059, 'train_acc_1': 0.59036, 'train_f1_macro': 0.19242}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 13:13:29,869 [MainThread] 🐦‍🔥 Epoch 16: {'val_loss': 0.01363, 'val_acc_1': 0.73379, 'val_auroc': 0.37755, 'val_f1_macro': 0.24124, 'train_loss_epoch': 0.01016, 'train_acc_1': 0.61029, 'train_f1_macro': 0.20156}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 13:16:27,054 [MainThread] 🐦‍🔥 Epoch 17: {'val_loss': 0.01327, 'val_acc_1': 0.7095, 'val_auroc': 0.37428, 'val_f1_macro': 0.21001, 'train_loss_epoch': 0.01028, 'train_acc_1': 0.58089, 'train_f1_macro': 0.19284}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 13:19:27,544 [MainThread] 🐦‍🔥 Epoch 18: {'val_loss': 0.01295, 'val_acc_1': 0.73087, 'val_auroc': 0.37552, 'val_f1_macro': 0.23539, 'train_loss_epoch': 0.01004, 'train_acc_1': 0.60366, 'train_f1_macro': 0.20031}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 13:22:28,690 [MainThread] 🐦‍🔥 Epoch 19: {'val_loss': 0.01232, 'val_acc_1': 0.73464, 'val_auroc': 0.37555, 'val_f1_macro': 0.23421, 'train_loss_epoch': 0.00956, 'train_acc_1': 0.6115, 'train_f1_macro': 0.20438}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 13:25:30,177 [MainThread] 🐦‍🔥 Epoch 20: {'val_loss': 0.01276, 'val_acc_1': 0.73864, 'val_auroc': 0.37623, 'val_f1_macro': 0.23762, 'train_loss_epoch': 0.00955, 'train_acc_1': 0.60456, 'train_f1_macro': 0.20164}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 13:28:33,014 [MainThread] 🐦‍🔥 Epoch 21: {'val_loss': 0.0135, 'val_acc_1': 0.74107, 'val_auroc': 0.37634, 'val_f1_macro': 0.24235, 'train_loss_epoch': 0.01012, 'train_acc_1': 0.57775, 'train_f1_macro': 0.19101}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 13:31:37,581 [MainThread] 🐦‍🔥 Epoch 22: {'val_loss': 0.01253, 'val_acc_1': 0.73262, 'val_auroc': 0.37551, 'val_f1_macro': 0.23004, 'train_loss_epoch': 0.00935, 'train_acc_1': 0.61661, 'train_f1_macro': 0.20621}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 13:34:42,316 [MainThread] 🐦‍🔥 Epoch 23: {'val_loss': 0.01216, 'val_acc_1': 0.74915, 'val_auroc': 0.37527, 'val_f1_macro': 0.2415, 'train_loss_epoch': 0.00932, 'train_acc_1': 0.61241, 'train_f1_macro': 0.20418}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 13:37:48,157 [MainThread] 🐦‍🔥 Epoch 24: {'val_loss': 0.01275, 'val_acc_1': 0.73235, 'val_auroc': 0.37557, 'val_f1_macro': 0.23499, 'train_loss_epoch': 0.00982, 'train_acc_1': 0.58151, 'train_f1_macro': 0.19157}
Monitored metric val_f1_macro did not improve in the last 10 records. Best score: 0.242. Signaling Trainer to stop.
 2024-05-27 13:37:52,219 [MainThread] 🐦‍🔥 Finished training fold 0
 2024-05-27 13:37:52,221 [MainThread] 🐦‍🔥 Saving model
 2024-05-27 13:37:52,489 [MainThread] 🐦‍🔥 Saved model to filename: 2024-05-27 12:08:58_efficientvit_b1.r224_in1k_foundation_val_0.25_lr_0.0005_decay_1e-06_fold_0_epochs_25.ckpt
 2024-05-27 13:37:52,599 [MainThread] 🐦‍🔥 Dataloaders ready to go brrr
 2024-05-27 13:37:52,699 [MainThread] 🐦‍🔥 Loading pretrained weights from Hugging Face hub (timm/efficientvit_b1.r224_in1k)
 2024-05-27 13:37:52,879 [MainThread] 🐦‍🔥 [timm/efficientvit_b1.r224_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 13:41:05,397 [MainThread] 🐦‍🔥 Epoch 0: {'val_loss': 0.02345, 'val_acc_1': 0.43586, 'val_auroc': 0.35653, 'val_f1_macro': 0.06743, 'train_loss_epoch': 0.0372, 'train_acc_1': 0.14378, 'train_f1_macro': 0.01382}
Metric val_f1_macro improved. New best score: 0.067


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 13:43:45,524 [MainThread] 🐦‍🔥 Epoch 1: {'val_loss': 0.01673, 'val_acc_1': 0.59667, 'val_auroc': 0.36971, 'val_f1_macro': 0.13428, 'train_loss_epoch': 0.02196, 'train_acc_1': 0.35075, 'train_f1_macro': 0.07325}
Metric val_f1_macro improved by 0.067 >= min_delta = 0.005. New best score: 0.134


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 13:46:27,180 [MainThread] 🐦‍🔥 Epoch 2: {'val_loss': 0.01419, 'val_acc_1': 0.65735, 'val_auroc': 0.3719, 'val_f1_macro': 0.17154, 'train_loss_epoch': 0.01898, 'train_acc_1': 0.4203, 'train_f1_macro': 0.10486}
Metric val_f1_macro improved by 0.037 >= min_delta = 0.005. New best score: 0.172


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 13:49:09,446 [MainThread] 🐦‍🔥 Epoch 3: {'val_loss': 0.01641, 'val_acc_1': 0.6223, 'val_auroc': 0.37022, 'val_f1_macro': 0.1346, 'train_loss_epoch': 0.01664, 'train_acc_1': 0.47495, 'train_f1_macro': 0.13123}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 13:51:52,165 [MainThread] 🐦‍🔥 Epoch 4: {'val_loss': 0.01509, 'val_acc_1': 0.64629, 'val_auroc': 0.37139, 'val_f1_macro': 0.1673, 'train_loss_epoch': 0.01597, 'train_acc_1': 0.47934, 'train_f1_macro': 0.13564}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 13:54:36,961 [MainThread] 🐦‍🔥 Epoch 5: {'val_loss': 0.01334, 'val_acc_1': 0.69042, 'val_auroc': 0.37231, 'val_f1_macro': 0.20158, 'train_loss_epoch': 0.01501, 'train_acc_1': 0.48776, 'train_f1_macro': 0.14142}
Metric val_f1_macro improved by 0.030 >= min_delta = 0.005. New best score: 0.202


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 13:57:22,717 [MainThread] 🐦‍🔥 Epoch 6: {'val_loss': 0.01339, 'val_acc_1': 0.68663, 'val_auroc': 0.37254, 'val_f1_macro': 0.20403, 'train_loss_epoch': 0.01363, 'train_acc_1': 0.53596, 'train_f1_macro': 0.16485}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 14:00:09,051 [MainThread] 🐦‍🔥 Epoch 7: {'val_loss': 0.01351, 'val_acc_1': 0.69467, 'val_auroc': 0.37225, 'val_f1_macro': 0.2004, 'train_loss_epoch': 0.01363, 'train_acc_1': 0.52865, 'train_f1_macro': 0.16104}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 14:02:56,669 [MainThread] 🐦‍🔥 Epoch 8: {'val_loss': 0.01358, 'val_acc_1': 0.69668, 'val_auroc': 0.3724, 'val_f1_macro': 0.20932, 'train_loss_epoch': 0.01201, 'train_acc_1': 0.57648, 'train_f1_macro': 0.18432}
Metric val_f1_macro improved by 0.008 >= min_delta = 0.005. New best score: 0.209


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 14:05:46,951 [MainThread] 🐦‍🔥 Epoch 9: {'val_loss': 0.01308, 'val_acc_1': 0.70948, 'val_auroc': 0.37242, 'val_f1_macro': 0.2215, 'train_loss_epoch': 0.01279, 'train_acc_1': 0.56542, 'train_f1_macro': 0.17721}
Metric val_f1_macro improved by 0.012 >= min_delta = 0.005. New best score: 0.222


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 14:08:38,419 [MainThread] 🐦‍🔥 Epoch 10: {'val_loss': 0.01265, 'val_acc_1': 0.73118, 'val_auroc': 0.37366, 'val_f1_macro': 0.23282, 'train_loss_epoch': 0.0118, 'train_acc_1': 0.5486, 'train_f1_macro': 0.17766}
Metric val_f1_macro improved by 0.011 >= min_delta = 0.005. New best score: 0.233


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 14:11:30,562 [MainThread] 🐦‍🔥 Epoch 11: {'val_loss': 0.013, 'val_acc_1': 0.73017, 'val_auroc': 0.37297, 'val_f1_macro': 0.23103, 'train_loss_epoch': 0.0112, 'train_acc_1': 0.58186, 'train_f1_macro': 0.18906}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 14:14:23,141 [MainThread] 🐦‍🔥 Epoch 12: {'val_loss': 0.01303, 'val_acc_1': 0.7352, 'val_auroc': 0.37384, 'val_f1_macro': 0.23454, 'train_loss_epoch': 0.0112, 'train_acc_1': 0.58128, 'train_f1_macro': 0.1878}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 14:17:17,216 [MainThread] 🐦‍🔥 Epoch 13: {'val_loss': 0.01345, 'val_acc_1': 0.7246, 'val_auroc': 0.37307, 'val_f1_macro': 0.22983, 'train_loss_epoch': 0.01147, 'train_acc_1': 0.5554, 'train_f1_macro': 0.17887}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 14:20:12,620 [MainThread] 🐦‍🔥 Epoch 14: {'val_loss': 0.01296, 'val_acc_1': 0.73634, 'val_auroc': 0.37326, 'val_f1_macro': 0.23262, 'train_loss_epoch': 0.01083, 'train_acc_1': 0.59263, 'train_f1_macro': 0.19263}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 14:23:09,461 [MainThread] 🐦‍🔥 Epoch 15: {'val_loss': 0.01272, 'val_acc_1': 0.74365, 'val_auroc': 0.37376, 'val_f1_macro': 0.24106, 'train_loss_epoch': 0.01013, 'train_acc_1': 0.60397, 'train_f1_macro': 0.19985}
Metric val_f1_macro improved by 0.008 >= min_delta = 0.005. New best score: 0.241


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 14:26:08,748 [MainThread] 🐦‍🔥 Epoch 16: {'val_loss': 0.01372, 'val_acc_1': 0.74328, 'val_auroc': 0.37431, 'val_f1_macro': 0.23941, 'train_loss_epoch': 0.01007, 'train_acc_1': 0.59881, 'train_f1_macro': 0.1985}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 14:29:07,934 [MainThread] 🐦‍🔥 Epoch 17: {'val_loss': 0.01349, 'val_acc_1': 0.74114, 'val_auroc': 0.37367, 'val_f1_macro': 0.23866, 'train_loss_epoch': 0.00938, 'train_acc_1': 0.6184, 'train_f1_macro': 0.21047}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 14:32:09,290 [MainThread] 🐦‍🔥 Epoch 18: {'val_loss': 0.01292, 'val_acc_1': 0.7447, 'val_auroc': 0.37298, 'val_f1_macro': 0.24212, 'train_loss_epoch': 0.0097, 'train_acc_1': 0.60174, 'train_f1_macro': 0.20174}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 14:35:10,508 [MainThread] 🐦‍🔥 Epoch 19: {'val_loss': 0.01405, 'val_acc_1': 0.7436, 'val_auroc': 0.3733, 'val_f1_macro': 0.2425, 'train_loss_epoch': 0.00971, 'train_acc_1': 0.61332, 'train_f1_macro': 0.20346}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 14:38:12,252 [MainThread] 🐦‍🔥 Epoch 20: {'val_loss': 0.01225, 'val_acc_1': 0.74721, 'val_auroc': 0.37325, 'val_f1_macro': 0.24263, 'train_loss_epoch': 0.00925, 'train_acc_1': 0.61326, 'train_f1_macro': 0.20748}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 14:41:16,628 [MainThread] 🐦‍🔥 Epoch 21: {'val_loss': 0.01338, 'val_acc_1': 0.75228, 'val_auroc': 0.37367, 'val_f1_macro': 0.24708, 'train_loss_epoch': 0.01026, 'train_acc_1': 0.57089, 'train_f1_macro': 0.18863}
Metric val_f1_macro improved by 0.006 >= min_delta = 0.005. New best score: 0.247


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 14:44:21,763 [MainThread] 🐦‍🔥 Epoch 22: {'val_loss': 0.01305, 'val_acc_1': 0.75676, 'val_auroc': 0.37424, 'val_f1_macro': 0.24893, 'train_loss_epoch': 0.00965, 'train_acc_1': 0.59416, 'train_f1_macro': 0.19797}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 14:47:27,984 [MainThread] 🐦‍🔥 Epoch 23: {'val_loss': 0.01326, 'val_acc_1': 0.76128, 'val_auroc': 0.37453, 'val_f1_macro': 0.25032, 'train_loss_epoch': 0.00919, 'train_acc_1': 0.61531, 'train_f1_macro': 0.20676}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 14:50:35,274 [MainThread] 🐦‍🔥 Epoch 24: {'val_loss': 0.01307, 'val_acc_1': 0.76339, 'val_auroc': 0.37432, 'val_f1_macro': 0.25097, 'train_loss_epoch': 0.0093, 'train_acc_1': 0.60821, 'train_f1_macro': 0.20358}


Validation: |          | 0/? [00:00<?, ?it/s]

 2024-05-27 14:53:43,534 [MainThread] 🐦‍🔥 Epoch 25: {'val_loss': 0.01331, 'val_acc_1': 0.75736, 'val_auroc': 0.37366, 'val_f1_macro': 0.24755, 'train_loss_epoch': 0.00905, 'train_acc_1': 0.59976, 'train_f1_macro': 0.20245}
