In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import linregress
import seaborn as sns

from lifelines import CoxPHFitter
from lifelines.statistics import logrank_test
from lifelines.utils import concordance_index
from statsmodels.stats.outliers_influence import variance_inflation_factor
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.linear_model import LassoCV, LinearRegression
from sklearn.metrics import r2_score, mean_squared_error
from scipy.stats import pearsonr, f_oneway

import warnings
warnings.filterwarnings('ignore')

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

In [3]:
# Data Loading
data_path = './data/'
curated_mri_vif_df = pd.read_csv(data_path + '3_baseline_vif.csv')
curated_mri_vif_df = curated_mri_vif_df.dropna()
curated_mri_vif_df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 406 entries, 0 to 529
Data columns (total 64 columns):
 #   Column                                   Non-Null Count  Dtype  
---  ------                                   --------------  -----  
 0   age                                      406 non-null    float64
 1   fampd                                    406 non-null    int64  
 2   race_black                               406 non-null    bool   
 3   race_asian                               406 non-null    bool   
 4   race_other                               406 non-null    bool   
 5   sex                                      406 non-null    int64  
 6   educyrs                                  406 non-null    float64
 7   subgroup_gba                             406 non-null    bool   
 8   subgroup_lrrk2                           406 non-null    bool   
 9   subgroup_prkn                            406 non-null    bool   
 10  apoe_e4                                  406 non-null  

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from lifelines.utils import concordance_index
from datetime import datetime
import os
import random
import warnings
import sys
import argparse

warnings.filterwarnings('ignore')

# Data Loading
data_path = './data/'
curated_mri_vif_df = pd.read_csv(data_path + '3_baseline_vif.csv')
curated_mri_vif_df = curated_mri_vif_df.dropna()
curated_mri_vif_df.info()

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from lifelines.utils import concordance_index
from datetime import datetime
import os
import random
import warnings
warnings.filterwarnings('ignore')

# =============================================================================
#
# RESEARCH QUESTION 4: MULTI-TASK LEARNING (MTL) ANALYSIS - COMPLETE VERSION
# WITH TRAINING DATA PRESERVATION FOR ACCURATE EVALUATION
#
# Author: GitHub Copilot (Complete Version)
# Date: 9 September 2025
#
# =============================================================================

print("🔧 RQ4 CONFIGURATION: COMPLETE MTL WITH DATA PRESERVATION")
print("=" * 70)

# Define modality groups
modality_groups = {
    'demographic_clinical': [
        'age', 'fampd', 'race_black', 'race_asian', 'race_other', 'sex',
        'educyrs', 'duration_yrs', 'ledd', 'mseadlg', 'hy', 'domside_left', 
        'domside_symmetric', 'hvltrdly', 'lns', 'vltanim', 'quip', 'ess', 'pigd', 
        'updrs2_score', 'updrs3_score', 'sdmtotal', 'stai', 'moca', 'rem', 'gds', 
        'bmi', 'updrs1_score', 'bjlot', 'scopa', 'upsit_pctl'
    ],
    'genetic': ['apoe_e4', 'subgroup_gba', 'subgroup_lrrk2', 'subgroup_prkn'],
    'biomarkers': ['urate', 'csfsaa_positive_lbd_like', 'csfsaa_positive_msa_like', 'csfsaa_inconclusive'],
    'datscan': ['mia_caudate_l', 'mia_caudate_r', 'mia_putamen_l', 'mia_putamen_r'],
    'mri': [f'mri_pc{i+1}' for i in range(10)]
}

# Configuration
REGRESSION_TARGET = 'moca_slope_iqr_cleaned'
SURVIVAL_TIME_TARGET = 'time_to_hy3_plus'
SURVIVAL_EVENT_TARGET = 'event_occurred'
ALL_MODALITIES = ['demographic_clinical', 'genetic', 'biomarkers', 'datscan', 'mri']

# Baseline performance
BASELINE_PERFORMANCE = {
    'regression_r2': 0.21,
    'survival_c_index': 0.83,
    'regression_r2_baseline': 0.0,
    'survival_c_index_baseline': 0.5
}

# Hyperparameter ranges
N_RANDOM_SEARCH = 5000  # Adjust as needed
EPOCHS = 1000
RANDOM_STATE = 42
TEST_SIZE = 0.2

HYPERPARAMETER_RANGES = {
    'learning_rate': [0.0001, 0.0005, 0.001, 0.002],
    'batch_size': [32, 64, 128],
    'hidden_dim': [64, 128, 256],
    'alpha_regression': [0.01, 0.1, 0.3, 0.5, 0.7, 1.0, 1.5, 2.0],
    'alpha_survival': [0.01, 0.1, 0.3, 0.5, 0.7, 1.0, 1.5, 2.0],
    'l2_reg': [0.0, 0.0001, 0.001, 0.01],
    'dropout_rate': [0.2, 0.4],
    'n_layers': [2, 3],
    'activation': ['relu'],
    'batch_norm': [True, False],
    'early_stopping_patience': [80]
}

features_to_standardize = [
    'age', 'fampd', 'educyrs', 'apoe_e4', 'urate', 'mia_caudate_l', 'mia_caudate_r', 
    'mia_putamen_l', 'mia_putamen_r', 'duration_yrs', 'ledd', 'mseadlg', 'hy', 
    'hvltrdly', 'lns', 'vltanim', 'quip', 'ess', 'pigd', 'updrs2_score', 
    'updrs3_score', 'sdmtotal', 'stai', 'moca', 'rem', 'gds', 'bmi', 
    'updrs1_score', 'bjlot', 'scopa', 'upsit_pctl', 'mri_pc1', 'mri_pc2', 
    'mri_pc3', 'mri_pc4', 'mri_pc5', 'mri_pc6', 'mri_pc7', 'mri_pc8', 
    'mri_pc9', 'mri_pc10'
]

print(f"🎯 Targets: {REGRESSION_TARGET} + {SURVIVAL_TIME_TARGET}")
print(f"📊 Total features: {sum(len(modality_groups[mod]) for mod in ALL_MODALITIES)}")
print(f"⚙️ Experiments: {N_RANDOM_SEARCH}")

# Create directories
os.makedirs('results/mtl', exist_ok=True)
os.makedirs('results/models/mtl', exist_ok=True)

# Set seeds
torch.manual_seed(RANDOM_STATE)
np.random.seed(RANDOM_STATE)
random.seed(RANDOM_STATE)

# --- COMBINED SCORE CALCULATION ---

def calculate_corrected_combined_score(test_r2, test_c_index, method='normalized_improvement'):
    """Calculate corrected combined score"""
    
    if method == 'normalized_improvement':
        r2_baseline = BASELINE_PERFORMANCE['regression_r2_baseline']  # 0.0
        c_index_baseline = BASELINE_PERFORMANCE['survival_c_index_baseline']  # 0.5
        
        # Calculate relative improvement
        r2_improvement = (test_r2 - r2_baseline) / (1.0 - r2_baseline)
        c_index_improvement = (test_c_index - c_index_baseline) / (1.0 - c_index_baseline)
        
        # Handle negative values
        r2_improvement = max(0, r2_improvement)
        c_index_improvement = max(0, c_index_improvement)
        
        # Equal weight average
        combined_score = 0.5 * r2_improvement + 0.5 * c_index_improvement
        
        return combined_score, r2_improvement, c_index_improvement
    
    else:
        # Simple method
        return test_r2 + test_c_index, test_r2, test_c_index

# --- PARAMETER GENERATOR ---

def generate_random_params():
    """Generate random hyperparameter combination"""
    return {
        'learning_rate': random.choice(HYPERPARAMETER_RANGES['learning_rate']),
        'batch_size': random.choice(HYPERPARAMETER_RANGES['batch_size']),
        'hidden_dim': random.choice(HYPERPARAMETER_RANGES['hidden_dim']),
        'alpha_regression': random.choice(HYPERPARAMETER_RANGES['alpha_regression']),
        'alpha_survival': random.choice(HYPERPARAMETER_RANGES['alpha_survival']),
        'l2_reg': random.choice(HYPERPARAMETER_RANGES['l2_reg']),
        'dropout_rate': random.choice(HYPERPARAMETER_RANGES['dropout_rate']),
        'n_layers': random.choice(HYPERPARAMETER_RANGES['n_layers']),
        'activation': random.choice(HYPERPARAMETER_RANGES['activation']),
        'batch_norm': random.choice(HYPERPARAMETER_RANGES['batch_norm']),
        'early_stopping_patience': random.choice(HYPERPARAMETER_RANGES['early_stopping_patience'])
    }

# --- DATA PREPROCESSING ---

class CompleteMTLDataPreprocessor:
    def __init__(self):
        self.scaler = StandardScaler()
        self.feature_names = None
        self.fitted = False
        
    def prepare_and_split_data(self, df, test_size=0.2, random_state=42):
        """Complete data processing with train/test split preservation"""
        try:
            # 1. Get feature list
            feature_list = []
            for modality in ALL_MODALITIES:
                if modality in modality_groups:
                    feature_list.extend(modality_groups[modality])
            
            available_features = [f for f in feature_list if f in df.columns]
            
            # 2. Create complete dataset
            required_cols = available_features + [REGRESSION_TARGET, SURVIVAL_TIME_TARGET, SURVIVAL_EVENT_TARGET]
            complete_data = df[required_cols].copy()
            
            # 3. Convert targets to numeric
            for col in [REGRESSION_TARGET, SURVIVAL_TIME_TARGET, SURVIVAL_EVENT_TARGET]:
                complete_data[col] = pd.to_numeric(complete_data[col], errors='coerce')
            
            # 4. Remove NaN targets
            complete_data = complete_data.dropna(subset=[REGRESSION_TARGET, SURVIVAL_TIME_TARGET, SURVIVAL_EVENT_TARGET])
            
            if len(complete_data) < 100:
                print(f"   ❌ Insufficient samples: {len(complete_data)}")
                return None
            
            # 5. Stratified split
            try:
                y_quartiles = pd.qcut(complete_data[REGRESSION_TARGET], q=4, labels=['Q1', 'Q2', 'Q3', 'Q4'], duplicates='drop')
                train_idx, test_idx = train_test_split(
                    range(len(complete_data)), 
                    test_size=test_size, 
                    random_state=random_state, 
                    stratify=y_quartiles
                )
            except:
                train_idx, test_idx = train_test_split(
                    range(len(complete_data)), 
                    test_size=test_size, 
                    random_state=random_state
                )
            
            # 6. Split data
            train_data = complete_data.iloc[train_idx].copy()
            test_data = complete_data.iloc[test_idx].copy()
            
            # 7. Process features separately
            X_train = self._process_features(train_data[available_features], fit_scaler=True)
            X_test = self._process_features(test_data[available_features], fit_scaler=False)
            
            # 8. Return complete data splits
            return {
                'X_train': X_train,
                'X_test': X_test,
                'y_reg_train': train_data[REGRESSION_TARGET],
                'y_reg_test': test_data[REGRESSION_TARGET],
                'y_time_train': train_data[SURVIVAL_TIME_TARGET],
                'y_time_test': test_data[SURVIVAL_TIME_TARGET],
                'y_event_train': train_data[SURVIVAL_EVENT_TARGET],
                'y_event_test': test_data[SURVIVAL_EVENT_TARGET],
                'feature_names': available_features,
                'train_indices': train_idx,
                'test_indices': test_idx,
                'complete_data': complete_data
            }
            
        except Exception as e:
            print(f"   ❌ Data preparation error: {str(e)}")
            return None
    
    def _process_features(self, X, fit_scaler=True):
        """Process features: boolean conversion + standardization"""
        X_processed = X.copy()
        
        # Boolean conversion
        for col in X_processed.columns:
            if X_processed[col].dtype == 'bool':
                X_processed[col] = X_processed[col].astype(int)
            elif X_processed[col].dtype == 'object':
                X_processed[col] = pd.to_numeric(X_processed[col], errors='coerce')
        
        X_processed = X_processed.fillna(0)
        
        # Standardization
        features_to_scale = [f for f in features_to_standardize if f in X_processed.columns]
        
        if features_to_scale:
            if fit_scaler:
                X_processed[features_to_scale] = self.scaler.fit_transform(X_processed[features_to_scale])
                self.fitted = True
            elif self.fitted:
                X_processed[features_to_scale] = self.scaler.transform(X_processed[features_to_scale])
        
        return X_processed.astype(np.float32)

# --- PYTORCH COMPONENTS ---

class MTLDataset(Dataset):
    def __init__(self, X, y_regression, y_time, y_event):
        # Ensure data is numeric
        X_processed = X.copy()
        for col in X_processed.columns:
            if X_processed[col].dtype == 'bool':
                X_processed[col] = X_processed[col].astype(int)
            elif X_processed[col].dtype == 'object':
                X_processed[col] = pd.to_numeric(X_processed[col], errors='coerce')
        
        X_processed = X_processed.fillna(0).astype(np.float32)
        
        self.X = torch.tensor(X_processed.values, dtype=torch.float32)
        self.y_regression = torch.tensor(y_regression.values.astype(np.float32), dtype=torch.float32).view(-1, 1)
        self.y_time = torch.tensor(y_time.values.astype(np.float32), dtype=torch.float32).view(-1, 1)
        self.y_event = torch.tensor(y_event.values.astype(np.float32), dtype=torch.float32).view(-1, 1)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y_regression[idx], self.y_time[idx], self.y_event[idx]

class CompleteMTLNet(nn.Module):
    def __init__(self, n_features, params):
        super(CompleteMTLNet, self).__init__()
        
        hidden_dim = params['hidden_dim']
        n_layers = params['n_layers']
        dropout_rate = params['dropout_rate']
        batch_norm = params['batch_norm']
        
        self.activation_fn = nn.ReLU()
        
        # Shared encoder
        layers = []
        input_dim = n_features
        
        for i in range(n_layers):
            layers.append(nn.Linear(input_dim, hidden_dim))
            if batch_norm:
                layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(self.activation_fn)
            layers.append(nn.Dropout(dropout_rate))
            input_dim = hidden_dim
        
        self.shared_encoder = nn.Sequential(*layers)
        
        # Task-specific heads
        self.regression_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            self.activation_fn,
            nn.Dropout(dropout_rate / 2),
            nn.Linear(hidden_dim // 2, 1)
        )
        
        self.survival_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            self.activation_fn,
            nn.Dropout(dropout_rate / 2),
            nn.Linear(hidden_dim // 2, 1)
        )

    def forward(self, x):
        shared_features = self.shared_encoder(x)
        regression_output = self.regression_head(shared_features)
        survival_output = self.survival_head(shared_features)
        return regression_output, survival_output

def enhanced_cox_loss(log_hazard, durations, events, eps=1e-8):
    """Enhanced Cox loss with numerical stability"""
    if torch.sum(events) == 0:
        return torch.tensor(0.0, requires_grad=True, device=log_hazard.device)
    
    sorted_indices = torch.argsort(durations.view(-1), descending=True)
    log_hazard_sorted = log_hazard.view(-1)[sorted_indices]
    events_sorted = events.view(-1)[sorted_indices]
    
    exp_hazard = torch.exp(torch.clamp(log_hazard_sorted, -20, 20))
    cumsum_exp = torch.cumsum(exp_hazard, dim=0)
    log_risk_set = torch.log(cumsum_exp + eps)
    
    event_mask = events_sorted == 1
    if torch.sum(event_mask) == 0:
        return torch.tensor(0.0, requires_grad=True, device=log_hazard.device)
    
    pll = torch.sum(log_hazard_sorted[event_mask] - log_risk_set[event_mask])
    return -pll

def compute_l2_regularization(model, l2_reg):
    """Compute L2 regularization term"""
    if l2_reg == 0:
        return torch.tensor(0.0, device=next(model.parameters()).device)
    
    l2_loss = torch.tensor(0.0, device=next(model.parameters()).device)
    for param in model.parameters():
        l2_loss += torch.norm(param, p=2) ** 2
    return l2_reg * l2_loss

def compute_balanced_mtl_loss(reg_pred, surv_pred, reg_target, time_target, event_target, 
                             alpha_regression, alpha_survival, model, l2_reg):
    """Balanced MTL loss function"""
    # Regression loss
    mse_loss = nn.MSELoss()(reg_pred, reg_target)
    
    # Survival loss
    cox_loss = enhanced_cox_loss(surv_pred, time_target, event_target)
    
    # L2 regularization
    l2_loss = compute_l2_regularization(model, l2_reg)
    
    # Dynamic balance
    if not torch.isnan(mse_loss) and not torch.isnan(cox_loss) and mse_loss > 0 and cox_loss > 0:
        mse_scale = mse_loss.detach()
        cox_scale = cox_loss.detach()
        
        if mse_scale > 0 and cox_scale > 0:
            scale_factor = mse_scale / cox_scale
            adjusted_alpha_survival = alpha_survival * max(1.0, scale_factor / 10.0)
            total_loss = alpha_regression * mse_loss + adjusted_alpha_survival * cox_loss + l2_loss
        else:
            total_loss = alpha_regression * mse_loss + alpha_survival * cox_loss + l2_loss
    else:
        total_loss = alpha_regression * mse_loss + alpha_survival * cox_loss + l2_loss
    
    return total_loss, mse_loss, cox_loss, l2_loss

class EarlyStopping:
    def __init__(self, patience=50, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = float('inf')
        
    def __call__(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            return False
        else:
            self.counter += 1
            return self.counter >= self.patience

print("✅ Complete components defined!")

# --- COMPLETE MTL PIPELINE ---

class CompleteMTLPipeline:
    def __init__(self, params):
        self.params = params
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.preprocessor = CompleteMTLDataPreprocessor()
        
    def train_and_evaluate(self, df):
        """Train and evaluate complete MTL model with data preservation"""
        # Use complete data processing
        data = self.preprocessor.prepare_and_split_data(df)
        if data is None:
            return None
        
        try:
            # Create datasets
            train_dataset = MTLDataset(
                data['X_train'], data['y_reg_train'], 
                data['y_time_train'], data['y_event_train']
            )
            
            train_loader = DataLoader(
                train_dataset, 
                batch_size=min(self.params['batch_size'], len(train_dataset)), 
                shuffle=True,
                drop_last=False
            )
            
            # Initialize model
            model = CompleteMTLNet(
                n_features=train_dataset.X.shape[1], 
                params=self.params
            ).to(self.device)
            
            optimizer = torch.optim.Adam(
                model.parameters(), 
                lr=self.params['learning_rate']
            )
            
            early_stopping = EarlyStopping(patience=self.params['early_stopping_patience'])
            
            # Training loop
            model.train()
            train_losses = []
            
            for epoch in range(EPOCHS):
                epoch_losses = []
                epoch_mse_losses = []
                epoch_cox_losses = []
                
                for X_batch, y_reg_batch, y_time_batch, y_event_batch in train_loader:
                    X_batch = X_batch.to(self.device)
                    y_reg_batch = y_reg_batch.to(self.device)
                    y_time_batch = y_time_batch.to(self.device)
                    y_event_batch = y_event_batch.to(self.device)
                    
                    optimizer.zero_grad()
                    
                    reg_pred, surv_pred = model(X_batch)
                    
                    # Compute balanced MTL loss
                    total_loss, mse_loss, cox_loss, l2_loss = compute_balanced_mtl_loss(
                        reg_pred, surv_pred, y_reg_batch, y_time_batch, y_event_batch,
                        self.params['alpha_regression'], self.params['alpha_survival'], 
                        model, self.params['l2_reg']
                    )
                    
                    total_loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    optimizer.step()
                    
                    epoch_losses.append(total_loss.item())
                    epoch_mse_losses.append(mse_loss.item())
                    epoch_cox_losses.append(cox_loss.item())
                
                if epoch_losses:
                    avg_epoch_loss = np.mean(epoch_losses)
                    train_losses.append(avg_epoch_loss)
                    
                    if early_stopping(avg_epoch_loss):
                        break
            
            # Evaluation
            model.eval()
            with torch.no_grad():
                X_test_processed = data['X_test'].copy()
                
                for col in X_test_processed.columns:
                    if X_test_processed[col].dtype == 'bool':
                        X_test_processed[col] = X_test_processed[col].astype(int)
                    elif X_test_processed[col].dtype == 'object':
                        X_test_processed[col] = pd.to_numeric(X_test_processed[col], errors='coerce')
                
                X_test_processed = X_test_processed.fillna(0).astype(np.float32)
                X_test_tensor = torch.tensor(X_test_processed.values, dtype=torch.float32).to(self.device)
                
                reg_pred_test, surv_pred_test = model(X_test_tensor)
                
                # Regression metrics
                reg_pred_test_np = reg_pred_test.cpu().numpy().flatten()
                test_r2 = r2_score(data['y_reg_test'], reg_pred_test_np)
                test_mae = mean_absolute_error(data['y_reg_test'], reg_pred_test_np)
                test_rmse = np.sqrt(mean_squared_error(data['y_reg_test'], reg_pred_test_np))
                
                # Survival metrics
                surv_pred_test_np = surv_pred_test.cpu().numpy().flatten()
                test_c_index = concordance_index(
                    data['y_time_test'], 
                    -surv_pred_test_np,
                    data['y_event_test']
                )
            
            # Calculate corrected combined score
            combined_score, r2_normalized, c_index_normalized = calculate_corrected_combined_score(
                test_r2, test_c_index, method='normalized_improvement'
            )
            
            # Calculate improvements
            r2_improvement = test_r2 - BASELINE_PERFORMANCE['regression_r2']
            c_index_improvement = test_c_index - BASELINE_PERFORMANCE['survival_c_index']
            
            return {
                'n_features': len(data['feature_names']),
                'n_train': len(data['X_train']),
                'n_test': len(data['X_test']),
                'test_r2': test_r2,
                'test_mae': test_mae,
                'test_rmse': test_rmse,
                'test_c_index': test_c_index,
                'r2_improvement': r2_improvement,
                'c_index_improvement': c_index_improvement,
                'combined_score': combined_score,
                'combined_score_simple': test_r2 + test_c_index,
                'r2_normalized': r2_normalized,
                'c_index_normalized': c_index_normalized,
                'epochs_trained': len(train_losses),
                'final_train_loss': train_losses[-1] if train_losses else np.nan,
                'avg_mse_loss': np.mean(epoch_mse_losses) if epoch_mse_losses else np.nan,
                'avg_cox_loss': np.mean(epoch_cox_losses) if epoch_cox_losses else np.nan,
                'model': model,
                'scaler': self.preprocessor.scaler,
                'feature_names': data['feature_names'],
                # 🔥 KEY: Preserve training data for accurate evaluation
                'training_data': {
                    'X_train': data['X_train'],
                    'X_test': data['X_test'],
                    'y_reg_train': data['y_reg_train'],
                    'y_reg_test': data['y_reg_test'],
                    'y_time_train': data['y_time_train'],
                    'y_time_test': data['y_time_test'],
                    'y_event_train': data['y_event_train'],
                    'y_event_test': data['y_event_test'],
                    'train_indices': data['train_indices'],
                    'test_indices': data['test_indices'],
                    'reg_predictions': reg_pred_test_np,
                    'surv_predictions': surv_pred_test_np
                },
                **self.params
            }
            
        except Exception as e:
            print(f"      Training error: {str(e)}")
            return None

# --- MAIN EXECUTION ---

def run_complete_mtl_experiment(df):
    """Run complete MTL experiment with data preservation"""
    
    print("🚀 STARTING COMPLETE MTL EXPERIMENT")
    print("=" * 80)
    
    all_results = []
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    print(f"📊 Experiments: {N_RANDOM_SEARCH}")
    print(f"📅 Epochs per experiment: {EPOCHS}")
    
    print(f"\n🎯 BASELINE COMPARISON:")
    print(f"   • Regression R²: {BASELINE_PERFORMANCE['regression_r2']:.3f}")
    print(f"   • Survival C-index: {BASELINE_PERFORMANCE['survival_c_index']:.3f}")
    
    print(f"\n🔧 COMPLETE MTL FEATURES:")
    print(f"   • Training data preservation")
    print(f"   • Corrected data processing")
    print(f"   • Balanced loss function")
    print(f"   • Normalized combined score")
    
    # Track best results
    best_combined_score = -np.inf
    best_result = None
    successful_experiments = 0
    
    for experiment_id in range(1, N_RANDOM_SEARCH + 1):
        print(f"\n--- Experiment {experiment_id}/{N_RANDOM_SEARCH} ---")
        
        params = generate_random_params()
        print(f"Params: lr={params['learning_rate']}, bs={params['batch_size']}, "
              f"hd={params['hidden_dim']}, αr={params['alpha_regression']}, αs={params['alpha_survival']}")
        
        try:
            pipeline = CompleteMTLPipeline(params)
            result = pipeline.train_and_evaluate(df)
            
            if result is not None:
                successful_experiments += 1
                
                result.update({
                    'timestamp': timestamp,
                    'experiment_id': experiment_id
                })
                
                # Display key metrics
                print(f"📊 R²: {result['test_r2']:.4f} | C-idx: {result['test_c_index']:.4f}")
                print(f"📊 R² Δ: {result['r2_improvement']:+.4f} | C-idx Δ: {result['c_index_improvement']:+.4f}")
                print(f"📊 Combined (norm): {result['combined_score']:.4f} | Combined (simple): {result['combined_score_simple']:.4f}")
                print(f"📊 Epochs: {result['epochs_trained']}")
                
                all_results.append(result)
                
                # Use normalized combined score for best model selection
                if result['combined_score'] > best_combined_score:
                    best_combined_score = result['combined_score']
                    best_result = result.copy()
                    
                    # 🔥 Save best model with complete training data
                    model_path = f"results/models/mtl/best_complete_mtl_{timestamp}.pt"
                    torch.save({
                        'model_state_dict': result['model'].state_dict(),
                        'params': params,
                        'metrics': {
                            'test_r2': result['test_r2'],
                            'test_c_index': result['test_c_index'],
                            'combined_score': result['combined_score'],
                            'combined_score_simple': result['combined_score_simple'],
                            'r2_improvement': result['r2_improvement'],
                            'c_index_improvement': result['c_index_improvement']
                        },
                        'feature_names': result['feature_names'],
                        # 🔥 KEY: Save complete training data for accurate evaluation
                        'training_data': result['training_data'],
                        'scaler_params': {
                            'scaler_mean': result['scaler'].mean_ if hasattr(result['scaler'], 'mean_') else None,
                            'scaler_scale': result['scaler'].scale_ if hasattr(result['scaler'], 'scale_') else None,
                            'features_to_standardize': [f for f in features_to_standardize if f in result['feature_names']]
                        },
                        'architecture_config': {
                            'input_dim': len(result['feature_names']),
                            'hidden_dim': params['hidden_dim'],
                            'dropout_rate': params['dropout_rate'],
                            'n_layers': params['n_layers'],
                            'batch_norm': params['batch_norm']
                        }
                    }, model_path)
                    print(f"💾 New best model + training data saved! {timestamp}")
            
            else:
                print("⚠️ Training failed")
            
        except Exception as e:
            print(f"❌ Error: {str(e)}")
            continue
    
    print(f"\n✅ Completed: {successful_experiments}/{N_RANDOM_SEARCH} successful")
    return all_results, best_result, timestamp

# Execute experiment
if __name__ == '__main__':
    # Load data
    data_path = './data/'
    curated_mri_vif_df = pd.read_csv(data_path + '3_baseline_vif.csv')
    curated_mri_vif_df = curated_mri_vif_df.dropna()
    
    print("🎯 Starting Complete MTL Experiment")
    print(f"Data shape: {curated_mri_vif_df.shape}")
    
    # Run experiment
    all_results, best_result, timestamp = run_complete_mtl_experiment(curated_mri_vif_df)
    
    # --- RESULTS ANALYSIS ---
    if all_results:
        results_df = pd.DataFrame(all_results)
        
        print("\n" + "="*80)
        print("🏆 COMPLETE MTL EXPERIMENT RESULTS")
        print("="*80)
        
        # Best model
        if best_result:
            print(f"\n🥇 BEST COMPLETE MTL MODEL:")
            print(f"   • Test R²: {best_result['test_r2']:.4f} (Δ = {best_result['r2_improvement']:+.4f})")
            print(f"   • Test C-index: {best_result['test_c_index']:.4f} (Δ = {best_result['c_index_improvement']:+.4f})")
            print(f"   • Combined Score (normalized): {best_result['combined_score']:.4f}")
            print(f"   • Combined Score (simple): {best_result['combined_score_simple']:.4f}")
            print(f"   • R² normalized: {best_result['r2_normalized']:.4f}")
            print(f"   • C-index normalized: {best_result['c_index_normalized']:.4f}")
            print(f"   • Test samples: {best_result['n_test']}")
            print(f"   • Hyperparameters:")
            print(f"     - α_regression: {best_result['alpha_regression']}")
            print(f"     - α_survival: {best_result['alpha_survival']}")
            print(f"     - Learning rate: {best_result['learning_rate']}")
            print(f"     - Hidden dim: {best_result['hidden_dim']}")
            print(f"     - L2 reg: {best_result['l2_reg']}")
        
        # Summary statistics
        print(f"\n📊 EXPERIMENT SUMMARY:")
        print(f"   • Mean R²: {results_df['test_r2'].mean():.4f} ± {results_df['test_r2'].std():.4f}")
        print(f"   • Mean C-index: {results_df['test_c_index'].mean():.4f} ± {results_df['test_c_index'].std():.4f}")
        print(f"   • Mean Combined (norm): {results_df['combined_score'].mean():.4f} ± {results_df['combined_score'].std():.4f}")
        print(f"   • Mean Combined (simple): {results_df['combined_score_simple'].mean():.4f} ± {results_df['combined_score_simple'].std():.4f}")
        
        # Improvement analysis
        r2_improvements = results_df['r2_improvement'] > 0
        c_index_improvements = results_df['c_index_improvement'] > 0
        both_improvements = r2_improvements & c_index_improvements
        
        print(f"\n📈 IMPROVEMENT ANALYSIS:")
        print(f"   • R² improvements: {r2_improvements.sum()}/{len(results_df)} ({100*r2_improvements.mean():.1f}%)")
        print(f"   • C-index improvements: {c_index_improvements.sum()}/{len(results_df)} ({100*c_index_improvements.mean():.1f}%)")
        print(f"   • Both improvements: {both_improvements.sum()}/{len(results_df)} ({100*both_improvements.mean():.1f}%)")
        
        # Top 10 models
        top_10 = results_df.nlargest(10, 'combined_score')[
            ['test_r2', 'test_c_index', 'combined_score', 'combined_score_simple', 
             'r2_improvement', 'c_index_improvement', 'alpha_regression', 'alpha_survival', 
             'learning_rate', 'hidden_dim', 'l2_reg']
        ]
        
        print(f"\n🏆 TOP 10 MODELS (by normalized combined score):")
        print(top_10.round(4).to_string())
        
        # Save results
        results_file = f'results/mtl/complete_mtl_results_{timestamp}.csv'
        # Remove training_data from DataFrame for CSV export
        results_df_clean = results_df.drop(columns=['model', 'scaler', 'training_data'], errors='ignore')
        results_df_clean.to_csv(results_file, index=False)
        
        # Final conclusion
        print(f"\n🎯 FINAL CONCLUSION:")
        if best_result:
            if best_result['r2_improvement'] > 0 and best_result['c_index_improvement'] > 0:
                print("✅ COMPLETE MTL OUTPERFORMS on BOTH tasks!")
            elif best_result['r2_improvement'] > 0:
                print("⚠️ COMPLETE MTL OUTPERFORMS on regression only")
            elif best_result['c_index_improvement'] > 0:
                print("⚠️ COMPLETE MTL OUTPERFORMS on survival only")
            else:
                print("❌ COMPLETE MTL does NOT outperform baselines")
        
        print(f"\n💾 Results saved: {results_file}")
        
        # Training data verification
        if best_result and 'training_data' in best_result:
            training_data = best_result['training_data']
            print(f"\n📊 TRAINING DATA PRESERVED:")
            print(f"   • Training samples: {len(training_data['X_train'])}")
            print(f"   • Test samples: {len(training_data['X_test'])}")
            print(f"   • Features: {len(best_result['feature_names'])}")
            print(f"   • Predictions available: regression & survival")
        
    print("\n" + "🎉" * 30)
    print("COMPLETE MTL EXPERIMENT FINISHED")
    print("🎉" * 30)

🔧 RQ4 CONFIGURATION: MTL WITH CORRECTED COMBINED SCORE
🎯 Targets: moca_slope_iqr_cleaned + time_to_hy3_plus
📊 Total features: 53
⚙️ Experiments: 100
✅ Fixed components defined!
🎯 Starting Fixed MTL Experiment
Data shape: (406, 64)
🚀 STARTING FIXED MTL EXPERIMENT
📊 Experiments: 100
📅 Epochs per experiment: 200

🎯 BASELINE COMPARISON:
   • Regression R²: 0.210
   • Survival C-index: 0.830

🔧 FIXED MTL FEATURES:
   • Corrected data processing (no leakage)
   • Balanced loss function
   • Normalized combined score

--- Experiment 1/100 ---
Params: lr=0.0001, bs=32, hd=256, αr=0.7, αs=0.5
📊 R²: 0.0781 | C-idx: 0.7483
📊 R² Δ: -0.1319 | C-idx Δ: -0.0817
📊 Combined (norm): 0.2873 | Combined (simple): 0.8264
📊 Epochs: 200
💾 New best model saved!

--- Experiment 2/100 ---
Params: lr=0.0001, bs=32, hd=64, αr=0.5, αs=1.5
📊 R²: 0.0387 | C-idx: 0.8160
📊 R² Δ: -0.1713 | C-idx Δ: -0.0140
📊 Combined (norm): 0.3354 | Combined (simple): 0.8547
📊 Epochs: 200
💾 New best model saved!

--- Experiment 3/100 -

In [None]:
def inspect_model_file(model_path):
    import torch
    try:
        checkpoint = torch.load(model_path, map_location='cpu')
        print(f"🔍 Model file contents:")
        print(f"Keys in checkpoint: {list(checkpoint.keys())}")
        
        for key, value in checkpoint.items():
            if isinstance(value, dict):
                print(f"  {key}: {list(value.keys())}")
            else:
                print(f"  {key}: {type(value)}")
                
        return checkpoint
    except Exception as e:
        print(f"❌ Error loading model: {e}")
        return None
    
model_path = "./results/models/mtl/best_final_mtl.pt"
checkpoint = inspect_model_file(model_path)

❌ Error loading model: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL numpy._core.multiarray.scalar was not an allowed global by default. Please use `torch.serialization.add_safe_globals([numpy._core.multiarray.scalar])` or the `torch.serialization.safe_globals([numpy._core.multiarray.scalar])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to 

In [None]:
def diagnose_feature_mismatch(model_path, df):
    """
    Detailed diagnosis of feature mismatch issues
    """
    import torch
    import pandas as pd
    import numpy as np
    
    print("🔍 DIAGNOSING FEATURE MISMATCH")
    print("=" * 60)
    
    # 1. Load model file
    try:
        checkpoint = torch.load(model_path, map_location='cpu')
        expected_input_dim = checkpoint['architecture_config']['input_dim']
        print(f"✅ Model expects: {expected_input_dim} features")
    except Exception as e:
        print(f"❌ Error loading model: {e}")
        return
    
    # 2. Get theoretical ALL_MODALITIES feature list
    print(f"\n📊 THEORETICAL FEATURES (ALL_MODALITIES):")
    
    all_theoretical_features = []
    for modality in ALL_MODALITIES:
        if modality in modality_groups:
            features = modality_groups[modality]
            all_theoretical_features.extend(features)
            print(f"   • {modality}: {len(features)} features")
            print(f"     {features[:5]}..." if len(features) > 5 else f"     {features}")
    
    print(f"\n   📋 Total theoretical features: {len(all_theoretical_features)}")
    
    # 3. Check features actually present in dataframe
    print(f"\n📈 ACTUAL FEATURES IN DATAFRAME:")
    available_features = [f for f in all_theoretical_features if f in df.columns]
    missing_features = [f for f in all_theoretical_features if f not in df.columns]
    
    print(f"   • Available in df: {len(available_features)} features")
    print(f"   • Missing from df: {len(missing_features)} features")
    
    # 4. Display missing features
    if missing_features:
        print(f"\n❌ MISSING FEATURES:")
        for modality in ALL_MODALITIES:
            modality_features = modality_groups[modality]
            missing_in_modality = [f for f in modality_features if f in missing_features]
            if missing_in_modality:
                print(f"   • {modality} missing: {missing_in_modality}")
    
    # 5. Check data type issues
    print(f"\n🔍 FEATURE TYPE ANALYSIS:")
    numeric_features = df[available_features].select_dtypes(include=[np.number]).columns.tolist()
    non_numeric_features = df[available_features].select_dtypes(exclude=[np.number]).columns.tolist()
    
    print(f"   • Numeric features: {len(numeric_features)}")
    print(f"   • Non-numeric features: {len(non_numeric_features)}")
    
    if non_numeric_features:
        print(f"   • Non-numeric: {non_numeric_features}")
        
        # Check data types and samples of non-numeric features
        for col in non_numeric_features[:5]:  # Only show first 5
            print(f"     - {col}: dtype={df[col].dtype}, unique={df[col].nunique()}")
            print(f"       Sample values: {df[col].unique()[:3].tolist()}")
    
    # 6. Calculate final available feature count
    final_available = len(numeric_features)
    print(f"\n📊 FINAL SUMMARY:")
    print(f"   • Model expects: {expected_input_dim} features")
    print(f"   • Theoretical total: {len(all_theoretical_features)} features")
    print(f"   • Actually available: {len(available_features)} features")
    print(f"   • Numeric available: {final_available} features")
    print(f"   • Shortfall: {expected_input_dim - final_available} features")
    
    # 7. Suggest solutions
    print(f"\n💡 RECOMMENDATIONS:")
    if missing_features:
        print(f"   1. Add missing features to dataframe or remove from model training")
    if non_numeric_features:
        print(f"   2. Convert non-numeric features to numeric")
    if final_available < expected_input_dim:
        print(f"   3. Use feature padding (add zeros) to match expected dimensions")
        print(f"   4. Or retrain model with available features only")
    
    # 8. Return detailed information
    return {
        'expected_features': expected_input_dim,
        'theoretical_features': all_theoretical_features,
        'available_features': available_features,
        'missing_features': missing_features,
        'numeric_features': numeric_features,
        'non_numeric_features': non_numeric_features,
        'final_count': final_available,
        'shortfall': expected_input_dim - final_available
    }

# Run diagnosis
model_path = "results/models/rq4_mtl_experiment_20250909_161400.pth"
diagnosis = diagnose_feature_mismatch(model_path, curated_mri_vif_df)

🔍 DIAGNOSING FEATURE MISMATCH
✅ Model expects: 53 features

📊 THEORETICAL FEATURES (ALL_MODALITIES):
   • demographic_clinical: 31 features
     ['age', 'fampd', 'race_black', 'race_asian', 'race_other']...
   • genetic: 4 features
     ['apoe_e4', 'subgroup_gba', 'subgroup_lrrk2', 'subgroup_prkn']
   • biomarkers: 4 features
     ['urate', 'csfsaa_positive_lbd_like', 'csfsaa_positive_msa_like', 'csfsaa_inconclusive']
   • datscan: 4 features
     ['mia_caudate_l', 'mia_caudate_r', 'mia_putamen_l', 'mia_putamen_r']
   • mri: 10 features
     ['mri_pc1', 'mri_pc2', 'mri_pc3', 'mri_pc4', 'mri_pc5']...

   📋 Total theoretical features: 53

📈 ACTUAL FEATURES IN DATAFRAME:
   • Available in df: 53 features
   • Missing from df: 0 features

🔍 FEATURE TYPE ANALYSIS:
   • Numeric features: 42
   • Non-numeric features: 11
   • Non-numeric: ['race_black', 'race_asian', 'race_other', 'domside_left', 'domside_symmetric', 'subgroup_gba', 'subgroup_lrrk2', 'subgroup_prkn', 'csfsaa_positive_lbd_like