In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Advancing Fluorescence Light Detection and Ranging in Scattering Media with 
a Physics-Guided Mixture-of-Experts and Evidential Critics
------------------------------------------------------------------------------
Version Date: 2025-04-29 
"""

# must be the very first torch‐multiprocessing line in the notebook
import torch
import torch.multiprocessing as mp
mp.set_start_method('fork', force=True)

# now all usual imports
import os, sys, argparse, math, copy, numpy as np, scipy.io as sio
from tqdm import tqdm
from sklearn.model_selection import train_test_split

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import AdamW
from torch.utils.data import DataLoader, IterableDataset
from torch.utils.tensorboard import SummaryWriter

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
# --- Constants ---
EPSILON = 1e-8 # Small value for numerical stability

In [None]:
######################################
# 1. Configuration 
######################################
class EvidenceMoeConfig:
    """Configuration class for the EvidenceMoE model with EDC."""
    def __init__(
        self,
        input_size=1,
        hidden_size=224,
        intermediate_size=384,
        num_hidden_layers=2,
        num_attention_heads=16,
        num_experts=3,
        max_position_embeddings=512,
        initializer_range=0.02,
        rms_norm_eps=1e-6,
        attention_dropout=0.15194,
        # Optimizer Hyperparameters
        expert_lr: float = 5e-5,         # LR for Experts 
        decider_lr: float = 1e-4,        # LR for Decider Head
        critic_lr: float = 1e-4,         # LR for EDC Critics
        expert_weight_decay: float = 0.0, # Weight decay for Experts
        decider_weight_decay: float = 0.0, # Weight decay for Decider Head
        critic_weight_decay: float = 1e-4, # Weight decay for EDC Critics
        # EDC Loss Hyperparameters
        lambda_KL: float = 1e-3,
        gamma_penalty: float = 1e-3,
        quality_kappa: float = 20.0,      # Kappa for 1 / (1 + kappa * MAE) target quality
        # Loss Component Weights
        lambda_primary: float = 1.0,       # Weight for primary MAE loss
        lambda_aux: float = 1.0,           # Weight for auxiliary expert loss
        lambda_critic_quality: float = 1.0, # Weight for combined Evi+KL quality loss
        lambda_corr: float = 1.0,           # Weight for correction accuracy loss
        lambda_penalty: float = 1.0,        # Weight for evidence penalty loss term
        lambda_diversity: float = 0.0,      # Weight for diversity loss
        # Other Model Params
        damping_factor=0.1,
        # RL Params (Inactive)
        rl_weight=0.0,
        rl_entropy_coeff=0.01,
        rl_similarity_penalty=0.5,
        # Ablation Flags
        ablate_quality_weighting: bool = False,
        ablate_correction: bool = False,
        ablate_quality_in_gating: bool = False,
        ablate_decider_feature: bool = False,
        ablate_decider_feature_fusion: bool = False,
        ablate_uniform_gating: bool = False,
        ablate_gating_dropout: bool = False,
        ablate_phased_training: bool = False,
        ablate_mean_pooling: bool = False,
        ablate_auxiliary_mae: bool = False,
    ):
        # Assign all parameters to self
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.num_experts = num_experts
        self.max_position_embeddings = max_position_embeddings
        self.initializer_range = initializer_range
        self.rms_norm_eps = rms_norm_eps
        self.attention_dropout = attention_dropout
        self.expert_lr = expert_lr        
        self.decider_lr = decider_lr     
        self.critic_lr = critic_lr
        self.expert_weight_decay = expert_weight_decay 
        self.decider_weight_decay = decider_weight_decay 
        self.critic_weight_decay = critic_weight_decay
        self.lambda_KL = lambda_KL
        self.gamma_penalty = gamma_penalty
        self.quality_kappa = quality_kappa
        self.damping_factor = damping_factor
        self.rl_weight = rl_weight
        self.rl_entropy_coeff = rl_entropy_coeff
        self.rl_similarity_penalty = rl_similarity_penalty
        self.global_seq_len = 70
        self.early_seq_len = 35
        self.late_seq_len = 35
        self.conv_kernel = 3
        self.num_filters = hidden_size
        # Ablations
        self.ablate_quality_weighting = ablate_quality_weighting
        self.ablate_correction = ablate_correction
        self.ablate_quality_in_gating = ablate_quality_in_gating
        self.ablate_decider_feature = ablate_decider_feature
        self.ablate_decider_feature_fusion = ablate_decider_feature_fusion
        self.ablate_uniform_gating = ablate_uniform_gating
        self.ablate_gating_dropout = ablate_gating_dropout
        self.ablate_phased_training = ablate_phased_training
        self.ablate_mean_pooling = ablate_mean_pooling
        self.ablate_auxiliary_mae = ablate_auxiliary_mae
        # Loss Weights
        self.lambda_primary = lambda_primary 
        self.lambda_aux = lambda_aux         
        self.lambda_critic_quality = lambda_critic_quality
        self.lambda_corr = lambda_corr
        self.lambda_penalty = lambda_penalty
        self.lambda_diversity = lambda_diversity 


In [None]:
######################################
# 2. Common Components
######################################
class PositionalEncoding(nn.Module):
    """Standard sinusoidal positional encoding for transformer inputs."""
    def __init__(self, embed_dim, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2, dtype=torch.float) * (-math.log(10000.0) / embed_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    def forward(self, x):
        L = x.size(1); return x + self.pe[:L].unsqueeze(0).to(x.device)

class StandardTransformerEncoderLayer(nn.Module):
    """A single layer of the standard Transformer encoder."""
    def __init__(self, d_model, nhead, dim_feedforward=256, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead, dropout=dropout, batch_first=True)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation = nn.ReLU()
    def forward(self, src):
        attn_output, attn_weights = self.self_attn(src, src, src)
        src = src + self.dropout1(attn_output)
        src = self.norm1(src)
        ff_output = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(ff_output)
        src = self.norm2(src)
        return src, attn_weights

class StandardTransformerEncoder(nn.Module):
    """Stack of N standard Transformer encoder layers."""
    def __init__(self, d_model, nhead, dim_feedforward, num_layers, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([StandardTransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout) for _ in range(num_layers)])
        self.num_layers = num_layers
    def forward(self, src):
        all_attn_weights = []
        output = src
        for layer in self.layers:
            output, attn_weights = layer(output)
            all_attn_weights.append(attn_weights)
        return output, all_attn_weights

class AttentionPooling(nn.Module):
    """Simple attention-pooling layer."""
    def __init__(self, input_dim):
        super().__init__()
        self.attn = nn.Sequential(nn.Linear(input_dim, input_dim), nn.Tanh(), nn.Linear(input_dim, 1))
    def forward(self, x):
        weights = self.attn(x); weights = torch.softmax(weights, dim=1); pooled = torch.sum(x * weights, dim=1); return pooled


In [None]:
######################################
# 3. Hybrid CNN-Transformer Encoder
######################################
class HybridCNNTransformerEncoder(nn.Module):
    """CNN front-end + Transformer encoder."""
    def __init__(self, config: EvidenceMoeConfig, conv_kernel=None, num_filters=None):
        super().__init__()
        self.config = config
        conv_kernel = conv_kernel if conv_kernel is not None else config.conv_kernel
        num_filters = num_filters if num_filters is not None else config.num_filters
        self.residual_conv = nn.Conv1d(in_channels=config.input_size, out_channels=num_filters, kernel_size=1)
        self.conv_initial = nn.Conv1d(in_channels=config.input_size, out_channels=num_filters, kernel_size=conv_kernel, padding=conv_kernel//2)
        self.conv1 = nn.Conv1d(in_channels=num_filters, out_channels=num_filters, kernel_size=conv_kernel, padding=conv_kernel//2)
        self.ln_conv = nn.LayerNorm(num_filters)
        self.dropout1 = nn.Dropout(0.1)
        self.relu = nn.ReLU()
        self.pos_encoding = PositionalEncoding(embed_dim=num_filters, max_len=config.max_position_embeddings)
        self.transformer_encoder = StandardTransformerEncoder(d_model=num_filters, nhead=config.num_attention_heads, dim_feedforward=config.intermediate_size, num_layers=config.num_hidden_layers, dropout=config.attention_dropout)
        self.norm = nn.LayerNorm(num_filters)
    def forward(self, x):
        x_unsqueezed = x.unsqueeze(1); residual = self.residual_conv(x_unsqueezed); c = self.conv_initial(x_unsqueezed); c = self.relu(c); c = self.conv1(c)
        c_transposed = c.transpose(1, 2); c_normed = self.ln_conv(c_transposed); c = c_normed.transpose(1, 2)
        c = self.dropout1(c); c = c + residual; c = self.relu(c); t = c.transpose(1, 2); t = self.pos_encoding(t)
        t, attn_maps = self.transformer_encoder(t); t = self.norm(t); t = torch.nan_to_num(t, nan=0.0, posinf=1e6, neginf=-1e6)
        return t, attn_maps

In [None]:
######################################
# 4. Expert Branch
######################################
class Expert(nn.Module):
    """Expert network predicting mean value(s)."""
    def __init__(self, config: EvidenceMoeConfig, input_seq_length, output_dim=2, return_attn=False):
        super().__init__(); self.seq_len = input_seq_length; self.config = config; self.output_dim = output_dim; self.return_attn = return_attn
        self.hybrid_encoder = HybridCNNTransformerEncoder(config); self.pooling = AttentionPooling(config.hidden_size)
        self.aux_head_mean = nn.Sequential(nn.Linear(config.hidden_size, config.hidden_size), nn.ReLU(), nn.Linear(config.hidden_size, config.hidden_size // 2), nn.ReLU(), nn.Linear(config.hidden_size // 2, output_dim))
    def forward(self, x):
        B, L = x.shape; 
        encoded, attn_maps = self.hybrid_encoder(x)
        features = encoded.mean(dim=1) if self.config.ablate_mean_pooling else self.pooling(encoded)
        mean = self.aux_head_mean(features); log_prob = torch.zeros(B, 1, device=mean.device); entropy = torch.zeros(B, 1, device=mean.device) 
        if self.return_attn: return features, mean, log_prob, entropy, mean, attn_maps
        else: return features, mean, log_prob, entropy, mean, None

In [None]:
######################################
# 5. Evidence-based Dirichlet Critic
######################################
class EvidenceCritic(nn.Module):
    """
    Evidence-based Dirichlet Critic (EDC) that outputs alpha, beta for quality
    (Beta distribution) and a correction signal. Uses shared backbone. Input includes features.
    """
    def __init__(self, input_dim, output_dim): # input_dim includes features + aux_pred dim
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim

        # Shared Backbone Network
        self.backbone = nn.Sequential(
            nn.Linear(input_dim, 32), nn.LayerNorm(32), nn.GELU(), nn.Dropout(0.2),
            nn.Linear(32, 16), nn.LayerNorm(16), nn.GELU(), nn.Dropout(0.2),
        )
        # Define slicing indices based on output_dim for evidence head
        self.evi_pos_indices = slice(0, output_dim)
        self.evi_neg_indices = slice(output_dim, 2 * output_dim)

        # Head 1: Evidence Prediction -> outputs raw evidence (pre-softplus)
        self.evi_head = nn.Linear(16, output_dim * 2)

        # Head 2: Correction Signal Prediction
        self.corr_head = nn.Linear(16, output_dim)

    def forward(self, x):
        # x: Concatenated [expert_features, expert_aux_pred], shape (B, input_dim)
        h = self.backbone(x) # Shared features, shape (B, 16)
        # --- Quality Branch (Evidence -> Alpha, Beta) ---
        raw_evidence = self.evi_head(h) # Shape (B, output_dim * 2)
        evidence_pos = F.softplus(raw_evidence[:, self.evi_pos_indices]) # Shape (B, output_dim)
        evidence_neg = F.softplus(raw_evidence[:, self.evi_neg_indices]) # Shape (B, output_dim)
        alpha = evidence_pos + 1.0 # Shape (B, output_dim)
        beta = evidence_neg + 1.0  # Shape (B, output_dim)
        # --- Correction Branch ---
        correction = self.corr_head(h) # Shape (B, output_dim)
        # Return alpha, beta (for quality), and the correction signal
        return alpha, beta, correction


In [None]:
######################################
# 6. Final Decider Head
######################################
class FinalDeciderHead(nn.Module):
    """Fuses expert outputs using gating"""
   
    def __init__(self, decider_input_dim, num_experts, num_quality_inputs, expert_total_dim,
                 ablate_decider_feature: bool = False, ablate_decider_feature_fusion: bool = False,
                 ablate_uniform_gating: bool = False, ablate_gating_dropout: bool = False):
        super().__init__()
       
        gating_input_dim = expert_total_dim + decider_input_dim + num_quality_inputs
        self.gate_layer = nn.Sequential(nn.Linear(gating_input_dim, decider_input_dim), nn.ReLU(), nn.Linear(decider_input_dim, num_experts), nn.Sigmoid())
        fusion_input_dim = expert_total_dim + decider_input_dim; self.fusion_layer = nn.Linear(fusion_input_dim, 2)
        self.alpha_depth = nn.Parameter(torch.tensor(0.15)); self.beta_depth = nn.Parameter(torch.tensor(0.85))
        self.alpha_lifetime = nn.Parameter(torch.tensor(0.65)); self.beta_lifetime = nn.Parameter(torch.tensor(0.35))
        self.ablate_decider_feature=ablate_decider_feature; self.ablate_decider_feature_fusion=ablate_decider_feature_fusion; self.ablate_uniform_gating=ablate_uniform_gating; self.ablate_gating_dropout=ablate_gating_dropout
        self.num_experts=num_experts; self.num_quality_inputs = num_quality_inputs; # Store num_quality_inputs
        self.register_buffer("last_gate_weights", torch.zeros(num_experts))

   
    def forward(self, corrected_aux, decider_feature, quality_scores):
        corrected_concat = torch.cat(corrected_aux, dim=1)
        if self.ablate_uniform_gating: gate_weights = torch.full((corrected_concat.size(0), self.num_experts), 1.0/self.num_experts, device=corrected_concat.device)
        else:
            gating_decider = torch.zeros_like(decider_feature) if self.ablate_decider_feature else decider_feature
            #Check against num_quality_inputs
            if quality_scores.shape[1] != self.num_quality_inputs: raise ValueError(f"FinalDeciderHead expected {self.num_quality_inputs} quality scores, got {quality_scores.shape[1]}")
            gating_input = torch.cat([corrected_concat, gating_decider, quality_scores], dim=1); gate_weights = self.gate_layer(gating_input)
            if self.training and not self.ablate_gating_dropout:
                mask = (torch.rand(gate_weights.size(0), device=gate_weights.device) < 0.2).float().unsqueeze(1)
                # Corrected dropout application
                gate_weights = gate_weights.clone() # Ensure out-of-place if needed
                gate_weights[:, 2:3] = gate_weights[:, 2:3] * (1.0 - mask)
        self.last_gate_weights = gate_weights.detach().clone(); early, late, global_expert = corrected_aux[0], corrected_aux[1], corrected_aux[2]
        # Apply gating weights
        gated_early=early*gate_weights[:,0:1]; gated_late=late*gate_weights[:,1:2]; gated_global=global_expert*gate_weights[:,2:3]
        fused_experts = torch.cat([gated_early, gated_late, gated_global], dim=1)
        fusion_decider = torch.zeros_like(decider_feature) if self.ablate_decider_feature_fusion else decider_feature
        fusion_input = torch.cat([fused_experts, fusion_decider], dim=1); raw_output = self.fusion_layer(fusion_input)
        # Apply tanh  
        pred_depth = 2.0*(self.alpha_depth+self.beta_depth*torch.tanh(raw_output[:,0:1])); pred_lifetime = 2.0*(self.alpha_lifetime+self.beta_lifetime*torch.tanh(raw_output[:,1:2]))
        pred_depth = torch.clamp(pred_depth, 0.0, 2.0); pred_lifetime = torch.clamp(pred_lifetime, 0.0, 2.0); final_pred = torch.cat([pred_depth, pred_lifetime], dim=1)
        return final_pred


In [None]:
######################################
# 8. EvidenceMoe Model 
######################################
class EvidenceMoeModel(nn.Module):
    """ MoE Model using EvidenceCritic """
    def __init__(self, config: EvidenceMoeConfig, mode="full"):
        super().__init__()
        self.mode = mode; self.config = config
        self.early_seq_len = config.early_seq_len; self.late_seq_len = config.late_seq_len; self.global_seq_len = config.global_seq_len
        self.early_expert = Expert(config, input_seq_length=self.early_seq_len, output_dim=1, return_attn=False) # Ensure return_attn is False if not needed
        self.late_expert = Expert(config, input_seq_length=self.late_seq_len, output_dim=1, return_attn=False)
        self.global_expert = Expert(config, input_seq_length=self.global_seq_len, output_dim=2, return_attn=True) # Global expert returns features needed for critic input
        
        # Use EvidenceCritic with combined feature+aux input dimension
        feat_dim = config.hidden_size
        self.critic_early = EvidenceCritic(input_dim=feat_dim+1, output_dim=1) # Hidden Feature (H) + Aux Pred (1)
        self.critic_late = EvidenceCritic(input_dim=feat_dim+1, output_dim=1)  # Hidden Feature (H) + Aux Pred (1)
        self.critic_global = EvidenceCritic(input_dim=feat_dim+2, output_dim=2) # Hidden Feature (H) + Aux Pred (2)

        self.final_head = FinalDeciderHead(
            decider_input_dim=config.hidden_size,
            num_experts=3,
            num_quality_inputs=4, #  Expecting 4 quality scores (e, l, g_d, g_l)
            expert_total_dim=4,   #  based on dimensions of corrected_aux
            ablate_decider_feature=config.ablate_decider_feature,
            ablate_decider_feature_fusion=config.ablate_decider_feature_fusion,
            ablate_uniform_gating=config.ablate_uniform_gating,
            ablate_gating_dropout=config.ablate_gating_dropout
        )

    def forward(self, x):
        if self.mode == "pretrain_experts":
            # Pretrain logic remains unchanged
            early_in=x[:,:self.early_seq_len]; late_in=x[:,self.early_seq_len:self.early_seq_len+self.late_seq_len]; global_in=x
            _, early_sample, _, _, _, _ = self.early_expert(early_in); _, late_sample, _, _, _, _ = self.late_expert(late_in)
            features_global, global_sample, _, _, _, global_attn_maps = self.global_expert(global_in)
            pred_depth = (early_sample + global_sample[:,0:1])/2.0; pred_lifetime = (late_sample + global_sample[:,1:2])/2.0
            final_pred = torch.cat([pred_depth, pred_lifetime], dim=1)
            return final_pred, [early_sample, late_sample, global_sample], global_attn_maps # Return maps needed later

        else: # Full mode with EvidenceCritic
            B, L = x.shape; assert L == self.global_seq_len, f"Expected length {self.global_seq_len}, got {L}"
            early_in=x[:,:self.early_seq_len]; late_in=x[:,self.early_seq_len:self.early_seq_len+self.late_seq_len]; global_in=x

            # --- Get Expert Outputs (Features and Aux Preds) ---
            # Need features_early/late even if attn maps aren't returned
            features_early, early_sample,_,_,_,_ = self.early_expert(early_in)
            features_late, late_sample,_,_,_,_ = self.late_expert(late_in)
            features_global, global_sample,_,_,_,global_attn_maps = self.global_expert(global_in)
            aux_preds = [early_sample, late_sample, global_sample] # List [(B,1), (B,1), (B,2)]
            expert_features = [features_early, features_late, features_global] # List [(B,H), (B,H), (B,H)]

            # --- Get Critic Outputs (using detached features + aux_preds as input) ---
            # Detach inputs for critic loss calculation path 
            critic_input_e = torch.cat([expert_features[0].detach(), aux_preds[0].detach()], dim=1)
            critic_input_l = torch.cat([expert_features[1].detach(), aux_preds[1].detach()], dim=1)
            critic_input_g = torch.cat([expert_features[2].detach(), aux_preds[2].detach()], dim=1)

            alpha_e, beta_e, corr_e = self.critic_early(critic_input_e)
            alpha_l, beta_l, corr_l = self.critic_late(critic_input_l)
            alpha_g, beta_g, corr_g = self.critic_global(critic_input_g)
            correction_signals = [corr_e, corr_l, corr_g]

            # --- Calculate Mean Quality Scores ---
            q_e = alpha_e/(alpha_e+beta_e+EPSILON); q_l = alpha_l/(alpha_l+beta_l+EPSILON)
            q_g_vec = alpha_g/(alpha_g+beta_g+EPSILON) # Shape (B, 2)
            # Create the (B, 4) quality score vector for gating/weighting 
            quality_scores_full = torch.cat([q_e, q_l, q_g_vec], dim=1) # Shape (B, 4)

            # --- Apply Ablations & Corrections ---
            q_g_scalar_for_ablation = q_g_vec.mean(dim=1, keepdim=True)
            quality_scores_mean_for_ablation = torch.cat([q_e, q_l, q_g_scalar_for_ablation], dim=1) # (B, 3)
            quality_scores_eff_3 = torch.ones_like(quality_scores_mean_for_ablation) if self.config.ablate_quality_weighting else quality_scores_mean_for_ablation

            correction_signals_eff = [torch.zeros_like(c).detach() for c in correction_signals] if self.config.ablate_correction else correction_signals
            damped_corrections = [corr * self.config.damping_factor for corr in correction_signals_eff]

            # ---Ablation: Apply Correction ONLY---
            corrected_aux = []
            for i, aux in enumerate(aux_preds):
                # Apply correction directly
                corrected = aux + damped_corrections[i]
                corrected_aux.append(corrected)

            # --- Prepare Inputs for Final Head ---
            decider_feature = features_global # Global expert's pooled features
            
            gating_quality_full = torch.ones_like(quality_scores_full) if self.config.ablate_quality_in_gating else quality_scores_full # Ablate all 4 if flag is set

            # --- Run Final Decider Head ---
            final_pred = self.final_head(corrected_aux, decider_feature, gating_quality_full) # Pass B,4 quality vector

            # --- Pack Outputs for Loss Calculation ---
            # Need original aux_preds, and critic outputs (alpha, beta, corr)
        
            q_g_scalar = quality_scores_full[:,2:3].mean(dim=1, keepdim=True)
            quality_scores_mean = torch.cat([
                quality_scores_full[:,0:1],  # q_e
                quality_scores_full[:,1:2],  # q_l
                q_g_scalar                  # mean of q_g_depth & q_g_lifetime
            ], dim=1)
            critic_outputs_dict = {
                "alpha_early": alpha_e, "beta_early": beta_e, "corr_early": corr_e,
                "alpha_late":  alpha_l, "beta_late":  beta_l, "corr_late":  corr_l,
                "alpha_global": alpha_g, "beta_global": beta_g, "corr_global": corr_g,
                "quality_scores_full": quality_scores_full,
                "quality_scores_mean": quality_scores_mean, 
            }
            

            return (final_pred, aux_preds, critic_outputs_dict, expert_features, # Return features needed for critic input in loss calc
                    decider_feature, global_attn_maps)

In [None]:
######################################
# 9. Diversity Loss
######################################
def diversity_loss(aux_preds, margin=0.3):
    loss=0.0; num_pairs=0
    for i in range(len(aux_preds)):
        for j in range(i+1,len(aux_preds)):
            f_i=F.normalize(aux_preds[i],dim=1); f_j=F.normalize(aux_preds[j],dim=1)
            cos_sim=(f_i*f_j).sum(dim=1); loss+=torch.mean(torch.clamp(cos_sim-margin,min=0.0)); num_pairs+=1
    if num_pairs>0: loss=loss/num_pairs
    return loss

In [None]:
######################################
# 10. Aux Error Quality Target Calculation
######################################
def compute_aux_error_quality_targets(aux_preds, ground_truth, quality_kappa):
    """ Calculates target quality = 1 / (1 + kappa * MAE) based on AUXILIARY errors. """
    gt_depth = ground_truth[:, 0:1]; gt_lifetime = ground_truth[:, 1:2]
    aux_early = aux_preds[0].view_as(gt_depth) if aux_preds[0].shape != gt_depth.shape else aux_preds[0]
    aux_late = aux_preds[1].view_as(gt_lifetime) if aux_preds[1].shape != gt_lifetime.shape else aux_preds[1]
    aux_global = aux_preds[2].view_as(ground_truth) if aux_preds[2].shape != ground_truth.shape else aux_preds[2]
    mae_early = torch.abs(aux_early - gt_depth)
    mae_late = torch.abs(aux_late - gt_lifetime)
    mae_global = torch.abs(aux_global - ground_truth)
    quality_target_early = 1.0 / (1.0 + quality_kappa * (mae_early + EPSILON))
    quality_target_late = 1.0 / (1.0 + quality_kappa * (mae_late + EPSILON))
    quality_target_global_vec = 1.0 / (1.0 + quality_kappa * (mae_global + EPSILON))
    return quality_target_early, quality_target_late, quality_target_global_vec

In [None]:
######################################
# 12. Training Function
######################################

# --- Define EDC Loss Helper Functions ---
def compute_evidential_loss(alpha, beta, target_q_gt):
    """ Calculates L_evi = MSE(pred_mean, target) + Variance(pred). """
    alpha_stable = alpha + EPSILON; beta_stable = beta + EPSILON; S = alpha_stable + beta_stable
    q_pred = alpha_stable / S
    mse_term = (target_q_gt - q_pred)**2
    variance_term = (alpha_stable * beta_stable) / (S.pow(2) * (S + 1.0))
    L_evi = mse_term + variance_term
    return L_evi # Shape: (B, out_dim)

def compute_kl_divergence_loss(alpha, beta):
    """ Calculates KL( Beta(alpha, beta) || Beta(1, 1) ). """
    alpha_stable = alpha + EPSILON; beta_stable = beta + EPSILON
    log_beta_posterior = torch.lgamma(alpha_stable) + torch.lgamma(beta_stable) - torch.lgamma(alpha_stable + beta_stable)
    term1 = -log_beta_posterior
    term2 = (alpha_stable - 1.0) * (torch.digamma(alpha_stable) - torch.digamma(alpha_stable + beta_stable))
    term3 = (beta_stable - 1.0) * (torch.digamma(beta_stable) - torch.digamma(alpha_stable + beta_stable))
    L_KL = torch.clamp(term1 + term2 + term3, min=0)
    return L_KL # Shape: (B, out_dim)

def compute_evidence_penalty_loss(alpha, beta, correction, gamma):
    """ Calculates L_e-penalty = gamma * correction^2 / S (element-wise). """
    alpha_stable = alpha + EPSILON; beta_stable = beta + EPSILON; S = alpha_stable + beta_stable
    delta_sq = correction**2
    if delta_sq.shape[-1] != S.shape[-1]: # Ensure broadcasting works if needed
        S_eff = S.mean(dim=-1, keepdim=True) if delta_sq.shape[-1] == 1 else S.expand_as(delta_sq)
    else: S_eff = S
    L_penalty = gamma * delta_sq / (S_eff + EPSILON) # Add epsilon to denominator
    return L_penalty
# --- End Loss Helpers ---

def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    device: torch.device,
    num_epochs: int = 1000,
    pretrain_epochs: int = 20,
    integration_epochs: int = 30,
) -> float:
    """ Trains the EvidenceMoeModel with EvidenceCritic and 3 Optimizers. """
    # ─── unwrap DDP if needed ────────────────────────────
    if hasattr(model, "module"):
        # model was wrapped in DistributedDataParallel
        model = model.module
    # ───────────────────────────────────────────────────────

    config_suffix = get_config_suffix(model.config)
    log_dir = f"runs/EDC_Critic_{config_suffix}_3opt_k8" 
    writer = SummaryWriter(log_dir=log_dir)
    print(f"TensorBoard logs will be saved to: {log_dir}")

    # Parameter Groups
    experts_params = (list(model.early_expert.parameters()) + list(model.late_expert.parameters()) + list(model.global_expert.parameters()))
    decider_params = list(model.final_head.parameters())
    critic_params = (list(model.critic_early.parameters()) + list(model.critic_late.parameters()) + list(model.critic_global.parameters()))

    # --- Optimizers 
    expert_lr_base = model.config.expert_lr # Get base LR from config if defined, else use default
    optimizer_experts = AdamW(experts_params, lr=expert_lr_base, weight_decay=model.config.expert_weight_decay)
    optimizer_decider = AdamW(decider_params, lr=model.config.decider_lr, weight_decay=model.config.decider_weight_decay)
    optimizer_critics = AdamW(critic_params, lr=model.config.critic_lr, weight_decay=model.config.critic_weight_decay)
    pretrain_optimizer = optimizer_experts # Use expert optimizer during pretrain

    # Prepare expert group dict for potential phased LR adjustment
    expert_group_config = {"params": experts_params, "lr": expert_lr_base} # Store config for later adjustment if needed

    # Loss Functions
    mae_loss_fn = nn.L1Loss()
    huber_loss_fn = nn.SmoothL1Loss(reduction='none')

    # Grad Scaler
    scaler = torch.cuda.amp.GradScaler(enabled=(device.type == 'cuda'))

    # Training State
    best_val_loss = float('inf'); best_model_state = None; patience_counter = 0; patience = 25; global_step = 0

    # Main Training Loop
    for epoch in range(num_epochs):
        # --- Phase Handling ---
        if hasattr(train_loader, "sampler") and hasattr(train_loader.sampler, "set_epoch"):
            train_loader.sampler.set_epoch(epoch)
        gradual_unfreeze_epochs = 10
        active_optimizers = [] # List of optimizers to step this epoch
        current_rl_weight = 0.0

        if model.config.ablate_phased_training:
            phase = "FULL JOINT"
            model.mode = "full"
            for p in model.parameters(): p.requires_grad_(True)
            # Set expert LR 
            for group in optimizer_experts.param_groups: group['lr'] = expert_lr_base
            active_optimizers = [optimizer_experts, optimizer_decider, optimizer_critics]
            current_rl_weight = model.config.rl_weight
            print(f"\n--- Epoch {epoch+1}/{num_epochs} :: Phase: {phase} (Ablated) ---")
        else:
            if epoch < pretrain_epochs:
                phase = "PRETRAIN"
                model.mode = "pretrain_experts"
                for p in model.parameters(): p.requires_grad_(False) # Freeze all first
                for p in experts_params: p.requires_grad_(True)     # Unfreeze only experts
                active_optimizers = [optimizer_experts]            # Only step expert optimizer
                current_rl_weight = 0.0
                print(f"\n--- Epoch {epoch+1}/{num_epochs} :: Phase: {phase} ---")
            elif epoch < integration_epochs:
                phase = "INTEGRATE"
                model.mode = "full"
                for p in model.parameters(): p.requires_grad_(False) # Freeze all first
                for p in decider_params + critic_params: p.requires_grad_(True) # Unfreeze decider & critics
                active_optimizers = [optimizer_decider, optimizer_critics] # Step decider & critic optimizers
                current_rl_weight = model.config.rl_weight
                print(f"\n--- Epoch {epoch+1}/{num_epochs} :: Phase: {phase} ---")
            else: # Full Joint Training phase starts
                phase = "JOINT"
                model.mode = "full"
                for p in model.parameters(): p.requires_grad_(True) # Ensure all are trainable

                # Gradual Unfreezing Logic for Experts
                if epoch < integration_epochs + gradual_unfreeze_epochs:
                    factor = (epoch - integration_epochs + 1) / gradual_unfreeze_epochs
                    # Adjust LR only for the expert optimizer group
                    for group in optimizer_experts.param_groups: group['lr'] = expert_lr_base * factor
                    phase += f" (Unfrz {factor:.1f})"
                else:
                    for group in optimizer_experts.param_groups: group['lr'] = expert_lr_base

                active_optimizers = [optimizer_experts, optimizer_decider, optimizer_critics] # Step all three
                current_rl_weight = model.config.rl_weight
                print(f"\n--- Epoch {epoch+1}/{num_epochs} :: Phase: {phase} ---")
        # --- End Phase Handling ---

        # --- Training Epoch ---
        model.train(); train_loss_sum=0.0; total_samples=0; total_grad_norm=0.0; grad_steps=0
        epoch_losses = {"primary": 0.0, "aux": 0.0, "quality": 0.0, "corr": 0.0, "penalty": 0.0, "total": 0.0, "div": 0.0}

        for batch_idx, (batch_x, batch_y) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1} [{phase}]")):
            batch_x = batch_x.to(device); batch_y = batch_y.to(device)

            # Zero gradients for active optimizers
            for opt in active_optimizers: opt.zero_grad()
            

            with torch.amp.autocast(device_type=device.type, enabled=device.type == 'cuda'):
                # --- Forward pass ---
                outputs = model(batch_x)

                # --- Loss Calculation ---
                if model.mode == "pretrain_experts":
                    final_pred, expert_outputs, _ = outputs
                    loss_e = mae_loss_fn(expert_outputs[0].squeeze(-1), batch_y[:, 0]); loss_l = mae_loss_fn(expert_outputs[1].squeeze(-1), batch_y[:, 1]); loss_g = mae_loss_fn(expert_outputs[2], batch_y)
                    total_loss = (3.0 * loss_e + loss_l + loss_g) / 5.0 # Weighted pretrain loss
                    epoch_losses["primary"] += total_loss.item() * batch_x.size(0); epoch_losses["total"] += total_loss.item() * batch_x.size(0)
                    # Placeholders for other losses
                    primary_loss=torch.tensor(0.0); aux_loss=torch.tensor(0.0); critic_quality_loss=torch.tensor(0.0); correction_loss=torch.tensor(0.0); evidence_penalty_loss=torch.tensor(0.0); rl_loss=torch.tensor(0.0); div_loss=torch.tensor(0.0)

                else: # Full mode or Integration mode
                    (final_pred, aux_preds, critic_outputs, expert_features, # Get features for detached critic input
                     rl_log_probs, rl_entropies, decider_feature, _) = outputs

                    # 1. Primary Loss (Always calculated)
                    primary_loss = mae_loss_fn(final_pred, batch_y)

                    # 2. Auxiliary Loss 
                    if model.config.ablate_auxiliary_mae: aux_loss = torch.tensor(0.0, device=device)
                    else: loss_e_aux = F.l1_loss(aux_preds[0], batch_y[:,0:1]); loss_l_aux = F.l1_loss(aux_preds[1], batch_y[:,1:2]); loss_g_aux = F.l1_loss(aux_preds[2], batch_y); aux_loss=(loss_e_aux+loss_l_aux+loss_g_aux.mean())/3.0

                    # --- Critic Loss Calculations (Only if critics are being trained) ---
                    if phase != "PRETRAIN":
                        # 3. Calculate Aux Error Quality Targets
                        with torch.no_grad(): target_q_e, target_q_l, target_q_g_vec = compute_aux_error_quality_targets(aux_preds, batch_y, quality_kappa=model.config.quality_kappa); target_q_g_d=target_q_g_vec[:,0:1]; target_q_g_l=target_q_g_vec[:,1:2]

                        # --- Get Critic Outputs using inputs for loss calc ---
                        critic_input_e = torch.cat([expert_features[0].detach(), aux_preds[0].detach()], dim=1)
                        critic_input_l = torch.cat([expert_features[1].detach(), aux_preds[1].detach()], dim=1)
                        critic_input_g = torch.cat([expert_features[2].detach(), aux_preds[2].detach()], dim=1)
                        alpha_e, beta_e, corr_e = model.critic_early(critic_input_e)
                        alpha_l, beta_l, corr_l = model.critic_late(critic_input_l)
                        alpha_g, beta_g, corr_g = model.critic_global(critic_input_g)
                        # We use alpha/beta/corr for loss calculation below

                        # 4. Critic Quality Loss (Evidential + KL)
                        L_evi_e=compute_evidential_loss(alpha_e,beta_e,target_q_e); L_KL_e=compute_kl_divergence_loss(alpha_e,beta_e)
                        L_evi_l=compute_evidential_loss(alpha_l,beta_l,target_q_l); L_KL_l=compute_kl_divergence_loss(alpha_l,beta_l)
                        alpha_g_d,beta_g_d=alpha_g[:,0:1],beta_g[:,0:1]; alpha_g_l,beta_g_l=alpha_g[:,1:2],beta_g[:,1:2]
                        L_evi_g_d=compute_evidential_loss(alpha_g_d,beta_g_d,target_q_g_d); L_KL_g_d=compute_kl_divergence_loss(alpha_g_d,beta_g_d)
                        L_evi_g_l=compute_evidential_loss(alpha_g_l,beta_g_l,target_q_g_l); L_KL_g_l=compute_kl_divergence_loss(alpha_g_l,beta_g_l)
                        critic_quality_loss = torch.mean( L_evi_e+L_evi_l+L_evi_g_d+L_evi_g_l + model.config.lambda_KL*(L_KL_e+L_KL_l+L_KL_g_d+L_KL_g_l) )

                        # 5. Correction Loss (Huber) - Use aux_preds and correction signals from main forward pass
                        # Get the correction signals generated by the forward pass inside the model
                        corr_e_fwd, corr_l_fwd, corr_g_fwd = critic_outputs["corr_early"], critic_outputs["corr_late"], critic_outputs["corr_global"]
                        if not model.config.ablate_correction:
                            corr_loss_e=huber_loss_fn(aux_preds[0]+model.config.damping_factor*corr_e_fwd, batch_y[:,0:1]) 
                            corr_loss_l=huber_loss_fn(aux_preds[1]+model.config.damping_factor*corr_l_fwd, batch_y[:,1:2])
                            corr_loss_g=huber_loss_fn(aux_preds[2]+model.config.damping_factor*corr_g_fwd, batch_y)
                            correction_loss = (torch.mean(corr_loss_e)+torch.mean(corr_loss_l)+torch.mean(corr_loss_g))/3.0
                        else: correction_loss = torch.tensor(0.0, device=device)

                        # 6. Evidence Penalty Loss 
                        if not model.config.ablate_correction:
                            L_pen_e=compute_evidence_penalty_loss(alpha_e,beta_e,corr_e,model.config.gamma_penalty) 
                            L_pen_l=compute_evidence_penalty_loss(alpha_l,beta_l,corr_l,model.config.gamma_penalty)
                            L_pen_g_d=compute_evidence_penalty_loss(alpha_g_d,beta_g_d,corr_g[:,0:1],model.config.gamma_penalty)
                            L_pen_g_l=compute_evidence_penalty_loss(alpha_g_l,beta_g_l,corr_g[:,1:2],model.config.gamma_penalty)
                            evidence_penalty_loss = torch.mean(L_pen_e + L_pen_l + L_pen_g_d + L_pen_g_l)
                        else: evidence_penalty_loss = torch.tensor(0.0, device=device)
                    else: # If PRETRAIN, critic losses are zero
                        critic_quality_loss = torch.tensor(0.0, device=device)
                        correction_loss = torch.tensor(0.0, device=device)
                        evidence_penalty_loss = torch.tensor(0.0, device=device)


                    # --- Optional Losses ---
                    div_loss = diversity_loss(aux_preds, margin=0.3) * model.config.lambda_diversity # Apply weight here
            

                    # --- Total Loss Combination ---
                    # Note: We calculate ALL losses, but only gradients relevant to the active optimizers will be used.
                    total_loss = ( model.config.lambda_primary * primary_loss
                                  + model.config.lambda_aux * aux_loss * (0.0 if model.config.ablate_auxiliary_mae else 1.0)
                                  + model.config.lambda_critic_quality * critic_quality_loss
                                  + model.config.lambda_corr * correction_loss
                                  + model.config.lambda_penalty * evidence_penalty_loss
                                  + model.config.lambda_diversity * div_loss
                                  + model.config.rl_weight * rl_loss
                                 )

                    # Store epoch losses
                    epoch_losses["primary"] += primary_loss.item() * batch_x.size(0)
                    epoch_losses["aux"] += aux_loss.item() * batch_x.size(0)
                    epoch_losses["quality"] += critic_quality_loss.item() * batch_x.size(0)
                    epoch_losses["corr"] += correction_loss.item() * batch_x.size(0)
                    epoch_losses["penalty"] += evidence_penalty_loss.item() * batch_x.size(0)
                    epoch_losses["rl"] += rl_loss.item() * batch_x.size(0)
                    epoch_losses["div"] += div_loss.item() * batch_x.size(0)
                    epoch_losses["total"] += total_loss.item() * batch_x.size(0)

            # --- Backprop & Step (Using Active Optimizers) ---
            loss_val = total_loss.item()
            if not np.isfinite(loss_val): print(f"[ERROR] Loss is NaN/Inf at E{epoch+1} B{batch_idx+1}. Skipping."); [opt.zero_grad() for opt in active_optimizers]; continue
            if loss_val > 1e5: print(f"[WARN] High loss: {loss_val:.4f} at E{epoch+1} B{batch_idx+1}")

            scaler.scale(total_loss).backward() # Calculate gradients for the whole graph

            # Unscale and step ONLY the active optimizers for the current phase
            for opt in active_optimizers: scaler.unscale_(opt)

            # Clip gradients for the whole model after unscaling active optimizers
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            # Step active optimizers
            for opt in active_optimizers: scaler.step(opt)

            scaler.update() # Update scaler once
            

            # --- Accumulate Batch Info ---
            global_step += 1; bs = batch_x.size(0); total_samples += bs
           
            total_grad_norm += grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm; grad_steps += 1

            # --- Log Batch Details Periodically  ---
            if (batch_idx + 1) % 100 == 0:
                 log_prefix = f"E{epoch+1} B{batch_idx+1}"
                 if model.mode == "pretrain_experts":
                     print(f"{log_prefix}: Pretrain Loss={total_loss.item():.4f}")
                 else:
                     # Log detailed losses
                     print(f"{log_prefix}: Loss={total_loss.item():.4f} [Pri={primary_loss.item():.4f} Aux={aux_loss.item():.4f} Q={critic_quality_loss.item():.4f} C={correction_loss.item():.4f} P={evidence_penalty_loss.item():.4f} RL={rl_loss.item():.4f} Div={div_loss.item():.4f}] Grad={grad_norm:.2f}")
                     # Log Quality Stats (Mean Quality)
                     with torch.no_grad():
                        quality_scores_mean_print = critic_outputs["quality_scores_mean"] # Shape (B, 3)
                        q_stats = {"min": quality_scores_mean_print.min().item(), "max": quality_scores_mean_print.max().item(),
                                   "mean": quality_scores_mean_print.mean().item(), "std": quality_scores_mean_print.std().item()}
                        print(f"    Quality (Mean) Stats: {q_stats}")
                     # Log Correction Stats
                     correction_signals = [corr_e, corr_l, corr_g]
                     with torch.no_grad():
                         # Use correction_signals list 
                         correction_signals_print = [c.detach().cpu() for c in correction_signals]
                         for i, cs in enumerate(correction_signals_print):
                             min_v = cs.min().item(); max_v = cs.max().item(); mean_v = cs.mean().item(); std_v = cs.std().item()
                             print(f"    Correction Signal [{i}] Stats: min={min_v:.3f} max={max_v:.3f} mean={mean_v:.3f} std={std_v:.3f}")
                     # Log Decider Feature Stats
                     with torch.no_grad():
                         decider_stats = {"min": decider_feature.min().item(), "max": decider_feature.max().item(),
                                          "mean": decider_feature.mean().item(), "std": decider_feature.std().item()}
                         print(f"    Decider Feature Stats: {decider_stats}")
                 sys.stdout.flush()

        # --- End of Epoch ---
        avg_grad_norm = total_grad_norm / grad_steps if grad_steps > 0 else 0.0
        if total_samples > 0:
             for key in epoch_losses: epoch_losses[key] /= total_samples
        else:
             for key in epoch_losses: epoch_losses[key] = 0.0 # Avoid NaN

        # Write Epoch Summaries to TensorBoard
        writer.add_scalar("Loss/Train_Total", epoch_losses["total"], epoch)
        writer.add_scalar("Grad_Norm/Avg", avg_grad_norm, epoch)
        # Log LRs for active optimizers
        for i, opt in enumerate([optimizer_experts, optimizer_decider, optimizer_critics]):
             if opt in active_optimizers: # Log LR if optimizer was used
                 writer.add_scalar(f"Learning_Rate/Group_{i}", opt.param_groups[0]['lr'], epoch)

        if model.mode != "pretrain_experts":
            writer.add_scalar("Loss/Primary", epoch_losses["primary"], epoch)
            writer.add_scalar("Loss/Auxiliary", epoch_losses["aux"], epoch)
            writer.add_scalar("Loss/Critic_Quality", epoch_losses["quality"], epoch)
            writer.add_scalar("Loss/Critic_Correction", epoch_losses["corr"], epoch)
            writer.add_scalar("Loss/Critic_Penalty", epoch_losses["penalty"], epoch)
            writer.add_scalar("Loss/RL", epoch_losses["rl"], epoch)
            writer.add_scalar("Loss/Diversity", epoch_losses["div"], epoch)
            writer.add_histogram("GateWeights", model.final_head.last_gate_weights.cpu(), epoch)
            # Add alpha/beta histograms from last batch if critic_outputs exists
            if 'critic_outputs' in locals() and critic_outputs:
                 writer.add_histogram("Quality/Alpha_Early", critic_outputs["alpha_early"].detach().cpu(), epoch)
                 writer.add_histogram("Quality/Beta_Early", critic_outputs["beta_early"].detach().cpu(), epoch)
                 writer.add_histogram("Correction/Early", critic_outputs["corr_early"].detach().cpu(), epoch)
                 # Add Late and Global if needed

        # Print Epoch Summary 
        print(f"\nEpoch {epoch+1}/{num_epochs} -> Avg Train Loss={epoch_losses['total']:.6f}, Avg Grad Norm: {avg_grad_norm:.4f}")
        if model.mode != "pretrain_experts":
            print(f"  Avg Losses: Pri={epoch_losses['primary']:.4f} Aux={epoch_losses['aux']:.4f} Q={epoch_losses['quality']:.4f} C={epoch_losses['corr']:.4f} P={epoch_losses['penalty']:.4f} RL={epoch_losses['rl']:.4f} Div={epoch_losses['div']:.4f}")

        # --- Validation Epoch ---
        model.eval(); val_loss_sum = 0.0; num_val_samples = 0
        with torch.no_grad():
            for batch_x, batch_y in val_loader:
                batch_x=batch_x.to(device); batch_y=batch_y.to(device)
                with torch.amp.autocast(device_type=device.type, enabled=device.type == 'cuda'):
                     if model.mode == "pretrain_experts": final_pred, _, _ = model(batch_x)
                     else: final_pred, _, _, _, _, _, _, _ = model(batch_x)
                     loss = mae_loss_fn(final_pred, batch_y)
                bs=batch_x.size(0); val_loss_sum+=loss.item()*bs; num_val_samples+=bs
            epoch_val_loss = val_loss_sum / num_val_samples if num_val_samples > 0 else float('inf')
            writer.add_scalar("Loss/Validation", epoch_val_loss, epoch)
            print(f"Epoch {epoch+1}/{num_epochs} -> VAL_LOSS={epoch_val_loss:.6f}") # This is Primary MAE Loss

            # --- Sample Predictions ---
            try:
                # Use a separate loader for consistent sampling if needed, or just grab first batch
                val_batch_x_sample, val_batch_y_sample = next(iter(val_loader))
                val_batch_x_sample=val_batch_x_sample.to(device); val_batch_y_sample=val_batch_y_sample.to(device)
                with torch.amp.autocast(device_type=device.type, enabled=device.type=='cuda'):
                    model.eval() # Ensure model is in eval mode
                    if model.mode == "pretrain_experts": final_pred_val, _, _ = model(val_batch_x_sample); critic_outputs_val=None
                    else: final_pred_val, _, critic_outputs_val, _, _, _, _, _ = model(val_batch_x_sample)

                val_batch_y_cpu=val_batch_y_sample.cpu(); final_pred_cpu=final_pred_val.cpu()
                quality_vals_print_map = {} # Use map for index safety
                batch_size_val = val_batch_x_sample.size(0)
                num_examples = min(5, batch_size_val)
                indices = np.random.choice(batch_size_val, num_examples, replace=False)

                if critic_outputs_val and "quality_scores_mean" in critic_outputs_val:
                    q_mean_cpu = critic_outputs_val["quality_scores_mean"].cpu()
                    for sample_idx in indices:
                        if sample_idx < q_mean_cpu.shape[0]: # Check index bounds
                            q_vals = q_mean_cpu[sample_idx].tolist()
                            quality_vals_print_map[sample_idx] = f"[{q_vals[0]:.2f},{q_vals[1]:.2f},{q_vals[2]:.2f}]"
                        else: quality_vals_print_map[sample_idx] = "N/A (Index out of bounds)"
                else:
                    for idx in indices: quality_vals_print_map[idx] = "N/A"

                print("Random validation sample predictions:")
                for idx in indices: # Iterate through selected indices
                    gt_d=val_batch_y_cpu[idx,0].item(); pred_d=final_pred_cpu[idx,0].item()
                    gt_l=val_batch_y_cpu[idx,1].item(); pred_l=final_pred_cpu[idx,1].item()
                    log_str = f"  Ex {idx}: GT D:{gt_d:.3f} P:{pred_d:.3f} | GT L:{gt_l:.3f} P:{pred_l:.3f}"
                    log_str += f" | Q(E,L,G_mean):{quality_vals_print_map.get(idx, 'Error')}"
                    print(log_str)
            except StopIteration: print("Validation loader empty, cannot show sample predictions.")
            sys.stdout.flush()

        # --- Checkpoint Saving ---
        if model.mode != "pretrain_experts":
            if epoch_val_loss < best_val_loss:
                best_val_loss = epoch_val_loss; best_model_state = model.state_dict()
                best_ckpt = f"best_checkpoint_{config_suffix}.pth"
                if dist.get_rank() == 0:
                    torch.save(best_model_state, best_ckpt); print(f"    Best model saved (Val Loss: {best_val_loss:.6f}) to {best_ckpt}"); patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= patience: print(f"Early stopping triggered after {epoch+1} epochs."); break

    # --- End of Training ---
    final_ckpt = f"final_{config_suffix}.pth"
    if best_model_state: print(f"Loading best state (Val Loss: {best_val_loss:.6f})"); model.load_state_dict(best_model_state)
    else: print("No best state found or only pretraining done, saving last state.")
        
    if dist.get_rank() == 0:
        torch.save(model.state_dict(), final_ckpt); print(f"Final model saved to {final_ckpt}")
        writer.add_text("Training Summary", f"Complete. Best val loss: {best_val_loss:.6f}. Saved: {final_ckpt}", 0); writer.close()
        print(f"Training complete. Best val loss: {best_val_loss:.6f}")
    return best_val_loss

In [None]:
######################################
# Helper: Suffix Generation 
######################################
def get_config_suffix(config: EvidenceMoeConfig):
    """Generates a filename suffix based on active ablations in the config."""
    parts = []
    if config.ablate_quality_weighting: parts.append("noQualW")
    if config.ablate_correction:        parts.append("noCorr")
    if config.ablate_quality_in_gating: parts.append("noQualGate")
    if config.ablate_decider_feature:   parts.append("noDecFeat")
    if config.ablate_decider_feature_fusion: parts.append("noDecFus")
    if config.ablate_uniform_gating:    parts.append("uniGate")
    if config.ablate_gating_dropout:    parts.append("noGateDrop")
    if config.ablate_phased_training:   parts.append("noPhase")
    if config.ablate_mean_pooling:      parts.append("meanPool")
    if config.ablate_auxiliary_mae:     parts.append("noAuxMAE")
    return "_k2_".join(parts) if parts else "model_k2_p10_20"

In [None]:
######################################
# 14. Main Execution Block
######################################
def main_worker(rank: int, world_size: int, args):
    # ——— DDP setup ———
    torch.cuda.set_device(rank)
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cuda.matmul.allow_tf32 = True

    dist.init_process_group("nccl", rank=rank, world_size=world_size)

    # 3) index files
    main_folder = args.data_path
    if not os.path.isdir(main_folder):
        sys.exit(f"ERROR: Data folder not found: {main_folder}")
    file_entries = []
    print(f"[Rank {rank}] indexing `{main_folder}` …")
    for root, dirs, files in os.walk(main_folder):
        for f in files:
            if f.endswith(".mat"):
                try:
                    parent_path = os.path.dirname(os.path.join(root, f))
                    parent = os.path.basename(parent_path)
                    grandparent = os.path.basename(os.path.dirname(parent_path))
                    if "lt_" in parent:
                        raw_lt_str = parent.split("lt_")[1]
                    elif "lt_" in grandparent:
                        raw_lt_str = grandparent.split("lt_")[1]
                    else:
                        continue
                    raw_lt = float(raw_lt_str)
                    file_entries.append((os.path.join(root, f), depth_value, lifetime_value))
                except:
                    continue

    if not file_entries:
        sys.exit("No data files found. Exiting.")
    np.random.shuffle(file_entries)
    train_entries, val_entries = train_test_split(
        file_entries,
        test_size=0.2,
        random_state=args.global_seed,
    )
    if rank == 0:
        print(f"  Total files: {len(file_entries)}, train: {len(train_entries)}, val: {len(val_entries)}")

    import math
    # ── shard file list by rank ────────────────────────────────────────────
    def shard_list(lst, rank, world_size):
        per = math.ceil(len(lst) / world_size)
        start = rank * per
        end   = min(start + per, len(lst))
        return lst[start:end]

    # 4) shard file list by process rank, then make loaders 
    train_entries_rank = train_entries[rank::world_size]
    val_entries_rank   = val_entries[rank::world_size]

    train_loader = DataLoader(
        train_dataset,
        batch_size=args.global_batch_size // world_size,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=args.global_batch_size // world_size,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=False,
    )

    # 5) build model + wrap in DDP
    device = torch.device("cuda", rank)
    config = EvidenceMoeConfig(
        hidden_size=224,
        intermediate_size=384,
        num_hidden_layers=2,
        num_attention_heads=16,
        expert_lr=5e-5,
        decider_lr=1e-4,
        critic_lr=1e-4,
        critic_weight_decay=1e-4,
        lambda_KL=1e-3,
        gamma_penalty=1e-3,
        quality_kappa=2.0,
        lambda_primary=1.0,
        lambda_aux=1.0,
        lambda_critic_quality=1.0,
        lambda_corr=1.0,
        lambda_penalty=1.0,
        lambda_diversity=0.0,
        damping_factor=0.1,
       # Select Ablations for this run
        ablate_quality_weighting=False, ablate_correction=False, ablate_quality_in_gating=False,
        ablate_decider_feature=False, ablate_decider_feature_fusion=False, ablate_uniform_gating=False,
        ablate_gating_dropout=False, ablate_phased_training = False, ablate_mean_pooling=False,
        ablate_auxiliary_mae=False,
    )
    start_mode = "full" if config.ablate_phased_training else "pretrain_experts"
    model = EvidenceMoeModel(config, mode=start_mode).to(device)
    model = DDP(model, device_ids=[rank])
    model.config = config
    # 6) train
    best_val = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        num_epochs=500,
        pretrain_epochs=10,
        integration_epochs=20,
    )

    if rank == 0:
        print(f"\n Training complete. Best Val MAE: {best_val:.6f}")

    # 7) cleanup
    dist.barrier()
    dist.destroy_process_group()
def launch_ddp(world_size: int, args):
    procs = []
    for rank in range(world_size):
        p = mp.Process(target=main_worker, args=(rank, world_size, args))
        p.start()
        procs.append(p)
    for p in procs:
        p.join()
def pick_free_port():
    """Find an unused TCP port on localhost."""
    import socket
    sock = socket.socket()
    sock.bind(("127.0.0.1", 0))
    _, port = sock.getsockname()
    sock.close()
    return port

if __name__ == "__main__":
    # manually build an args object
    from types import SimpleNamespace
    
    args = SimpleNamespace(
        data_path="",
        global_batch_size=512,
        global_seed=42,
        num_workers=8,
    )
    
    # set up DDP env (you can keep your pick_free_port() code above)
    port = pick_free_port()
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = str(port)
    
    ngpus = torch.cuda.device_count()
    assert ngpus > 1, "Need multiple GPUs for DDP."
    
    # launch training
    launch_ddp(ngpus, args)

