# Patient Stratification using Multiple Instance Learning (MIL) Tutorial - Part 1

## Introduction and Background

Patient stratification is a critical task in precision medicine that aims to classify patients into different groups based on their clinical characteristics, treatment response, or survival outcomes. This tutorial demonstrates how to perform patient stratification using Multiple Instance Learning (MIL) on multiplex immunofluorescence images.

---

## Workflow Overview

This tutorial is organized into multiple parts:

- **Part 1 (Current): Data Preparation Pipeline**
    - **Patch Extraction:** Extract image patches from multiplex immunofluorescence images
    - **Feature Extraction:** Extract deep learning features using pre-trained KRONOS models
    - **H5AD Object Preparation:** Build AnnData objects for downstream analysis

- **Part 2 (Future): MIL Analysis Pipeline**
    - **Patient-level Data Aggregation:** Aggregate patch-level features to patient-level representations
    - **MIL Model Training and Evaluation:** Train MIL models for patient stratification with cross-validation

---

## Input File Requirements

### Dataset Structure

Your dataset folder should be organized as follows:



to classify patients into different groups based on their clinical characteristics, treatment response, or survival outcomes. This tutorial demonstrates how to perform patient stratification using Multiple Instance Learning (MIL) on multiplex immunofluorescence images.

---

## Workflow Overview

This tutorial is organized into multiple parts:

- **Part 1: Data Preparation Pipeline**
    - **Patch Extraction:** Extract image patches from multiplex immunofluorescence images
    - **Feature Extraction:** Extract deep learning features using pre-trained KRONOS models
    - **H5AD Object Preparation:** Build AnnData objects for downstream analysis

- **Part 2: MIL Analysis Pipeline**
    - **Patient-level Data Aggregation:** Aggregate patch-level features to patient-level representations
    - **MIL Model Training and Evaluation:** Train MIL models for patient stratification with cross-validation

---

## Input File Requirements

### Dataset Structure

Your dataset folder should be organized as follows:



## Step 1: Experiment Configuration
In this section, we define the configuration and hyperparameters for the patient stratification pipeline. Ensure your dataset folder is organized according to the structure described above.

In [1]:
import os
import pandas as pd
import numpy as np
from utils.PatchExtraction_FeatureExtraction_h5ad import PatchExtraction, FeatureExtraction, H5ADBuilder

# Define the root directory for the project
project_dir = "/fs/ess/PCON0022/Yuzhou/Kronos/HNSCC"  # Replace with your actual project directory

# Configuration dictionary containing all parameters for the pipeline
config = {
    # Dataset-related parameters - modify these paths according to your dataset structure
    "image_dir": f"{project_dir}/Datasets/test_TIFF/",  # Path to multiplex image files
    "marker_csv_path": f"{project_dir}/Datasets/test_TIFF/marker_metadata_HNSCC.csv",  # Path to marker metadata CSV
    "patient_metadata_csv_path": f"{project_dir}/Datasets/HNSCC_meta.csv",  # Path to patient metadata CSV
    
    # Output directories for intermediate and final results
    "patch_output_dir": f"{project_dir}/patches/",  # Directory to save extracted patches
    "feature_output_dir": f"{project_dir}/features/",  # Directory to save extracted features
    "h5ad_output_dir": f"{project_dir}/h5ad_objects/",  # Directory to save AnnData objects
    "results_dir": f"{project_dir}/results/",  # Directory for final results (Part 2)
    
    # Model-related parameters for KRONOS feature extraction
    "checkpoint_path": "hf_hub:MahmoodLab/kronos",  # Pre-trained KRONOS model checkpoint
    "hf_auth_token": None,  # Hugging Face authentication token (if required)
    "cache_dir": f"{project_dir}/models/cache/",  # Directory to cache KRONOS model
    "model_type": "vits16",  # Type of pre-trained model (vits16, vitl16)
    "token_overlap": True,  # Whether to use token overlap during feature extraction
    
    # Patch extraction parameters
    "patch_size": 128,  # Size of patches to extract (128x128 pixels)
    "stride": 128,  # Stride for patch extraction (non-overlapping patches)
    "file_ext": ".tif",  # File extension of input images
    
    # Feature extraction parameters
    "nuclear_stain": "DAPI",  # Name of nuclear stain marker
    "max_value": 65535.0,  # Maximum possible pixel value (depends on image bit depth)
    "batch_size": 4,  # Batch size for feature extraction
    "num_workers": 4,  # Number of workers for data loading
    "extract_token_features": True,  # Whether to extract patch token features
    
    # H5AD building parameters
    "model_name": "Kronos",  # Model name for output file naming
    "dataset_name": "PatientStratification",  # Dataset name for output file naming
    "core_id_column": "TMA_core_num",  # Column name for core IDs in patient metadata
    
    # Control flags for pipeline steps
    "verbose": True,  # Whether to print detailed progress information
}

print("Configuration loaded successfully!")
print(f"Project directory: {project_dir}")
print(f"Patch size: {config['patch_size']}x{config['patch_size']}")
print(f"Model: {config['model_name']} ({config['model_type']})")

  from .autonotebook import tqdm as notebook_tqdm


Configuration loaded successfully!
Project directory: /fs/ess/PCON0022/Yuzhou/Kronos/HNSCC
Patch size: 128x128
Model: Kronos (vits16)


## Step 2: Patch Extraction
In this step, we extract patches from the multiplex immunofluorescence images. The patches are extracted using a sliding window approach across the entire image and saved as HDF5 files containing individual marker datasets.

In [2]:
print("=== Step 2: Patch Extraction ===")

# Initialize patch extraction
patch_config = {
    "image_dir": config["image_dir"],
    "output_dir": config["patch_output_dir"],
    "marker_csv_path": config["marker_csv_path"],
    "patch_size": config["patch_size"],
    "stride": config["stride"],
    "file_ext": config["file_ext"]
}

# Create patch extractor
patch_extractor = PatchExtraction(patch_config)

# Extract patches from all images in the directory
print(f"Starting patch extraction from images in: {config['image_dir']}")
print(f"Patch size: {config['patch_size']}x{config['patch_size']}")
print(f"Stride: {config['stride']} (non-overlapping patches)")

# Option 1: Extract from all images in directory
patch_results = patch_extractor.extract_all_patches()

# Option 2: Extract from specific image files (uncomment if needed)
# specific_files = ["sample_001.ome.tiff", "sample_002.ome.tiff"]  # Replace with your files
# patch_results = patch_extractor.extract_all_patches(file_list=specific_files)

# Display results
print("\nPatch extraction completed!")
print("Summary:")
total_patches = sum(patch_results.values())
print(f"- Total images processed: {len(patch_results)}")
print(f"- Total patches extracted: {total_patches}")
print(f"- Average patches per image: {total_patches/len(patch_results):.1f}")

# Show detailed results for each image
print("\nDetailed results:")
for image_name, patch_count in patch_results.items():
    print(f"  {image_name}: {patch_count} patches")

print(f"\nPatches saved to: {config['patch_output_dir']}")

=== Step 2: Patch Extraction ===
Using GPU: NVIDIA A100-PCIE-40GB
Starting patch extraction from images in: /fs/ess/PCON0022/Yuzhou/Kronos/HNSCC/Datasets/test_TIFF/
Patch size: 128x128
Stride: 128 (non-overlapping patches)
Processing file 1/2: A-1.tif
Skipping A-1.tif as patches already exist
Processing file 2/2: L-8.tif
Skipping L-8.tif as patches already exist
Patch extraction completed!

Patch extraction completed!
Summary:
- Total images processed: 2
- Total patches extracted: 722
- Average patches per image: 361.0

Detailed results:
  A-1.tif: 361 patches
  L-8.tif: 361 patches

Patches saved to: /fs/ess/PCON0022/Yuzhou/Kronos/HNSCC/patches/


## Step 3: Feature Extraction
In this step, we extract deep learning features from the patches using the pre-trained KRONOS model. The features are saved as numpy arrays for downstream analysis.

In [3]:
print("=== Step 3: Feature Extraction ===")

# Configure feature extraction parameters
feature_config = {
    "dataset_dir": os.path.join(config["patch_output_dir"], f"{config['patch_size']}_{config['stride']}"),
    "feature_dir": config["feature_output_dir"],
    "checkpoint_path": config["checkpoint_path"],
    "hf_auth_token": config["hf_auth_token"],
    "cache_dir": config["cache_dir"],
    "model_type": config["model_type"],
    "token_overlap": config["token_overlap"],
    "marker_info": config["marker_csv_path"],
    "nuclear_stain": config["nuclear_stain"],
    "max_value": config["max_value"],
    "batch_size": config["batch_size"],
    "num_workers": config["num_workers"]
}

# Create feature extractor
print(f"Initializing KRONOS model: {config['model_type']}")
print(f"Loading model from: {config['checkpoint_path']}")

feature_extractor = FeatureExtraction(feature_config)

# Extract features from all patches
print(f"Starting feature extraction from patches in: {feature_config['dataset_dir']}")
print(f"Batch size: {config['batch_size']}")
print(f"Extract token features: {config['extract_token_features']}")

num_processed = feature_extractor.extract_features_from_patches(
    token_features=config["extract_token_features"]
)

print(f"\nFeature extraction completed!")
print(f"- Total patches processed: {num_processed}")
print(f"- Features saved to: {config['feature_output_dir']}")

# Display feature types extracted
feature_types = []
if os.path.exists(os.path.join(config["feature_output_dir"], "norm_clstoken")):
    cls_count = len(os.listdir(os.path.join(config["feature_output_dir"], "norm_clstoken")))
    feature_types.append(f"CLS token features: {cls_count} files")

if config["extract_token_features"] and os.path.exists(os.path.join(config["feature_output_dir"], "norm_patchtokens")):
    token_count = len(os.listdir(os.path.join(config["feature_output_dir"], "norm_patchtokens")))
    feature_types.append(f"Patch token features: {token_count} files")

print("Feature types extracted:")
for feature_type in feature_types:
    print(f"  - {feature_type}")

=== Step 3: Feature Extraction ===
Initializing KRONOS model: vits16
Loading model from: hf_hub:MahmoodLab/kronos
Using GPU: NVIDIA A100-PCIE-40GB
Loading Kronos model...
[92mLoaded model weights from /fs/ess/PCON0022/Yuzhou/Kronos/HNSCC/models/cache/models--MahmoodLab--kronos/snapshots/e379603f57f68ba61e77a783d83c9899850dcae3/kronos_vits16_model.pt[0m
Model loaded successfully with precision: torch.float32, embedding dim: 384
Starting feature extraction from patches in: /fs/ess/PCON0022/Yuzhou/Kronos/HNSCC/patches/128_128
Batch size: 4
Extract token features: True
Found 722 patches to process
Processing batch 1/181
Processing batch 2/181
Processing batch 3/181
Processing batch 4/181
Processing batch 5/181
Processing batch 6/181
Processing batch 7/181
Processing batch 8/181
Processing batch 9/181
Processing batch 10/181
Processing batch 11/181
Processing batch 12/181
Processing batch 13/181
Processing batch 14/181
Processing batch 15/181
Processing batch 16/181
Processing batch 17/18

## Step 4: H5AD Object Preparation
In this final step of Part 1, we build AnnData (h5ad) objects from the extracted features. These objects combine the features with metadata and can be used for downstream analysis including scanpy workflows and MIL analysis.

In [4]:
print("=== Step 4: H5AD Object Preparation ===")

# Configure H5AD building parameters
h5ad_config = {
    "embedding_path": os.path.join(config["feature_output_dir"], "norm_clstoken"),  # Using CLS token features
    "output_dir": config["h5ad_output_dir"],
    "metadata_path": config["patient_metadata_csv_path"],
    "model_name": config["model_name"],
    "patch_size": f"{config['patch_size']}_{config['stride']}",
    "dataset_name": config["dataset_name"],
    "core_id_column": config["core_id_column"]
}

# Create H5AD builder
print(f"Building H5AD object from features in: {h5ad_config['embedding_path']}")
print(f"Using metadata from: {config['patient_metadata_csv_path']}")

h5ad_builder = H5ADBuilder(h5ad_config)

# Build H5AD object
h5ad_path = h5ad_builder.build_h5ad()

if h5ad_path:
    print(f"\nH5AD object successfully created!")
    print(f"Saved to: {h5ad_path}")
    
    # Load and display basic information about the H5AD object
    import scanpy as sc
    adata = sc.read_h5ad(h5ad_path)
    
    print(f"\nH5AD Object Summary:")
    print(f"- Shape: {adata.shape[0]} observations × {adata.shape[1]} features")
    print(f"- Unique patients/cores: {adata.obs[config['core_id_column']].nunique()}")
    print(f"- Available metadata columns: {list(adata.obs.columns)}")
    
    # Display sample distribution if response column exists
    if 'response' in adata.obs.columns:
        response_counts = adata.obs['response'].value_counts()
        print(f"- Response distribution:")
        for response, count in response_counts.items():
            print(f"  {response}: {count} patches")
    
    print(f"\nThe H5AD object is ready for downstream analysis!")
    print(f"You can now proceed with:")
    print(f"  - Scanpy analysis workflows")
    print(f"  - Patient stratification using MIL (Part 2)")
    print(f"  - Dimensionality reduction and visualization")
    
else:
    print("Failed to create H5AD object. Please check the configuration and try again.")

=== Step 4: H5AD Object Preparation ===
Building H5AD object from features in: /fs/ess/PCON0022/Yuzhou/Kronos/HNSCC/features/norm_clstoken
Using metadata from: /fs/ess/PCON0022/Yuzhou/Kronos/HNSCC/Datasets/HNSCC_meta.csv
Loading metadata from /fs/ess/PCON0022/Yuzhou/Kronos/HNSCC/Datasets/HNSCC_meta.csv...
Loaded metadata for 87 cores

Processing Kronos 128_128 embeddings...
Looking for embedding files in: /fs/ess/PCON0022/Yuzhou/Kronos/HNSCC/features/norm_clstoken
Found 722 embedding files
Processing embedding files...
Processing file 1/722...
Processing file 101/722...
Processing file 201/722...
Processing file 301/722...
Processing file 401/722...
Processing file 501/722...
Processing file 601/722...
Processing file 701/722...
Successfully loaded 722 valid embeddings
Creating embedding matrix...
Embedding matrix shape: (722, 384)
Creating AnnData object...
AnnData object created with shape: (722, 384)
Adding sample-level metadata...
Metadata successfully added
Saving AnnData object t

# Patient Stratification using Multiple Instance Learning (MIL) Tutorial - Part 2

## MIL Analysis Pipeline

Welcome to Part 2 of the patient stratification tutorial! In this section, we will perform Multiple Instance Learning (MIL) analysis on the prepared data from Part 1. This tutorial focuses on patient-level classification using patch-level features with rigorous cross-validation evaluation.

---

## Overview of Part 2

This part covers the advanced MIL analysis pipeline:

- **Step 4: MIL Dataset Preparation**  
  Convert H5AD objects to MIL format for training

- **Step 5: Cross-Validation Analysis**  
  Perform repeated stratified k-fold cross-validation

- **Step 6: Model Training and Evaluation**  
  Train MIL models with comprehensive evaluation

- **Step 7: Results Analysis and Visualization**  
  Analyze performance across different feature types

---

## Multiple Instance Learning (MIL) Background

In MIL, we work with _"bags"_ (patients) containing multiple _"instances"_ (image patches). The goal is to predict patient-level labels using patch-level features without requiring patch-level annotations. This is particularly valuable in medical imaging where:

- **Patients are bags** with binary labels (e.g., responder vs. non-responder)
- **Image patches are instances** with rich feature representations
- **Only patient-level labels** are available for training

Our MIL model aggregates patch-level features into patient-level representations using neural networks with learnable aggregation functions.

---


## Step 4: MIL Dataset Preparation and Configuration
In this step, we configure the MIL analysis pipeline and prepare datasets from the H5AD objects created in Part 1. Here, we assume user should generate a h5ad object and please supply the h5ad object to the following tutorials.

In [5]:
import torch
import numpy as np
import pandas as pd
from utils.patient_stratification import PatientStratification, PatientStratificationClassifier

print("=== Step 4: MIL Dataset Preparation and Configuration ===")

# Define the root directory for the project
project_dir = "/fs/ess/PAS2205/Yuzhou/Datasets/CTCL_pembro_data"  # Replace with your actual project directory

# MIL Analysis Configuration
mil_config = {
    # Input data (from Part 1)
    "h5ad_path": f"{project_dir}/CTCL_adata_256_256-Kronos.h5ad",  # Path to H5AD file from Part 1
    
    # Output directory for MIL results
    "output_dir": f"{project_dir}/mil_results/",
    
    # Data parameters - must match your H5AD object structure
    "patient_col": "Patients",  # Column name for patient identifiers in H5AD
    "label_col": "Response",   # Column name for response labels (R/NR or 1/0)
    
    # MIL model architecture parameters
    "n_neurons": 256,          # Number of neurons in the first hidden layer
    "hidden_layers": [],       # Additional hidden layers (e.g., [128, 64] for two extra layers)
    "dropout_rate": 0.0,       # Dropout rate for regularization
    "model_aggregation": torch.mean,  # Aggregation function (torch.mean or torch.max)
    
    # Training parameters
    "lr": 1e-3,               # Learning rate
    "weight_decay": 1e-5,     # L2 regularization weight
    "n_epochs": 50,           # Number of training epochs per fold
    "batch_size": 4,          # Batch size for training
    "normalize_data": False,   # Whether to normalize input features
    
    # Learning rate scheduler
    "lr_scheduler": True,     # Whether to use learning rate scheduling
    "lr_step_size": 50,       # Step size for LR decay
    "lr_gamma": 0.5,          # Gamma for LR decay
    
    # Cross-validation parameters
    "n_repeats": 2,          # Number of repeated cross-validation runs
    "n_folds": 5,             # Number of folds per repetition
    
    # Feature types to benchmark (different dimensionality reduction methods)
    "feature_types": ['pca50', 'pca100'],
    
    # Logging and output control
    "verbose": False,          # Detailed progress logging
    "loss_log_interval": 10,  # Log training loss every N epochs
}

# Validate H5AD file exists
import os
if not os.path.exists(mil_config["h5ad_path"]):
    raise FileNotFoundError(f"H5AD file not found: {mil_config['h5ad_path']}")
    
print(f"✓ H5AD file found: {mil_config['h5ad_path']}")
print(f"✓ Output directory: {mil_config['output_dir']}")
print(f"✓ Cross-validation: {mil_config['n_repeats']} repeats × {mil_config['n_folds']} folds = {mil_config['n_repeats'] * mil_config['n_folds']} total runs")
print(f"✓ Feature types to test: {mil_config['feature_types']}")

# Initialize the patient stratification system
patient_stratification = PatientStratification(mil_config)
print("✓ MIL analysis system initialized successfully!")

=== Step 4: MIL Dataset Preparation and Configuration ===
✓ H5AD file found: /fs/ess/PAS2205/Yuzhou/Datasets/CTCL_pembro_data/CTCL_adata_256_256-Kronos.h5ad
✓ Output directory: /fs/ess/PAS2205/Yuzhou/Datasets/CTCL_pembro_data/mil_results/
✓ Cross-validation: 2 repeats × 5 folds = 10 total runs
✓ Feature types to test: ['pca50', 'pca100']
✓ MIL analysis system initialized successfully!


## Step 5: Cross-Validation Analysis Setup
Before running the full analysis, let's examine our data and set up the cross-validation framework.

In [6]:
print("=== Step 5: Cross-Validation Analysis Setup ===")

# Load and examine the H5AD data
import scanpy as sc
adata = sc.read_h5ad(mil_config["h5ad_path"])

print(f"Dataset Overview:")
print(f"  Total observations (patches): {adata.shape[0]:,}")
print(f"  Feature dimensions: {adata.shape[1]:,}")
print(f"  Patients/cores: {adata.obs[mil_config['patient_col']].nunique()}")

# Examine label distribution
if mil_config["label_col"] in adata.obs.columns:
    label_counts = adata.obs.groupby(mil_config["patient_col"])[mil_config["label_col"]].first().value_counts()
    print(f"\nPatient-level label distribution:")
    for label, count in label_counts.items():
        print(f"  {label}: {count} patients")
    
    # Check for class imbalance
    min_class = label_counts.min()
    max_class = label_counts.max()
    imbalance_ratio = max_class / min_class
    print(f"  Class imbalance ratio: {imbalance_ratio:.2f}")
    if imbalance_ratio > 3:
        print(" Significant class imbalance detected - consider stratified sampling")
else:
    print(f"Warning: Label column '{mil_config['label_col']}' not found in data")

# Examine patches per patient distribution
patches_per_patient = adata.obs[mil_config["patient_col"]].value_counts()
print(f"\nPatches per patient statistics:")
print(f"  Mean: {patches_per_patient.mean():.1f}")
print(f"  Median: {patches_per_patient.median():.1f}")
print(f"  Min: {patches_per_patient.min()}")
print(f"  Max: {patches_per_patient.max()}")
print(f"  Std: {patches_per_patient.std():.1f}")

# Display a few examples
print(f"\nExample patients and their patch counts:")
for patient, count in patches_per_patient.head(5).items():
    label = adata.obs[adata.obs[mil_config["patient_col"]] == patient][mil_config["label_col"]].iloc[0]
    print(f"  {patient}: {count} patches, label: {label}")

print("\n✓ Data exploration completed - ready for MIL analysis!")

=== Step 5: Cross-Validation Analysis Setup ===
Dataset Overview:
  Total observations (patches): 875
  Feature dimensions: 19,584
  Patients/cores: 14

Patient-level label distribution:
  NR: 7 patients
  R: 7 patients
  Class imbalance ratio: 1.00

Patches per patient statistics:
  Mean: 62.5
  Median: 70.0
  Min: 35
  Max: 105
  Std: 20.3

Example patients and their patch counts:
  Patients_8: 105 patches, label: NR
  Patients_1: 70 patches, label: R
  Patients_5: 70 patches, label: R
  Patients_4: 70 patches, label: NR
  Patients_6: 70 patches, label: NR

✓ Data exploration completed - ready for MIL analysis!


## Step 6: MIL Model Training and Evaluation
Now we'll run the complete MIL benchmarking across different feature types with repeated cross-validation.


In [7]:
print("=== Step 6: MIL Model Training and Evaluation ===")

# Initialize the MIL classifier
classifier = PatientStratificationClassifier(mil_config)

# Store results for all feature types
all_results = {}
summary_statistics = []

print(f"Starting comprehensive MIL benchmarking...")
print(f"This will run {len(mil_config['feature_types'])} feature types × {mil_config['n_repeats']} repeats × {mil_config['n_folds']} folds")
print(f"Total model training runs: {len(mil_config['feature_types']) * mil_config['n_repeats'] * mil_config['n_folds']}")

# Run benchmarking for each feature type
for i, feature_type in enumerate(mil_config['feature_types']):
    print(f"\n{'='*60}")
    print(f"BENCHMARKING FEATURE TYPE {i+1}/{len(mil_config['feature_types'])}: {feature_type.upper()}")
    print(f"{'='*60}")
    
    try:
        # Run repeated cross-validation for this feature type
        results_df = classifier.run_cross_validation(
            h5ad_path=mil_config["h5ad_path"],
            feature_type=feature_type
        )
        
        # Save detailed results and compute summary statistics
        summary_stats = classifier.save_results_and_plot(
            results_df=results_df,
            feature_type=feature_type,
            output_dir=mil_config["output_dir"]
        )
        
        # Store results
        all_results[feature_type] = results_df
        summary_statistics.append(summary_stats)
        
        # Display immediate results for this feature type
        mean_auc = results_df['test_auc'].mean()
        std_auc = results_df['test_auc'].std()
        print(f"\n {feature_type.upper()} Results Summary:")
        print(f"   Mean AUC: {mean_auc:.4f} ± {std_auc:.4f}")
        print(f"   95% CI: [{summary_stats['ci_lower']:.4f}, {summary_stats['ci_upper']:.4f}]")
        print(f"   Best fold AUC: {results_df['test_auc'].max():.4f}")
        print(f"   Worst fold AUC: {results_df['test_auc'].min():.4f}")
        
    except Exception as e:
        print(f" Error processing {feature_type}: {str(e)}")
        continue

print(f"\n{'='*60}")
print(" MIL BENCHMARKING COMPLETED!")
print(f"{'='*60}")

=== Step 6: MIL Model Training and Evaluation ===
Starting comprehensive MIL benchmarking...
This will run 2 feature types × 2 repeats × 5 folds
Total model training runs: 20

BENCHMARKING FEATURE TYPE 1/2: PCA50

 PCA50 Results Summary:
   Mean AUC: 0.5500 ± 0.3689
   95% CI: [0.2861, 0.8139]
   Best fold AUC: 1.0000
   Worst fold AUC: 0.0000

BENCHMARKING FEATURE TYPE 2/2: PCA100

 PCA100 Results Summary:
   Mean AUC: 0.5500 ± 0.3689
   95% CI: [0.2861, 0.8139]
   Best fold AUC: 1.0000
   Worst fold AUC: 0.0000

 MIL BENCHMARKING COMPLETED!


## Step 7: Results Analysis and Visualization
Finally, let's analyze and visualize the comprehensive results across all feature types.

In [8]:
print("=== Step 7: Results Analysis and Visualization ===")

# Convert to DataFrame for easy analysis
comparison_df = pd.DataFrame(summary_statistics)

# Create and save comprehensive visualization
patient_stratification._create_comparison_plot(comparison_df, mil_config["output_dir"])

    
print(f"  • Per-fold results: *_fold_auc.csv")
print(f"  • Summary statistics: *_summary.csv")
print(f"  • Combined comparison: combined_feature_comparison.csv")
print(f"  • Visualization: feature_type_comparison.png")

=== Step 7: Results Analysis and Visualization ===
  • Per-fold results: *_fold_auc.csv
  • Summary statistics: *_summary.csv
  • Combined comparison: combined_feature_comparison.csv
  • Visualization: feature_type_comparison.png
