In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from ls.config.loader import load_config
import IPython.display as ipd
import torch

In [None]:
# --- 1. load config ---
cfg = load_config("../configs/config.yaml")

print("Dataset config:", cfg.dataset)
print("Audio config:", cfg.audio)

## Test DataLoaders

### ICBHI Dataloader using only test set as the validation set

In [None]:
# # Regular training
from ls.data.dataloaders import build_dataloaders

train_loader, test_loader = build_dataloaders(cfg.dataset, cfg.audio)

In [None]:
for batch in train_loader:
    print(batch["input_values"].shape, batch["labels"].shape, batch["filename"])
    break

In [None]:
from collections import Counter
import numpy as np

def inspect_batch_balance(train_loader, n_batches=100):
    pattern_counter = Counter()
    total = 0
    for i, batch in enumerate(train_loader):
        labels = batch["labels"].cpu().numpy()
        # For multi-label mode (2 columns)
        if labels.ndim == 2:
            patterns = [tuple(v) for v in labels]
        else:
            patterns = [int(v) for v in labels]
        pattern_counter.update(patterns)
        total += len(patterns)
        if i >= n_batches:
            break

    print(f"\nChecked {total} samples from {n_batches} batches")
    for pat, c in sorted(pattern_counter.items()):
        print(f"Pattern {pat}: {c} ({100*c/total:.1f}%)")

In [None]:
inspect_batch_balance(train_loader, n_batches=100)

In [None]:
def tensors_equal(a, b, atol=1e-6, rtol=1e-5):
    return torch.allclose(a, b, atol=atol, rtol=rtol)

In [None]:
# --- Inspect one sample ---
idx = torch.randint(cfg.dataset.batch_size, (1,)).item()
print(f"Inspecting sample index {idx} in the batch")
print("Keys:", batch.keys())
print("Filename:", batch["filename"][idx])
print("Cycle index:", batch["cycle_index"][idx])
print("Label:", batch["labels"][idx])
print("Duration:", batch["duration"][idx])
print("Start-End:", batch["start_time"][idx], "-", batch["end_time"][idx])
print("Crackle/Wheeze:", batch["crackle"][idx], batch["wheeze"][idx])
# print(sample["aug_audio"].shape, sample["aug_fbank"].shape, sample["audio"].shape, sample["fbank"].shape)
print("Waveform shape:", batch["audio"][idx].shape)
print("Mel image shape:", batch["input_values"][idx].shape)
# print(f"Waveform augmented: {not tensors_equal(batch['audio'][idx].view(-1), batch['aug_audio'][idx].view(-1))}")
# print(f"Mel augmented: {not tensors_equal(batch['fbank'][idx].view(-1), batch['aug_fbank'][idx].view(-1))}")

# --- 4. Plot waveform ---
waveform = batch["audio"][idx].squeeze().numpy()
plt.figure(figsize=(12, 3))
plt.plot(waveform)
plt.title(f"Waveform ({batch['filename'][idx]} - cycle {batch['cycle_index'][idx]})")
plt.show()

# --- 5. Plot mel spectrogram ---
mel = batch["input_values"][idx].squeeze(0)  # [freq, time] for imshow
freq_axis = np.linspace(0, cfg.audio.sample_rate // 2, mel.shape[0])
time_axis = np.arange(mel.shape[1]) * cfg.audio.frame_shift / 1000  # in seconds
plt.figure(figsize=(12, 5))
plt.imshow(mel, origin="lower", aspect="auto", extent=[0, time_axis[-1], 0, freq_axis[-1]], cmap="viridis")
plt.title("Mel filterbank")
plt.colorbar()
plt.show()

ipd.Audio(waveform, rate=cfg.audio.sample_rate)  # listen to the sound

In [None]:
from ls.data.dataloaders import compute_and_cache_stats

mean, std = compute_and_cache_stats(
    train_loader.dataset, cache_file="train_stats.json", batch_size=cfg.dataset.batch_size,
    num_workers=cfg.dataset.num_workers
)
print(f"Dataset mean: {mean:.4f}, std: {std:.4f}")

## Stratified Grouped KFold Dataloader

In [None]:
from ls.data.dataloaders import build_train_val_kfold

In [None]:
# Or: k-fold CV on train set
folds, test_loader = build_train_val_kfold(
    cfg.dataset, cfg.audio, n_splits=5, max_retries=50, seed=cfg.seed
)

In [None]:
for i, (train_loader, val_loader) in enumerate(folds):
    # print(f"Training fold {i+1}")
    for batch in train_loader:
        # Process each batch
        print(batch["input_values"].shape, batch["label"].shape, batch["filename"])
        break