In [1]:
import pandas as pd

In [2]:
import torch
import numpy as np
import pandas as pd
import librosa
from audiomentations import Compose
from typing import List, Dict, Optional


class AudioDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        input_df: pd.DataFrame,
        filepath_col: str,
        target_col: str,
        n_classes: int,
        sample_rate: int,
        target_duration: float,
        normalize_audio: bool = True,
        mixup_params: Optional[Dict] = None,
        is_train: bool = True,
        wave_piece: str = "center",
        audio_transforms: Optional[Compose] = None,
    ) -> None:

        self.df = input_df.reset_index(drop=True)

        self.filepath_col = filepath_col
        self.target_col = target_col

        self.sample_rate = sample_rate
        self.target_duration = target_duration
        self.target_sample_count = int(sample_rate * target_duration)

        self.normalize_audio = normalize_audio
        self.is_train = is_train
        self.wave_piece = wave_piece
        assert wave_piece in ("center", "random")

        self.n_classes = n_classes
        self.audio_transforms = audio_transforms

        # Mixup
        self.mixup_audio = mixup_params and is_train
        self.mixup_params = mixup_params or {
            "prob": 0.0,
            "alpha": 0.5,
            "hard_target": False
        }

    def __len__(self):
        return len(self.df)

    def _get_wave(self, idx: int) -> np.ndarray:
        filepath = self.df[self.filepath_col].iloc[idx]
        wave, _ = librosa.load(filepath, sr=self.sample_rate)
        return wave

    def _process_wave(self, wave: np.ndarray) -> np.ndarray:
        length = len(wave)

        if length < self.target_sample_count:
            # pad
            wave = np.pad(wave, (0, self.target_sample_count - length), mode="constant")
        else:
            # crop
            if self.wave_piece == "center":
                start = max(0, (length - self.target_sample_count) // 2)
            else:
                start = np.random.randint(0, length - self.target_sample_count + 1)

            wave = wave[start:start + self.target_sample_count]

        return wave

    def _get_mixup_idx(self):
        return np.random.randint(0, len(self.df))

    def _prepare_target(self, idx: int, sec_idx: Optional[int] = None):
        # MAIN TARGET
        cls1 = int(self.df[self.target_col].iloc[idx])  # NUMBER LABEL
        y1 = np.zeros(self.n_classes, dtype=np.float32)
        y1[cls1] = 1.0

        # WITHOUT MIXUP
        if sec_idx is None:
            return y1

        # MIXUP TARGET
        cls2 = int(self.df[self.target_col].iloc[sec_idx])
        y2 = np.zeros(self.n_classes, dtype=np.float32)
        y2[cls2] = 1.0

        alpha = self.mixup_params["alpha"]
        y_mix = alpha * y1 + (1 - alpha) * y2

        if self.mixup_params["hard_target"]:
            y_mix = (y_mix > 0).astype(np.float32)

        return y_mix

    def __getitem__(self, idx: int):
        # MAIN WAVE
        wave = self._get_wave(idx)
        wave = self._process_wave(wave)

        # MIXUP WAVE
        if self.mixup_audio and np.random.rand() < self.mixup_params["prob"]:
            sec_idx = self._get_mixup_idx()
            sec_wave = self._get_wave(sec_idx)
            sec_wave = self._process_wave(sec_wave)

            alpha = self.mixup_params["alpha"]
            wave = alpha * wave + (1 - alpha) * sec_wave
            target = self._prepare_target(idx, sec_idx)
        else:
            target = self._prepare_target(idx)

        # AUGMENTATIONS
        if self.audio_transforms and self.is_train:
            wave = self.audio_transforms(samples=wave, sample_rate=self.sample_rate)

        # NORMALIZE
        if self.normalize_audio:
            wave = librosa.util.normalize(wave)

        return torch.from_numpy(wave).float(), torch.from_numpy(target).float()


In [3]:
from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift, Shift, Gain

def get_augmentations():
    audio_transforms = Compose([
        AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5),
        TimeStretch(min_rate=0.9, max_rate=1.1, p=0.5),
        PitchShift(min_semitones=-2, max_semitones=2, p=0.5),
        Shift(min_shift=-0.1, max_shift=0.1, p=0.5),
        Gain(min_gain_db=-6, max_gain_db=6, p=0.5),
    ])
    return audio_transforms

In [4]:
import torch


class AudioForward(torch.nn.Module):
    def __init__(
        self,
        loss_function,
        output_key,
        input_key,
    ):
        super().__init__()
        self.loss_function = loss_function
        self.output_key = output_key
        self.input_key = input_key

    def forward(self, runner, batch, epoch=None):
        specs, targets = batch
        output = runner.model(specs)
        output["sigmoid_predictions"] = torch.sigmoid(output["logits"])
        output["softmax_predictions"] = torch.softmax(output["logits"], dim=-1)
        inputs = {
            "specs": specs,
            "targets": targets,
            "targets_1d": targets.argmax(dim=-1),
        }
        losses = {
            "loss": self.loss_function(
                output[self.output_key],
                inputs[self.input_key],
            )
        }
        return losses, inputs, output

In [5]:
import lightning


class LitTrainer(lightning.LightningModule):
    def __init__(
        self,
        model,
        forward,
        optimizer,
        scheduler,
        scheduler_params,
        batch_key,
        metric_input_key,
        metric_output_key,
        val_metrics,
        train_metrics,
    ) -> None:
        super().__init__()

        self.model = model
        self._forward = forward
        self._optimizer = optimizer
        self._scheduler = scheduler
        self._scheduler_params = scheduler_params
        self._batch_key = batch_key

        self._metric_input_key = metric_input_key
        self._metric_output_key = metric_output_key
        self._val_metrics = val_metrics
        self._train_metrics = train_metrics

    def _aggregate_outputs(self, losses, inputs, outputs):
        united = losses
        united.update({"input_" + k: v for k, v in inputs.items()})
        united.update({"output_" + k: v for k, v in outputs.items()})
        return united

    def training_step(self, batch):
        losses, inputs, outputs = self._forward(self, batch, epoch=self.current_epoch)

        for k, v in losses.items():
            self.log(
                "train/" + k,
                v,
                on_step=True,
                on_epoch=False,
                prog_bar=True,
                logger=True,
                batch_size=inputs[self._batch_key].shape[0],
                sync_dist=True,
            )
            self.log(
                "train/avg_" + k,
                v,
                on_step=False,
                on_epoch=True,
                prog_bar=True,
                logger=True,
                batch_size=inputs[self._batch_key].shape[0],
                sync_dist=True,
            )
        self.log(
            "train/model_time",
            on_step=True,
            on_epoch=False,
            prog_bar=True,
            logger=True,
            batch_size=1,
            sync_dist=True,
        )
        self.log(
            "train/avg_model_time",
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            logger=True,
            batch_size=1,
            sync_dist=True,
        )
        return self._aggregate_outputs(losses, inputs, outputs)

    def validation_step(self, batch, batch_idx):
        losses, inputs, outputs = self._forward(self, batch, epoch=self.current_epoch)

        if self._val_metrics is not None:
            self._val_metrics.update(
                outputs[self._metric_output_key],
                inputs[self._metric_input_key]
            )

        for k, v in losses.items():
            self.log(
                "valid/" + k,
                v,
                on_step=True,
                on_epoch=False,
                prog_bar=True,
                logger=True,
                batch_size=inputs[self._batch_key].shape[0],
                sync_dist=True,
            )
            self.log(
                "valid/avg_" + k,
                v,
                on_step=False,
                on_epoch=True,
                prog_bar=True,
                logger=True,
                batch_size=inputs[self._batch_key].shape[0],
                sync_dist=True,
            )
        self.log(
            "valid/model_time",
            on_step=True,
            on_epoch=False,
            prog_bar=True,
            logger=True,
            batch_size=1,
            sync_dist=True,
        )
        self.log(
            "valid/avg_model_time",
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            logger=True,
            batch_size=1,
            sync_dist=True,
        )

        return self._aggregate_outputs(losses, inputs, outputs)

    def on_train_epoch_end(self):
        pass

    def on_validation_epoch_end(self):
        metric_values = self._val_metrics.compute()
        self.log_dict(
            {"valid/"+k:v for k,v in metric_values.items()},
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            sync_dist=True,
        )
        self._val_metrics.reset()

    def configure_optimizers(self):
        scheduler = {"scheduler": self._scheduler}
        scheduler.update(self._scheduler_params)
        return (
            [self._optimizer], [scheduler],
        )

In [6]:
import os

TRAIN_PATH = '/data/train.csv'
VALID_PATH = '/data/val.csv'

In [7]:
train_df = pd.read_csv(TRAIN_PATH)
val_df = pd.read_csv(VALID_PATH)

In [8]:
classes = ['siren', 'gunshot', 'explosion', 'casual']
class_to_idx = {c: i for i, c in enumerate(classes)}

train_dataset = AudioDataset(
    input_df=train_df,
    filepath_col='path',
    target_col='target',
    n_classes=4,               # <--- ось тут тільки число!
    sample_rate=16000,
    target_duration=10,
    audio_transforms=get_augmentations()
)
val_dataset = AudioDataset(
    input_df=val_df,
    filepath_col='path',
    target_col='target',
    n_classes=4,
    sample_rate=16000,
    target_duration=10,
)

In [9]:

from torch.utils.data import DataLoader

BATCH_SIZE = 16

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [10]:
import torch


class GeMGlobalBlock(torch.nn.Module):

    def __init__(self, p: float = 3., eps: float = 1e-6):
        super().__init__()
        self.p = torch.nn.Parameter(torch.ones(1) * p)
        self.eps = eps
        self.pool = torch.nn.AdaptiveAvgPool2d((1,1))

    def forward(self, x):
        x = x.clamp(min=self.eps).pow(self.p)
        x = self.pool(x)
        x = x.pow(1.0 / self.p)
        return x.view(x.size(0), x.size(1))

In [11]:
import torch
import torch.nn as nn
import timm
from typing import Dict, Any, Optional, List
from torchaudio.transforms import MelSpectrogram
import math
import numpy as np
import torch
import torch.nn as nn
from torchaudio.functional import amplitude_to_DB
from torchaudio.transforms import FrequencyMasking, TimeMasking
from typing import Optional


class NormalizeMelSpec(nn.Module):
    def __init__(self, eps=1e-6, normalize_standart=True, normalize_minmax=True):
        super().__init__()
        self.eps = eps
        self.normalize_standart = normalize_standart
        self.normalize_minmax = normalize_minmax

    def forward(self, X):
        if self.normalize_standart:
            mean = X.mean((-2, -1), keepdim=True)
            std = X.std((-2, -1), keepdim=True)
            X = (X - mean) / (std + self.eps)
        if self.normalize_minmax:
            norm_max = torch.amax(X, dim=(-2, -1), keepdim=True)
            norm_min = torch.amin(X, dim=(-2, -1), keepdim=True)
            X = (X - norm_min) / (norm_max - norm_min + self.eps)
        return X


class CustomMasking(nn.Module):
    def __init__(self, mask_max_length: int, mask_max_masks: int, p=1.0, inplace=True):
        super().__init__()
        assert isinstance(mask_max_masks, int) and mask_max_masks > 0
        self.mask_max_masks = mask_max_masks
        self.mask_max_length = mask_max_length
        self.mask_module = None
        self.p = p
        self.inplace = inplace

    def forward(self, x):
        if not self.inplace:
            output = x.clone()
        for i in range(x.shape[0]):
            if np.random.binomial(n=1, p=self.p):
                n_applies = np.random.randint(low=1, high=self.mask_max_masks + 1)
                for _ in range(n_applies):
                    if self.inplace:
                        x[i : i + 1] = self.mask_module(x[i : i + 1])
                    else:
                        output[i : i + 1] = self.mask_module(output[i : i + 1])
        if self.inplace:
            return x
        else:
            return output


class CustomTimeMasking(CustomMasking):
    def __init__(self, mask_max_length: int, mask_max_masks: int, p=1.0, inplace=True):
        super().__init__(mask_max_length=mask_max_length, mask_max_masks=mask_max_masks, p=p, inplace=inplace)
        self.mask_module = TimeMasking(time_mask_param=mask_max_length)


class CustomFreqMasking(CustomMasking):
    def __init__(self, mask_max_length: int, mask_max_masks: int, p=1.0, inplace=True):
        super().__init__(mask_max_length=mask_max_length, mask_max_masks=mask_max_masks, p=p, inplace=inplace)
        self.mask_module = FrequencyMasking(freq_mask_param=mask_max_length)


class ChannelAgnosticAmplitudeToDB(nn.Module):
    def __init__(self, stype: str = "power", top_db: Optional[float] = None):
        super().__init__()
        self.stype = stype
        if top_db is not None and top_db < 0:
            raise ValueError("top_db must be positive value")
        self.top_db = top_db
        self.multiplier = 10.0 if stype == "power" else 20.0
        self.amin = 1e-10
        self.ref_value = 1.0
        self.db_multiplier = math.log10(max(self.amin, self.ref_value))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert x.dim() in [3, 4], f"Expected 3D or 4D tensor, but got {x.dim()}D tensor"

        add_fake_channel = False
        if x.dim() == 3:
            x = x.unsqueeze(1)
            add_fake_channel = True

        x_db = amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db)

        if add_fake_channel:
            x_db = x_db.squeeze(1)
        return x_db

class SpecCNNClassifier(nn.Module):
    def __init__(
            self,
            backbone: str,
            device: str,
            n_classes: int,
            classifier_dropout: float,
            spec_params: Dict[str, Any],
            top_db: float,
            normalize_config: Dict[str, bool],
            pretrained: bool,
            pool_type: str,
            out_indices: List[int],
            in_chans: int,
            timm_kwargs: Optional[Dict],
            spec_augment_config: Optional[Dict[str, Any]]

    ):
        super().__init__()
        timm_kwargs = {} if timm_kwargs is None else timm_kwargs
        self.out_indices = None if out_indices == "None" else tuple(out_indices)
        self.n_specs = in_chans

        self.device = device

        self.spectogram_extractor = nn.Sequential(
            MelSpectrogram(**spec_params),
            ChannelAgnosticAmplitudeToDB(top_db=top_db),
            NormalizeMelSpec(**normalize_config),
        )

        if spec_augment_config is not None:
            self.spec_augment = []
            if "freq_mask" in spec_augment_config:
                self.spec_augment.append(CustomFreqMasking(**spec_augment_config["freq_mask"]))
            if "time_mask" in spec_augment_config:
                self.spec_augment.append(CustomTimeMasking(**spec_augment_config["time_mask"]))
            self.spec_augment = nn.Sequential(*self.spec_augment)
        else:
            self.spec_augment = None

        # model
        self.backbone = timm.create_model(
            backbone,
            features_only=True,
            pretrained=pretrained,
            in_chans=self.n_specs,
            exportable=True,
            out_indices=self.out_indices,
            **timm_kwargs,
        )

        print(self.backbone.feature_info.channels())

        feature_dims = self.backbone.feature_info.channels() if self.out_indices is not None else [
            self.backbone.feature_info.channels()[-1]]
        print(f"feature dims: {feature_dims}")

        # pooling
        pools: List[nn.Module] = []
        if pool_type.lower() == "gem":
            pools = [GeMGlobalBlock() for _ in feature_dims]
        elif pool_type.lower() == "adavg":
            pools = [
                nn.Sequential(
                    nn.AdaptiveAvgPool2d((1, 1)),
                    nn.Flatten(start_dim=1)
                )
                for _ in feature_dims
            ]
        else:
            raise ValueError(f"Unsupported pool_type={pool_type!r}; choose 'gem' or 'avg'")

        self.pool = nn.ModuleList(pools)

        self.emb_dim = sum(feature_dims)

        # head
        self.classifier = nn.Sequential(
            nn.Dropout(p=classifier_dropout),
            nn.Linear(self.emb_dim, n_classes),
        )

        self.to(self.device)

    def forward(self, input, return_spec_feature=False, return_cnn_emb=False):

        # specs
        specs = self.spectogram_extractor(input)

        # multi channel mode support
        specs = specs.unsqueeze(1).expand(-1, self.n_specs, -1, -1).contiguous()

        if self.spec_augment is not None and self.training:
            specs = self.spec_augment(specs)
        if return_spec_feature:
            return specs

        # features - list of stages
        features = self.backbone(specs)

        if self.out_indices is None:
            features = [features[-1]]

        pooled = [p(fmap) for fmap, p in zip(features, self.pool)]

        emb = torch.cat(pooled, dim=1)

        if return_cnn_emb:
            return emb

        logits = self.classifier(emb)

        return {"logits": logits}

In [12]:
import yaml
PATH2CONFIG = 'D:\\audio_cls_coursework\\src\\model\\model_config.yml'
with open(PATH2CONFIG, "r") as f:
    cfg = yaml.safe_load(f)

device = cfg["device"]
n_classes = cfg["n_classes"]
spec_params = cfg["spec_params"]
normalize_config = cfg["normalize_config"]
backbone = cfg["backbone"]
pretrained = cfg["pretrained"]
pool_type = cfg["pool_type"]
out_indices = cfg["out_indices"]
in_chans = cfg["in_chans"]
timm_kwargs = cfg["timm_kwargs"]


In [None]:
model = SpecCNNClassifier(
    backbone=cfg["backbone"],
    device=device,
    n_classes=cfg["n_classes"],
    classifier_dropout=cfg["classifier_dropout"],
    spec_params=cfg["spec_params"],
    top_db=cfg["top_db"],
    normalize_config=cfg["normalize_config"],
    pretrained=cfg["pretrained"],
    pool_type=cfg["pool_type"],
    out_indices=cfg["out_indices"],
    in_chans=cfg["in_chans"],
    timm_kwargs=cfg["timm_kwargs"],
    spec_augment_config=None
)


In [19]:
from torchmetrics import MetricCollection
import torch
import torchvision


class FocalLoss(torch.nn.Module):
    def __init__(
        self,
        alpha: float,
        gamma: float,
        reduction: str,
    ):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        return torchvision.ops.focal_loss.sigmoid_focal_loss(
            inputs=inputs,
            targets=targets,
            alpha=self.alpha,
            gamma=self.gamma,
            reduction=self.reduction,
        )
# --- Параметри ---
device = "cuda" if torch.cuda.is_available() else "cpu"
n_classes = 206
precision_mode = "16-mixed"
train_strategy = "auto"
n_epochs = 10
log_every_n_steps = 4



metric_names = ["rocauc"]
metric_params = {"average": "macro", "task": "multiclass", "num_classes": n_classes}
KEY2LOSSES = {
    "bce" : torch.nn.BCEWithLogitsLoss,
    'ce': torch.nn.CrossEntropyLoss,
    'focal': FocalLoss,
}
loss_fn = KEY2LOSSES['focal']
forward = AudioForward(
    loss_function=KEY2LOSSES['focal'],
    input_key="targets",
    output_key="logits",
)
import torchmetrics

KEY2METRICS = {
    "f1" : torchmetrics.F1Score,
    'recall':torchmetrics.Recall,
    'precision':torchmetrics.Precision,
    'accuracy':torchmetrics.Accuracy,
    'rocauc': torchmetrics.AUROC
}
# --- Метрики ---
metrics = MetricCollection([KEY2METRICS[name](**metric_params) for name in metric_names])

In [None]:
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from torch import optim
from lightning.pytorch import loggers as pl_loggers

optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader)*n_epochs, eta_min=1e-6)

lightning_model = LitTrainer(
    model=model,
    forward=forward,
    optimizer=optimizer,
    scheduler=scheduler,
    scheduler_params={"interval": "step"},
    batch_key="specs",
    metric_input_key="targets_1d",
    metric_output_key="sigmoid_predictions",
    val_metrics=metrics,
    train_metrics=metrics,
)

wandb_logger = pl_loggers.WandbLogger(project="audio_project", log_model=True)

checkpoint_cb = ModelCheckpoint(
    dirpath="checkpoints",
    monitor="valid/rocauc",
    mode="max",
    save_top_k=3,
)
lr_monitor = LearningRateMonitor(logging_interval="step")

trainer = lightning.Trainer(
    accelerator="auto",
    devices="auto",
    max_epochs=n_epochs,
    precision=precision_mode,
    strategy=train_strategy,
    logger=wandb_logger,
    callbacks=[checkpoint_cb, lr_monitor],
    log_every_n_steps=log_every_n_steps,
)

# --- Навчання ---
trainer.fit(
    model=lightning_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)