# 02 — Train Model A (Mel-only) — DenseNet121 + BiLSTM (PyTorch) — **v2 FIXED**

Your last run produced ~3% accuracy and collapsed predictions. This v2 notebook fixes the **most common causes**:

## ✅ Fixes in v2
1) **No random cropping in Val/Test** (was causing unstable evaluation and poor learning signals).  
   - Train: random crop  
   - Val/Test: center crop (deterministic)

2) **Do NOT use both WeightedRandomSampler + Class Weights together** (over-corrects imbalance and can collapse).  
   - v2 default: **class-balanced loss only** (Effective Number method), **no sampler**.

3) **ImageNet normalization** before DenseNet (important for transfer learning).

4) **Better fine-tuning strategy**  
   - Train head + LSTM + **last DenseNet block** from the start  
   - Then optionally unfreeze more

5) **Macro-F1 best checkpoint** (best model saved by validation Macro-F1)

Outputs:
- `models/best_model.pth`
- `models/config.json`
- `models/label_to_id.json`, `models/id_to_label.json` (based on available classes)

> NOTE: With your current data you have 6 classes (Bronchitis & Pleural effusion are missing).


In [1]:
# Cell 1 — (Optional) install dependencies if needed
# !pip -q install librosa soundfile scikit-learn tqdm matplotlib


In [2]:
# Cell 2 — Imports & seed
from __future__ import annotations

import json, math, random
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Dict, List, Tuple, Any

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from tqdm.auto import tqdm
import librosa

from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix


In [3]:
# Cell 3 — Paths
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

PROJECT_ROOT = Path.cwd().resolve().parents[0]  # .../lung_sound_project
MANIFEST_DIR = PROJECT_ROOT / "data/processed/manifests"
MODELS_DIR = PROJECT_ROOT / "models"
MODELS_DIR.mkdir(parents=True, exist_ok=True)

train_df = pd.read_csv(MANIFEST_DIR / "train.csv")
val_df   = pd.read_csv(MANIFEST_DIR / "val.csv")
test_df  = pd.read_csv(MANIFEST_DIR / "test.csv")

print("Train:", train_df.shape, "Val:", val_df.shape, "Test:", test_df.shape)
print("Train labels:\n", train_df["label"].value_counts())


Train: (732, 4) Val: (152, 4) Test: (291, 4)
Train labels:
 label
COPD             476
Normal            99
Asthma            68
Pneumonia         42
Heart failure     38
Lung fibrosis      9
Name: count, dtype: int64


## 1) Labels present in your data

We train only on labels that exist (currently 6).  
This prevents training for classes with 0 samples.


In [4]:
# Cell 4 — Build label maps from available labels
all_df = pd.concat([train_df, val_df, test_df], ignore_index=True)
LABELS = sorted(all_df["label"].unique().tolist())
label_to_id = {lbl: i for i, lbl in enumerate(LABELS)}
id_to_label = {i: lbl for lbl, i in label_to_id.items()}

print("Available labels:", LABELS, "=> num_classes:", len(LABELS))

(MODELS_DIR / "label_to_id.json").write_text(json.dumps(label_to_id, indent=2), encoding="utf-8")
(MODELS_DIR / "id_to_label.json").write_text(json.dumps(id_to_label, indent=2), encoding="utf-8")
print("Saved label maps to:", MODELS_DIR)


Available labels: ['Asthma', 'COPD', 'Heart failure', 'Lung fibrosis', 'Normal', 'Pneumonia'] => num_classes: 6
Saved label maps to: /teamspace/studios/this_studio/lung_sound_project/models


## 2) Config (tweak here)

- `num_segments * segment_seconds` = total window length per file  
- If your files are short, keep total <= 10s  
- If learning is hard, try `num_segments=6` and `segment_seconds=1.5`


In [5]:
# Cell 5 — Config

@dataclass
class TrainConfig:
    sample_rate: int = 22050
    num_segments: int = 5
    segment_seconds: float = 2.0

    n_mels: int = 128
    n_fft: int = 2048
    hop_length: int = 512
    img_size: int = 224

    batch_size: int = 8
    num_workers: int = 2

    lstm_hidden: int = 128
    lstm_layers: int = 1
    bidirectional: bool = True
    dropout: float = 0.5

    epochs_stage1: int = 10
    epochs_stage2: int = 8

    lr_head: float = 1e-3
    lr_backbone_last: float = 1e-4
    lr_backbone_more: float = 1e-5

    weight_decay: float = 1e-4
    grad_clip: float = 1.0

    loss_beta: float = 0.999  # effective-number

    aug_time_shift: bool = True
    aug_add_noise: bool = True
    aug_spec_mask: bool = True

cfg = TrainConfig()
cfg


TrainConfig(sample_rate=22050, num_segments=5, segment_seconds=2.0, n_mels=128, n_fft=2048, hop_length=512, img_size=224, batch_size=8, num_workers=2, lstm_hidden=128, lstm_layers=1, bidirectional=True, dropout=0.5, epochs_stage1=10, epochs_stage2=8, lr_head=0.001, lr_backbone_last=0.0001, lr_backbone_more=1e-05, weight_decay=0.0001, grad_clip=1.0, loss_beta=0.999, aug_time_shift=True, aug_add_noise=True, aug_spec_mask=True)

In [6]:
# Cell 6 — Audio helpers (FIX: deterministic crop for val/test)

def load_audio(path: str, sr: int) -> np.ndarray:
    y, _ = librosa.load(path, sr=sr, mono=True)
    if y.size == 0:
        return np.zeros(sr, dtype=np.float32)
    peak = np.max(np.abs(y))
    if peak > 0:
        y = y / peak
    return y.astype(np.float32)

def pad_or_crop(y: np.ndarray, target_len: int, random_crop: bool) -> np.ndarray:
    if len(y) == target_len:
        return y
    if len(y) < target_len:
        return np.pad(y, (0, target_len - len(y)), mode="constant")
    if random_crop:
        start = np.random.randint(0, len(y) - target_len + 1)
    else:
        start = (len(y) - target_len) // 2
    return y[start:start + target_len]

def split_segments(y: np.ndarray, sr: int, num_segments: int, seg_seconds: float, random_crop: bool) -> np.ndarray:
    seg_len = int(sr * seg_seconds)
    total_len = seg_len * num_segments
    y = pad_or_crop(y, total_len, random_crop=random_crop)
    return y.reshape(num_segments, seg_len)

def time_shift(y: np.ndarray, shift_max: float = 0.2) -> np.ndarray:
    shift = int(random.uniform(-shift_max, shift_max) * len(y))
    return np.roll(y, shift)

def add_noise(y: np.ndarray, noise_level: float = 0.01) -> np.ndarray:
    noise = np.random.randn(len(y)).astype(np.float32)
    return y + noise_level * noise


In [7]:
# Cell 7 — Mel image + ImageNet normalization (FIX)

IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
IMAGENET_STD  = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)

def mel_image(segment: np.ndarray, sr: int, cfg: TrainConfig, spec_mask: bool) -> torch.Tensor:
    S = librosa.feature.melspectrogram(
        y=segment, sr=sr, n_mels=cfg.n_mels, n_fft=cfg.n_fft, hop_length=cfg.hop_length, power=2.0
    )
    S_db = librosa.power_to_db(S, ref=np.max)
    S_db = (S_db - S_db.min()) / (S_db.max() - S_db.min() + 1e-8)

    if spec_mask:
        n_mels = S_db.shape[0]
        t = S_db.shape[1]
        fm = random.randint(0, max(1, n_mels // 10))
        f0 = random.randint(0, max(0, n_mels - fm))
        if fm > 0:
            S_db[f0:f0+fm, :] = 0.0
        tm = random.randint(0, max(1, t // 12))
        t0 = random.randint(0, max(0, t - tm))
        if tm > 0:
            S_db[:, t0:t0+tm] = 0.0

    img = torch.from_numpy(S_db).unsqueeze(0).float()
    img = F.interpolate(img.unsqueeze(0), size=(cfg.img_size, cfg.img_size), mode="bilinear", align_corners=False).squeeze(0)
    img3 = img.repeat(3, 1, 1)
    img3 = (img3 - IMAGENET_MEAN) / IMAGENET_STD
    return img3


In [8]:
# Cell 8 — Dataset

class LungSoundDataset(Dataset):
    def __init__(self, df: pd.DataFrame, cfg: TrainConfig, label_to_id: Dict[str,int], train: bool):
        self.df = df.reset_index(drop=True)
        self.cfg = cfg
        self.label_to_id = label_to_id
        self.train = train

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        row = self.df.iloc[idx]
        wav_path = row["filepath"]
        label = row["label"]

        y = load_audio(wav_path, self.cfg.sample_rate)

        if self.train:
            if self.cfg.aug_time_shift and random.random() < 0.5:
                y = time_shift(y)
            if self.cfg.aug_add_noise and random.random() < 0.5:
                y = add_noise(y, noise_level=random.uniform(0.002, 0.01))

        segs = split_segments(
            y, self.cfg.sample_rate, self.cfg.num_segments, self.cfg.segment_seconds,
            random_crop=self.train
        )

        imgs = []
        for s in segs:
            imgs.append(mel_image(
                s, sr=self.cfg.sample_rate, cfg=self.cfg,
                spec_mask=(self.train and self.cfg.aug_spec_mask and random.random() < 0.4)
            ))
        x = torch.stack(imgs, dim=0)
        return x, self.label_to_id[label]


## 3) Class-balanced loss (Effective Number of Samples)

This avoids over-oversampling tiny classes (like Lung fibrosis=12 total).


In [9]:
# Cell 9 — Class-balanced weights

def class_balanced_weights(df: pd.DataFrame, labels: List[str], beta: float = 0.999) -> torch.Tensor:
    counts = df["label"].value_counts().to_dict()
    n = np.array([counts.get(lbl, 0) for lbl in labels], dtype=np.float32)
    eff = (1.0 - np.power(beta, n)) / (1.0 - beta + 1e-12)
    w = 1.0 / (eff + 1e-12)
    w = w / w.sum() * len(labels)
    return torch.tensor(w, dtype=torch.float32)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cb_w = class_balanced_weights(train_df, LABELS, beta=cfg.loss_beta).to(device)
print("Class-balanced weights:", {lbl: float(cb_w[i].detach().cpu()) for i,lbl in enumerate(LABELS)})

criterion = nn.CrossEntropyLoss(weight=cb_w)


Class-balanced weights: {'Asthma': 0.47734230756759644, 'COPD': 0.08286261558532715, 'Heart failure': 0.8415800929069519, 'Lung fibrosis': 3.502346992492676, 'Normal': 0.3329238295555115, 'Pneumonia': 0.762944221496582}


In [10]:
# Cell 10 — DataLoaders
train_ds = LungSoundDataset(train_df, cfg, label_to_id, train=True)
val_ds   = LungSoundDataset(val_df, cfg, label_to_id, train=False)
test_ds  = LungSoundDataset(test_df, cfg, label_to_id, train=False)

train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True)
test_loader  = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True)

xb, yb = next(iter(train_loader))
print("Batch:", xb.shape, yb.shape, "device:", device)


Batch: torch.Size([8, 5, 3, 224, 224]) torch.Size([8]) device: cuda


## 4) Model: DenseNet121 encoder → BiLSTM → classifier


In [11]:
# Cell 11 — Model

class DenseNetEncoder(nn.Module):
    def __init__(self, pretrained: bool = True):
        super().__init__()
        m = torchvision.models.densenet121(
            weights=torchvision.models.DenseNet121_Weights.IMAGENET1K_V1 if pretrained else None
        )
        self.features = m.features
        self.out_dim = 1024

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        f = self.features(x)
        f = F.relu(f, inplace=True)
        f = F.adaptive_avg_pool2d(f, (1, 1)).flatten(1)
        return f

class DenseNetBiLSTM(nn.Module):
    def __init__(self, num_classes: int, hidden: int, layers: int, bidirectional: bool, dropout: float):
        super().__init__()
        self.encoder = DenseNetEncoder(pretrained=True)
        self.lstm = nn.LSTM(
            input_size=self.encoder.out_dim,
            hidden_size=hidden,
            num_layers=layers,
            batch_first=True,
            bidirectional=bidirectional,
            dropout=0.0 if layers == 1 else dropout
        )
        out_dim = hidden * (2 if bidirectional else 1)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(out_dim, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C, H, W = x.shape
        x = x.view(B*T, C, H, W)
        feats = self.encoder(x).view(B, T, -1)
        out, _ = self.lstm(feats)
        last = self.dropout(out[:, -1, :])
        return self.fc(last)

model = DenseNetBiLSTM(len(LABELS), cfg.lstm_hidden, cfg.lstm_layers, cfg.bidirectional, cfg.dropout).to(device)
print("Params (M):", sum(p.numel() for p in model.parameters())/1e6)


Params (M): 8.137094


In [12]:
# Cell 12 — Freeze strategy (train last block from start)

def set_trainable_last_block(m: DenseNetBiLSTM):
    for p in m.encoder.parameters():
        p.requires_grad = False
    for name, p in m.encoder.features.named_parameters():
        if name.startswith("denseblock4") or name.startswith("norm5"):
            p.requires_grad = True
    for p in m.lstm.parameters():
        p.requires_grad = True
    for p in m.fc.parameters():
        p.requires_grad = True

def unfreeze_denseblock3(m: DenseNetBiLSTM):
    for name, p in m.encoder.features.named_parameters():
        if name.startswith("denseblock3") or name.startswith("transition3"):
            p.requires_grad = True

set_trainable_last_block(model)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print("Trainable params:", trainable/1e6, "M / Total:", total/1e6, "M")


Trainable params: 3.343366 M / Total: 8.137094 M


In [13]:
# Cell 13 — Optimizer param groups

def make_optimizer(m: DenseNetBiLSTM, lr_head: float, lr_backbone: float, wd: float):
    backbone_params, head_params = [], []
    for name, p in m.named_parameters():
        if not p.requires_grad:
            continue
        if name.startswith("encoder"):
            backbone_params.append(p)
        else:
            head_params.append(p)
    return torch.optim.AdamW(
        [{"params": head_params, "lr": lr_head},
         {"params": backbone_params, "lr": lr_backbone}],
        weight_decay=wd
    )

optimizer = make_optimizer(model, cfg.lr_head, cfg.lr_backbone_last, cfg.weight_decay)


In [14]:
# Cell 14 — Train/eval loops
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler(enabled=(device.type == "cuda"))

@torch.no_grad()
def evaluate(loader: DataLoader) -> Dict[str, Any]:
    model.eval()
    ys, ps, losses = [], [], []
    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        logits = model(x)
        loss = criterion(logits, y)
        losses.append(loss.item())
        pred = logits.argmax(1)
        ys.extend(y.detach().cpu().numpy().tolist())
        ps.extend(pred.detach().cpu().numpy().tolist())
    acc = accuracy_score(ys, ps)
    f1m = f1_score(ys, ps, average="macro")
    return {"loss": float(np.mean(losses)), "acc": float(acc), "f1m": float(f1m), "y": ys, "pred": ps}

def train_one_epoch(loader: DataLoader) -> float:
    model.train()
    losses = []
    for x, y in tqdm(loader, leave=False):
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)

        if device.type == "cuda":
            with autocast():
                logits = model(x)
                loss = criterion(logits, y)
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
            optimizer.step()

        losses.append(loss.item())
    return float(np.mean(losses))


  scaler = GradScaler(enabled=(device.type == "cuda"))


In [15]:
# Cell 15 — Stage 1 training (save best by Val Macro-F1)

best_path = MODELS_DIR / "best_model.pth"
best_f1 = -1.0
history = []

def save_ckpt(path: Path):
    ckpt = {
        "state_dict": model.state_dict(),
        "labels": LABELS,
        "label_to_id": label_to_id,
        "id_to_label": id_to_label,
        "config": asdict(cfg),
    }
    torch.save(ckpt, path)

print("=== Stage 1: head+LSTM+last DenseNet block ===")
for epoch in range(1, cfg.epochs_stage1 + 1):
    tr_loss = train_one_epoch(train_loader)
    val_m = evaluate(val_loader)
    history.append({"stage": 1, "epoch": epoch, "train_loss": tr_loss, **val_m})
    print(f"Epoch {epoch}/{cfg.epochs_stage1} | tr_loss={tr_loss:.4f} | val_loss={val_m['loss']:.4f} | val_acc={val_m['acc']:.4f} | val_f1m={val_m['f1m']:.4f}")

    if val_m["f1m"] > best_f1:
        best_f1 = val_m["f1m"]
        save_ckpt(best_path)
        print("✅ Saved best:", best_path, "val_f1m=", best_f1)

print("Best Val Macro-F1 after Stage 1:", best_f1)


=== Stage 1: head+LSTM+last DenseNet block ===


  0%|          | 0/92 [00:00<?, ?it/s]

  with autocast():
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
   Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980> 
 Traceback (most recent call last):
   File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
       self._shutdown_workers() ^
^  File "/home/zeus/miniconda3/e

Epoch 1/10 | tr_loss=1.7833 | val_loss=1.2417 | val_acc=0.7237 | val_f1m=0.2389
✅ Saved best: /teamspace/studios/this_studio/lung_sound_project/models/best_model.pth val_f1m= 0.2388727858293076


  0%|          | 0/92 [00:00<?, ?it/s]

  with autocast():
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^Exception ignored in: ^<function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>^
^Traceback (most recent call last):
^^  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^    ^self._shutdown_workers()^
^  File "/home/zeus/min

Epoch 2/10 | tr_loss=1.6443 | val_loss=1.0826 | val_acc=0.7303 | val_f1m=0.2605
✅ Saved best: /teamspace/studios/this_studio/lung_sound_project/models/best_model.pth val_f1m= 0.26048206278026903


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
AssertionErrorException ignored in: : <function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>can only test a child process

Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
  

Epoch 3/10 | tr_loss=1.5910 | val_loss=0.9827 | val_acc=0.7237 | val_f1m=0.2028


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980><function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
        if w.is_alive():if w.is_alive():

              ^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiproc

Epoch 4/10 | tr_loss=1.5152 | val_loss=0.8235 | val_acc=0.7500 | val_f1m=0.2447


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980><function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():    
if w.is_alive(): 
           ^ ^ ^^^^^^^^^^^^^^^^^^^
^  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiprocess

Epoch 5/10 | tr_loss=1.5277 | val_loss=0.7443 | val_acc=0.8092 | val_f1m=0.2251


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>self._shutdown_workers()

Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        if w.is_alive():self._shutdown_workers()

   File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
       if w.is_alive(): 
    ^  ^ ^ ^ ^^^^^^^^^^^^^^^
^  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiprocessin

Epoch 6/10 | tr_loss=1.4362 | val_loss=0.8272 | val_acc=0.7368 | val_f1m=0.2323


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980><function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
        if w.is_alive():if w.is_alive():

              ^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiproc

Epoch 7/10 | tr_loss=1.4155 | val_loss=0.9530 | val_acc=0.6776 | val_f1m=0.2786
✅ Saved best: /teamspace/studios/this_studio/lung_sound_project/models/best_model.pth val_f1m= 0.2786367966251616


  0%|          | 0/92 [00:00<?, ?it/s]

  with autocast():
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^Exception ignored in: ^^<function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>^
^Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^^    ^self._shutdown_workers()^
^  File "/hom

Epoch 8/10 | tr_loss=1.4145 | val_loss=0.8281 | val_acc=0.7829 | val_f1m=0.2260


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
  Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980> 
 Traceback (most recent call last):
   File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
     ^self._shutdown_workers()^
^  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^^    ^if w.is_alive():^
 ^ ^^  ^^ 
    File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiprocessing/proces

Epoch 9/10 | tr_loss=1.3949 | val_loss=0.7818 | val_acc=0.7961 | val_f1m=0.3938
✅ Saved best: /teamspace/studios/this_studio/lung_sound_project/models/best_model.pth val_f1m= 0.393806485911749


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980><function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
        if w.is_alive():if w.is_alive():

             ^ ^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiproc

Epoch 10/10 | tr_loss=1.3648 | val_loss=0.7531 | val_acc=0.7895 | val_f1m=0.3989
✅ Saved best: /teamspace/studios/this_studio/lung_sound_project/models/best_model.pth val_f1m= 0.39887366671507635
Best Val Macro-F1 after Stage 1: 0.39887366671507635


In [16]:
# Cell 16 — Stage 2: unfreeze denseblock3 and fine-tune

print("\n=== Stage 2: unfreeze denseblock3 + fine-tune ===")
unfreeze_denseblock3(model)
optimizer = make_optimizer(model, lr_head=cfg.lr_head*0.5, lr_backbone=cfg.lr_backbone_more, wd=cfg.weight_decay)

for epoch in range(1, cfg.epochs_stage2 + 1):
    tr_loss = train_one_epoch(train_loader)
    val_m = evaluate(val_loader)
    history.append({"stage": 2, "epoch": epoch, "train_loss": tr_loss, **val_m})
    print(f"FT Epoch {epoch}/{cfg.epochs_stage2} | tr_loss={tr_loss:.4f} | val_loss={val_m['loss']:.4f} | val_acc={val_m['acc']:.4f} | val_f1m={val_m['f1m']:.4f}")

    if val_m["f1m"] > best_f1:
        best_f1 = val_m["f1m"]
        save_ckpt(best_path)
        print("✅ Saved best:", best_path, "val_f1m=", best_f1)

print("Best Val Macro-F1 final:", best_f1)



=== Stage 2: unfreeze denseblock3 + fine-tune ===


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980><function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
        if w.is_alive():if w.is_alive():

              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiproce

FT Epoch 1/8 | tr_loss=1.3083 | val_loss=0.7524 | val_acc=0.7434 | val_f1m=0.2809


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
Exception ignored in:   File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
<function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>    
if w.is_alive():Traceback (most recent call last):

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
       self._shutdown_workers() 
   File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
      ^if w.is_alive():^
^^  ^ ^ ^ ^  ^^^^^^^
^  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiprocessing/pr

FT Epoch 2/8 | tr_loss=1.3327 | val_loss=0.7489 | val_acc=0.7632 | val_f1m=0.3904


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>
<function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>Traceback (most recent call last):

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
Traceback (most recent call last):
      File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
self._shutdown_workers()
      File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
self._shutdown_workers()    
if w.is_alive():  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers

      if w.is_alive(): 
        ^^ ^ ^ ^^^^^^^^^^^^^^^
^  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiprocessin

FT Epoch 3/8 | tr_loss=1.2813 | val_loss=0.6927 | val_acc=0.7829 | val_f1m=0.3547


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    AssertionErrorself._shutdown_workers(): 
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
can only test a child process
    if w.is_alive():Exc

FT Epoch 4/8 | tr_loss=1.2936 | val_loss=0.6757 | val_acc=0.7961 | val_f1m=0.3923


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>

Traceback (most recent call last):
Exception ignored in:   File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>    Traceback (most recent call last):
self._shutdown_workers()  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    self._shutdown_workers()    
if w.is_alive():  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers

      if w.is_alive(): 
        ^ ^  ^^^^^^^^^^^^^^^^^^^
^  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiprocess

FT Epoch 5/8 | tr_loss=1.2823 | val_loss=0.6532 | val_acc=0.8026 | val_f1m=0.3539


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980><function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        self._shutdown_workers()
self._shutdown_workers()  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():    
if w.is_alive(): 
            ^ ^^^^^^^^^^^^^^^^^^^^Exception ignored in: ^^<function _MultiProcessingDataLoaderIter.__del__ at

FT Epoch 6/8 | tr_loss=1.2778 | val_loss=0.6883 | val_acc=0.7961 | val_f1m=0.3219


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>
Exception ignored in: Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__

    Traceback (most recent call last):
self._shutdown_workers()  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    self._shutdown_workers()    
if w.is_alive():  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers

     if w.is_alive():  
         ^ ^ ^^^^^^^^^^^^^^^^^^^^
^  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiproces

FT Epoch 7/8 | tr_loss=1.2769 | val_loss=0.6747 | val_acc=0.7961 | val_f1m=0.3681


  0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980><function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():    
if w.is_alive(): 
            ^ ^^^^^^^^^^^^^^^^^^^^^^
^  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiproce

FT Epoch 8/8 | tr_loss=1.2794 | val_loss=0.6830 | val_acc=0.8092 | val_f1m=0.3641
Best Val Macro-F1 final: 0.39887366671507635


## 5) Test evaluation (load best checkpoint)


In [17]:
# Cell 17 — Evaluate on test

assert best_path.exists(), "best_model.pth not found"
ckpt = torch.load(best_path, map_location=device)
model.load_state_dict(ckpt["state_dict"])

test_m = evaluate(test_loader)
print("TEST loss:", test_m["loss"])
print("TEST acc :", test_m["acc"])
print("TEST f1m :", test_m["f1m"])

y_true, y_pred = test_m["y"], test_m["pred"]
print("\nClassification report:\n")
print(classification_report(y_true, y_pred, target_names=LABELS, digits=4))

cm = confusion_matrix(y_true, y_pred)
print("Confusion matrix (rows=true, cols=pred):\n", cm)


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7ceea819e980>Traceback (most recent call last):

  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()    
self._shutdown_workers()  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
if w.is_alive():
    if w.is_alive(): 
          ^ ^ ^ ^^^^^^^^^^^^^^^^^^^
^  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/multiproces

TEST loss: 0.6022332275940759
TEST acc : 0.7560137457044673
TEST f1m : 0.3412692205109267

Classification report:

               precision    recall  f1-score   support

       Asthma     0.2045    0.4737    0.2857        19
         COPD     0.9949    0.8795    0.9336       224
Heart failure     0.0714    0.0909    0.0800        11
Lung fibrosis     0.0000    0.0000    0.0000         2
       Normal     0.4706    0.2963    0.3636        27
    Pneumonia     0.2778    0.6250    0.3846         8

     accuracy                         0.7560       291
    macro avg     0.3365    0.3942    0.3413       291
 weighted avg     0.8332    0.7560    0.7847       291

Confusion matrix (rows=true, cols=pred):
 [[  9   0   7   0   3   0]
 [  9 197   3   0   3  12]
 [  8   0   1   0   2   0]
 [  2   0   0   0   0   0]
 [ 15   1   2   0   8   1]
 [  1   0   1   0   1   5]]


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


## 6) Save config.json (for Gradio inference)


In [18]:
# Cell 18 — Save config.json

config_out = {
    "labels": LABELS,
    "audio": {"sample_rate": cfg.sample_rate, "num_segments": cfg.num_segments, "segment_seconds": cfg.segment_seconds},
    "mel": {"n_mels": cfg.n_mels, "n_fft": cfg.n_fft, "hop_length": cfg.hop_length, "img_size": cfg.img_size},
    "model": {"backbone": "densenet121", "lstm_hidden": cfg.lstm_hidden, "lstm_layers": cfg.lstm_layers,
              "bidirectional": cfg.bidirectional, "dropout": cfg.dropout},
    "loss": {"class_balanced_beta": cfg.loss_beta}
}
(MODELS_DIR / "config.json").write_text(json.dumps(config_out, indent=2), encoding="utf-8")
print("Saved:", MODELS_DIR / "config.json")


Saved: /teamspace/studios/this_studio/lung_sound_project/models/config.json


✅ If this v2 run still performs badly, next debugging steps:
1) Train **without LSTM** (mean-pool segment logits) to isolate the issue.
2) Train **ICBHI-only 4-class** baseline (Normal/Asthma/COPD/Pneumonia) to validate pipeline.
3) Add datasets for Bronchitis & Pleural effusion to truly reach 8 classes.
