# 02 — Train Model A (Mel-only) — DenseNet121 + BiLSTM (PyTorch)

This notebook trains **Model A: Mel Spectrogram only** using **DenseNet121 (ImageNet pretrained) + BiLSTM** on our merged dataset manifest.

## What you already have (from 01_build_manifest.ipynb)
- `data/processed/manifests/train.csv`
- `data/processed/manifests/val.csv`
- `data/processed/manifests/test.csv`
- `models/label_to_id.json` and `models/id_to_label.json` (may be 8-class)

## ⚠️ Important note about your current data
From your printed counts, your dataset currently contains **6 classes**:
- COPD, Normal, Asthma, Heart failure, Pneumonia, Lung fibrosis

**Bronchitis** and **Pleural effusion** have **0 samples** right now.  
So by default this notebook trains a **6-class** classifier (recommended), and it will **overwrite** the label maps to match available classes.

If you later add data for the missing 2 classes, re-run 01_build_manifest and then re-run this notebook.

## Key settings (easy to tweak)
- Sample rate: 22,050 Hz
- Segments per recording: 5
- Segment length: 2.0 sec
- Mel parameters: n_mels=128, n_fft=2048, hop_length=512
- Image size for DenseNet: 224x224 (3-channel)
- Class imbalance handling: **WeightedRandomSampler + class-weighted CE loss**
- Mixed precision training: enabled when CUDA is available

Outputs:
- `models/best_model.pth`
- `models/config.json`
- Updated `models/label_to_id.json` + `models/id_to_label.json`


In [3]:
# Cell 1 — Install deps (run once if needed)
# NOTE: In Lightning many of these may already exist.
# If you get import errors, uncomment and run.

# !pip -q install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# !pip -q install librosa soundfile scikit-learn tqdm matplotlib


In [4]:
# Cell 2 — Imports & seed

from __future__ import annotations

import os
import json
import math
import random
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Tuple, Optional, Dict, Any, List

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

import torchvision
from torchvision import transforms

import librosa
from tqdm.auto import tqdm

from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix


In [5]:
# Cell 3 — Resolve paths

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

PROJECT_ROOT = Path.cwd().resolve().parents[0]  # .../lung_sound_project
MANIFEST_DIR = PROJECT_ROOT / "data/processed/manifests"
MODELS_DIR = PROJECT_ROOT / "models"
MODELS_DIR.mkdir(parents=True, exist_ok=True)

train_csv = MANIFEST_DIR / "train.csv"
val_csv   = MANIFEST_DIR / "val.csv"
test_csv  = MANIFEST_DIR / "test.csv"

assert train_csv.exists() and val_csv.exists() and test_csv.exists(), "Run 01_build_manifest.ipynb first!"

train_df = pd.read_csv(train_csv)
val_df   = pd.read_csv(val_csv)
test_df  = pd.read_csv(test_csv)

print("Train:", train_df.shape, "Val:", val_df.shape, "Test:", test_df.shape)
print("Train labels:\n", train_df["label"].value_counts())


Train: (732, 4) Val: (152, 4) Test: (291, 4)
Train labels:
 label
COPD             476
Normal            99
Asthma            68
Pneumonia         42
Heart failure     38
Lung fibrosis      9
Name: count, dtype: int64


## 1) Decide which labels to train

We auto-detect available labels from `manifest_all` (train+val+test).  
This avoids training for classes that have 0 samples.

By default: **use only labels present** and overwrite label maps accordingly.


In [6]:
# Cell 4 — Label mapping (auto from available labels)

all_df = pd.concat([train_df, val_df, test_df], ignore_index=True)
available_labels = sorted(all_df["label"].unique().tolist())

print("Available labels in data:", available_labels)

# This is your intended 8-class set (for reference)
intended_8 = [
    "Normal",
    "Asthma",
    "COPD",
    "Pneumonia",
    "Bronchitis",
    "Heart failure",
    "Pleural effusion",
    "Lung fibrosis",
]

missing = [c for c in intended_8 if c not in available_labels]
if missing:
    print("⚠️ Missing classes (0 samples):", missing)
    print("➡️ Training will use only available labels (recommended).")

LABELS = available_labels  # train on what exists
label_to_id = {lbl: i for i, lbl in enumerate(LABELS)}
id_to_label = {i: lbl for lbl, i in label_to_id.items()}

# Save updated label maps (for inference + Gradio)
(MODELS_DIR / "label_to_id.json").write_text(json.dumps(label_to_id, indent=2), encoding="utf-8")
(MODELS_DIR / "id_to_label.json").write_text(json.dumps(id_to_label, indent=2), encoding="utf-8")

print("Saved updated label maps to:", MODELS_DIR)
print("Num classes:", len(LABELS))


Available labels in data: ['Asthma', 'COPD', 'Heart failure', 'Lung fibrosis', 'Normal', 'Pneumonia']
⚠️ Missing classes (0 samples): ['Bronchitis', 'Pleural effusion']
➡️ Training will use only available labels (recommended).
Saved updated label maps to: /teamspace/studios/this_studio/lung_sound_project/models
Num classes: 6


## 2) Feature extraction: Mel Spectrogram (Model A)

We convert each recording into a **sequence** of mel-spectrogram images:
- Split into `NUM_SEGMENTS` segments
- Each segment -> mel spectrogram -> normalize -> resize to 224x224 -> 3 channels
- DenseNet processes each segment image -> produces feature vector
- BiLSTM learns the temporal pattern across segments


In [7]:
# Cell 5 — Config

@dataclass
class TrainConfig:
    sample_rate: int = 22050
    num_segments: int = 5
    segment_seconds: float = 2.0  # per segment
    n_mels: int = 128
    n_fft: int = 2048
    hop_length: int = 512
    img_size: int = 224

    batch_size: int = 8
    num_workers: int = 2

    lstm_hidden: int = 128
    lstm_layers: int = 1
    bidirectional: bool = True
    dropout: float = 0.5

    # training
    epochs_frozen: int = 6      # freeze DenseNet, train LSTM+head
    epochs_finetune: int = 6    # unfreeze last dense block
    lr_frozen: float = 1e-3
    lr_finetune: float = 1e-5
    weight_decay: float = 1e-4

    # imbalance handling
    use_weighted_sampler: bool = True
    use_class_weights: bool = True

    # caching (optional)
    use_cache: bool = False
    cache_dir: str = "data/processed/cache_mels"  # inside project
    cache_format: str = "pt"  # "pt" recommended

    # augmentations
    aug_time_shift: bool = True
    aug_add_noise: bool = True
    aug_spec_mask: bool = True  # lightweight SpecAugment

cfg = TrainConfig()
cfg


TrainConfig(sample_rate=22050, num_segments=5, segment_seconds=2.0, n_mels=128, n_fft=2048, hop_length=512, img_size=224, batch_size=8, num_workers=2, lstm_hidden=128, lstm_layers=1, bidirectional=True, dropout=0.5, epochs_frozen=6, epochs_finetune=6, lr_frozen=0.001, lr_finetune=1e-05, weight_decay=0.0001, use_weighted_sampler=True, use_class_weights=True, use_cache=False, cache_dir='data/processed/cache_mels', cache_format='pt', aug_time_shift=True, aug_add_noise=True, aug_spec_mask=True)

In [8]:
# Cell 6 — Audio helpers

def load_audio(path: str, sr: int) -> np.ndarray:
    y, _ = librosa.load(path, sr=sr, mono=True)
    if y.size == 0:
        return np.zeros(sr, dtype=np.float32)
    # normalize peak
    peak = np.max(np.abs(y))
    if peak > 0:
        y = y / peak
    return y.astype(np.float32)

def pad_or_crop(y: np.ndarray, target_len: int) -> np.ndarray:
    if len(y) == target_len:
        return y
    if len(y) < target_len:
        pad = target_len - len(y)
        return np.pad(y, (0, pad), mode="constant")
    # crop
    start = np.random.randint(0, len(y) - target_len + 1)
    return y[start:start + target_len]

def split_segments(y: np.ndarray, sr: int, num_segments: int, seg_seconds: float) -> np.ndarray:
    seg_len = int(sr * seg_seconds)
    total_len = seg_len * num_segments
    y = pad_or_crop(y, total_len)
    return y.reshape(num_segments, seg_len)

def time_shift(y: np.ndarray, shift_max: float = 0.2) -> np.ndarray:
    # shift up to ±20% of length
    shift = int(random.uniform(-shift_max, shift_max) * len(y))
    return np.roll(y, shift)

def add_noise(y: np.ndarray, noise_level: float = 0.01) -> np.ndarray:
    noise = np.random.randn(len(y)).astype(np.float32)
    return y + noise_level * noise


In [9]:
# Cell 7 — Mel spectrogram -> 3-channel image tensor

def mel_image(segment: np.ndarray, sr: int, n_mels: int, n_fft: int, hop_length: int, img_size: int,
              spec_mask: bool = False) -> torch.Tensor:
    # mel power spectrogram
    S = librosa.feature.melspectrogram(
        y=segment, sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length, power=2.0
    )
    # log scale
    S_db = librosa.power_to_db(S, ref=np.max)

    # normalize to [0, 1] per-sample
    S_db = (S_db - S_db.min()) / (S_db.max() - S_db.min() + 1e-8)

    # SpecAugment (very light): time + freq mask
    if spec_mask:
        # frequency mask
        fm = random.randint(0, max(1, n_mels // 8))
        f0 = random.randint(0, max(0, n_mels - fm))
        if fm > 0:
            S_db[f0:f0+fm, :] = 0.0
        # time mask
        t = S_db.shape[1]
        tm = random.randint(0, max(1, t // 10))
        t0 = random.randint(0, max(0, t - tm))
        if tm > 0:
            S_db[:, t0:t0+tm] = 0.0

    # convert to torch (1, H, W)
    img = torch.from_numpy(S_db).unsqueeze(0).float()  # 1 x 128 x T

    # resize to 224x224
    img = F.interpolate(img.unsqueeze(0), size=(img_size, img_size), mode="bilinear", align_corners=False).squeeze(0)

    # make 3 channels
    img3 = img.repeat(3, 1, 1)  # 3 x 224 x 224
    return img3


## 3) Optional caching (speeds up repeated epochs)

If `cfg.use_cache=True`, each (audio file, segment config) will be cached to disk as a tensor.
This can speed training when CPU mel computation becomes the bottleneck.

With L4 GPU, caching is **nice to have** but not strictly required.  
If training feels slow, enable caching.


In [10]:
# Cell 8 — Dataset

class LungSoundDataset(Dataset):
    def __init__(self, df: pd.DataFrame, cfg: TrainConfig, label_to_id: Dict[str, int], train: bool):
        self.df = df.reset_index(drop=True)
        self.cfg = cfg
        self.label_to_id = label_to_id
        self.train = train

        self.cache_dir = (PROJECT_ROOT / cfg.cache_dir)
        if cfg.use_cache:
            self.cache_dir.mkdir(parents=True, exist_ok=True)

    def __len__(self) -> int:
        return len(self.df)

    def _cache_key(self, wav_path: str) -> str:
        # stable key based on filename + config
        p = Path(wav_path)
        key = f"{p.stem}_sr{self.cfg.sample_rate}_ns{self.cfg.num_segments}_ss{self.cfg.segment_seconds}_m{self.cfg.n_mels}_fft{self.cfg.n_fft}_hop{self.cfg.hop_length}_img{self.cfg.img_size}.pt"
        return key

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        row = self.df.iloc[idx]
        wav_path = row["filepath"]
        label = row["label"]
        y = None

        if self.cfg.use_cache:
            key = self._cache_key(wav_path)
            cache_path = self.cache_dir / key
            if cache_path.exists():
                x = torch.load(cache_path, map_location="cpu")
                return x, self.label_to_id[label]

        y = load_audio(wav_path, self.cfg.sample_rate)

        # augmentations (train only)
        if self.train:
            if self.cfg.aug_time_shift and random.random() < 0.5:
                y = time_shift(y)
            if self.cfg.aug_add_noise and random.random() < 0.5:
                y = add_noise(y, noise_level=random.uniform(0.002, 0.01))

        segs = split_segments(y, self.cfg.sample_rate, self.cfg.num_segments, self.cfg.segment_seconds)

        imgs = []
        for s in segs:
            img = mel_image(
                s, sr=self.cfg.sample_rate,
                n_mels=self.cfg.n_mels, n_fft=self.cfg.n_fft, hop_length=self.cfg.hop_length,
                img_size=self.cfg.img_size,
                spec_mask=(self.train and self.cfg.aug_spec_mask and random.random() < 0.5)
            )
            imgs.append(img)

        x = torch.stack(imgs, dim=0)  # (T, 3, 224, 224)

        if self.cfg.use_cache:
            torch.save(x, cache_path)

        return x, self.label_to_id[label]


## 4) DataLoaders + imbalance handling

Your data is heavily imbalanced (COPD dominates). We use:
- **WeightedRandomSampler** (oversamples minority classes)
- **Class-weighted CrossEntropyLoss** (optional)

We optimize for **macro F1** (more fair than accuracy under imbalance).


In [11]:
# Cell 9 — Build samplers and dataloaders

def make_weights(df: pd.DataFrame, label_to_id: Dict[str, int]) -> np.ndarray:
    counts = df["label"].value_counts().to_dict()
    weights = []
    for lbl in df["label"].tolist():
        w = 1.0 / max(1, counts.get(lbl, 1))
        weights.append(w)
    return np.array(weights, dtype=np.float32)

train_ds = LungSoundDataset(train_df, cfg, label_to_id, train=True)
val_ds   = LungSoundDataset(val_df, cfg, label_to_id, train=False)
test_ds  = LungSoundDataset(test_df, cfg, label_to_id, train=False)

sampler = None
if cfg.use_weighted_sampler:
    w = make_weights(train_df, label_to_id)
    sampler = WeightedRandomSampler(weights=torch.from_numpy(w), num_samples=len(w), replacement=True)

train_loader = DataLoader(
    train_ds, batch_size=cfg.batch_size, shuffle=(sampler is None),
    sampler=sampler, num_workers=cfg.num_workers, pin_memory=True
)
val_loader = DataLoader(
    val_ds, batch_size=cfg.batch_size, shuffle=False,
    num_workers=cfg.num_workers, pin_memory=True
)
test_loader = DataLoader(
    test_ds, batch_size=cfg.batch_size, shuffle=False,
    num_workers=cfg.num_workers, pin_memory=True
)

# quick sanity batch
xb, yb = next(iter(train_loader))
print("Batch X:", xb.shape, "Batch y:", yb.shape, "num_classes:", len(LABELS))


Batch X: torch.Size([8, 5, 3, 224, 224]) Batch y: torch.Size([8]) num_classes: 6


## 5) Model: DenseNet121 → (sequence) → BiLSTM → classifier

- DenseNet121 extracts a feature vector per segment
- BiLSTM models temporal breathing pattern across segments
- Final linear layer predicts class


In [12]:
# Cell 10 — Model definition

class DenseNetEncoder(nn.Module):
    def __init__(self, pretrained: bool = True):
        super().__init__()
        m = torchvision.models.densenet121(weights=torchvision.models.DenseNet121_Weights.IMAGENET1K_V1 if pretrained else None)
        self.features = m.features
        self.out_dim = 1024  # DenseNet121 feature dim
        # final classifier removed

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, 3, 224, 224)
        f = self.features(x)                 # (B, 1024, H, W)
        f = F.relu(f, inplace=True)
        f = F.adaptive_avg_pool2d(f, (1, 1)) # (B, 1024, 1, 1)
        f = f.flatten(1)                     # (B, 1024)
        return f

class DenseNetBiLSTM(nn.Module):
    def __init__(self, num_classes: int, hidden: int = 128, layers: int = 1, bidirectional: bool = True, dropout: float = 0.5):
        super().__init__()
        self.encoder = DenseNetEncoder(pretrained=True)
        self.lstm = nn.LSTM(
            input_size=self.encoder.out_dim,
            hidden_size=hidden,
            num_layers=layers,
            batch_first=True,
            bidirectional=bidirectional,
            dropout=0.0 if layers == 1 else dropout
        )
        lstm_out = hidden * (2 if bidirectional else 1)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(lstm_out, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, T, 3, 224, 224)
        B, T, C, H, W = x.shape
        x = x.view(B*T, C, H, W)
        feats = self.encoder(x)              # (B*T, 1024)
        feats = feats.view(B, T, -1)         # (B, T, 1024)
        out, _ = self.lstm(feats)            # (B, T, hidden*dir)
        last = out[:, -1, :]                 # (B, hidden*dir)
        last = self.dropout(last)
        logits = self.fc(last)               # (B, num_classes)
        return logits

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DenseNetBiLSTM(num_classes=len(LABELS), hidden=cfg.lstm_hidden, layers=cfg.lstm_layers,
                       bidirectional=cfg.bidirectional, dropout=cfg.dropout).to(device)

print("Device:", device)
print("Model params:", sum(p.numel() for p in model.parameters())/1e6, "M")


Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /home/zeus/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth


100%|██████████| 30.8M/30.8M [00:00<00:00, 46.5MB/s]


Device: cuda
Model params: 8.137094 M


In [13]:
# Cell 11 — Freeze/unfreeze helpers

def freeze_encoder(m: DenseNetBiLSTM, freeze: bool = True):
    for p in m.encoder.parameters():
        p.requires_grad = not freeze

def unfreeze_last_dense_block(m: DenseNetBiLSTM):
    # DenseNet features has: conv0, norm0, relu0, pool0, denseblock1..4, transition1..3, norm5
    # We'll unfreeze denseblock4 + norm5 only
    for name, p in m.encoder.features.named_parameters():
        if name.startswith("denseblock4") or name.startswith("norm5"):
            p.requires_grad = True
        else:
            p.requires_grad = False


## 6) Loss + Metrics

Because of imbalance, we use:
- **macro F1** to choose the best model
- optional class weights in CrossEntropyLoss


In [14]:
# Cell 12 — Loss setup (class weights optional)

def compute_class_weights(df: pd.DataFrame, labels: List[str]) -> torch.Tensor:
    counts = df["label"].value_counts().to_dict()
    w = []
    for lbl in labels:
        w.append(1.0 / max(1, counts.get(lbl, 1)))
    w = np.array(w, dtype=np.float32)
    w = w / w.sum() * len(labels)  # normalize
    return torch.tensor(w, dtype=torch.float32)

class_weights = None
if cfg.use_class_weights:
    class_weights = compute_class_weights(train_df, LABELS).to(device)
    print("Class weights:", class_weights.detach().cpu().numpy().round(3))

criterion = nn.CrossEntropyLoss(weight=class_weights)


Class weights: [0.469 0.067 0.839 3.543 0.322 0.759]


In [15]:
# Cell 13 — Eval function

@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader) -> Dict[str, Any]:
    model.eval()
    ys, preds = [], []
    losses = []
    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        logits = model(x)
        loss = criterion(logits, y)
        losses.append(loss.item())
        p = torch.argmax(logits, dim=1)
        ys.extend(y.detach().cpu().numpy().tolist())
        preds.extend(p.detach().cpu().numpy().tolist())

    acc = accuracy_score(ys, preds)
    f1m = f1_score(ys, preds, average="macro")
    return {"loss": float(np.mean(losses)), "acc": float(acc), "f1_macro": float(f1m), "y": ys, "pred": preds}


## 7) Training loop (2 phases)

**Phase 1 (Frozen)**: Freeze DenseNet, train only LSTM + head  
**Phase 2 (Fine-tune)**: Unfreeze last DenseNet block for better adaptation

We save best checkpoint by **val macro F1**.


In [16]:
# Cell 14 — Training loop

from torch.cuda.amp import autocast, GradScaler

def train_one_epoch(model: nn.Module, loader: DataLoader, optimizer: torch.optim.Optimizer, scaler: Optional[GradScaler]) -> float:
    model.train()
    losses = []
    for x, y in tqdm(loader, leave=False):
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        if device.type == "cuda":
            with autocast():
                logits = model(x)
                loss = criterion(logits, y)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()

        losses.append(loss.item())

    return float(np.mean(losses))


best = {"f1_macro": -1.0, "path": None}
scaler = GradScaler(enabled=(device.type == "cuda"))

def save_checkpoint(path: Path, model: nn.Module, cfg: TrainConfig, labels: List[str]):
    ckpt = {
        "state_dict": model.state_dict(),
        "labels": labels,
        "config": asdict(cfg),
    }
    torch.save(ckpt, path)

# Phase 1: freeze encoder
freeze_encoder(model, freeze=True)
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                              lr=cfg.lr_frozen, weight_decay=cfg.weight_decay)

print("=== Phase 1: Frozen encoder ===")
for epoch in range(1, cfg.epochs_frozen + 1):
    tr_loss = train_one_epoch(model, train_loader, optimizer, scaler)
    val_metrics = evaluate(model, val_loader)
    print(f"Epoch {epoch}/{cfg.epochs_frozen} | train_loss={tr_loss:.4f} | val_loss={val_metrics['loss']:.4f} | val_acc={val_metrics['acc']:.4f} | val_f1m={val_metrics['f1_macro']:.4f}")

    if val_metrics["f1_macro"] > best["f1_macro"]:
        best["f1_macro"] = val_metrics["f1_macro"]
        best_path = MODELS_DIR / "best_model.pth"
        save_checkpoint(best_path, model, cfg, LABELS)
        best["path"] = str(best_path)
        print("✅ Saved best checkpoint:", best_path, "f1m=", best["f1_macro"])

# Phase 2: fine-tune last block
print("\n=== Phase 2: Fine-tune last DenseNet block ===")
unfreeze_last_dense_block(model)
# make sure LSTM+head still trainable
for p in model.lstm.parameters():
    p.requires_grad = True
for p in model.fc.parameters():
    p.requires_grad = True

optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                              lr=cfg.lr_finetune, weight_decay=cfg.weight_decay)

for epoch in range(1, cfg.epochs_finetune + 1):
    tr_loss = train_one_epoch(model, train_loader, optimizer, scaler)
    val_metrics = evaluate(model, val_loader)
    print(f"FT Epoch {epoch}/{cfg.epochs_finetune} | train_loss={tr_loss:.4f} | val_loss={val_metrics['loss']:.4f} | val_acc={val_metrics['acc']:.4f} | val_f1m={val_metrics['f1_macro']:.4f}")

    if val_metrics["f1_macro"] > best["f1_macro"]:
        best["f1_macro"] = val_metrics["f1_macro"]
        best_path = MODELS_DIR / "best_model.pth"
        save_checkpoint(best_path, model, cfg, LABELS)
        best["path"] = str(best_path)
        print("✅ Saved best checkpoint:", best_path, "f1m=", best["f1_macro"])

print("\nBest val macro F1:", best["f1_macro"], "Checkpoint:", best["path"])


=== Phase 1: Frozen encoder ===


  scaler = GradScaler(enabled=(device.type == "cuda"))


  0%|          | 0/92 [00:00<?, ?it/s]

  with autocast():
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7221b4075620>Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7221b4075620>Traceback (most recent call last):

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
Traceback (most recent call last):
      File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
self._shutdown_workers()    
self._shutdown_workers()  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():    
if w.is_alive(): 
            ^ ^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/zeus/miniconda3/envs/cloudspace/lib/p

Epoch 1/6 | train_loss=1.4565 | val_loss=3.1699 | val_acc=0.0066 | val_f1m=0.0022
✅ Saved best checkpoint: /teamspace/studios/this_studio/lung_sound_project/models/best_model.pth f1m= 0.002178649237472767


  0%|          | 0/92 [00:00<?, ?it/s]

  with autocast():
Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7221b4075620><function _MultiProcessingDataLoaderIter.__del__ at 0x7221b4075620>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
        if w.is_alive():if w.is_alive():
 
            ^ ^^^^^^^^^^^^^^^^^^^^^^
^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/p

Epoch 2/6 | train_loss=1.5055 | val_loss=3.1384 | val_acc=0.0066 | val_f1m=0.0022


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7221b4075620>
<function _MultiProcessingDataLoaderIter.__del__ at 0x7221b4075620>
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
self._shutdown_workers()    
self._shutdown_workers()  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers

      File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
if w.is_alive():    
if w.is_alive(): 
             ^^^^^^^^^^^^^^^^^^^^^
^^  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiproces

Epoch 3/6 | train_loss=1.4051 | val_loss=3.3743 | val_acc=0.0066 | val_f1m=0.0022


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7221b4075620><function _MultiProcessingDataLoaderIter.__del__ at 0x7221b4075620>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
        if w.is_alive():if w.is_alive():

              ^^^^^^^^^^^^^^^^^^^^^^
^  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiproces

Epoch 4/6 | train_loss=1.4001 | val_loss=2.7811 | val_acc=0.0132 | val_f1m=0.0333
✅ Saved best checkpoint: /teamspace/studios/this_studio/lung_sound_project/models/best_model.pth f1m= 0.03333333333333333


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7221b4075620><function _MultiProcessingDataLoaderIter.__del__ at 0x7221b4075620>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
        if w.is_alive():if w.is_alive():

              ^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiproc

Epoch 5/6 | train_loss=1.3463 | val_loss=3.0289 | val_acc=0.0066 | val_f1m=0.0044


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7221b4075620><function _MultiProcessingDataLoaderIter.__del__ at 0x7221b4075620>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
        if w.is_alive():if w.is_alive():

          Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x7221b4075620>  
 Except

Epoch 6/6 | train_loss=1.4020 | val_loss=2.6795 | val_acc=0.0132 | val_f1m=0.0303

=== Phase 2: Fine-tune last DenseNet block ===


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7221b4075620>
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7221b4075620> 
 Traceback (most recent call last):
   File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
       self._shutdown_workers() 
^  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^    ^if w.is_alive():^
^  ^ ^ ^ ^ ^ ^^^^
^Exception ignored in:   File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.

FT Epoch 1/6 | train_loss=1.4142 | val_loss=2.7587 | val_acc=0.0132 | val_f1m=0.0278


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7221b4075620>
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
Exception ignored in:     <function _MultiProcessingDataLoaderIter.__del__ at 0x7221b4075620>self._shutdown_workers()

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
Traceback (most recent call last):
      File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
if w.is_alive():    
self._shutdown_workers()  
   File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
       if w.is_alive(): 
^ ^ ^ ^ ^ ^ ^ ^^^^^^^^^^
^  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiprocessing/

FT Epoch 2/6 | train_loss=1.2996 | val_loss=2.6996 | val_acc=0.0066 | val_f1m=0.0056


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7221b4075620>
Exception ignored in: Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x7221b4075620>  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__

    Traceback (most recent call last):
self._shutdown_workers()  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__

      File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
self._shutdown_workers()    
if w.is_alive():  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers

     if w.is_alive(): 
           ^ ^^^^^^^^^^^^^^^^^^^^^
^  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiproces

FT Epoch 3/6 | train_loss=1.3241 | val_loss=2.6613 | val_acc=0.0066 | val_f1m=0.0057


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7221b4075620><function _MultiProcessingDataLoaderIter.__del__ at 0x7221b4075620>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
        if w.is_alive():if w.is_alive():

              ^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiproc

FT Epoch 4/6 | train_loss=1.2566 | val_loss=2.6810 | val_acc=0.0066 | val_f1m=0.0058


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7221b4075620>
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7221b4075620>  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers

    Traceback (most recent call last):
if w.is_alive():  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__

      self._shutdown_workers() 
   File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
      if w.is_alive(): 
^ ^ ^ ^ ^ ^ ^ ^^^^^^^^^^
^  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiprocessing/

FT Epoch 5/6 | train_loss=1.2662 | val_loss=2.6552 | val_acc=0.0066 | val_f1m=0.0068


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7221b4075620>
Exception ignored in: Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x7221b4075620>  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__

    Traceback (most recent call last):
self._shutdown_workers()  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
        self._shutdown_workers()if w.is_alive():

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
      if w.is_alive(): 
        ^ ^ ^ ^^^^^^^^^^^^^^^^^
^  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiprocessi

FT Epoch 6/6 | train_loss=1.2143 | val_loss=2.7095 | val_acc=0.0066 | val_f1m=0.0062

Best val macro F1: 0.03333333333333333 Checkpoint: /teamspace/studios/this_studio/lung_sound_project/models/best_model.pth


## 8) Evaluate best checkpoint on Test set

We load `models/best_model.pth`, run test evaluation, and print:
- Accuracy
- Macro F1
- Confusion matrix
- Per-class report


In [17]:
# Cell 15 — Load best checkpoint and evaluate on test

best_path = MODELS_DIR / "best_model.pth"
assert best_path.exists(), "No best_model.pth found. Did training run?"

ckpt = torch.load(best_path, map_location=device)
model.load_state_dict(ckpt["state_dict"])

test_metrics = evaluate(model, test_loader)
print("TEST loss:", test_metrics["loss"])
print("TEST acc :", test_metrics["acc"])
print("TEST f1m :", test_metrics["f1_macro"])

y_true = test_metrics["y"]
y_pred = test_metrics["pred"]

print("\nClassification report:\n")
print(classification_report(y_true, y_pred, target_names=LABELS, digits=4))

cm = confusion_matrix(y_true, y_pred)
print("Confusion matrix (rows=true, cols=pred):\n", cm)


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7221b4075620>
Traceback (most recent call last):
Exception ignored in:   File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x7221b4075620>    
Traceback (most recent call last):
self._shutdown_workers()
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
        self._shutdown_workers()if w.is_alive():

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
      if w.is_alive(): 
        ^ ^ ^ ^^^^^^^^^^^^^^^^^^
^  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiprocess

TEST loss: 2.6988468363478377
TEST acc : 0.030927835051546393
TEST f1m : 0.04064175896731075

Classification report:

               precision    recall  f1-score   support

       Asthma     0.0000    0.0000    0.0000        19
         COPD     0.0000    0.0000    0.0000       224
Heart failure     0.1000    0.1818    0.1290        11
Lung fibrosis     0.0364    1.0000    0.0702         2
       Normal     0.0000    0.0000    0.0000        27
    Pneumonia     0.0231    0.6250    0.0446         8

     accuracy                         0.0309       291
    macro avg     0.0266    0.3011    0.0406       291
 weighted avg     0.0047    0.0309    0.0066       291

Confusion matrix (rows=true, cols=pred):
 [[  0   0   5  12   0   2]
 [  0   0   2  19   0 203]
 [  0   0   2   9   0   0]
 [  0   0   0   2   0   0]
 [  0   0  10  11   0   6]
 [  0   0   1   2   0   5]]


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


## 9) Save config.json (for inference + reproducibility)

We save a small config describing:
- labels
- audio params
- segment params
- mel params
- model params


In [18]:
# Cell 16 — Save config.json

config_out = {
    "labels": LABELS,
    "label_to_id": label_to_id,
    "id_to_label": id_to_label,
    "audio": {
        "sample_rate": cfg.sample_rate,
        "num_segments": cfg.num_segments,
        "segment_seconds": cfg.segment_seconds,
    },
    "mel": {
        "n_mels": cfg.n_mels,
        "n_fft": cfg.n_fft,
        "hop_length": cfg.hop_length,
        "img_size": cfg.img_size,
    },
    "model": {
        "backbone": "densenet121",
        "lstm_hidden": cfg.lstm_hidden,
        "lstm_layers": cfg.lstm_layers,
        "bidirectional": cfg.bidirectional,
        "dropout": cfg.dropout,
    }
}

(MODELS_DIR / "config.json").write_text(json.dumps(config_out, indent=2), encoding="utf-8")
print("Saved:", MODELS_DIR / "config.json")


Saved: /teamspace/studios/this_studio/lung_sound_project/models/config.json


✅ Done.

Next notebook: **03_evaluate.ipynb** (plots, saved confusion matrix image, error analysis)
and then **app/gradio_app.py** can load:
- `models/best_model.pth`
- `models/config.json`
- `models/id_to_label.json`
