In [None]:
# Setup and Imports
#!/usr/bin/env python3
import ast
import json
from pathlib import Path
from typing import Optional, Union

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.transforms as transforms
from PIL import Image
from sklearn.datasets import make_multilabel_classification
from sklearn.preprocessing import StandardScaler

# Import Phase 1 (basic multi-label)
from core.multilabel_net import MultiLabelMEDAF
from core.multilabel_train import train_multilabel

# Import Phase 2 (per-class gating)
from core.multilabel_net_v2 import MultiLabelMEDAFv2
from core.multilabel_train_v2 import train_multilabel_v2, ComparativeTrainingFramework

# Import test utilities (synthetic dataset fallback)
from test_multilabel_medaf import SyntheticMultiLabelDataset

torch.__version__, torch.cuda.is_available()


In [None]:
# Constants and Paths
KNOWN_LABELS = [
    "Atelectasis",
    "Cardiomegaly",
    "Effusion",
    "Infiltration",
    "Mass",
    "Nodule",
    "Pneumonia",
    "Pneumothorax",
]

DEFAULT_IMAGE_ROOT = Path("datasets/data/NIH/images-224")
DEFAULT_KNOWN_CSV = Path("datasets/data/NIH/chestxray_train_known.csv")
DEFAULT_CHECKPOINT_DIR = Path("checkpoints/medaf_phase1")


In [None]:
# Dataset: ChestXrayKnownDataset
class ChestXrayKnownDataset(data.Dataset):
    """Dataset that reads the known-label ChestX-ray14 split for Phase 1 training."""

    def __init__(
        self,
        csv_path: Path,
        image_root: Path,
        img_size: int = 224,
        max_samples: Optional[int] = None,
        transform=None,
    ) -> None:
        self.csv_path = Path(csv_path)
        self.image_root = Path(image_root)
        self.img_size = img_size
        self.class_names = KNOWN_LABELS
        self.num_classes = len(self.class_names)

        if not self.csv_path.exists():
            raise FileNotFoundError(f"ChestX-ray CSV not found: {self.csv_path}")
        if not self.image_root.exists():
            raise FileNotFoundError(
                f"ChestX-ray image directory not found: {self.image_root}"
            )

        if transform is None:
            self.transform = transforms.Compose(
                [
                    transforms.Resize((img_size, img_size)),
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225],
                    ),
                ]
            )
        else:
            self.transform = transform

        df = pd.read_csv(self.csv_path)
        if "known_labels" not in df.columns:
            raise ValueError(
                "Expected 'known_labels' column in CSV. Run utils/create_chestxray_splits.py first."
            )

        if max_samples is not None and max_samples < len(df):
            df = df.sample(n=max_samples, random_state=42).reset_index(drop=True)
        else:
            df = df.reset_index(drop=True)

        self.records = df.to_dict("records")
        self.label_to_idx = {label: idx for idx, label in enumerate(self.class_names)}

    @staticmethod
    def _parse_label_list(raw_value):
        if isinstance(raw_value, list):
            return raw_value
        if pd.isna(raw_value):
            return []
        if isinstance(raw_value, str):
            raw_value = raw_value.strip()
            if not raw_value:
                return []
            try:
                parsed = ast.literal_eval(raw_value)
                if isinstance(parsed, (list, tuple)):
                    return list(parsed)
                if isinstance(parsed, str):
                    return [parsed]
            except (ValueError, SyntaxError):
                pass
            return [item.strip() for item in raw_value.split("|") if item.strip()]
        return []

    def __len__(self) -> int:
        return len(self.records)

    def __getitem__(self, idx):
        record = self.records[idx]
        image_path = self.image_root / record["Image Index"]
        if not image_path.exists():
            raise FileNotFoundError(f"Missing image: {image_path}")

        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)

        labels = torch.zeros(self.num_classes, dtype=torch.float32)
        for label in self._parse_label_list(record.get("known_labels", [])):
            if label in self.label_to_idx:
                labels[self.label_to_idx[label]] = 1.0

        return image, labels


In [None]:
# Demo Class: MultiLabelMEDAFDemo
class MultiLabelMEDAFDemo:
    """
    Comprehensive demo for Multi-Label MEDAF
    """

    def __init__(self, config):
        self.config = config
        self.config.setdefault("data_source", "chestxray")
        self.config.setdefault("img_size", 224)
        self.config.setdefault("num_classes", len(KNOWN_LABELS))
        self.config.setdefault("val_ratio", 0.1)
        self.config.setdefault("num_workers", 0)
        self.config.setdefault("checkpoint_dir", str(DEFAULT_CHECKPOINT_DIR))
        self.config.setdefault("phase1_checkpoint", "medaf_phase1_chestxray.pt")
        self.config.setdefault("num_samples", 1000)
        self.config.setdefault("avg_labels_per_sample", 3)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.results = {}
        self.train_loader = None
        self.val_loader = None
        self.test_loader = None
        self.dataset_name = None
        self.class_names = KNOWN_LABELS

    def create_dataset(self):
        """Create train/validation loaders for Phase 1."""

        data_source = self.config.get("data_source", "chestxray").lower()

        if data_source == "chestxray":
            csv_path = Path(self.config.get("known_csv", DEFAULT_KNOWN_CSV))
            image_root = Path(self.config.get("image_root", DEFAULT_IMAGE_ROOT))
            max_samples = self.config.get("max_samples")
            if isinstance(max_samples, str):
                max_samples = int(max_samples)
            print(f"Loading ChestX-ray14 known-label split from {csv_path}")
            dataset = ChestXrayKnownDataset(
                csv_path=csv_path,
                image_root=image_root,
                img_size=self.config.get("img_size", 224),
                max_samples=max_samples,
            )
            self.dataset_name = "ChestX-ray14 (known labels)"
            self.class_names = dataset.class_names
            self.config["num_classes"] = dataset.num_classes
            self.config["img_size"] = dataset.img_size
        else:
            print(
                "ChestX-ray data not requested or unavailable. Falling back to synthetic dataset."
            )
            dataset = SyntheticMultiLabelDataset(
                num_samples=self.config.get("num_samples", 1000),
                img_size=self.config.get("img_size", 32),
                num_classes=self.config.get("num_classes", 8),
                avg_labels_per_sample=self.config.get("avg_labels_per_sample", 3),
                random_state=42,
            )
            self.dataset_name = "Synthetic"
            self.class_names = [f"class_{i}" for i in range(dataset.num_classes)]

        val_ratio = float(self.config.get("val_ratio", 0.1))
        val_size = max(1, int(len(dataset) * val_ratio))
        train_size = len(dataset) - val_size

        train_dataset, val_dataset = data.random_split(
            dataset,
            [train_size, val_size],
            generator=torch.Generator().manual_seed(42),
        )

        batch_size = self.config.get("batch_size", 16)
        num_workers = self.config.get("num_workers", 4)
        pin_memory = torch.cuda.is_available()

        self.train_loader = data.DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=pin_memory,
        )

        self.val_loader = data.DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=pin_memory,
        )

        self.test_loader = self.val_loader

        print(
            f"Dataset prepared: {train_size} train / {val_size} val samples ({self.dataset_name})"
        )

    def demo_phase1(self):
        """Demonstrate Phase 1: Basic Multi-Label MEDAF"""
        print("\n" + "=" * 60)
        print("PHASE 1 DEMO: Basic Multi-Label MEDAF")
        print("=" * 60)

        # Configuration for Phase 1
        args = {
            "img_size": self.config["img_size"],
            "backbone": "resnet18",
            "num_classes": self.config["num_classes"],
            "gate_temp": 100,
            "loss_keys": ["b1", "b2", "b3", "gate", "divAttn", "total"],
            "acc_keys": ["acc1", "acc2", "acc3", "accGate"],
            "loss_wgts": [0.7, 1.0, 0.01],
        }

        # Create Phase 1 model
        model = MultiLabelMEDAF(args)
        model.to(self.device)

        print(
            f"Phase 1 Model Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}"
        )

        # Training setup
        criterion = {"bce": nn.BCEWithLogitsLoss()}
        optimizer = torch.optim.Adam(
            model.parameters(), lr=self.config["learning_rate"]
        )

        # Training
        phase1_metrics = []
        for epoch in range(self.config["num_epochs"]):
            metrics = train_multilabel(
                self.train_loader, model, criterion, optimizer, args, self.device
            )
            phase1_metrics.append(metrics)

            if epoch % 2 == 0:
                print(f"Epoch {epoch}: Loss={metrics:.4f}")

        final_loss = phase1_metrics[-1] if phase1_metrics else float("nan")
        checkpoint_path = self.save_model(model, args, phase1_metrics)

        self.results["phase1"] = {
            "model": model,
            "final_loss": final_loss,
            "metrics_history": phase1_metrics,
            "checkpoint": str(checkpoint_path),
        }

        if phase1_metrics:
            print(f"Phase 1 Final Loss: {final_loss:.4f}")
        else:
            print("Phase 1 completed with zero epochs (no training performed)")
        print(f"Phase 1 checkpoint saved to: {checkpoint_path}")
        print(
            "Use load_phase1_checkpoint(CheckpointPath) to reload this model for evaluation."
        )

    def save_model(self, model, args, loss_history):
        ckpt_dir = Path(self.config["checkpoint_dir"])
        ckpt_dir.mkdir(parents=True, exist_ok=True)
        checkpoint_path = ckpt_dir / self.config["phase1_checkpoint"]

        payload = {
            "state_dict": model.state_dict(),
            "args": args,
            "class_names": self.class_names,
            "dataset": self.dataset_name,
        }
        torch.save(payload, checkpoint_path)

        metadata = {
            "dataset": self.dataset_name,
            "class_names": self.class_names,
            "num_epochs": self.config["num_epochs"],
            "batch_size": self.config.get("batch_size"),
            "learning_rate": self.config.get("learning_rate"),
            "loss_history": [float(loss) for loss in loss_history],
            "device": str(self.device),
            "checkpoint": str(checkpoint_path),
            "config": {
                k: v
                for k, v in self.config.items()
                if isinstance(v, (int, float, str, bool))
            },
        }
        metadata_path = checkpoint_path.with_suffix(".json")
        with metadata_path.open("w", encoding="utf-8") as fp:
            json.dump(metadata, fp, indent=2)

        return checkpoint_path

    def demo_phase2_comparative(self):
        """Demonstrate Phase 2: Comparative Analysis"""
        print("\n" + "=" * 60)
        print("PHASE 2 DEMO: Per-Class Gating Comparative Analysis")
        print("=" * 60)

        # Base configuration
        base_args = {
            "img_size": self.config["img_size"],
            "backbone": "resnet18",
            "num_classes": self.config["num_classes"],
            "gate_temp": 100,
            "loss_keys": ["b1", "b2", "b3", "gate", "divAttn", "total"],
            "acc_keys": ["acc1", "acc2", "acc3", "accGate"],
            "loss_wgts": [0.7, 1.0, 0.01],
            "enhanced_diversity": True,
            "diversity_type": "cosine",
        }

        # Configurations to compare
        configurations = {
            "global_gating": {
                "name": "Global Gating",
                "use_per_class_gating": False,
                "use_label_correlation": False,
            },
            "per_class_gating": {
                "name": "Per-Class Gating",
                "use_per_class_gating": True,
                "use_label_correlation": False,
            },
            "enhanced_per_class": {
                "name": "Enhanced Per-Class",
                "use_per_class_gating": True,
                "use_label_correlation": True,
                "gating_regularization": 0.01,
            },
        }

        # Create comparative framework
        framework = ComparativeTrainingFramework(base_args)
        criterion = {"bce": nn.BCEWithLogitsLoss()}

        phase2_results = {}

        for config_key, config_opts in configurations.items():
            print(f"\n--- Training {config_opts['name']} ---")

            # Merge configuration
            args = base_args.copy()
            args.update({k: v for k, v in config_opts.items() if k != "name"})

            # Create model
            model = MultiLabelMEDAFv2(args)
            model.to(self.device)

            # Print model info
            summary = model.get_gating_summary()
            param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)

            print(f"Model: {summary['gating_type']} gating")
            print(f"Parameters: {param_count:,}")
            print(f"Label correlation: {summary['use_label_correlation']}")

            # Training
            optimizer = torch.optim.Adam(
                model.parameters(), lr=self.config["learning_rate"]
            )

            metrics_history = []
            best_acc = 0

            for epoch in range(self.config["num_epochs"]):
                framework.current_epoch = epoch

                metrics = train_multilabel_v2(
                    self.train_loader,
                    model,
                    criterion,
                    optimizer,
                    args,
                    self.device,
                    framework,
                )

                metrics_history.append(metrics)

                if metrics["subset_acc"] > best_acc:
                    best_acc = metrics["subset_acc"]

                if epoch % 2 == 0:
                    print(
                        f"Epoch {epoch}: Loss={metrics['total_loss']:.4f}, Acc={metrics['subset_acc']:.2f}%"
                    )

            phase2_results[config_key] = {
                "model": model,
                "config": config_opts,
                "best_accuracy": best_acc,
                "final_metrics": metrics_history[-1],
                "metrics_history": metrics_history,
            }

        # Print comparative analysis
        framework.print_comparison()

        self.results["phase2"] = phase2_results

    def analyze_attention_patterns(self):
        """Analyze attention patterns between global and per-class gating"""
        print("\n" + "=" * 60)
        print("ATTENTION PATTERN ANALYSIS")
        print("=" * 60)

        if "phase2" not in self.results:
            print("Phase 2 results not available for analysis")
            return

        # Get sample batch
        sample_batch = next(iter(self.test_loader))
        inputs, targets = sample_batch[0][:4].to(self.device), sample_batch[1][:4].to(
            self.device
        )

        print(f"Analyzing batch with shape: {inputs.shape}")
        print(f"Target labels:\n{targets}")

        # Analyze global vs per-class gating
        global_model = self.results["phase2"]["global_gating"]["model"]
        per_class_model = self.results["phase2"]["per_class_gating"]["model"]

        global_model.eval()
        per_class_model.eval()

        with torch.no_grad():
            # Global gating analysis
            global_outputs = global_model(
                inputs, targets, return_attention_weights=True
            )
            global_gate_pred = global_outputs["gate_pred"]

            print(f"\nGlobal Gating Weights (averaged across samples):")
            print(f"Expert preferences: {global_gate_pred.mean(dim=0)}")

            # Per-class gating analysis
            pc_outputs = per_class_model(inputs, targets, return_attention_weights=True)
            if "per_class_weights" in pc_outputs:
                pc_weights = pc_outputs["per_class_weights"]

                print(f"\nPer-Class Gating Weights:")
                print(f"Shape: {pc_weights.shape}")

                # Average expert preferences per class
                avg_class_prefs = pc_weights.mean(dim=0)
                print(f"Average expert preferences per class:")
                for class_idx in range(avg_class_prefs.shape[0]):
                    expert_prefs = avg_class_prefs[class_idx]
                    dominant_expert = expert_prefs.argmax().item()
                    max_pref = expert_prefs.max().item()
                    print(
                        f"  Class {class_idx}: Expert {dominant_expert} ({max_pref:.3f}) - {expert_prefs}"
                    )

                # Measure specialization
                expert_entropy = -(
                    avg_class_prefs * torch.log(avg_class_prefs + 1e-8)
                ).sum(dim=-1)
                avg_entropy = expert_entropy.mean().item()

                print(f"\nSpecialization Analysis:")
                print(
                    f"Average gating entropy: {avg_entropy:.3f} (lower = more specialized)"
                )
                print(f"Class entropies: {expert_entropy}")

                # Expert usage distribution
                expert_usage = avg_class_prefs.mean(dim=0)
                print(f"Overall expert usage: {expert_usage}")

    def plot_training_curves(self):
        """Plot training curves for comparison"""
        print("\n" + "=" * 60)
        print("PLOTTING TRAINING CURVES")
        print("=" * 60)

        if "phase2" not in self.results:
            print("Phase 2 results not available for plotting")
            return

        try:
            import matplotlib.pyplot as plt
        except ImportError as exc:
            print(f"Matplotlib not available: {exc}")
            return

        plt.figure(figsize=(15, 5))

        # Loss curves
        plt.subplot(1, 3, 1)
        for config_key, results in self.results["phase2"].items():
            metrics_history = results["metrics_history"]
            losses = [m["total_loss"] for m in metrics_history]
            plt.plot(losses, label=results["config"]["name"])
        plt.title("Training Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.legend()
        plt.grid(True)

        # Accuracy curves
        plt.subplot(1, 3, 2)
        for config_key, results in self.results["phase2"].items():
            metrics_history = results["metrics_history"]
            accuracies = [m["subset_acc"] for m in metrics_history]
            plt.plot(accuracies, label=results["config"]["name"])
        plt.title("Subset Accuracy")
        plt.xlabel("Epoch")
        plt.ylabel("Accuracy (%)")
        plt.legend()
        plt.grid(True)

        # Diversity loss curves
        plt.subplot(1, 3, 3)
        for config_key, results in self.results["phase2"].items():
            metrics_history = results["metrics_history"]
            diversity_losses = [m["diversity_loss"] for m in metrics_history]
            plt.plot(diversity_losses, label=results["config"]["name"])
        plt.title("Diversity Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Diversity Loss")
        plt.legend()
        plt.grid(True)

        plt.tight_layout()
        plt.savefig("multilabel_medaf_comparison.png", dpi=150, bbox_inches="tight")
        plt.show()

        print("Training curves saved as 'multilabel_medaf_comparison.png'")

    def print_final_summary(self):
        """Print comprehensive summary"""
        print("\n" + "=" * 70)
        print("MULTI-LABEL MEDAF DEMO SUMMARY")
        print("=" * 70)

        # Phase 1 summary
        if "phase1" in self.results:
            print("\n📊 Phase 1 (Basic Multi-Label MEDAF):")
            print(f"   Final Loss: {self.results['phase1']['final_loss']:.4f}")
            if "checkpoint" in self.results["phase1"]:
                print(f"   Checkpoint: {self.results['phase1']['checkpoint']}")

        # Phase 2 summary
        if "phase2" in self.results:
            print("\n🎯 Phase 2 (Per-Class Gating Comparative Analysis):")

            for config_key, results in self.results["phase2"].items():
                config_name = results["config"]["name"]
                best_acc = results["best_accuracy"]
                final_loss = results["final_metrics"]["total_loss"]

                print(
                    f"   {config_name:20}: Acc={best_acc:6.2f}%, Loss={final_loss:.4f}"
                )

            # Find best configuration
            best_config = max(
                self.results["phase2"].items(), key=lambda x: x[1]["best_accuracy"]
            )

            print(f"\n🏆 Best Configuration: {best_config[1]['config']['name']}")
            print(f"   Best Accuracy: {best_config[1]['best_accuracy']:.2f}%")

            # Calculate improvements
            if (
                "global_gating" in self.results["phase2"]
                and "per_class_gating" in self.results["phase2"]
            ):
                global_acc = self.results["phase2"]["global_gating"]["best_accuracy"]
                pc_acc = self.results["phase2"]["per_class_gating"]["best_accuracy"]
                improvement = pc_acc - global_acc

                print(f"\n📈 Per-Class Gating Improvement: {improvement:+.2f}%")

                if improvement > 0:
                    print("   ✅ Per-class gating shows performance benefits!")
                else:
                    print(
                        "   ℹ️  Results may vary with longer training and real datasets"
                    )

        train_count = len(self.train_loader.dataset) if self.train_loader else 0
        val_count = len(self.val_loader.dataset) if self.val_loader else 0

        print(f"\n🔧 Configuration Used:")
        print(
            f"   Dataset: {self.dataset_name} | train {train_count}, val {val_count}, {self.config['num_classes']} classes"
        )
        print(
            f"   Training: {self.config['num_epochs']} epochs, batch size {self.config['batch_size']}"
        )
        print(f"   Device: {self.device}")

        print(f"\n📝 Key Insights:")
        print(f"   • Multi-label MEDAF successfully handles multiple labels per sample")
        print(f"   • Per-class gating enables class-specific expert specialization")
        print(
            f"   • Attention diversity encourages experts to focus on different regions"
        )
        print(f"   • Configurable architecture allows easy experimentation")

        print("\n🚀 Next Steps:")
        print("   1. Experiment with real multi-label datasets (PASCAL VOC, MS-COCO)")
        print("   2. Conduct comprehensive ablation studies")
        print("   3. Implement advanced research extensions")
        print("   4. Scale to larger models and datasets")

    def run_demo(self):
        """Run complete demonstration"""
        print("🎬 Multi-Label MEDAF Complete Demonstration")
        if self.config.get("run_phase2"):
            print("Phase 1: Basic Multi-Label + Phase 2: Per-Class Gating")
        else:
            print("Phase 1: Basic Multi-Label Training")

        # Create dataset
        self.create_dataset()

        # Demo Phase 1
        self.demo_phase1()

        # Demo Phase 2 with comparative analysis
        if self.config.get("run_phase2"):
            self.demo_phase2_comparative()

        # Analyze attention patterns
        if self.config.get("run_phase2"):
            self.analyze_attention_patterns()

        # Plot results
        if self.config.get("run_phase2"):
            try:
                self.plot_training_curves()
            except Exception as e:
                print(f"Plotting failed: {e} (matplotlib may not be available)")

        # Final summary
        self.print_final_summary()


In [None]:
# Utility: Load Phase 1 Checkpoint
def load_phase1_checkpoint(
    checkpoint_path: Union[str, Path],
    device: Union[str, torch.device, None] = None,
):
    """Load a saved Phase 1 MEDAF checkpoint."""

    device_obj = torch.device(device) if device else torch.device("cpu")
    checkpoint = torch.load(checkpoint_path, map_location=device_obj)

    args = checkpoint.get("args")
    if args is None:
        raise KeyError("Checkpoint is missing 'args'.")

    model = MultiLabelMEDAF(args)
    model.load_state_dict(checkpoint["state_dict"])
    model.to(device_obj)
    model.eval()

    return model, checkpoint


In [None]:
# Override: MultiLabelMEDAF (from core.multilabel_net.py)
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from core.net import build_backbone, conv1x1, Classifier

class MultiLabelMEDAF(nn.Module):
    """
    Multi-Label version of MEDAF (Multi-Expert Diverse Attention Fusion)

    Key changes from original MEDAF:
    1. Support for multi-hot label targets
    2. BCEWithLogitsLoss instead of CrossEntropyLoss
    3. Multi-label attention diversity computation
    4. Per-sample CAM extraction for multiple positive classes
    """

    def __init__(self, args=None):
        super(MultiLabelMEDAF, self).__init__()
        backbone, feature_dim, self.cam_size = build_backbone(
            img_size=args["img_size"],
            backbone_name=args["backbone"],
            projection_dim=-1,
            inchan=3,
        )
        self.img_size = args["img_size"]
        self.gate_temp = args["gate_temp"]
        self.num_classes = args["num_classes"]
        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        # Shared layers (L1-L3)
        self.shared_l3 = nn.Sequential(*list(backbone.children())[:-6])

        # Expert branch 1
        self.branch1_l4 = nn.Sequential(*list(backbone.children())[-6:-3])
        self.branch1_l5 = nn.Sequential(*list(backbone.children())[-3])
        self.branch1_cls = conv1x1(feature_dim, self.num_classes)

        # Expert branch 2 (deep copy)
        self.branch2_l4 = copy.deepcopy(self.branch1_l4)
        self.branch2_l5 = copy.deepcopy(self.branch1_l5)
        self.branch2_cls = conv1x1(feature_dim, self.num_classes)

        # Expert branch 3 (deep copy)
        self.branch3_l4 = copy.deepcopy(self.branch1_l4)
        self.branch3_l5 = copy.deepcopy(self.branch1_l5)
        self.branch3_cls = conv1x1(feature_dim, self.num_classes)

        # Gating network
        self.gate_l3 = copy.deepcopy(self.shared_l3)
        self.gate_l4 = copy.deepcopy(self.branch1_l4)
        self.gate_l5 = copy.deepcopy(self.branch1_l5)
        self.gate_cls = nn.Sequential(
            Classifier(feature_dim, int(feature_dim / 4), bias=True),
            Classifier(int(feature_dim / 4), 3, bias=True),  # 3 experts
        )

    def forward(self, x, y=None, return_ft=False):
        """
        Forward pass for multi-label MEDAF

        Args:
            x: Input tensor [B, C, H, W]
            y: Multi-hot labels [B, num_classes] or None
            return_ft: Whether to return features

        Returns:
            Dictionary containing logits, gate predictions, and CAMs/features
        """
        b = x.size(0)
        ft_till_l3 = self.shared_l3(x)

        # Expert branch 1
        branch1_l4 = self.branch1_l4(ft_till_l3.clone())
        branch1_l5 = self.branch1_l5(branch1_l4)
        b1_ft_cams = self.branch1_cls(branch1_l5)  # [B, num_classes, H, W]
        b1_logits = self.avg_pool(b1_ft_cams).view(b, -1)

        # Expert branch 2
        branch2_l4 = self.branch2_l4(ft_till_l3.clone())
        branch2_l5 = self.branch2_l5(branch2_l4)
        b2_ft_cams = self.branch2_cls(branch2_l5)  # [B, num_classes, H, W]
        b2_logits = self.avg_pool(b2_ft_cams).view(b, -1)

        # Expert branch 3
        branch3_l4 = self.branch3_l4(ft_till_l3.clone())
        branch3_l5 = self.branch3_l5(branch3_l4)
        b3_ft_cams = self.branch3_cls(branch3_l5)  # [B, num_classes, H, W]
        b3_logits = self.avg_pool(b3_ft_cams).view(b, -1)

        # Store CAMs for diversity loss computation
        cams_list = [b1_ft_cams, b2_ft_cams, b3_ft_cams]

        # Multi-label CAM extraction for positive classes
        if y is not None:
            multi_label_cams = self._extract_multilabel_cams(cams_list, y)
        else:
            multi_label_cams = None

        if return_ft:
            # Aggregate features from all experts
            fts = (
                b1_ft_cams.detach().clone()
                + b2_ft_cams.detach().clone()
                + b3_ft_cams.detach().clone()
            )

        # Gating network
        gate_l5 = self.gate_l5(self.gate_l4(self.gate_l3(x)))
        gate_pool = self.avg_pool(gate_l5).view(b, -1)
        gate_pred = F.softmax(self.gate_cls(gate_pool) / self.gate_temp, dim=1)

        # Adaptive fusion using gating weights
        gate_logits = torch.stack(
            [b1_logits.detach(), b2_logits.detach(), b3_logits.detach()], dim=-1
        )
        gate_logits = gate_logits * gate_pred.view(
            gate_pred.size(0), 1, gate_pred.size(1)
        )
        gate_logits = gate_logits.sum(-1)

        logits_list = [b1_logits, b2_logits, b3_logits, gate_logits]

        if return_ft and y is None:
            outputs = {
                "logits": logits_list,
                "gate_pred": gate_pred,
                "fts": fts,
                "cams_list": cams_list,
            }
        else:
            outputs = {
                "logits": logits_list,
                "gate_pred": gate_pred,
                "multi_label_cams": multi_label_cams,
                "cams_list": cams_list,
            }

        return outputs

    def _extract_multilabel_cams(self, cams_list, targets):
        """
        Extract CAMs for all positive classes in multi-label setting
        """
        batch_size = targets.size(0)
        extracted_cams = []

        for expert_idx, expert_cams in enumerate(cams_list):
            expert_extracted = []

            for batch_idx in range(batch_size):
                # Find positive class indices for this sample
                positive_classes = torch.where(targets[batch_idx] == 1)[0]

                if len(positive_classes) > 0:
                    # Extract CAMs for positive classes
                    sample_cams = expert_cams[
                        batch_idx, positive_classes
                    ]  # [num_positive, H, W]
                    expert_extracted.append(sample_cams)
                else:
                    # If no positive classes, create zero tensor
                    H, W = expert_cams.shape[-2:]
                    expert_extracted.append(
                        torch.zeros(1, H, W, device=expert_cams.device)
                    )

            extracted_cams.append(expert_extracted)

        return extracted_cams

    def get_params(self, prefix="extractor"):
        """Get model parameters for different learning rates"""
        extractor_params = (
            list(self.shared_l3.parameters())
            + list(self.branch1_l4.parameters())
            + list(self.branch1_l5.parameters())
            + list(self.branch2_l4.parameters())
            + list(self.branch2_l5.parameters())
            + list(self.branch3_l4.parameters())
            + list(self.branch3_l5.parameters())
            + list(self.gate_l3.parameters())
            + list(self.gate_l4.parameters())
            + list(self.gate_l5.parameters())
        )
        extractor_params_ids = list(map(id, extractor_params))
        classifier_params = filter(
            lambda p: id(p) not in extractor_params_ids, self.parameters()
        )

        if prefix in ["extractor", "extract"]:
            return extractor_params
        elif prefix in ["classifier"]:
            return classifier_params


In [None]:
# Override: Training utilities (from core.multilabel_train.py)
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from misc.util import *


def multiLabelAttnDiv(cams_list, targets, eps=1e-6):
    """
    Multi-label attention diversity loss
    """
    if targets is None or targets.sum() == 0:
        return torch.tensor(0.0, device=cams_list[0].device)

    cos = nn.CosineSimilarity(dim=1, eps=eps)
    diversity_loss = 0.0
    total_pairs = 0
    batch_size = targets.size(0)

    for batch_idx in range(batch_size):
        positive_classes = torch.where(targets[batch_idx] == 1)[0]
        if len(positive_classes) == 0:
            continue

        for class_idx in positive_classes:
            expert_cams = torch.stack(
                [
                    cams_list[0][batch_idx, class_idx],
                    cams_list[1][batch_idx, class_idx],
                    cams_list[2][batch_idx, class_idx],
                ]
            )
            expert_cams = expert_cams.view(3, -1)
            expert_cams = F.normalize(expert_cams, p=2, dim=-1)
            mean = expert_cams.mean(dim=-1, keepdim=True)
            expert_cams = F.relu(expert_cams - mean)

            for i in range(3):
                for j in range(i + 1, 3):
                    similarity = cos(
                        expert_cams[i : i + 1], expert_cams[j : j + 1]
                    ).mean()
                    diversity_loss += similarity
                    total_pairs += 1

    if total_pairs > 0:
        return diversity_loss / total_pairs
    else:
        return torch.tensor(0.0, device=cams_list[0].device)


def multiLabelAccuracy(predictions, targets, threshold=0.5):
    """Compute multi-label accuracy metrics"""
    with torch.no_grad():
        probs = torch.sigmoid(predictions)
        pred_binary = (probs > threshold).float()
        subset_acc = (pred_binary == targets).all(dim=1).float().mean()
        hamming_acc = (pred_binary == targets).float().mean()
        tp = (pred_binary * targets).sum(dim=0)
        fp = (pred_binary * (1 - targets)).sum(dim=0)
        fn = ((1 - pred_binary) * targets).sum(dim=0)
        precision = tp / (tp + fp + 1e-8)
        recall = tp / (tp + fn + 1e-8)
        f1 = 2 * (precision * recall) / (precision + recall + 1e-8)
        precision = precision.mean()
        recall = recall.mean()
        f1 = f1.mean()
    return subset_acc, hamming_acc, precision, recall, f1


def train_multilabel(train_loader, model, criterion, optimizer, args, device=None):
    """Training loop for multi-label MEDAF"""
    model.train()

    loss_keys = args["loss_keys"]
    acc_keys = args["acc_keys"]

    loss_meter = {p: AverageMeter() for p in loss_keys}
    acc_meter = {p: AverageMeter() for p in acc_keys}
    time_start = time.time()

    for i, data_batch in enumerate(train_loader):
        inputs = data_batch[0].to(device)
        targets = data_batch[1].to(device)

        output_dict = model(inputs, targets)
        logits = output_dict["logits"]
        cams_list = output_dict["cams_list"]

        bce_losses = [
            criterion["bce"](logit.float(), targets.float()) for logit in logits[:3]
        ]
        gate_loss = criterion["bce"](logits[3].float(), targets.float())
        diversity_loss = multiLabelAttnDiv(cams_list, targets)

        loss_values = bce_losses + [gate_loss, diversity_loss]
        total_loss = (
            args["loss_wgts"][0] * sum(bce_losses)
            + args["loss_wgts"][1] * gate_loss
            + args["loss_wgts"][2] * diversity_loss
        )
        loss_values.append(total_loss)

        acc_values = []
        for logit in logits:
            subset_acc, hamming_acc, _, _, _ = multiLabelAccuracy(logit, targets)
            acc_values.append(subset_acc * 100)

        multi_loss = {loss_keys[k]: loss_values[k] for k in range(len(loss_keys))}
        train_accs = {acc_keys[k]: acc_values[k] for k in range(len(acc_keys))}

        update_meter(loss_meter, multi_loss, inputs.size(0))
        update_meter(acc_meter, train_accs, inputs.size(0))

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        if i % 50 == 0:
            tmp_str = f"Batch [{i}/{len(train_loader)}] "
            tmp_str += "< Training Loss >\n"
            for k, v in loss_meter.items():
                tmp_str += f"{k}:{v.value:.4f} "
            tmp_str += "\n< Training Accuracy >\n"
            for k, v in acc_meter.items():
                tmp_str += f"{k}:{v.value:.1f} "
            print(tmp_str)

    time_elapsed = time.time() - time_start
    print(f"\nEpoch completed in {time_elapsed:.1f}s")

    tmp_str = "< Final Training Loss >\n"
    for k, v in loss_meter.items():
        tmp_str += f"{k}:{v.value:.4f} "
    tmp_str += "\n< Final Training Accuracy >\n"
    for k, v in acc_meter.items():
        tmp_str += f"{k}:{v.value:.1f} "
    print(tmp_str)

    return loss_meter[loss_keys[-1]].value


In [None]:
# Override (v2): PerClassGating and LabelCorrelationModule
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from core.net import build_backbone, conv1x1, Classifier

class PerClassGating(nn.Module):
    """Per-class gating mechanism for multi-label classification"""
    def __init__(self, feature_dim, num_classes, num_experts=3, hidden_dim=None, dropout=0.1):
        super(PerClassGating, self).__init__()
        self.num_classes = num_classes
        self.num_experts = num_experts
        self.feature_dim = feature_dim
        if hidden_dim is None:
            hidden_dim = feature_dim // 4
        self.shared_transform = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout)
        )
        self.class_gates = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(hidden_dim, hidden_dim // 2),
                    nn.ReLU(),
                    nn.Dropout(dropout),
                    nn.Linear(hidden_dim // 2, num_experts),
                )
                for _ in range(num_classes)
            ]
        )
        for gate in self.class_gates:
            for layer in gate:
                if isinstance(layer, nn.Linear):
                    nn.init.normal_(layer.weight, 0, 0.01)
                    nn.init.constant_(layer.bias, 0)

    def forward(self, features, temperature=1.0):
        batch_size = features.size(0)
        shared_features = self.shared_transform(features)
        gate_logits = []
        for class_gate in self.class_gates:
            logits = class_gate(shared_features)
            gate_logits.append(logits)
        gate_logits = torch.stack(gate_logits, dim=1)
        gate_weights = F.softmax(gate_logits / temperature, dim=-1)
        return gate_weights, gate_logits

class LabelCorrelationModule(nn.Module):
    """Module to capture label co-occurrence patterns for better gating"""
    def __init__(self, num_classes, embedding_dim=64):
        super(LabelCorrelationModule, self).__init__()
        self.num_classes = num_classes
        self.embedding_dim = embedding_dim
        self.label_embeddings = nn.Embedding(num_classes, embedding_dim)
        self.correlation_attention = nn.MultiheadAttention(embedding_dim, num_heads=4, batch_first=True)
        self.output_proj = nn.Linear(embedding_dim, embedding_dim)

    def forward(self, predicted_labels=None):
        label_indices = torch.arange(self.num_classes, device=self.label_embeddings.weight.device)
        all_embeddings = self.label_embeddings(label_indices)
        batch_size = 1 if predicted_labels is None else predicted_labels.size(0)
        label_emb = all_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
        corr_emb, _ = self.correlation_attention(label_emb, label_emb, label_emb)
        correlation_features = self.output_proj(corr_emb)
        return correlation_features


In [None]:
# Override (v2): MultiLabelMEDAFv2
class MultiLabelMEDAFv2(nn.Module):
    """Enhanced Multi-Label MEDAF with configurable per-class gating"""
    def __init__(self, args=None):
        super(MultiLabelMEDAFv2, self).__init__()
        self.use_per_class_gating = args.get("use_per_class_gating", False)
        self.use_label_correlation = args.get("use_label_correlation", False)
        self.enhanced_diversity = args.get("enhanced_diversity", False)
        backbone, feature_dim, self.cam_size = build_backbone(
            img_size=args["img_size"],
            backbone_name=args["backbone"],
            projection_dim=-1,
            inchan=3,
        )
        self.img_size = args["img_size"]
        self.gate_temp = args["gate_temp"]
        self.num_classes = args["num_classes"]
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.shared_l3 = nn.Sequential(*list(backbone.children())[:-6])
        self.branch1_l4 = nn.Sequential(*list(backbone.children())[-6:-3])
        self.branch1_l5 = nn.Sequential(*list(backbone.children())[-3])
        self.branch1_cls = conv1x1(feature_dim, self.num_classes)
        self.branch2_l4 = copy.deepcopy(self.branch1_l4)
        self.branch2_l5 = copy.deepcopy(self.branch1_l5)
        self.branch2_cls = conv1x1(feature_dim, self.num_classes)
        self.branch3_l4 = copy.deepcopy(self.branch1_l4)
        self.branch3_l5 = copy.deepcopy(self.branch1_l5)
        self.branch3_cls = conv1x1(feature_dim, self.num_classes)
        self.gate_l3 = copy.deepcopy(self.shared_l3)
        self.gate_l4 = copy.deepcopy(self.branch1_l4)
        self.gate_l5 = copy.deepcopy(self.branch1_l5)
        if self.use_per_class_gating:
            self.per_class_gating = PerClassGating(
                feature_dim=feature_dim,
                num_classes=self.num_classes,
                num_experts=3,
                dropout=args.get("gating_dropout", 0.1),
            )
        else:
            self.gate_cls = nn.Sequential(
                Classifier(feature_dim, int(feature_dim / 4), bias=True),
                Classifier(int(feature_dim / 4), 3, bias=True),
            )
        if self.use_label_correlation:
            self.label_correlation = LabelCorrelationModule(
                num_classes=self.num_classes,
                embedding_dim=args.get("label_embedding_dim", 64),
            )

    def forward(self, x, y=None, return_ft=False, return_attention_weights=False):
        b = x.size(0)
        ft_till_l3 = self.shared_l3(x)
        branch1_l4 = self.branch1_l4(ft_till_l3.clone())
        branch1_l5 = self.branch1_l5(branch1_l4)
        b1_ft_cams = self.branch1_cls(branch1_l5)
        b1_logits = self.avg_pool(b1_ft_cams).view(b, -1)
        branch2_l4 = self.branch2_l4(ft_till_l3.clone())
        branch2_l5 = self.branch2_l5(branch2_l4)
        b2_ft_cams = self.branch2_cls(branch2_l5)
        b2_logits = self.avg_pool(b2_ft_cams).view(b, -1)
        branch3_l4 = self.branch3_l4(ft_till_l3.clone())
        branch3_l5 = self.branch3_l5(branch3_l4)
        b3_ft_cams = self.branch3_cls(branch3_l5)
        b3_logits = self.avg_pool(b3_ft_cams).view(b, -1)
        cams_list = [b1_ft_cams, b2_ft_cams, b3_ft_cams]
        expert_logits = [b1_logits, b2_logits, b3_logits]
        if y is not None:
            multi_label_cams = self._extract_multilabel_cams(cams_list, y)
        else:
            multi_label_cams = None
        if return_ft:
            fts = b1_ft_cams.detach().clone() + b2_ft_cams.detach().clone() + b3_ft_cams.detach().clone()
        gate_l5 = self.gate_l5(self.gate_l4(self.gate_l3(x)))
        gate_features = self.avg_pool(gate_l5).view(b, -1)
        if self.use_per_class_gating:
            gate_weights, gate_logits = self.per_class_gating(gate_features, self.gate_temp)
            expert_stack = torch.stack(expert_logits, dim=-1)
            fused_logits = (expert_stack * gate_weights).sum(dim=-1)
            gate_pred = gate_weights.mean(dim=1)
        else:
            gate_pred = F.softmax(self.gate_cls(gate_features) / self.gate_temp, dim=1)
            gate_logits_stack = torch.stack(
                [b1_logits.detach(), b2_logits.detach(), b3_logits.detach()], dim=-1
            )
            gate_logits_stack = gate_logits_stack * gate_pred.view(
                gate_pred.size(0), 1, gate_pred.size(1)
            )
            fused_logits = gate_logits_stack.sum(-1)
        logits_list = expert_logits + [fused_logits]
        outputs = {
            "logits": logits_list,
            "gate_pred": gate_pred,
            "cams_list": cams_list,
            "gating_type": "per_class" if self.use_per_class_gating else "global",
        }
        if y is not None:
            outputs["multi_label_cams"] = multi_label_cams
        if return_ft and y is None:
            outputs["fts"] = fts
        if return_attention_weights and self.use_per_class_gating:
            outputs["per_class_weights"] = gate_weights
            outputs["gate_logits"] = gate_logits
        if self.use_label_correlation:
            corr_features = self.label_correlation()
            outputs["correlation_features"] = corr_features
        return outputs

    def _extract_multilabel_cams(self, cams_list, targets):
        batch_size = targets.size(0)
        extracted_cams = []
        for expert_idx, expert_cams in enumerate(cams_list):
            expert_extracted = []
            for batch_idx in range(batch_size):
                positive_classes = torch.where(targets[batch_idx] == 1)[0]
                if len(positive_classes) > 0:
                    sample_cams = expert_cams[batch_idx, positive_classes]
                    expert_extracted.append(sample_cams)
                else:
                    H, W = expert_cams.shape[-2:]
                    expert_extracted.append(torch.zeros(1, H, W, device=expert_cams.device))
            extracted_cams.append(expert_extracted)
        return extracted_cams

    def get_params(self, prefix="extractor"):
        base_extractor_params = (
            list(self.shared_l3.parameters())
            + list(self.branch1_l4.parameters())
            + list(self.branch1_l5.parameters())
            + list(self.branch2_l4.parameters())
            + list(self.branch2_l5.parameters())
            + list(self.branch3_l4.parameters())
            + list(self.branch3_l5.parameters())
            + list(self.gate_l3.parameters())
            + list(self.gate_l4.parameters())
            + list(self.gate_l5.parameters())
        )
        if self.use_per_class_gating:
            base_extractor_params.extend(list(self.per_class_gating.parameters()))
        if self.use_label_correlation:
            base_extractor_params.extend(list(self.label_correlation.parameters()))
        extractor_params_ids = list(map(id, base_extractor_params))
        classifier_params = filter(lambda p: id(p) not in extractor_params_ids, self.parameters())
        if prefix in ["extractor", "extract"]:
            return base_extractor_params
        elif prefix in ["classifier"]:
            return list(classifier_params)

    def get_gating_summary(self):
        summary = {
            "gating_type": "per_class" if self.use_per_class_gating else "global",
            "num_classes": self.num_classes,
            "use_label_correlation": self.use_label_correlation,
            "enhanced_diversity": self.enhanced_diversity,
        }
        if self.use_per_class_gating:
            total_gate_params = sum(p.numel() for p in self.per_class_gating.parameters())
            summary["gating_parameters"] = total_gate_params
        return summary


In [None]:
# Override (v2): Training and Evaluation Utilities
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import defaultdict
import numpy as np
from misc.util import *


def enhancedMultiLabelAttnDiv(cams_list, targets, gate_weights=None, diversity_type="cosine", eps=1e-6):
    if targets is None or targets.sum() == 0:
        return torch.tensor(0.0, device=cams_list[0].device)
    diversity_loss = 0.0
    total_pairs = 0
    batch_size = targets.size(0)
    for batch_idx in range(batch_size):
        positive_classes = torch.where(targets[batch_idx] == 1)[0]
        if len(positive_classes) == 0:
            continue
        for class_idx in positive_classes:
            expert_cams = torch.stack(
                [
                    cams_list[0][batch_idx, class_idx],
                    cams_list[1][batch_idx, class_idx],
                    cams_list[2][batch_idx, class_idx],
                ]
            )
            expert_cams = expert_cams.view(3, -1)
            expert_cams = F.normalize(expert_cams, p=2, dim=-1)
            mean = expert_cams.mean(dim=-1, keepdim=True)
            expert_cams = F.relu(expert_cams - mean)
            if gate_weights is not None:
                class_gate_weights = gate_weights[batch_idx, class_idx]
                expert_cams = expert_cams * class_gate_weights.unsqueeze(-1)
            if diversity_type == "cosine":
                cos = nn.CosineSimilarity(dim=1, eps=eps)
                for i in range(3):
                    for j in range(i + 1, 3):
                        similarity = cos(expert_cams[i : i + 1], expert_cams[j : j + 1]).mean()
                        diversity_loss += similarity
                        total_pairs += 1
            elif diversity_type == "l2":
                for i in range(3):
                    for j in range(i + 1, 3):
                        l2_dist = torch.norm(expert_cams[i] - expert_cams[j], p=2)
                        diversity_loss -= l2_dist
                        total_pairs += 1
            elif diversity_type == "kl":
                expert_probs = F.softmax(expert_cams, dim=-1)
                for i in range(3):
                    for j in range(i + 1, 3):
                        kl_div = F.kl_div(
                            expert_probs[i : i + 1].log(), expert_probs[j : j + 1], reduction="batchmean",
                        )
                        diversity_loss -= kl_div
                        total_pairs += 1
    return diversity_loss / max(total_pairs, 1)


class ComparativeTrainingFramework:
    def __init__(self, args):
        self.args = args
        self.results = defaultdict(list)
        self.current_epoch = 0
    def log_metrics(self, model_type, epoch, metrics):
        metrics_with_meta = {"epoch": epoch, "model_type": model_type, **metrics}
        self.results[model_type].append(metrics_with_meta)
    def get_comparison_summary(self):
        summary = {}
        for model_type, results in self.results.items():
            if not results:
                continue
            latest = results[-1]
            summary[model_type] = {
                "final_epoch": latest["epoch"],
                "final_subset_acc": latest.get("subset_acc", 0),
                "final_hamming_acc": latest.get("hamming_acc", 0),
                "final_f1": latest.get("f1", 0),
                "final_loss": latest.get("total_loss", 0),
                "avg_diversity_loss": np.mean([r.get("diversity_loss", 0) for r in results[-10:]]),
            }
        return summary
    def print_comparison(self):
        summary = self.get_comparison_summary()
        print("\n" + "=" * 60)
        print("COMPARATIVE ANALYSIS: Global vs Per-Class Gating")
        print("=" * 60)
        if "global" in summary and "per_class" in summary:
            global_results = summary["global"]
            pc_results = summary["per_class"]
            print(f"{'Metric':<20} {'Global':<15} {'Per-Class':<15} {'Improvement':<15}")
            print("-" * 65)
            metrics = ["final_subset_acc", "final_hamming_acc", "final_f1"]
            for metric in metrics:
                global_val = global_results.get(metric, 0)
                pc_val = pc_results.get(metric, 0)
                improvement = ((pc_val - global_val) / max(global_val, 1e-8)) * 100
                print(f"{metric.replace('final_', ''):<20} {global_val:<15.4f} {pc_val:<15.4f} {improvement:+.2f}%")
            global_loss = global_results.get("final_loss", float("inf"))
            pc_loss = pc_results.get("final_loss", float("inf"))
            loss_improvement = ((global_loss - pc_loss) / max(global_loss, 1e-8)) * 100
            print(f"{'loss_reduction':<20} {global_loss:<15.4f} {pc_loss:<15.4f} {loss_improvement:+.2f}%")
        print("=" * 60)


def train_multilabel_v2(train_loader, model, criterion, optimizer, args, device=None, comparative_framework=None):
    model.train()
    loss_keys = args["loss_keys"]
    acc_keys = args["acc_keys"]
    loss_meter = {p: AverageMeter() for p in loss_keys}
    acc_meter = {p: AverageMeter() for p in acc_keys}
    diversity_meter = AverageMeter()
    gating_entropy_meter = AverageMeter()
    time_start = time.time()
    gating_summary = model.get_gating_summary()
    gating_type = gating_summary["gating_type"]
    print(f"\nTraining with {gating_type} gating...")
    for i, data_batch in enumerate(train_loader):
        inputs = data_batch[0].to(device)
        targets = data_batch[1].to(device)
        output_dict = model(inputs, targets, return_attention_weights=True)
        logits = output_dict["logits"]
        cams_list = output_dict["cams_list"]
        gate_pred = output_dict["gate_pred"]
        per_class_weights = output_dict.get("per_class_weights", None)
        bce_losses = [criterion["bce"](logit.float(), targets.float()) for logit in logits[:3]]
        gate_loss = criterion["bce"](logits[3].float(), targets.float())
        if args.get("enhanced_diversity", False) and per_class_weights is not None:
            diversity_loss = enhancedMultiLabelAttnDiv(cams_list, targets, per_class_weights, diversity_type=args.get("diversity_type", "cosine"))
        else:
            diversity_loss = multiLabelAttnDiv(cams_list, targets)
        gating_reg_loss = 0.0
        if per_class_weights is not None and args.get("gating_regularization", 0) > 0:
            gate_entropy = -(per_class_weights * torch.log(per_class_weights + 1e-8)).sum(dim=-1)
            gating_reg_loss = -gate_entropy.mean()
        loss_values = bce_losses + [gate_loss, diversity_loss]
        total_loss = (
            args["loss_wgts"][0] * sum(bce_losses)
            + args["loss_wgts"][1] * gate_loss
            + args["loss_wgts"][2] * diversity_loss
            + args.get("gating_regularization", 0) * gating_reg_loss
        )
        loss_values.append(total_loss)
        acc_values = []
        for logit in logits:
            subset_acc, hamming_acc, precision, recall, f1 = multiLabelAccuracy(logit, targets)
            acc_values.append(subset_acc * 100)
        multi_loss = {loss_keys[k]: loss_values[k] for k in range(len(loss_keys))}
        train_accs = {acc_keys[k]: acc_values[k] for k in range(len(acc_keys))}
        update_meter(loss_meter, multi_loss, inputs.size(0))
        update_meter(acc_meter, train_accs, inputs.size(0))
        diversity_meter.update(diversity_loss.item(), inputs.size(0))
        if per_class_weights is not None:
            gate_entropy = (-(per_class_weights * torch.log(per_class_weights + 1e-8)).sum(dim=-1).mean())
            gating_entropy_meter.update(gate_entropy.item(), inputs.size(0))
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        if i % 50 == 0:
            tmp_str = f"Batch [{i}/{len(train_loader)}] ({gating_type}) "
            tmp_str += f"Loss: {total_loss.item():.4f}, "
            tmp_str += f"Div: {diversity_loss.item():.6f}, "
            if per_class_weights is not None:
                tmp_str += f"GateEnt: {gating_entropy_meter.value:.3f}, "
            tmp_str += f"SubsetAcc: {acc_values[3]:.2f}%"
            print(tmp_str)
    time_elapsed = time.time() - time_start
    final_metrics = {
        "total_loss": loss_meter[loss_keys[-1]].value,
        "diversity_loss": diversity_meter.value,
        "subset_acc": acc_meter[acc_keys[-1]].value,
        "hamming_acc": acc_meter[acc_keys[-1]].value,
        "gating_type": gating_type,
        "training_time": time_elapsed,
    }
    if per_class_weights is not None:
        final_metrics["gating_entropy"] = gating_entropy_meter.value
    if comparative_framework is not None:
        comparative_framework.log_metrics(gating_type, comparative_framework.current_epoch, final_metrics)
    print(f"\nEpoch Summary ({gating_type} gating):")
    print(f"  Total Loss: {final_metrics['total_loss']:.4f}")
    print(f"  Diversity Loss: {final_metrics['diversity_loss']:.6f}")
    print(f"  Subset Accuracy: {final_metrics['subset_acc']:.2f}%")
    if "gating_entropy" in final_metrics:
        print(f"  Gating Entropy: {final_metrics['gating_entropy']:.3f}")
    print(f"  Training Time: {time_elapsed:.1f}s")
    return final_metrics


def evaluate_multilabel_v2(model, test_loader, criterion, args, device=None):
    model.eval()
    all_predictions = []
    all_targets = []
    all_gate_weights = []
    eval_loss = 0.0
    eval_diversity = 0.0
    num_batches = 0
    gating_type = model.get_gating_summary()["gating_type"]
    with torch.no_grad():
        for data_batch in test_loader:
            inputs, targets = data_batch[0].to(device), data_batch[1].to(device)
            outputs = model(inputs, targets, return_attention_weights=True)
            logits = outputs["logits"]
            cams_list = outputs["cams_list"]
            fused_logits = logits[3]
            bce_loss = criterion["bce"](fused_logits, targets)
            diversity_loss = multiLabelAttnDiv(cams_list, targets)
            eval_loss += bce_loss.item()
            eval_diversity += diversity_loss.item()
            num_batches += 1
            all_predictions.append(torch.sigmoid(fused_logits))
            all_targets.append(targets)
            if "per_class_weights" in outputs:
                all_gate_weights.append(outputs["per_class_weights"])
    all_predictions = torch.cat(all_predictions, dim=0)
    all_targets = torch.cat(all_targets, dim=0)
    subset_acc, hamming_acc, precision, recall, f1 = multiLabelAccuracy(all_predictions, all_targets, threshold=0.5)
    eval_metrics = {
        "eval_loss": eval_loss / num_batches,
        "eval_diversity": eval_diversity / num_batches,
        "subset_accuracy": subset_acc.item(),
        "hamming_accuracy": hamming_acc.item(),
        "precision": precision.item(),
        "recall": recall.item(),
        "f1_score": f1.item(),
        "gating_type": gating_type,
    }
    if all_gate_weights:
        all_gate_weights = torch.cat(all_gate_weights, dim=0)
        expert_preferences = all_gate_weights.mean(dim=0)
        expert_entropy = -(expert_preferences * torch.log(expert_preferences + 1e-8)).sum(dim=-1)
        eval_metrics.update({
            "avg_gating_entropy": expert_entropy.mean().item(),
            "expert_specialization": expert_preferences.std(dim=0).mean().item(),
            "max_expert_preference": expert_preferences.max().item(),
            "min_expert_preference": expert_preferences.min().item(),
        })
    return eval_metrics


In [None]:
# Config for running the demo from notebook
config = {
    "data_source": "chestxray",  # or "synthetic"
    "known_csv": str(DEFAULT_KNOWN_CSV),
    "image_root": str(DEFAULT_IMAGE_ROOT),
    "batch_size": 16,
    "num_epochs": 2,
    "learning_rate": 1e-4,
    "val_ratio": 0.1,
    "num_workers": 2,
    # "max_samples": None,
    "max_samples": 200,  # for quick iterations in notebook
    "phase1_checkpoint": "medaf_phase1_chestxray.pt",
    "checkpoint_dir": str(DEFAULT_CHECKPOINT_DIR),
    "run_phase2": True,
}

print("Notebook demo config set.")
print(config)


In [None]:
# Run the demo from notebook
# Set seeds for reproducibility
import numpy as np
import torch

torch.manual_seed(42)
np.random.seed(42)

# Initialize and run
_demo = MultiLabelMEDAFDemo(config)
_demo.run_demo()

# Optionally evaluate v2 models if run_phase2
if config.get("run_phase2") and hasattr(_demo, "results") and "phase2" in _demo.results:
    try:
        criterion = {"bce": nn.BCEWithLogitsLoss()}
        for key, res in _demo.results["phase2"].items():
            print(f"\nEvaluating {res['config']['name']} model...")
            eval_metrics = evaluate_multilabel_v2(res["model"], _demo.test_loader, criterion, {}, device=_demo.device)
            print(eval_metrics)
    except Exception as e:
        print(f"Evaluation skipped due to error: {e}")
