In [10]:
import shutil
from pathlib import Path
import random

def split_dataset(
    data_dir: str | Path,
    output_dir: str | Path,
    val_ratio: float = 0.2,
    seed: int = 42,
) -> None:
    """Split dataset into train/val folders preserving class structure."""
    random.seed(seed)
    data_dir, output_dir = Path(data_dir), Path(output_dir)

    for class_dir in data_dir.iterdir():
        if not class_dir.is_dir():
            continue

        files = list(class_dir.glob("*.wav"))
        random.shuffle(files)

        split_idx = int(len(files) * (1 - val_ratio))
        train_files, val_files = files[:split_idx], files[split_idx:]
        print(f"Class {class_dir.name}: {len(train_files)} train, {len(val_files)} val")
        for split, split_files in [("train", train_files), ("val", val_files)]:
            split_class_dir = output_dir / split / class_dir.name
            split_class_dir.mkdir(parents=True, exist_ok=True)

            for f in split_files:
                shutil.copy2(f, split_class_dir / f.name)


split_dataset("data/forest_fire_dataset", "data_split", val_ratio=0.2)


Class class_1: 231 train, 58 val
Class class_0: 1600 train, 400 val
Class .ipynb_checkpoints: 0 train, 0 val


In [5]:
from src.forest_fires.models.ast_lora import LoRaASTClassifier
num_labels = 2
lr = 1e-4
model = LoRaASTClassifier(num_labels=num_labels, lr=lr)
trainable_params_ast = 0
total_params_ast = 0
for p in model.ast_lora.parameters():
    if p.requires_grad:
        trainable_params_ast += p.numel()
    total_params_ast += p.numel()
print(f"Trainable params in AST LoRA: {trainable_params_ast} / {total_params_ast} ({100 * trainable_params_ast / total_params_ast:.2f}%)")

trainable_params_classifier = 0
total_params_classifier = 0
for p in model.classifier.parameters():
    if p.requires_grad:
        trainable_params_classifier += p.numel()
    total_params_classifier += p.numel()
print(f"Trainable params in classifier: {trainable_params_classifier} / {total_params_classifier} ({100 * trainable_params_classifier / total_params_classifier:.2f}%)")

print(f"Total trainable params: {trainable_params_ast + trainable_params_classifier} / {total_params_ast + total_params_classifier} ({100 * (trainable_params_ast + trainable_params_classifier) / (total_params_ast + total_params_classifier):.2f}%)")

Trainable params in AST LoRA: 294912 / 86482176 (0.34%)
Trainable params in classifier: 1538 / 1538 (100.00%)
Total trainable params: 296450 / 86483714 (0.34%)


In [None]:
"""Training script for ASTLightning with LoRA adapters."""

import torch  # type: ignore
from lightning.pytorch import seed_everything  # type: ignore
from torch.optim import AdamW
from tqdm import tqdm
from pathlib import Path
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix

from forest_fires.data_preprocessing.ast_lora_datamodule import get_dataloader
from src.forest_fires.models.ast_lora import LoRaASTClassifier


def main(batch_size: int, lr: float = 1e-4, num_epochs: int = 20, ckpt_dir: str = "checkpoints") -> None:
    """Manual training loop with metrics and checkpointing."""
    seed_everything(42)

    # Data
    train_loader = get_dataloader(
        "data/forest_fire_dataset/train", batch_size=batch_size, shuffle=True
    )
    val_loader = get_dataloader(
        "data/forest_fire_dataset/val", batch_size=batch_size, shuffle=False
    )

    num_labels = len(train_loader.dataset.class_to_idx)
    model = LoRaASTClassifier(num_labels=num_labels, lr=lr)
    model.classifier = model.classifier.to(torch.float32)
    model.to("cuda").train()
    optimizer = AdamW(model.parameters(), lr=lr)
    scaler = torch.cuda.amp.GradScaler()
    loss_fn = torch.nn.CrossEntropyLoss()

    ckpt_dir = Path(ckpt_dir)
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    best_val_acc = 0.0

    for epoch in range(num_epochs):
        # ---------------- TRAIN ----------------
        model.train()
        epoch_loss, epoch_acc = 0.0, 0.0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
        for features, labels in pbar:
            features, labels = features.to(model.device), labels.to(model.device)
            optimizer.zero_grad(set_to_none=True)

            with torch.cuda.amp.autocast():
                logits = model(features)
                loss = loss_fn(logits, labels)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

            acc = (logits.argmax(dim=-1) == labels).float().mean().item()
            epoch_loss += loss.item()
            epoch_acc += acc
            pbar.set_postfix(loss=f"{loss.item():.4f}", acc=f"{acc:.3f}")

        avg_loss = epoch_loss / len(train_loader)
        avg_acc = epoch_acc / len(train_loader)
        print(f"Epoch {epoch+1:02d}: train_loss={avg_loss:.4f}, train_acc={avg_acc:.3f}")

        # ---------------- VALIDATION ----------------
        model.eval()
        val_loss, preds_all, labels_all = 0.0, [], []
        with torch.no_grad(), torch.cuda.amp.autocast():
            for features, labels in val_loader:
                features, labels = features.to(model.device), labels.to(model.device)
                logits = model(features)
                loss = loss_fn(logits, labels)
                val_loss += loss.item()

                preds_all.extend(logits.argmax(dim=-1).cpu().tolist())
                labels_all.extend(labels.cpu().tolist())

        val_loss /= len(val_loader)
        acc = accuracy_score(labels_all, preds_all)
        f1 = f1_score(labels_all, preds_all, average="weighted")
        prec = precision_score(labels_all, preds_all, average="weighted", zero_division=0)
        rec = recall_score(labels_all, preds_all, average="weighted")
        cm = confusion_matrix(labels_all, preds_all)

        print(
            f"           val_loss={val_loss:.4f}, val_acc={acc:.3f}, "
            f"val_f1={f1:.3f}, val_prec={prec:.3f}, val_rec={rec:.3f}"
        )
        print("Confusion Matrix:\n", cm)

        # ---------------- CHECKPOINT ----------------
        if acc > best_val_acc:
            best_val_acc = acc
            ckpt_path = ckpt_dir / f"epoch{epoch+1:02d}-acc{acc:.3f}.pt"
            torch.save(
                {"epoch": epoch + 1, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict()},
                ckpt_path,
            )
            print(f"Saved new best model checkpoint to {ckpt_path}")


if __name__ == "__main__":
    main(batch_size=8)


INFO: Seed set to 42
INFO:lightning.fabric.utilities.seed:Seed set to 42


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

  scaler = torch.cuda.amp.GradScaler()
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  with torch.cuda.amp.autocast():
  with torch.no_grad(), torch.cuda.amp.autocast():


Epoch 01: train_loss=0.0691, train_acc=0.963
           val_loss=0.0230, val_acc=0.998, val_f1=0.998, val_prec=0.998, val_rec=0.998
Confusion Matrix:
 [[400   0]
 [  1  57]]
Saved new best model checkpoint to checkpoints/epoch01-acc0.998.pt


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  with torch.cuda.amp.autocast():
  with torch.no_grad(), torch.cuda.amp.autocast():


Epoch 02: train_loss=0.0001, train_acc=1.000
           val_loss=0.0251, val_acc=0.998, val_f1=0.998, val_prec=0.998, val_rec=0.998
Confusion Matrix:
 [[400   0]
 [  1  57]]


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  with torch.cuda.amp.autocast():
  with torch.no_grad(), torch.cuda.amp.autocast():


Epoch 03: train_loss=0.0000, train_acc=1.000
           val_loss=0.0267, val_acc=0.998, val_f1=0.998, val_prec=0.998, val_rec=0.998
Confusion Matrix:
 [[400   0]
 [  1  57]]


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  with torch.cuda.amp.autocast():
  with torch.no_grad(), torch.cuda.amp.autocast():


Epoch 04: train_loss=0.0000, train_acc=1.000
           val_loss=0.0278, val_acc=0.998, val_f1=0.998, val_prec=0.998, val_rec=0.998
Confusion Matrix:
 [[400   0]
 [  1  57]]


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  with torch.cuda.amp.autocast():
  with torch.no_grad(), torch.cuda.amp.autocast():


Epoch 05: train_loss=0.0000, train_acc=1.000
           val_loss=0.0287, val_acc=0.998, val_f1=0.998, val_prec=0.998, val_rec=0.998
Confusion Matrix:
 [[400   0]
 [  1  57]]


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  with torch.cuda.amp.autocast():
  with torch.no_grad(), torch.cuda.amp.autocast():


Epoch 06: train_loss=0.0000, train_acc=1.000
           val_loss=0.0295, val_acc=0.998, val_f1=0.998, val_prec=0.998, val_rec=0.998
Confusion Matrix:
 [[400   0]
 [  1  57]]


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  with torch.cuda.amp.autocast():
  with torch.no_grad(), torch.cuda.amp.autocast():


Epoch 07: train_loss=0.0000, train_acc=1.000
           val_loss=0.0303, val_acc=0.998, val_f1=0.998, val_prec=0.998, val_rec=0.998
Confusion Matrix:
 [[400   0]
 [  1  57]]


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  with torch.cuda.amp.autocast():
  with torch.no_grad(), torch.cuda.amp.autocast():


Epoch 08: train_loss=0.0000, train_acc=1.000
           val_loss=0.0309, val_acc=0.998, val_f1=0.998, val_prec=0.998, val_rec=0.998
Confusion Matrix:
 [[400   0]
 [  1  57]]


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  with torch.cuda.amp.autocast():
  with torch.no_grad(), torch.cuda.amp.autocast():


Epoch 09: train_loss=0.0000, train_acc=1.000
           val_loss=0.0314, val_acc=0.998, val_f1=0.998, val_prec=0.998, val_rec=0.998
Confusion Matrix:
 [[400   0]
 [  1  57]]


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  with torch.cuda.amp.autocast():
  with torch.no_grad(), torch.cuda.amp.autocast():


Epoch 10: train_loss=0.0000, train_acc=1.000
           val_loss=0.0319, val_acc=0.998, val_f1=0.998, val_prec=0.998, val_rec=0.998
Confusion Matrix:
 [[400   0]
 [  1  57]]


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  with torch.cuda.amp.autocast():
  with torch.no_grad(), torch.cuda.amp.autocast():


Epoch 11: train_loss=0.0000, train_acc=1.000
           val_loss=0.0324, val_acc=0.998, val_f1=0.998, val_prec=0.998, val_rec=0.998
Confusion Matrix:
 [[400   0]
 [  1  57]]


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  with torch.cuda.amp.autocast():
  with torch.no_grad(), torch.cuda.amp.autocast():


Epoch 12: train_loss=0.0000, train_acc=1.000
           val_loss=0.0328, val_acc=0.998, val_f1=0.998, val_prec=0.998, val_rec=0.998
Confusion Matrix:
 [[400   0]
 [  1  57]]


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  with torch.cuda.amp.autocast():
  with torch.no_grad(), torch.cuda.amp.autocast():


Epoch 13: train_loss=0.0000, train_acc=1.000
           val_loss=0.0332, val_acc=0.998, val_f1=0.998, val_prec=0.998, val_rec=0.998
Confusion Matrix:
 [[400   0]
 [  1  57]]


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  with torch.cuda.amp.autocast():
  with torch.no_grad(), torch.cuda.amp.autocast():


Epoch 14: train_loss=0.0000, train_acc=1.000
           val_loss=0.0336, val_acc=0.998, val_f1=0.998, val_prec=0.998, val_rec=0.998
Confusion Matrix:
 [[400   0]
 [  1  57]]


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  with torch.cuda.amp.autocast():
  with torch.no_grad(), torch.cuda.amp.autocast():


Epoch 15: train_loss=0.0000, train_acc=1.000
           val_loss=0.0339, val_acc=0.998, val_f1=0.998, val_prec=0.998, val_rec=0.998
Confusion Matrix:
 [[400   0]
 [  1  57]]


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  with torch.cuda.amp.autocast():
  with torch.no_grad(), torch.cuda.amp.autocast():


Epoch 16: train_loss=0.0000, train_acc=1.000
           val_loss=0.0343, val_acc=0.998, val_f1=0.998, val_prec=0.998, val_rec=0.998
Confusion Matrix:
 [[400   0]
 [  1  57]]


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  with torch.cuda.amp.autocast():
  with torch.no_grad(), torch.cuda.amp.autocast():


Epoch 17: train_loss=0.0000, train_acc=1.000
           val_loss=0.0346, val_acc=0.998, val_f1=0.998, val_prec=0.998, val_rec=0.998
Confusion Matrix:
 [[400   0]
 [  1  57]]


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  with torch.cuda.amp.autocast():
  with torch.no_grad(), torch.cuda.amp.autocast():


Epoch 18: train_loss=0.0000, train_acc=1.000
           val_loss=0.0349, val_acc=0.998, val_f1=0.998, val_prec=0.998, val_rec=0.998
Confusion Matrix:
 [[400   0]
 [  1  57]]


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  with torch.cuda.amp.autocast():
  with torch.no_grad(), torch.cuda.amp.autocast():


Epoch 19: train_loss=0.0000, train_acc=1.000
           val_loss=0.0352, val_acc=0.998, val_f1=0.998, val_prec=0.998, val_rec=0.998
Confusion Matrix:
 [[400   0]
 [  1  57]]


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  with torch.cuda.amp.autocast():
  with torch.no_grad(), torch.cuda.amp.autocast():


Epoch 20: train_loss=0.0000, train_acc=1.000
           val_loss=0.0355, val_acc=0.998, val_f1=0.998, val_prec=0.998, val_rec=0.998
Confusion Matrix:
 [[400   0]
 [  1  57]]
