# 04 — Train with **Focal Loss** (to boost Macro‑F1 on minority classes)
Model A (Mel-only): **DenseNet121 + BiLSTM** (PyTorch)

You already got:
- Test Acc ≈ 0.756
- Test Macro‑F1 ≈ 0.341

Macro‑F1 is low mainly due to **tiny classes** (Heart failure=11, Lung fibrosis=2 in test).  
Focal Loss helps by focusing training on **hard / misclassified** samples.

## What this notebook changes vs v2
✅ Same stable data pipeline (center crop in Val/Test, ImageNet norm)  
✅ Same model and fine-tuning schedule  
✅ Replaces loss with **FocalLoss** (gamma=2)  
✅ Uses **mild alpha weights** (clipped) — safer than extreme class weights  
✅ Saves best checkpoint by **Val Macro‑F1**

Outputs:
- `models/best_model_focal.pth`
- `models/config_focal.json`


In [1]:
# (Optional) install deps if needed
# !pip -q install librosa soundfile scikit-learn tqdm matplotlib


In [2]:
from __future__ import annotations

import json, random
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from tqdm.auto import tqdm
import librosa

from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix


In [3]:
# Paths + seed
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

PROJECT_ROOT = Path.cwd().resolve().parents[0]  # run from notebooks/
MANIFEST_DIR = PROJECT_ROOT / "data/processed/manifests"
MODELS_DIR = PROJECT_ROOT / "models"
MODELS_DIR.mkdir(parents=True, exist_ok=True)

train_df = pd.read_csv(MANIFEST_DIR / "train.csv")
val_df   = pd.read_csv(MANIFEST_DIR / "val.csv")
test_df  = pd.read_csv(MANIFEST_DIR / "test.csv")

print("Train:", train_df.shape, "Val:", val_df.shape, "Test:", test_df.shape)
print("Train labels:\n", train_df["label"].value_counts())


Train: (732, 4) Val: (152, 4) Test: (291, 4)
Train labels:
 label
COPD             476
Normal            99
Asthma            68
Pneumonia         42
Heart failure     38
Lung fibrosis      9
Name: count, dtype: int64


## Labels present (currently 6)


In [4]:
all_df = pd.concat([train_df, val_df, test_df], ignore_index=True)
LABELS = sorted(all_df["label"].unique().tolist())
label_to_id = {lbl: i for i, lbl in enumerate(LABELS)}
id_to_label = {i: lbl for lbl, i in label_to_id.items()}

print("Labels:", LABELS, "num_classes:", len(LABELS))


Labels: ['Asthma', 'COPD', 'Heart failure', 'Lung fibrosis', 'Normal', 'Pneumonia'] num_classes: 6


## Config
Tip: if training is slow, reduce `num_segments` or `img_size`, or increase `batch_size` if VRAM allows.


In [5]:
@dataclass
class TrainConfig:
    sample_rate: int = 22050
    num_segments: int = 5
    segment_seconds: float = 2.0

    n_mels: int = 128
    n_fft: int = 2048
    hop_length: int = 512
    img_size: int = 224

    batch_size: int = 8
    num_workers: int = 2

    lstm_hidden: int = 128
    lstm_layers: int = 1
    bidirectional: bool = True
    dropout: float = 0.5

    epochs_stage1: int = 10
    epochs_stage2: int = 8

    lr_head: float = 1e-3
    lr_backbone_last: float = 1e-4
    lr_backbone_more: float = 1e-5

    weight_decay: float = 1e-4
    grad_clip: float = 1.0

    # Focal loss params
    focal_gamma: float = 2.0
    focal_alpha_clip_min: float = 0.5
    focal_alpha_clip_max: float = 3.0

    # light augmentation (train only)
    aug_time_shift: bool = True
    aug_add_noise: bool = True
    aug_spec_mask: bool = True

cfg = TrainConfig()
cfg


TrainConfig(sample_rate=22050, num_segments=5, segment_seconds=2.0, n_mels=128, n_fft=2048, hop_length=512, img_size=224, batch_size=8, num_workers=2, lstm_hidden=128, lstm_layers=1, bidirectional=True, dropout=0.5, epochs_stage1=10, epochs_stage2=8, lr_head=0.001, lr_backbone_last=0.0001, lr_backbone_more=1e-05, weight_decay=0.0001, grad_clip=1.0, focal_gamma=2.0, focal_alpha_clip_min=0.5, focal_alpha_clip_max=3.0, aug_time_shift=True, aug_add_noise=True, aug_spec_mask=True)

In [6]:
# Audio helpers (FIX: deterministic center crop for val/test)
def load_audio(path: str, sr: int) -> np.ndarray:
    y, _ = librosa.load(path, sr=sr, mono=True)
    if y.size == 0:
        return np.zeros(sr, dtype=np.float32)
    peak = np.max(np.abs(y))
    if peak > 0:
        y = y / peak
    return y.astype(np.float32)

def pad_or_crop(y: np.ndarray, target_len: int, random_crop: bool) -> np.ndarray:
    if len(y) == target_len:
        return y
    if len(y) < target_len:
        return np.pad(y, (0, target_len - len(y)), mode="constant")
    if random_crop:
        start = np.random.randint(0, len(y) - target_len + 1)
    else:
        start = (len(y) - target_len) // 2
    return y[start:start + target_len]

def split_segments(y: np.ndarray, sr: int, num_segments: int, seg_seconds: float, random_crop: bool) -> np.ndarray:
    seg_len = int(sr * seg_seconds)
    total_len = seg_len * num_segments
    y = pad_or_crop(y, total_len, random_crop=random_crop)
    return y.reshape(num_segments, seg_len)

def time_shift(y: np.ndarray, shift_max: float = 0.2) -> np.ndarray:
    shift = int(random.uniform(-shift_max, shift_max) * len(y))
    return np.roll(y, shift)

def add_noise(y: np.ndarray, noise_level: float = 0.01) -> np.ndarray:
    noise = np.random.randn(len(y)).astype(np.float32)
    return y + noise_level * noise


In [7]:
# Mel image + ImageNet normalization (important for DenseNet transfer)
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
IMAGENET_STD  = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)

def mel_image(segment: np.ndarray, sr: int, cfg: TrainConfig, spec_mask: bool) -> torch.Tensor:
    S = librosa.feature.melspectrogram(
        y=segment, sr=sr, n_mels=cfg.n_mels, n_fft=cfg.n_fft, hop_length=cfg.hop_length, power=2.0
    )
    S_db = librosa.power_to_db(S, ref=np.max)
    S_db = (S_db - S_db.min()) / (S_db.max() - S_db.min() + 1e-8)

    if spec_mask:
        n_mels = S_db.shape[0]
        t = S_db.shape[1]
        fm = random.randint(0, max(1, n_mels // 10))
        f0 = random.randint(0, max(0, n_mels - fm))
        if fm > 0:
            S_db[f0:f0+fm, :] = 0.0
        tm = random.randint(0, max(1, t // 12))
        t0 = random.randint(0, max(0, t - tm))
        if tm > 0:
            S_db[:, t0:t0+tm] = 0.0

    img = torch.from_numpy(S_db).unsqueeze(0).float()  # 1 x 128 x T
    img = F.interpolate(img.unsqueeze(0), size=(cfg.img_size, cfg.img_size), mode="bilinear", align_corners=False).squeeze(0)
    img3 = img.repeat(3, 1, 1)
    img3 = (img3 - IMAGENET_MEAN) / IMAGENET_STD
    return img3


In [8]:
class LungSoundDataset(Dataset):
    def __init__(self, df: pd.DataFrame, cfg: TrainConfig, label_to_id: Dict[str,int], train: bool):
        self.df = df.reset_index(drop=True)
        self.cfg = cfg
        self.label_to_id = label_to_id
        self.train = train

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        wav_path = row["filepath"]
        label = row["label"]

        y = load_audio(wav_path, self.cfg.sample_rate)

        if self.train:
            if self.cfg.aug_time_shift and random.random() < 0.5:
                y = time_shift(y)
            if self.cfg.aug_add_noise and random.random() < 0.5:
                y = add_noise(y, noise_level=random.uniform(0.002, 0.01))

        segs = split_segments(
            y, self.cfg.sample_rate, self.cfg.num_segments, self.cfg.segment_seconds,
            random_crop=self.train
        )

        imgs = []
        for s in segs:
            imgs.append(mel_image(
                s, sr=self.cfg.sample_rate, cfg=self.cfg,
                spec_mask=(self.train and self.cfg.aug_spec_mask and random.random() < 0.4)
            ))
        x = torch.stack(imgs, dim=0)  # (T, 3, 224, 224)
        return x, self.label_to_id[label]


## Focal Loss (with mild alpha weights)

We compute alpha weights from training label counts, but **clip** them to avoid extreme scaling
(e.g., Lung fibrosis is tiny).


In [9]:
def compute_alpha_from_counts(train_df: pd.DataFrame, labels: List[str], clip_min: float, clip_max: float) -> torch.Tensor:
    counts = train_df["label"].value_counts().to_dict()
    n = np.array([counts.get(lbl, 1) for lbl in labels], dtype=np.float32)

    # mild inverse-sqrt weighting (safer than 1/n)
    median = np.median(n)
    alpha = np.sqrt(median / n)

    alpha = np.clip(alpha, clip_min, clip_max)
    alpha = alpha / alpha.mean()  # normalize around 1
    return torch.tensor(alpha, dtype=torch.float32)

class FocalLoss(nn.Module):
    def __init__(self, gamma: float = 2.0, alpha: Optional[torch.Tensor] = None):
        super().__init__()
        self.gamma = gamma
        self.register_buffer("alpha", alpha if alpha is not None else torch.tensor([]))

    def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        ce = F.cross_entropy(logits, target, reduction="none")
        pt = torch.exp(-ce)  # prob of correct class
        loss = (1 - pt) ** self.gamma * ce

        if self.alpha.numel() > 0:
            a = self.alpha.gather(0, target)
            loss = a * loss

        return loss.mean()


In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
alpha = compute_alpha_from_counts(
    train_df, LABELS, cfg.focal_alpha_clip_min, cfg.focal_alpha_clip_max
).to(device)

print("Alpha weights:", {lbl: float(alpha[i].detach().cpu()) for i,lbl in enumerate(LABELS)})
criterion = FocalLoss(gamma=cfg.focal_gamma, alpha=alpha)


Alpha weights: {'Asthma': 0.7748332619667053, 'COPD': 0.43077588081359863, 'Heart failure': 1.036503791809082, 'Lung fibrosis': 2.1298129558563232, 'Normal': 0.6421627998352051, 'Pneumonia': 0.9859118461608887}


In [11]:
# DataLoaders (shuffle train; NO sampler)
train_loader = DataLoader(LungSoundDataset(train_df, cfg, label_to_id, train=True),
                          batch_size=cfg.batch_size, shuffle=True,
                          num_workers=cfg.num_workers, pin_memory=True)

val_loader = DataLoader(LungSoundDataset(val_df, cfg, label_to_id, train=False),
                        batch_size=cfg.batch_size, shuffle=False,
                        num_workers=cfg.num_workers, pin_memory=True)

test_loader = DataLoader(LungSoundDataset(test_df, cfg, label_to_id, train=False),
                         batch_size=cfg.batch_size, shuffle=False,
                         num_workers=cfg.num_workers, pin_memory=True)

xb, yb = next(iter(train_loader))
print("Batch:", xb.shape, yb.shape, "device:", device)


Batch: torch.Size([8, 5, 3, 224, 224]) torch.Size([8]) device: cuda


## Model (same as v2)


In [12]:
class DenseNetEncoder(nn.Module):
    def __init__(self, pretrained: bool = True):
        super().__init__()
        m = torchvision.models.densenet121(
            weights=torchvision.models.DenseNet121_Weights.IMAGENET1K_V1 if pretrained else None
        )
        self.features = m.features
        self.out_dim = 1024

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        f = self.features(x)
        f = F.relu(f, inplace=True)
        f = F.adaptive_avg_pool2d(f, (1, 1)).flatten(1)
        return f

class DenseNetBiLSTM(nn.Module):
    def __init__(self, num_classes: int, hidden: int, layers: int, bidirectional: bool, dropout: float):
        super().__init__()
        self.encoder = DenseNetEncoder(pretrained=True)
        self.lstm = nn.LSTM(
            input_size=self.encoder.out_dim,
            hidden_size=hidden,
            num_layers=layers,
            batch_first=True,
            bidirectional=bidirectional,
            dropout=0.0 if layers == 1 else dropout
        )
        out_dim = hidden * (2 if bidirectional else 1)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(out_dim, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C, H, W = x.shape
        x = x.view(B*T, C, H, W)
        feats = self.encoder(x).view(B, T, -1)
        out, _ = self.lstm(feats)
        last = self.dropout(out[:, -1, :])
        return self.fc(last)

model = DenseNetBiLSTM(len(LABELS), cfg.lstm_hidden, cfg.lstm_layers, cfg.bidirectional, cfg.dropout).to(device)
print("Params (M):", sum(p.numel() for p in model.parameters())/1e6)


Params (M): 8.137094


In [13]:
# Freeze/unfreeze schedule (same as v2)
def set_trainable_last_block(m: DenseNetBiLSTM):
    for p in m.encoder.parameters():
        p.requires_grad = False
    for name, p in m.encoder.features.named_parameters():
        if name.startswith("denseblock4") or name.startswith("norm5"):
            p.requires_grad = True
    for p in m.lstm.parameters():
        p.requires_grad = True
    for p in m.fc.parameters():
        p.requires_grad = True

def unfreeze_denseblock3(m: DenseNetBiLSTM):
    for name, p in m.encoder.features.named_parameters():
        if name.startswith("denseblock3") or name.startswith("transition3"):
            p.requires_grad = True

set_trainable_last_block(model)
print("Trainable params (M):", sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6)


Trainable params (M): 3.343366


In [14]:
def make_optimizer(m: DenseNetBiLSTM, lr_head: float, lr_backbone: float, wd: float):
    backbone_params, head_params = [], []
    for name, p in m.named_parameters():
        if not p.requires_grad:
            continue
        if name.startswith("encoder"):
            backbone_params.append(p)
        else:
            head_params.append(p)
    return torch.optim.AdamW(
        [{"params": head_params, "lr": lr_head},
         {"params": backbone_params, "lr": lr_backbone}],
        weight_decay=wd
    )

optimizer = make_optimizer(model, cfg.lr_head, cfg.lr_backbone_last, cfg.weight_decay)


In [15]:
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler(enabled=(device.type == "cuda"))

@torch.no_grad()
def evaluate(loader: DataLoader) -> Dict[str, Any]:
    model.eval()
    ys, ps, losses = [], [], []
    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        logits = model(x)
        loss = criterion(logits, y)
        losses.append(loss.item())
        pred = logits.argmax(1)
        ys.extend(y.detach().cpu().numpy().tolist())
        ps.extend(pred.detach().cpu().numpy().tolist())
    acc = accuracy_score(ys, ps)
    f1m = f1_score(ys, ps, average="macro")
    return {"loss": float(np.mean(losses)), "acc": float(acc), "f1m": float(f1m), "y": ys, "pred": ps}

def train_one_epoch(loader: DataLoader) -> float:
    model.train()
    losses = []
    for x, y in tqdm(loader, leave=False):
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        if device.type == "cuda":
            with autocast():
                logits = model(x)
                loss = criterion(logits, y)
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
            optimizer.step()

        losses.append(loss.item())

    return float(np.mean(losses))


  scaler = GradScaler(enabled=(device.type == "cuda"))


## Train (Stage 1 + Stage 2) — save best by Val Macro‑F1


In [16]:
best_path = MODELS_DIR / "best_model_focal.pth"
best_f1 = -1.0
history = []

def save_ckpt(path: Path):
    ckpt = {
        "state_dict": model.state_dict(),
        "labels": LABELS,
        "label_to_id": label_to_id,
        "id_to_label": id_to_label,
        "config": asdict(cfg),
        "loss": {"name": "focal", "gamma": cfg.focal_gamma,
                 "alpha": alpha.detach().cpu().numpy().tolist()},
    }
    torch.save(ckpt, path)

print("=== Stage 1: head+LSTM+last DenseNet block (Focal Loss) ===")
for epoch in range(1, cfg.epochs_stage1 + 1):
    tr_loss = train_one_epoch(train_loader)
    val_m = evaluate(val_loader)
    history.append({"stage": 1, "epoch": epoch, "train_loss": tr_loss, **val_m})
    print(f"Epoch {epoch}/{cfg.epochs_stage1} | tr_loss={tr_loss:.4f} | val_loss={val_m['loss']:.4f} | val_acc={val_m['acc']:.4f} | val_f1m={val_m['f1m']:.4f}")

    if val_m["f1m"] > best_f1:
        best_f1 = val_m["f1m"]
        save_ckpt(best_path)
        print("✅ Saved best:", best_path, "val_f1m=", best_f1)

print("Best Val Macro‑F1 after Stage 1:", best_f1)


=== Stage 1: head+LSTM+last DenseNet block (Focal Loss) ===


  0%|          | 0/92 [00:00<?, ?it/s]

  with autocast():
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    Exception ignored in: assert self._parent_pid == os.getpid(), 'can only test a child process'<function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>

 Traceback (most recent call last):
   File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
      self._shutdown_workers()  
   File "/home/zeus/miniconda3/envs/cl

Epoch 1/10 | tr_loss=0.5838 | val_loss=0.2974 | val_acc=0.7763 | val_f1m=0.2452
✅ Saved best: /teamspace/studios/this_studio/lung_sound_project/models/best_model_focal.pth val_f1m= 0.24521126419860595


  0%|          | 0/92 [00:00<?, ?it/s]

  with autocast():
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
        Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20> 
 ^Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^^    ^self._shutdown_workers()^
^  File "/home/zeus/minicond

Epoch 2/10 | tr_loss=0.5198 | val_loss=0.2695 | val_acc=0.7303 | val_f1m=0.1939


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20> 
 Traceback (most recent call last):
    File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
      self._shutdown_workers() 
^^  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^    ^^if w.is_alive():^
^ ^ ^  ^ ^ ^ 
^  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiprocessing/proce

Epoch 3/10 | tr_loss=0.4959 | val_loss=0.2707 | val_acc=0.7434 | val_f1m=0.2043


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
Exception ignored in:     <function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>self._shutdown_workers()

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    if w.is_alive():    
self._shutdown_workers() 
   File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
      if w.is_alive(): 
    ^ ^ ^ ^ ^ ^^^^^^^^^^^^^^^
^  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiprocessi

Epoch 4/10 | tr_loss=0.4787 | val_loss=0.2598 | val_acc=0.8092 | val_f1m=0.2243


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20><function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
        if w.is_alive():if w.is_alive():

              ^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiproc

Epoch 5/10 | tr_loss=0.4593 | val_loss=0.2514 | val_acc=0.8355 | val_f1m=0.2408


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20><function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
        if w.is_alive():if w.is_alive():

              ^^^^^^^^^^^^^^^^^^^^^^^
^  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiproce

Epoch 6/10 | tr_loss=0.4463 | val_loss=0.2585 | val_acc=0.7697 | val_f1m=0.2166


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20><function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
        if w.is_alive():if w.is_alive():

              ^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiproc

Epoch 7/10 | tr_loss=0.4316 | val_loss=0.2615 | val_acc=0.7697 | val_f1m=0.2673
✅ Saved best: /teamspace/studios/this_studio/lung_sound_project/models/best_model_focal.pth val_f1m= 0.26731812696724977


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20><function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>

Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
        if w.is_alive():if w.is_alive():

             ^ ^^^^Exception ignored in: ^^Exception ignored in: <function _MultiProcessingDataLoaderIter.__de

Epoch 8/10 | tr_loss=0.4381 | val_loss=0.2560 | val_acc=0.8158 | val_f1m=0.2900
✅ Saved best: /teamspace/studios/this_studio/lung_sound_project/models/best_model_focal.pth val_f1m= 0.29001937984496123


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^Exception ignored in: ^<function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>^
^Traceback (most recent call last):
^  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^    ^^self._shutdown_workers()
^  File "/home/zeus/min

Epoch 9/10 | tr_loss=0.4193 | val_loss=0.2348 | val_acc=0.8421 | val_f1m=0.3396
✅ Saved best: /teamspace/studios/this_studio/lung_sound_project/models/best_model_focal.pth val_f1m= 0.3395510264362724


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20><function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():    
if w.is_alive(): 
            ^ ^^^^^^^^^^^^^^^^^^^^
^  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiprocess

Epoch 10/10 | tr_loss=0.3971 | val_loss=0.2436 | val_acc=0.8618 | val_f1m=0.4482
✅ Saved best: /teamspace/studios/this_studio/lung_sound_project/models/best_model_focal.pth val_f1m= 0.4482051326754768
Best Val Macro‑F1 after Stage 1: 0.4482051326754768


In [17]:
print("\n=== Stage 2: unfreeze denseblock3 + fine‑tune ===")
unfreeze_denseblock3(model)
optimizer = make_optimizer(model, lr_head=cfg.lr_head*0.5, lr_backbone=cfg.lr_backbone_more, wd=cfg.weight_decay)

for epoch in range(1, cfg.epochs_stage2 + 1):
    tr_loss = train_one_epoch(train_loader)
    val_m = evaluate(val_loader)
    history.append({"stage": 2, "epoch": epoch, "train_loss": tr_loss, **val_m})
    print(f"FT Epoch {epoch}/{cfg.epochs_stage2} | tr_loss={tr_loss:.4f} | val_loss={val_m['loss']:.4f} | val_acc={val_m['acc']:.4f} | val_f1m={val_m['f1m']:.4f}")

    if val_m["f1m"] > best_f1:
        best_f1 = val_m["f1m"]
        save_ckpt(best_path)
        print("✅ Saved best:", best_path, "val_f1m=", best_f1)

print("Best Val Macro‑F1 final:", best_f1)



=== Stage 2: unfreeze denseblock3 + fine‑tune ===


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20><function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
        if w.is_alive():if w.is_alive():

             ^ ^Exception ignored in: ^^^<function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>^


FT Epoch 1/8 | tr_loss=0.4008 | val_loss=0.2493 | val_acc=0.8224 | val_f1m=0.3819


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
Exception ignored in:     self._shutdown_workers()<function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
Traceback (most recent call last):
      File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
if w.is_alive():    
self._shutdown_workers() 
   File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
      if w.is_alive(): 
    ^ ^ ^ ^ ^ ^^^^^^Exception ignored in: ^<function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f8

FT Epoch 2/8 | tr_loss=0.3776 | val_loss=0.2425 | val_acc=0.8289 | val_f1m=0.3141


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
Exception ignored in:     <function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>
self._shutdown_workers()Traceback (most recent call last):

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    self._shutdown_workers()    
if w.is_alive():  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers

    if w.is_alive(): 
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiprocess

FT Epoch 3/8 | tr_loss=0.3712 | val_loss=0.2409 | val_acc=0.8224 | val_f1m=0.3473


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20><function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
        if w.is_alive():if w.is_alive():

              ^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiproc

FT Epoch 4/8 | tr_loss=0.3904 | val_loss=0.2402 | val_acc=0.8158 | val_f1m=0.2701


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20><function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
        if w.is_alive():if w.is_alive():

            ^ ^ ^^^^^^^^^^^^^^^^^^^^
^  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiproces

FT Epoch 5/8 | tr_loss=0.3875 | val_loss=0.2431 | val_acc=0.8289 | val_f1m=0.3409


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>    
self._shutdown_workers()Traceback (most recent call last):

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
        self._shutdown_workers()if w.is_alive(): 

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
      if w.is_alive(): 
        ^^ ^ ^^^^^^^^^^^^^^^^^^^^^

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiproc

FT Epoch 6/8 | tr_loss=0.3766 | val_loss=0.2490 | val_acc=0.8092 | val_f1m=0.3227


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
Exception ignored in:     <function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>self._shutdown_workers()

Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        if w.is_alive():self._shutdown_workers()

   File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
      if w.is_alive(): 
      ^ ^^ ^ ^^ ^^^^^^^^^^^^
^  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiprocessing

FT Epoch 7/8 | tr_loss=0.3646 | val_loss=0.2428 | val_acc=0.8026 | val_f1m=0.3501


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
Exception ignored in:     <function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>self._shutdown_workers()

Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    if w.is_alive():    
self._shutdown_workers()
   File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
      if w.is_alive(): 
      ^ ^ ^ ^^ ^^^^^^^^^^^^^^
^  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiprocessin

FT Epoch 8/8 | tr_loss=0.3540 | val_loss=0.2376 | val_acc=0.8487 | val_f1m=0.4254
Best Val Macro‑F1 final: 0.4482051326754768


## Test evaluation (load best focal checkpoint)


In [18]:
assert best_path.exists(), "best_model_focal.pth not found"
ckpt2 = torch.load(best_path, map_location=device)
model.load_state_dict(ckpt2["state_dict"])

test_m = evaluate(test_loader)
print("TEST loss:", test_m["loss"])
print("TEST acc :", test_m["acc"])
print("TEST f1m :", test_m["f1m"])

y_true, y_pred = test_m["y"], test_m["pred"]
print("\nClassification report:\n")
print(classification_report(y_true, y_pred, target_names=LABELS, digits=4))
print("Confusion matrix (rows=true, cols=pred):\n", confusion_matrix(y_true, y_pred))


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x787fc0f86a20>self._shutdown_workers()

Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    if w.is_alive():    
self._shutdown_workers()
   File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
       if w.is_alive():
      ^ ^  ^^ ^^^^^^^^^^^^^^^^^
^  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiprocess

TEST loss: 0.23137015105867004
TEST acc : 0.8109965635738832
TEST f1m : 0.3717223311902946

Classification report:

               precision    recall  f1-score   support

       Asthma     0.3158    0.3158    0.3158        19
         COPD     0.9464    0.9464    0.9464       224
Heart failure     0.2500    0.0909    0.1333        11
Lung fibrosis     0.0000    0.0000    0.0000         2
       Normal     0.3571    0.5556    0.4348        27
    Pneumonia     1.0000    0.2500    0.4000         8

     accuracy                         0.8110       291
    macro avg     0.4782    0.3598    0.3717       291
 weighted avg     0.8192    0.8110    0.8055       291

Confusion matrix (rows=true, cols=pred):
 [[  6   3   2   0   8   0]
 [  4 212   0   0   8   0]
 [  2   0   1   0   8   0]
 [  1   0   0   0   1   0]
 [  6   6   0   0  15   0]
 [  0   3   1   0   2   2]]


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


## Save config_focal.json (for Gradio inference)


In [19]:
config_out = {
    "labels": LABELS,
    "audio": {"sample_rate": cfg.sample_rate, "num_segments": cfg.num_segments, "segment_seconds": cfg.segment_seconds},
    "mel": {"n_mels": cfg.n_mels, "n_fft": cfg.n_fft, "hop_length": cfg.hop_length, "img_size": cfg.img_size},
    "model": {"backbone": "densenet121", "lstm_hidden": cfg.lstm_hidden, "lstm_layers": cfg.lstm_layers,
              "bidirectional": cfg.bidirectional, "dropout": cfg.dropout},
    "loss": {"name": "focal", "gamma": cfg.focal_gamma, "alpha": alpha.detach().cpu().numpy().tolist()}
}
(MODELS_DIR / "config_focal.json").write_text(json.dumps(config_out, indent=2), encoding="utf-8")
print("Saved:", MODELS_DIR / "config_focal.json")


Saved: /teamspace/studios/this_studio/lung_sound_project/models/config_focal.json


### What to compare with v2
Report both models:
- v2 (class-balanced CE): `best_model.pth`
- focal (this notebook): `best_model_focal.pth`

Compare:
- Val Macro‑F1 (best)
- Test Macro‑F1
- Heart failure & Normal recall

Paste your focal test metrics after running — if Macro‑F1 improves, we'll keep focal as final model.
