In [None]:
import os
import sys
import yaml
import math
import random
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR

from tqdm.auto import tqdm
from pathlib import Path
from dataclasses import dataclass, field
from typing import Dict, Any, Optional

torch.set_float32_matmul_precision("high")  

In [None]:
from models.wce_frame_onset import *
from dataloader_wce import *

from utils import set_seed, get_device

# Add project root to sys.path
project_root = Path(os.getcwd())
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

In [None]:
# -----------------------------------------------------------------------------
# Metrics Tracker (Supports Subject-wise ERR)
# -----------------------------------------------------------------------------
def compute_MAPE(y_pred, y_true, eps=1e-8):
    # Mean Absolute Percentage Error
    # handle zeros in y_true by clipping or adding eps
    y_true_clamped = np.maximum(y_true, 1.0)
    abs_err = np.abs(y_pred - y_true)
    return np.mean(abs_err / y_true_clamped) * 100.0

def compute_MPE(y_pred, y_true, eps=1e-8):
    # Mean Percentage Error
    y_true_clamped = np.maximum(y_true, 1.0)
    err = (y_pred - y_true)
    return np.mean(err / y_true_clamped) * 100.0

class MetricsTracker:
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.losses = []
        self.y_true = []
        self.y_pred = []
        self.subjects = []
        self.sample_count = 0
    
    def update(self, loss_val, y_true_batch, y_pred_batch, subject_batch=None):
        loss_val = float(loss_val) 
        batch_size = len(y_true_batch)
        self.losses.append(loss_val * batch_size)
        self.sample_count += batch_size
        
        if isinstance(y_true_batch, torch.Tensor):
            y_true_batch = y_true_batch.detach().cpu().numpy()
        if isinstance(y_pred_batch, torch.Tensor):
            y_pred_batch = y_pred_batch.detach().cpu().numpy()
            
        self.y_true.append(y_true_batch)
        self.y_pred.append(y_pred_batch)
        
        if subject_batch is not None:
            # subject_batch should be list or numpy array of strings/ids
            self.subjects.append(subject_batch)
            
    def result(self):
        avg_loss = np.sum(self.losses) / max(1, self.sample_count)
        if not self.y_true:
            return {"loss": avg_loss, "mpe": 0.0, "mape": 0.0, "err_median": 0.0}

        y_t = np.concatenate(self.y_true)
        y_p = np.concatenate(self.y_pred)
        
        # 1. Basic Micro Metrics (Global)
        mape = compute_MAPE(y_p, y_t)
        mpe = compute_MPE(y_p, y_t)
        
        res = {
            "loss": avg_loss,
            "mpe": mpe,      # micro MPE
            "mape": mape,    # micro MAPE
            "err_median": 0.0 # default
        }
        
        # 2. Subject-wise Aggregation (ERR)
        if self.subjects:
            subjs = np.concatenate(self.subjects)
            df = pd.DataFrame({"subj": subjs, "pred": y_p, "true": y_t})
            
            # Group by Subject -> Sum Counts
            grp = df.groupby("subj", sort=False).agg({"pred": "sum", "true": "sum"})
            
            # ERR Calculation: |Pred - True| / max(True, 1.0)
            err_per_subject = (grp["pred"] - grp["true"]).abs() / np.clip(grp["true"], 1.0, None)
            
            res["err_mean"] = float(err_per_subject.mean())
            res["err_median"] = float(err_per_subject.median()) # â˜… Key Metric
            res["err_std"] = float(err_per_subject.std(ddof=0))
            
        return res

In [None]:
# -----------------------------------------------------------------------------
# Utils for Optimizer/Scheduler (Simplified from pipeline)
# -----------------------------------------------------------------------------

def sanitize_optimizer_cfg(cfg):
    # Remove basic types that are not needed for optimizer kwargs
    return {k: v for k, v in cfg.items() if k not in ["name", "scheme"]}

def build_optimizer(model, cfg):
    opt_name = cfg.get("name", "AdamW")
    opt_kwargs = sanitize_optimizer_cfg(cfg)
    if opt_name == "AdamW":
        return optim.AdamW(model.parameters(), **opt_kwargs)
    elif opt_name == "Adam":
        return optim.Adam(model.parameters(), **opt_kwargs)
    else:
        # fallback
        return optim.AdamW(model.parameters(), **opt_kwargs)

def get_loss_config(train_cfg):
    """Extract loss configuration from train_cfg"""
    return {
        "frame_weight": train_cfg.get("loss_weights", {}).get("frame", 1.0),
        "count_weight": train_cfg.get("loss_weights", {}).get("count", 0.01),
        "count_loss_type": train_cfg.get("count_loss_type", "l2"),
        "pos_weight": train_cfg.get("pos_weight", None),
    }

def get_scheduler(optimizer, name, epochs=None, steps_per_epoch=None, total_steps=None, **kwargs):
    """
    Simplified factory supporting only 'cosine_warmup'.
    Returns (scheduler, mode='batch')
    """
    if name != "cosine_warmup":
        print(f"[Warning] Scheduler '{name}' is not supported in this simplified script. Returning None.")
        return None, None

    # 1. Infer Total Steps
    if total_steps is None:
        if epochs is not None and steps_per_epoch is not None:
            total_steps = int(epochs) * int(steps_per_epoch)
        else:
            raise ValueError("Scheduler needs total_steps or (epochs and steps_per_epoch)")

    # 2. Parse Args (default values provided if missing)
    warmup_steps = kwargs.get("warmup_steps", 10200)
    min_lr = kwargs.get("min_lr", 0.0333)

    # 3. Define Lambda
    def lr_lambda(step: int) -> float:
        # Phase 1: Linear Warmup
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        
        # Phase 2: Cosine Decay
        progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        progress = min(1.0, max(0.0, progress)) # Clamp 0~1
        
        return 0.5 * (1.0 + math.cos(math.pi * progress)) * (1.0 - min_lr) + min_lr

    return LambdaLR(optimizer, lr_lambda), "batch"

In [None]:
# -----------------------------------------------------------------------------
# Trainer
# -----------------------------------------------------------------------------
class Trainer:
    def __init__(self, model, train_loader, val_loader, args, train_cfg, opt_cfg, scheduler_cfg=None):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.args = args
        self.train_cfg = train_cfg
        self.device = torch.device(train_cfg.get("device", "cuda" if torch.cuda.is_available() else "cpu"))
        self.model.to(self.device)

        sanitized_opt_cfg = sanitize_optimizer_cfg(opt_cfg)
        self.optimizer = build_optimizer(self.model, sanitized_opt_cfg)

        # Scheduler
        self.scheduler = None
        self.scheduler_mode = None
        
        if scheduler_cfg and scheduler_cfg.get("scheme", "none").lower() != "none":
            self.scheduler, self.scheduler_mode = get_scheduler(
                self.optimizer,
                name=scheduler_cfg["scheme"],
                epochs=self.train_cfg.get("num_epochs"),              
                steps_per_epoch=self.train_cfg.get("steps_per_epoch", len(train_loader)),
                total_steps=self.train_cfg.get("steps_per_epoch", len(train_loader))*self.train_cfg.get("num_epochs"),
                **{k: v for k, v in scheduler_cfg.items() if k != "scheme"}
            )

        # GradScaler
        self.scaler = torch.amp.GradScaler('cuda', enabled=(self.device.type == 'cuda'))
        
        # Loss config
        self.loss_cfg = get_loss_config(self.train_cfg)
        
    def train_step(self, batch):
        self.model.train()
        feats, lengths, labels, counts_gt = batch
        feats = feats.to(self.device)
        lengths = lengths.to(self.device)
        labels = labels.to(self.device)
        counts_gt = counts_gt.to(self.device)
        
        self.optimizer.zero_grad(set_to_none=True)
        amp_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
        
        with torch.amp.autocast(device_type=self.device.type, dtype=amp_dtype, enabled=self.scaler.is_enabled()):
            outputs = self.model(feats, lengths)
            loss = frame_onset_bce_count_mix(
                out_dict=outputs,
                frame_targets=labels,
                lengths=lengths,
                pos_weight=self.loss_cfg["pos_weight"],
                frame_weight=self.loss_cfg["frame_weight"],
                count_weight=self.loss_cfg["count_weight"],
                count_loss_type=self.loss_cfg["count_loss_type"]
            )
        
        self.scaler.scale(loss).backward()
        self.scaler.step(self.optimizer)
        self.scaler.update()
        
        # Scheduler Step (Batch)
        if self.scheduler and self.scheduler_mode == 'batch':
            self.scheduler.step()
        
        # Return loss and predictions for metrics
        pred_count = outputs.get('count', outputs['frame_logits'].sigmoid().sum(dim=1))
        return loss.item(), counts_gt, pred_count

    @torch.inference_mode()
    def validation_step(self, batch):
        self.model.eval()
        feats, lengths, labels, counts_gt = batch
        feats = feats.to(self.device)
        lengths = lengths.to(self.device)
        labels = labels.to(self.device)
        counts_gt = counts_gt.to(self.device)
        
        outputs = self.model(feats, lengths)
        loss = frame_onset_bce_count_mix(
            out_dict=outputs,
            frame_targets=labels,
            lengths=lengths,
            pos_weight=self.loss_cfg["pos_weight"],
            frame_weight=self.loss_cfg["frame_weight"],
            count_weight=self.loss_cfg["count_weight"],
            count_loss_type=self.loss_cfg["count_loss_type"]
        )
        
        pred_count = outputs.get('count', outputs['frame_logits'].sigmoid().sum(dim=1))
        return loss.item(), counts_gt, pred_count

    def fit(self, num_epochs=100, early_stop_metric="val_err_median", patience=10, min_delta=0.01):
        best_metric = float('inf')
        patience_counter = 0
        
        # Check if dataset has 'speaker_id' for subject-wise metrics
        has_speaker_id = False
        if hasattr(self.val_loader.dataset, 'meta') and 'speaker_id' in self.val_loader.dataset.meta.columns:
            has_speaker_id = True

        for epoch in range(num_epochs):

            # Training
            train_tracker = MetricsTracker()
            pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
            
            for batch in pbar:
                loss_val, y_true, y_pred = self.train_step(batch)
                train_tracker.update(loss_val, y_true, y_pred) # Train doesn't usually need subject info
                pbar.set_postfix({'loss': loss_val})
            
            # Scheduler Step (Epoch)
            if self.scheduler and self.scheduler_mode == 'epoch':
                self.scheduler.step()
                
            train_res = train_tracker.result()
            
            # Validation
            val_tracker = MetricsTracker()
            sample_cursor = 0
            
            for batch in tqdm(self.val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Valid]"):
                loss_val, y_true, y_pred = self.validation_step(batch) 
                
                subjects = []
                if has_speaker_id:
                    batch_size = len(y_true)
                    dataset_obj = self.val_loader.dataset
                    
                    for j in range(batch_size):
                        curr_idx = sample_cursor + j
                        if hasattr(dataset_obj, 'indices'):
                            real_idx = dataset_obj.indices[curr_idx]
                            subj = dataset_obj.dataset.get_subject_id(real_idx)
                        else:
                            subj = dataset_obj.get_subject_id(curr_idx)
                        subjects.append(subj)
                    sample_cursor += batch_size
                else:
                    subjects = None
                
                val_tracker.update(loss_val, y_true, y_pred, subjects)
            
            val_res = val_tracker.result()

            # Scheduler Step (Metric) - e.g. ReduceLROnPlateau
            if self.scheduler and self.scheduler_mode == 'metric':
                self.scheduler.step(val_res['loss'])

            print(f"\n[Epoch {epoch+1}]")
            print(f"  Train -> Loss: {train_res['loss']:.4f} | MPE: {train_res['mpe']:.2f}% | MAPE: {train_res['mape']:.2f}%")
            print(f"  Valid -> Loss: {val_res['loss']:.4f} | MPE: {val_res['mpe']:.2f}% | MAPE: {val_res['mape']:.2f}%")
            
            # Early Stopping & Checkpoint
            current_val = float('inf')
            if early_stop_metric == "val_err_median":
                current_val = val_res.get('err_median', val_res['mape']) # Fallback to MAPE if no err_median
            elif early_stop_metric == "val_loss":
                current_val = val_res['loss']
            elif early_stop_metric == "val_mape":
                current_val = val_res['mape']
            
            if current_val < best_metric - min_delta:
                best_metric = current_val
                patience_counter = 0
                torch.save(self.model.state_dict(), "model_best.pt")
                print(f"  >> New Best Model saved! ({early_stop_metric}={current_val:.4f})")
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print(f"  >> Early stopping triggered! No improvement for {patience} epochs.")
                    break

In [None]:
# Load Config
class ArgsMock:
    pass

def load_data_from_args(pt_path, csv_path):
    print(f"Loading meta: {csv_path}")
    if not os.path.exists(csv_path):
        raise FileNotFoundError(f"{csv_path} not found.")
    meta = pd.read_csv(csv_path)
    
    print(f"Loading features: {pt_path}")
    if not os.path.exists(pt_path):
        raise FileNotFoundError(f"{pt_path} not found.")
    features = torch.load(pt_path) 
    # Usually dataset expects dict or list. If pt_path is the large dict {utt_id: features}, it's fine.
    
    return features, meta

config_path = "./conf/wce_frame_onset.yaml"
if os.path.exists(config_path):
    with open(config_path, "r") as f:
        CONFIG = yaml.safe_load(f)
    print(f"Loaded config from {config_path}")
else:
    raise FileNotFoundError(f"Config file {config_path} not found.")

# Extract Configs
DATASET_ARGS = CONFIG.get('dataset_args', {})
MODEL_ARGS = CONFIG.get('model_args', {})
OPT_CFG = CONFIG.get('optimizer_args', {})
SCHEDULER_CFG = CONFIG.get('scheduler_args', {})
TRAIN_CFG = CONFIG.get('train_args', {})

# Global setting overrides
if 'enable_amp' in CONFIG: TRAIN_CFG['enable_amp'] = CONFIG['enable_amp']
if 'do_compile' in CONFIG: TRAIN_CFG['do_compile'] = CONFIG['do_compile']

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
TRAIN_CFG['device'] = str(device) 

# Set Seed
set_seed(TRAIN_CFG.get('seed', 42))

# Load Data
# The SpeechCountDataset expects features and meta.
# In config: X matches meta csv, y matches feature pt.
try:
    train_feats, train_meta = load_data_from_args(DATASET_ARGS['X_train'], DATASET_ARGS['y_train'])
    valid_feats, valid_meta = load_data_from_args(DATASET_ARGS['X_valid'], DATASET_ARGS['y_valid'])
    
    target = TRAIN_CFG.get('target_label', 'frame_onset')
    
    # Initialize Datasets
    train_ds = SpeechCountDataset(train_feats, train_meta, target=target)
    valid_ds = SpeechCountDataset(valid_feats, valid_meta, target=target)

    # 4. DataLoaders
    # Apply num-workers and prefetch-factor from config
    num_workers = CONFIG.get('num-workers', 0)
    prefetch_factor = CONFIG.get('prefetch-factor', 6)
    batch_size = TRAIN_CFG.get('batch_size', 256)
    
    # Adjust prefetch_factor: it requires num_workers > 0
    actual_prefetch = prefetch_factor if num_workers > 0 else None

    train_loader = DataLoader(
        train_ds, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=num_workers, 
        prefetch_factor=actual_prefetch,
        persistent_workers=True
        collate_fn=pad_collate_frame_onset,
        pin_memory=torch.cuda.is_available()
    )
    valid_loader = DataLoader(
        valid_ds, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=num_workers, 
        prefetch_factor=actual_prefetch,
        persistent_workers=True,
        collate_fn=pad_collate_frame_onset,
        pin_memory=torch.cuda.is_available()
    )

    # 5. Model Setup
    model = WCEFrameOnset(**MODEL_ARGS)

    # Compile if requested
    if TRAIN_CFG.get('do_compile', False):
        print("Compiling model (torch.compile)...")
        model = torch.compile(model)

    # 6. Prepare Trainer Args
    ARGS = ArgsMock()
    ARGS.loss = TRAIN_CFG.get('loss_function', 'frame_bce_mix')

    # Update steps
    TRAIN_CFG["steps_per_epoch"] = len(train_loader)
    TRAIN_CFG["total_steps"] = len(train_loader) * TRAIN_CFG["num_epochs"]
    
    # Pass scheduler scheme and optimizer name correctly
    SCHEDULER_CFG['scheme'] = CONFIG.get('scheduler', 'none')
    OPT_CFG['name'] = CONFIG.get('optimize', 'AdamW')
    
    # 7. Run Training
    trainer = Trainer(model, train_loader, valid_loader, ARGS, TRAIN_CFG, OPT_CFG, SCHEDULER_CFG)
    trainer.fit(
        num_epochs=TRAIN_CFG["num_epochs"], 
        early_stop_metric=TRAIN_CFG.get("monitor_metric"),
        patience=TRAIN_CFG.get("early_stopping_window", 10),
        min_delta=TRAIN_CFG.get("early_stopping_min_delta", 0.01)
    )

except Exception as e:
    print(f"An error occurred during setup: {e}")
    import traceback
    traceback.print_exc()