In [None]:
import os
import torchaudio
import torch
from alibi_detect.cd.pytorch import HiddenOutput, preprocess_drift
from alibi_detect.cd import MMDDriftOnline
from functools import partial
import torch.nn as nn
import torch.nn.functional as F

# === Constants ===
SAMPLE_RATE = 16000
RAW_MODEL_INPUT_LEN = 320000
DROPOUT = 0.3

# === Model ===
class EmbeddingClassifier(nn.Module):
    def __init__(self, emb_dim=2048, num_cls=206):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(emb_dim, 2048), nn.BatchNorm1d(2048), nn.ReLU(), nn.Dropout(DROPOUT),
            nn.Linear(2048, 1024),    nn.BatchNorm1d(1024), nn.ReLU(), nn.Dropout(DROPOUT),
            nn.Linear(1024, 512),     nn.BatchNorm1d(512),  nn.ReLU(), nn.Dropout(DROPOUT),
            nn.Linear(512, num_cls)
        )
    def forward(self, x):
        return self.net(x)

# === Audio preprocessing ===
def preprocess_audio(wav_path):
    try:
        _ = torchaudio.info(wav_path)  # Ensures it's readable
    except Exception as e:
        raise ValueError(f"Unreadable audio file: {wav_path}") from e
    waveform, sr = torchaudio.load(wav_path)
    if sr != SAMPLE_RATE:
        waveform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(waveform)
    waveform = waveform.mean(dim=0) if waveform.dim() > 1 else waveform
    waveform = F.pad(waveform, (0, max(0, RAW_MODEL_INPUT_LEN - waveform.shape[0])))
    waveform = (waveform - waveform.mean()) / waveform.std().clamp_min(1e-6)
    return waveform[:RAW_MODEL_INPUT_LEN]

def get_mel_tensor(waveform_tensor):
    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=SAMPLE_RATE, n_fft=1024, hop_length=256, n_mels=64
    )
    mel = torch.log1p(mel_transform(waveform_tensor))
    mel_padded = torch.zeros((1, 64, 313))
    mel_len = min(mel.shape[-1], 313)
    mel_padded[:, :, :mel_len] = mel[:, :, :mel_len]
    return mel_padded

# === Initialize drift detector ===
def init_drift_detector(model_path: str, ref_paths: list):
    model = EmbeddingClassifier()
    model.load_state_dict(torch.load(model_path, map_location="cpu"), strict=False)
    model.eval()

    feature_model = HiddenOutput(model, layer=-1)
    preprocess_fn = partial(preprocess_drift, model=feature_model)

    ref_mels = []
    for p in ref_paths:
        try:
            print(f"Processing {p}")
            waveform = preprocess_audio(p)
            mel = get_mel_tensor(waveform)  # shape: [1, 64, 313]
            ref_mels.append(mel)
        except Exception as e:
            print(f"Warning: Skipping {p} due to error: {e}")

    if not ref_mels:
        raise RuntimeError("No valid reference audio files found. Please check your dataset.")

    x_ref = torch.cat(ref_mels, dim=0)  # shape: [N, 64, 313]

    return MMDDriftOnline(
        x_ref, ert=300, window_size=10, backend='pytorch', preprocess_fn=preprocess_fn
    )


In [None]:
REF_AUDIO_FOLDER = "training/21211"
TEST_AUDIO_PATH = "validating/21211/XC925908.ogg"

ref_audio_files = [
    os.path.join(REF_AUDIO_FOLDER, f)
    for f in os.listdir(REF_AUDIO_FOLDER)
    if f.lower().endswith(".ogg")
]

cd_online = init_drift_detector("best_emb_mlp.pt", ref_audio_files)

waveform = preprocess_audio(TEST_AUDIO_PATH)
mel_tensor = get_mel_tensor(waveform)
is_drift = cd_online.predict(mel_tensor.numpy())["data"]["is_drift"]
print("\nDrift detected!" if is_drift else "\nNo drift detected.")

In [None]:
cd_online.reset_state()
save_detector(cd_online, "cd")


In [None]:
try:
    waveform, sr = torchaudio.load("training/21211/XC909280.ogg")
    print(f"Loaded with sample rate {sr}, shape: {waveform.shape}")
except Exception as e:
    print(f"Failed: {e}")

In [None]:
import os
import subprocess

ogg_dir = "training/21211"
wav_dir = "training/21211_wav"
os.makedirs(wav_dir, exist_ok=True)

for fname in os.listdir(ogg_dir):
    if fname.lower().endswith(".ogg"):
        ogg_path = os.path.join(ogg_dir, fname)
        wav_name = fname.replace(".ogg", ".wav")
        wav_path = os.path.join(wav_dir, wav_name)

        try:
            subprocess.run([
                "ffmpeg", "-y", "-i", ogg_path, wav_path
            ], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
            print(f"Converted: {fname} → {wav_name}")
        except subprocess.CalledProcessError:
            print(f"Failed to convert: {fname}")
