In [21]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datasets import load_dataset
from sklearn.preprocessing import StandardScaler
import umap.umap_ as umap
from tqdm.auto import tqdm
import os
import torch
from collections import Counter

# Set the random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Constants
OUTPUT_DIR = "multimodal_analysis_results"
os.makedirs(OUTPUT_DIR, exist_ok=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Selected cancer types
TARGET_CANCER_TYPES = ['TCGA-KIRC', 'TCGA-OV', 'TCGA-BRCA']

def get_id_column(df):
    """
    Determine the appropriate ID column in a dataframe.
    Check for common patient ID column names and return the first one found.
    """
    possible_id_columns = ['case_submitter_id', 'PatientID', 'patient_id', 'ID']
    
    for col in possible_id_columns:
        if col in df.columns:
            return col
    
    # If no known ID column is found, print all columns and raise an error
    print(f"Available columns: {df.columns.tolist()}")
    raise ValueError("No patient ID column found in dataframe.")

def process_clinical_embeddings():
    """
    Process clinical embeddings using the embedding_shape field if available
    """
    print("Loading clinical data...")
    clinical_data = load_dataset("Lab-Rasool/TCGA", "clinical", split="gatortron").to_pandas()
    print(f"Loaded {len(clinical_data)} total clinical samples")
    
    # Filter by cancer type
    if 'project_id' in clinical_data.columns:
        clinical_data = clinical_data[clinical_data['project_id'].isin(TARGET_CANCER_TYPES)]
        print(f"After filtering for {TARGET_CANCER_TYPES}, found {len(clinical_data)} clinical samples")
    
    # Remove rows with null embeddings
    clinical_data = clinical_data.dropna(subset=["embedding"])
    print(f"After removing null embeddings, {len(clinical_data)} clinical samples remain")
    
    # Get ID column
    clinical_id_col = get_id_column(clinical_data)
    
    # Process embeddings
    processed_embeddings = []
    
    if "embedding_shape" in clinical_data.columns:
        print("Using embedding_shape for clinical embeddings")
        for idx, row in tqdm(clinical_data.iterrows(), desc="Processing clinical embeddings", total=len(clinical_data)):
            try:
                emb = np.frombuffer(row["embedding"], dtype=np.float32)
                shape = row["embedding_shape"]
                reshaped_emb = emb.reshape(shape)
                processed_embeddings.append(reshaped_emb)
            except Exception as e:
                print(f"Error processing clinical embedding at index {idx}: {e}")
                # Try to find a valid embedding to determine shape
                if len(processed_embeddings) > 0:
                    processed_embeddings.append(np.zeros_like(processed_embeddings[0]))
                else:
                    processed_embeddings.append(np.zeros(1024, dtype=np.float32))
    else:
        print("No embedding_shape for clinical embeddings, using raw buffers")
        for idx, row in tqdm(clinical_data.iterrows(), desc="Processing clinical embeddings", total=len(clinical_data)):
            try:
                emb = np.frombuffer(row["embedding"], dtype=np.float32)
                processed_embeddings.append(emb)
            except Exception as e:
                print(f"Error processing clinical embedding at index {idx}: {e}")
                if len(processed_embeddings) > 0:
                    processed_embeddings.append(np.zeros_like(processed_embeddings[0]))
                else:
                    processed_embeddings.append(np.zeros(1024, dtype=np.float32))
    
    # Create DataFrame with patient IDs and embeddings
    clinical_df = pd.DataFrame({
        'patient_id': clinical_data[clinical_id_col],
        'cancer_type': clinical_data['project_id'],
        'modality': 'clinical'
    })
    clinical_df['embeddings'] = processed_embeddings
    
    return clinical_df

def process_pathology_embeddings(target_patient_ids=None):
    """
    Process pathology embeddings using the embedding_shape field if available
    """
    print("Loading pathology report data...")
    pathology_data = load_dataset("Lab-Rasool/TCGA", "pathology_report", split="gatortron").to_pandas()
    print(f"Loaded {len(pathology_data)} total pathology samples")
    
    # Get ID column
    pathology_id_col = get_id_column(pathology_data)
    
    # Filter by patient IDs if provided
    if target_patient_ids:
        original_count = len(pathology_data)
        pathology_data = pathology_data[pathology_data[pathology_id_col].isin(target_patient_ids)]
        print(f"After filtering by patient IDs, found {len(pathology_data)} pathology samples out of {original_count}")
    
    # Remove rows with null embeddings
    pathology_data = pathology_data.dropna(subset=["embedding"])
    print(f"After removing null embeddings, {len(pathology_data)} pathology samples remain")
    
    # Process embeddings
    processed_embeddings = []
    
    if "embedding_shape" in pathology_data.columns:
        print("Using embedding_shape for pathology embeddings")
        for idx, row in tqdm(pathology_data.iterrows(), desc="Processing pathology embeddings", total=len(pathology_data)):
            try:
                emb = np.frombuffer(row["embedding"], dtype=np.float32)
                shape = row["embedding_shape"]
                reshaped_emb = emb.reshape(shape)
                processed_embeddings.append(reshaped_emb)
            except Exception as e:
                print(f"Error processing pathology embedding at index {idx}: {e}")
                if len(processed_embeddings) > 0:
                    processed_embeddings.append(np.zeros_like(processed_embeddings[0]))
                else:
                    processed_embeddings.append(np.zeros(1024, dtype=np.float32))
    else:
        print("No embedding_shape for pathology embeddings, using raw buffers")
        for idx, row in tqdm(pathology_data.iterrows(), desc="Processing pathology embeddings", total=len(pathology_data)):
            try:
                emb = np.frombuffer(row["embedding"], dtype=np.float32)
                processed_embeddings.append(emb)
            except Exception as e:
                print(f"Error processing pathology embedding at index {idx}: {e}")
                if len(processed_embeddings) > 0:
                    processed_embeddings.append(np.zeros_like(processed_embeddings[0]))
                else:
                    processed_embeddings.append(np.zeros(1024, dtype=np.float32))
    
    # Create DataFrame with patient IDs and embeddings
    pathology_df = pd.DataFrame({
        'patient_id': pathology_data[pathology_id_col],
        'modality': 'pathology_report'
    })
    pathology_df['embeddings'] = processed_embeddings
    
    return pathology_df

def process_radiology_embeddings(target_patient_ids=None):
    """
    Process radiology embeddings using the embedding_shape field to correctly reshape data
    """
    print("Loading radiology data...")
    radiology_data = load_dataset("Lab-Rasool/TCGA", "radiology", split="radimagenet").to_pandas()
    print(f"Loaded {len(radiology_data)} total radiology samples")
    
    # Get ID column
    radiology_id_col = get_id_column(radiology_data)
    
    # Filter by patient IDs if provided
    if target_patient_ids:
        original_count = len(radiology_data)
        radiology_data = radiology_data[radiology_data[radiology_id_col].isin(target_patient_ids)]
        print(f"After filtering by patient IDs, found {len(radiology_data)} radiology samples out of {original_count}")
    
    # Remove rows with null embeddings
    radiology_data = radiology_data.dropna(subset=["embedding"])
    print(f"After removing null embeddings, {len(radiology_data)} radiology samples remain")
    
    # Process embeddings
    processed_embeddings = []
    
    if "embedding_shape" in radiology_data.columns:
        print("Using embedding_shape for radiology embeddings")
        for idx, row in tqdm(radiology_data.iterrows(), desc="Processing radiology embeddings", total=len(radiology_data)):
            try:
                emb = np.frombuffer(row["embedding"], dtype=np.float32)
                shape = row["embedding_shape"]
                
                # Sometimes embedding_shape is a string, check and convert if needed
                if isinstance(shape, str):
                    shape = eval(shape)  # Convert string to tuple
                
                reshaped_emb = emb.reshape(shape)
                
                # If multi-dimensional, flatten to 1D for consistency
                if len(reshaped_emb.shape) > 1:
                    # Take mean along all dimensions except the last
                    flattened_emb = np.mean(reshaped_emb, axis=tuple(range(len(reshaped_emb.shape)-1)))
                else:
                    flattened_emb = reshaped_emb
                
                processed_embeddings.append(flattened_emb)
            except Exception as e:
                print(f"Error processing radiology embedding at index {idx}: {e}")
                if len(processed_embeddings) > 0:
                    processed_embeddings.append(np.zeros_like(processed_embeddings[0]))
                else:
                    processed_embeddings.append(np.zeros(1000, dtype=np.float32))
    else:
        print("No embedding_shape for radiology embeddings, using raw buffers")
        for idx, row in tqdm(radiology_data.iterrows(), desc="Processing radiology embeddings", total=len(radiology_data)):
            try:
                # Check if embedding is a list
                if isinstance(row["embedding"], list):
                    # Process list of embeddings
                    slice_embs = []
                    for e in row["embedding"]:
                        if e is not None:
                            slice_embs.append(np.frombuffer(e, dtype=np.float32))
                    
                    if slice_embs:
                        # Average embeddings
                        avg_emb = np.mean(slice_embs, axis=0)
                        processed_embeddings.append(avg_emb)
                    else:
                        # Default empty embedding
                        processed_embeddings.append(np.zeros(1000, dtype=np.float32))
                else:
                    # Process single embedding
                    emb = np.frombuffer(row["embedding"], dtype=np.float32)
                    processed_embeddings.append(emb)
            except Exception as e:
                print(f"Error processing radiology embedding at index {idx}: {e}")
                if len(processed_embeddings) > 0:
                    processed_embeddings.append(np.zeros_like(processed_embeddings[0]))
                else:
                    processed_embeddings.append(np.zeros(1000, dtype=np.float32))
    
    # Ensure all embeddings have the same length
    lengths = [len(emb) for emb in processed_embeddings]
    unique_lengths = set(lengths)
    
    if len(unique_lengths) > 1:
        print(f"Found {len(unique_lengths)} different embedding lengths: {unique_lengths}")
        # Standardize to the most common length
        length_counts = Counter(lengths)
        most_common_length = length_counts.most_common(1)[0][0]
        print(f"Standardizing all embeddings to length {most_common_length}")
        
        standardized_embeddings = []
        for emb in processed_embeddings:
            if len(emb) < most_common_length:
                # Pad with zeros
                padded = np.zeros(most_common_length, dtype=np.float32)
                padded[:len(emb)] = emb
                standardized_embeddings.append(padded)
            else:
                # Truncate
                standardized_embeddings.append(emb[:most_common_length])
        
        processed_embeddings = standardized_embeddings
    
    # Create DataFrame with patient IDs and embeddings
    radiology_df = pd.DataFrame({
        'patient_id': radiology_data[radiology_id_col],
        'modality': 'radiology'
    })
    
    # Add cancer_type if available
    if 'project_id' in radiology_data.columns:
        radiology_df['cancer_type'] = radiology_data['project_id']
    
    radiology_df['embeddings'] = processed_embeddings
    
    return radiology_df

def process_molecular_embeddings(target_patient_ids=None):
    """
    Process molecular embeddings using the embedding_shape field if available
    """
    print("Loading molecular data...")
    molecular_data = load_dataset("Lab-Rasool/TCGA", "molecular", split="senmo").to_pandas()
    print(f"Loaded {len(molecular_data)} total molecular samples")
    print(f"Molecular data columns: {molecular_data.columns.tolist()}")
    
    # Get ID column
    molecular_id_col = get_id_column(molecular_data)
    
    # Filter by patient IDs if provided
    if target_patient_ids:
        original_count = len(molecular_data)
        molecular_data = molecular_data[molecular_data[molecular_id_col].isin(target_patient_ids)]
        print(f"After filtering by patient IDs, found {len(molecular_data)} molecular samples out of {original_count}")
    
    # Determine embedding column
    if "embedding" in molecular_data.columns:
        embedding_col = "embedding"
    elif "Embeddings" in molecular_data.columns:
        embedding_col = "Embeddings"
    else:
        # Try to find a column with 'embed' in the name
        embed_cols = [col for col in molecular_data.columns if 'embed' in col.lower()]
        if embed_cols:
            embedding_col = embed_cols[0]
            print(f"Using {embedding_col} for molecular embeddings")
        else:
            print("No embedding column found in molecular data")
            return pd.DataFrame(columns=['patient_id', 'modality', 'cancer_type', 'embeddings'])
    
    # Remove rows with null embeddings
    molecular_data = molecular_data.dropna(subset=[embedding_col])
    print(f"After removing null embeddings, {len(molecular_data)} molecular samples remain")
    
    # Process embeddings
    processed_embeddings = []
    
    if "embedding_shape" in molecular_data.columns:
        print("Using embedding_shape for molecular embeddings")
        for idx, row in tqdm(molecular_data.iterrows(), desc="Processing molecular embeddings", total=len(molecular_data)):
            try:
                emb = np.frombuffer(row[embedding_col], dtype=np.float32)
                shape = row["embedding_shape"]
                reshaped_emb = emb.reshape(shape)
                processed_embeddings.append(reshaped_emb)
            except Exception as e:
                print(f"Error processing molecular embedding at index {idx}: {e}")
                if len(processed_embeddings) > 0:
                    processed_embeddings.append(np.zeros_like(processed_embeddings[0]))
                else:
                    processed_embeddings.append(np.zeros(48, dtype=np.float32))
    else:
        print("No embedding_shape for molecular embeddings, using raw buffers")
        for idx, row in tqdm(molecular_data.iterrows(), desc="Processing molecular embeddings", total=len(molecular_data)):
            try:
                emb = np.frombuffer(row[embedding_col], dtype=np.float32)
                processed_embeddings.append(emb)
            except Exception as e:
                print(f"Error processing molecular embedding at index {idx}: {e}")
                if len(processed_embeddings) > 0:
                    processed_embeddings.append(np.zeros_like(processed_embeddings[0]))
                else:
                    processed_embeddings.append(np.zeros(48, dtype=np.float32))
    
    # Create DataFrame with patient IDs and embeddings
    molecular_df = pd.DataFrame({
        'patient_id': molecular_data[molecular_id_col],
        'modality': 'molecular'
    })
    
    # Add cancer_type if available
    if 'project_id' in molecular_data.columns:
        molecular_df['cancer_type'] = molecular_data['project_id']
    
    molecular_df['embeddings'] = processed_embeddings
    
    return molecular_df

def load_multimodal_data():
    """
    Load and process embeddings from all modalities
    """
    # Process clinical data first
    clinical_df = process_clinical_embeddings()
    
    # Get patient IDs from clinical data to filter other modalities
    target_patient_ids = set(clinical_df['patient_id'])
    print(f"Found {len(target_patient_ids)} unique patients in clinical data")
    
    # Process other modalities with patient ID filtering
    pathology_df = process_pathology_embeddings(target_patient_ids)
    radiology_df = process_radiology_embeddings(target_patient_ids)
    molecular_df = process_molecular_embeddings(target_patient_ids)
    
    # Add cancer type to modalities that don't have it
    # Get mapping from clinical data
    patient_to_cancer = dict(zip(clinical_df['patient_id'], clinical_df['cancer_type']))
    
    # Update pathology cancer type
    if 'cancer_type' not in pathology_df.columns:
        pathology_df['cancer_type'] = pathology_df['patient_id'].map(
            lambda x: patient_to_cancer.get(x, TARGET_CANCER_TYPES[0])
        )
    
    # Update radiology cancer type if needed
    if 'cancer_type' not in radiology_df.columns:
        radiology_df['cancer_type'] = radiology_df['patient_id'].map(
            lambda x: patient_to_cancer.get(x, TARGET_CANCER_TYPES[0])
        )
    
    # Update molecular cancer type if needed
    if 'cancer_type' not in molecular_df.columns:
        molecular_df['cancer_type'] = molecular_df['patient_id'].map(
            lambda x: patient_to_cancer.get(x, TARGET_CANCER_TYPES[0])
        )
    
    # Print summary
    print(f"\nFinal summary:")
    print(f"Clinical data: {len(clinical_df)} samples")
    print(f"Pathology data: {len(pathology_df)} samples")
    print(f"Radiology data: {len(radiology_df)} samples")
    print(f"Molecular data: {len(molecular_df)} samples")
    
    return clinical_df, pathology_df, radiology_df, molecular_df

def align_patient_data(clinical_df, pathology_df, radiology_df, molecular_df):
    """
    Find patients that have data in all or multiple modalities and align their data
    """
    print("Aligning patient data across modalities...")
    
    # Find common patients across modalities
    clinical_patients = set(clinical_df['patient_id'])
    pathology_patients = set(pathology_df['patient_id'])
    radiology_patients = set(radiology_df['patient_id'])
    molecular_patients = set(molecular_df['patient_id'])
    
    print(f"Clinical patients: {len(clinical_patients)}")
    print(f"Pathology patients: {len(pathology_patients)}")
    print(f"Radiology patients: {len(radiology_patients)}")
    print(f"Molecular patients: {len(molecular_patients)}")
    
    # Find patients with data in multiple modalities
    print(f"Patients in clinical and pathology: {len(clinical_patients & pathology_patients)}")
    print(f"Patients in clinical and radiology: {len(clinical_patients & radiology_patients)}")
    print(f"Patients in clinical and molecular: {len(clinical_patients & molecular_patients)}")
    
    # Find patients with data in all modalities
    all_modalities = clinical_patients & pathology_patients & radiology_patients & molecular_patients
    print(f"Found {len(all_modalities)} patients with data in all four modalities")
    
    # Choose patients for analysis
    if len(all_modalities) >= 50:  # If we have enough patients with all modalities
        print(f"Using {len(all_modalities)} patients with data in all modalities")
        common_patients = all_modalities
    else:
        # Find patients with data in at least 3 modalities
        at_least_three = (
            (clinical_patients & pathology_patients & radiology_patients) | 
            (clinical_patients & pathology_patients & molecular_patients) |
            (clinical_patients & radiology_patients & molecular_patients)
        )
        
        if len(at_least_three) >= 50:
            print(f"Using {len(at_least_three)} patients with data in at least 3 modalities")
            common_patients = at_least_three
        else:
            # Use patients with clinical data + at least one other modality
            clinical_plus = (
                (clinical_patients & pathology_patients) |
                (clinical_patients & radiology_patients) |
                (clinical_patients & molecular_patients)
            )
            
            if len(clinical_plus) >= 50:
                print(f"Using {len(clinical_plus)} patients with clinical data + at least one other modality")
                common_patients = clinical_plus
            else:
                # Fallback to just clinical patients
                print(f"Using {len(clinical_patients)} patients with clinical data")
                common_patients = clinical_patients
    
    # Get cancer type distribution for chosen patients
    patient_cancer_types = clinical_df[clinical_df['patient_id'].isin(common_patients)]['cancer_type'].value_counts()
    print("\nCancer type distribution for selected patients:")
    print(patient_cancer_types)
    
    # # Limit to a reasonable number of patients for visualization (max 500)
    # if len(common_patients) > 500:
    #     print(f"Limiting to 500 patients for visualization (from {len(common_patients)} total)")
    #     # Try to balance cancer types
    #     patients_by_cancer = {}
    #     for patient_id in common_patients:
    #         # Find cancer type
    #         cancer_type = None
    #         if patient_id in set(clinical_df['patient_id']):
    #             cancer_type = clinical_df[clinical_df['patient_id'] == patient_id]['cancer_type'].iloc[0]
            
    #         if cancer_type:
    #             if cancer_type not in patients_by_cancer:
    #                 patients_by_cancer[cancer_type] = []
    #             patients_by_cancer[cancer_type].append(patient_id)
        
    #     # Select equal numbers from each cancer type if possible
    #     balanced_patients = []
    #     patients_per_type = 500 // len(patients_by_cancer)
        
    #     for cancer_type, patients in patients_by_cancer.items():
    #         balanced_patients.extend(patients[:patients_per_type])
        
    #     # If we need more, add from the largest categories
    #     while len(balanced_patients) < 500 and any(len(p) > patients_per_type for p in patients_by_cancer.values()):
    #         for cancer_type, patients in sorted(patients_by_cancer.items(), key=lambda x: len(x[1]), reverse=True):
    #             if len(balanced_patients) >= 500:
    #                 break
    #             if len(patients) > patients_per_type:
    #                 balanced_patients.append(patients[patients_per_type])
    #                 patients_per_type += 1
        
    #     common_patients = set(balanced_patients)
    
    # Filter dataframes to keep only common patients
    clinical_filtered = clinical_df[clinical_df['patient_id'].isin(common_patients)]
    pathology_filtered = pathology_df[pathology_df['patient_id'].isin(common_patients)]
    radiology_filtered = radiology_df[radiology_df['patient_id'].isin(common_patients)]
    molecular_filtered = molecular_df[molecular_df['patient_id'].isin(common_patients)]
    
    # Create aligned dataframe with consistent patient ordering
    common_patients_list = sorted(list(common_patients))
    aligned_data = pd.DataFrame({'patient_id': common_patients_list})
    
    # Merge cancer type
    aligned_data = aligned_data.merge(
        clinical_filtered[['patient_id', 'cancer_type']],
        on='patient_id',
        how='left'
    )
    
    # Fill missing cancer types
    for df in [pathology_filtered, radiology_filtered, molecular_filtered]:
        if len(df) > 0:
            missing_mask = aligned_data['cancer_type'].isna()
            if missing_mask.any():
                for idx, row in aligned_data[missing_mask].iterrows():
                    patient_match = df[df['patient_id'] == row['patient_id']]
                    if len(patient_match) > 0 and 'cancer_type' in patient_match.columns:
                        aligned_data.loc[idx, 'cancer_type'] = patient_match['cancer_type'].iloc[0]
    
    # Fill any still missing cancer types with default
    missing_mask = aligned_data['cancer_type'].isna()
    if missing_mask.any():
        print(f"Still missing cancer type for {missing_mask.sum()} patients. Using default.")
        aligned_data.loc[missing_mask, 'cancer_type'] = TARGET_CANCER_TYPES[0]
    
    # Get embedding shapes for defaults
    clinical_embedding_shape = None
    pathology_embedding_shape = None
    radiology_embedding_shape = None
    molecular_embedding_shape = None
    
    if len(clinical_filtered) > 0:
        clinical_embedding_shape = clinical_filtered['embeddings'].iloc[0].shape
    if len(pathology_filtered) > 0:
        pathology_embedding_shape = pathology_filtered['embeddings'].iloc[0].shape
    if len(radiology_filtered) > 0:
        radiology_embedding_shape = radiology_filtered['embeddings'].iloc[0].shape
    if len(molecular_filtered) > 0:
        molecular_embedding_shape = molecular_filtered['embeddings'].iloc[0].shape
    
    # Use default shapes if not found
    if clinical_embedding_shape is None:
        clinical_embedding_shape = (1024,)
    if pathology_embedding_shape is None:
        pathology_embedding_shape = (1024,)
    if radiology_embedding_shape is None:
        radiology_embedding_shape = (1000,)
    if molecular_embedding_shape is None:
        molecular_embedding_shape = (48,)
    
    # Initialize embedding columns as object type to hold arrays
    aligned_data['clinical_embedding'] = None
    aligned_data['pathology_embedding'] = None
    aligned_data['radiology_embedding'] = None
    aligned_data['molecular_embedding'] = None
    
    # Add embeddings for each modality
    for patient_id in tqdm(common_patients_list, desc="Aligning embeddings"):
        # Get the index for this patient
        patient_idx = aligned_data.index[aligned_data['patient_id'] == patient_id].tolist()[0]
        
        # Add clinical embeddings
        clinical_match = clinical_filtered[clinical_filtered['patient_id'] == patient_id]
        if len(clinical_match) > 0:
            aligned_data.at[patient_idx, 'clinical_embedding'] = clinical_match['embeddings'].values[0]
        else:
            aligned_data.at[patient_idx, 'clinical_embedding'] = np.zeros(clinical_embedding_shape)
        
        # Add pathology embeddings
        pathology_match = pathology_filtered[pathology_filtered['patient_id'] == patient_id]
        if len(pathology_match) > 0:
            aligned_data.at[patient_idx, 'pathology_embedding'] = pathology_match['embeddings'].values[0]
        else:
            aligned_data.at[patient_idx, 'pathology_embedding'] = np.zeros(pathology_embedding_shape)
        
        # Add radiology embeddings
        radiology_match = radiology_filtered[radiology_filtered['patient_id'] == patient_id]
        if len(radiology_match) > 0:
            aligned_data.at[patient_idx, 'radiology_embedding'] = radiology_match['embeddings'].values[0]
        else:
            aligned_data.at[patient_idx, 'radiology_embedding'] = np.zeros(radiology_embedding_shape)
        
        # Add molecular embeddings
        molecular_match = molecular_filtered[molecular_filtered['patient_id'] == patient_id]
        if len(molecular_match) > 0:
            aligned_data.at[patient_idx, 'molecular_embedding'] = molecular_match['embeddings'].values[0]
        else:
            aligned_data.at[patient_idx, 'molecular_embedding'] = np.zeros(molecular_embedding_shape)
    
    return aligned_data

def create_multimodal_embeddings(aligned_data):
    """
    Create integrated multimodal embeddings by concatenating modality embeddings
    """
    print("Creating multimodal embeddings...")
    
    # Handle multi-dimensional embeddings by extracting and flattening if needed
    clinical_embeddings = []
    pathology_embeddings = []
    radiology_embeddings = []
    molecular_embeddings = []
    
    for _, row in aligned_data.iterrows():
        # Process clinical embedding
        emb = row['clinical_embedding']
        if emb is None:
            clinical_embeddings.append(np.zeros(1024))
        elif len(np.array(emb).shape) > 1:
            # If multi-dimensional, flatten or take mean
            emb_array = np.array(emb)
            if len(emb_array.shape) == 2:
                # Take mean along first dimension for 2D arrays
                clinical_embeddings.append(np.mean(emb_array, axis=0))
            else:
                # For higher dimensions, flatten to 1D
                clinical_embeddings.append(np.array(emb).flatten())
        else:
            clinical_embeddings.append(np.array(emb))
        
        # Process pathology embedding
        emb = row['pathology_embedding']
        if emb is None:
            pathology_embeddings.append(np.zeros(1024))
        elif len(np.array(emb).shape) > 1:
            emb_array = np.array(emb)
            if len(emb_array.shape) == 2:
                pathology_embeddings.append(np.mean(emb_array, axis=0))
            else:
                pathology_embeddings.append(np.array(emb).flatten())
        else:
            pathology_embeddings.append(np.array(emb))
        
        # Process radiology embedding
        emb = row['radiology_embedding']
        if emb is None:
            radiology_embeddings.append(np.zeros(1000))
        elif len(np.array(emb).shape) > 1:
            emb_array = np.array(emb)
            if len(emb_array.shape) == 2:
                radiology_embeddings.append(np.mean(emb_array, axis=0))
            else:
                radiology_embeddings.append(np.array(emb).flatten())
        else:
            radiology_embeddings.append(np.array(emb))
        
        # Process molecular embedding
        emb = row['molecular_embedding']
        if emb is None:
            molecular_embeddings.append(np.zeros(48))
        elif len(np.array(emb).shape) > 1:
            emb_array = np.array(emb)
            if len(emb_array.shape) == 2:
                molecular_embeddings.append(np.mean(emb_array, axis=0))
            else:
                molecular_embeddings.append(np.array(emb).flatten())
        else:
            molecular_embeddings.append(np.array(emb))
    
    # Convert lists to numpy arrays
    clinical_embeddings = np.array(clinical_embeddings)
    pathology_embeddings = np.array(pathology_embeddings)
    radiology_embeddings = np.array(radiology_embeddings)
    molecular_embeddings = np.array(molecular_embeddings)
    
    # Print shapes for debugging
    print(f"Clinical embeddings shape: {clinical_embeddings.shape}")
    print(f"Pathology embeddings shape: {pathology_embeddings.shape}")
    print(f"Radiology embeddings shape: {radiology_embeddings.shape}")
    print(f"Molecular embeddings shape: {molecular_embeddings.shape}")
    
    # Normalize each modality separately
    clinical_scaled = StandardScaler().fit_transform(clinical_embeddings)
    pathology_scaled = StandardScaler().fit_transform(pathology_embeddings)
    radiology_scaled = StandardScaler().fit_transform(radiology_embeddings)
    molecular_scaled = StandardScaler().fit_transform(molecular_embeddings)
    
    # Create multimodal embeddings by concatenation
    multimodal_embeddings = np.hstack([
        clinical_scaled, 
        pathology_scaled, 
        radiology_scaled, 
        molecular_scaled
    ])
    
    print(f"Multimodal embeddings shape: {multimodal_embeddings.shape}")
    
    return multimodal_embeddings

import numpy as np
from sklearn import metrics
from sklearn.cluster import KMeans, DBSCAN
from sklearn.manifold import TSNE
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial import distance
import os

# Add this function to the existing code
def evaluate_embedding_quality(aligned_data, umap_embeddings_dict, output_dir):
    """
    Quantitatively evaluate the quality of different embeddings (unimodal vs multimodal)
    based on clustering metrics.
    
    Args:
        aligned_data: DataFrame with patient data and cancer type labels
        umap_embeddings_dict: Dictionary containing UMAP embeddings for each modality
        output_dir: Directory to save the evaluation results
    """
    print("\nEvaluating embedding quality for each modality...")
    
    # Get true labels (cancer types)
    true_labels = aligned_data['cancer_type'].values
    unique_labels = aligned_data['cancer_type'].unique()
    num_clusters = len(unique_labels)
    
    # Create numeric labels for clustering metrics
    label_map = {label: i for i, label in enumerate(unique_labels)}
    numeric_labels = np.array([label_map[label] for label in true_labels])
    
    # Initialize dictionary to store metrics
    metrics_df = pd.DataFrame()
    
    # Initialize clustering algorithms
    kmeans = KMeans(n_clusters=num_clusters, random_state=42, n_init=10)
    dbscan = DBSCAN(eps=0.5, min_samples=5)
    
    # Evaluate each modality
    for modality, embeddings in umap_embeddings_dict.items():
        print(f"Evaluating {modality} embeddings...")
        
        # Skip if embeddings are empty
        if len(embeddings) == 0:
            print(f"  Skipping {modality} - no embeddings available")
            continue
            
        # 1. Silhouette Score - measures how well samples are clustered with samples of the same class
        try:
            silhouette_avg = metrics.silhouette_score(embeddings, numeric_labels)
        except:
            silhouette_avg = float('nan')
            print(f"  Could not compute silhouette score for {modality}")
        
        # 2. Davies-Bouldin Index - lower values indicate better separation
        try:
            davies_bouldin = metrics.davies_bouldin_score(embeddings, numeric_labels)
        except:
            davies_bouldin = float('nan')
            print(f"  Could not compute Davies-Bouldin score for {modality}")
        
        # 3. Calinski-Harabasz Index - higher values indicate better defined clusters
        try:
            calinski_harabasz = metrics.calinski_harabasz_score(embeddings, numeric_labels)
        except:
            calinski_harabasz = float('nan')
            print(f"  Could not compute Calinski-Harabasz score for {modality}")
        
        # 4. Apply K-means clustering and compute metrics
        try:
            kmeans_labels = kmeans.fit_predict(embeddings)
            kmeans_ami = metrics.adjusted_mutual_info_score(numeric_labels, kmeans_labels)
            kmeans_ari = metrics.adjusted_rand_score(numeric_labels, kmeans_labels)
            kmeans_v_measure = metrics.v_measure_score(numeric_labels, kmeans_labels)
        except:
            kmeans_ami = kmeans_ari = kmeans_v_measure = float('nan')
            print(f"  Could not compute K-means metrics for {modality}")
        
        # 5. Apply DBSCAN clustering and compute metrics
        try:
            dbscan_labels = dbscan.fit_predict(embeddings)
            # Only compute metrics if DBSCAN found multiple clusters
            if len(np.unique(dbscan_labels)) > 1 and -1 not in dbscan_labels:
                dbscan_ami = metrics.adjusted_mutual_info_score(numeric_labels, dbscan_labels)
                dbscan_ari = metrics.adjusted_rand_score(numeric_labels, dbscan_labels)
                dbscan_v_measure = metrics.v_measure_score(numeric_labels, dbscan_labels)
            else:
                dbscan_ami = dbscan_ari = dbscan_v_measure = float('nan')
                print(f"  DBSCAN did not find valid clusters for {modality}")
        except:
            dbscan_ami = dbscan_ari = dbscan_v_measure = float('nan')
            print(f"  Could not compute DBSCAN metrics for {modality}")
        
        # 6. Compute average distance between cancer types
        try:
            cancer_centroids = {}
            for cancer in unique_labels:
                mask = aligned_data['cancer_type'] == cancer
                if np.sum(mask) > 0:
                    cancer_centroids[cancer] = np.mean(embeddings[mask], axis=0)
            
            # Compute pairwise distances between centroids
            if len(cancer_centroids) > 1:
                centroid_vectors = np.array(list(cancer_centroids.values()))
                pairwise_distances = distance.pdist(centroid_vectors, 'euclidean')
                avg_intercluster_distance = np.mean(pairwise_distances)
            else:
                avg_intercluster_distance = float('nan')
        except:
            avg_intercluster_distance = float('nan')
            print(f"  Could not compute intercluster distances for {modality}")
        
        # Add metrics to dataframe
        metrics_df = pd.concat([metrics_df, pd.DataFrame({
            'Modality': [modality],
            'Silhouette Score': [silhouette_avg],
            'Davies-Bouldin Index': [davies_bouldin],
            'Calinski-Harabasz Index': [calinski_harabasz],
            'K-means AMI': [kmeans_ami],
            'K-means ARI': [kmeans_ari],
            'K-means V-measure': [kmeans_v_measure],
            'DBSCAN AMI': [dbscan_ami],
            'DBSCAN ARI': [dbscan_ari],
            'DBSCAN V-measure': [dbscan_v_measure],
            'Avg Intercluster Distance': [avg_intercluster_distance]
        })], ignore_index=True)
    
    # Save metrics to CSV
    metrics_df.to_csv(os.path.join(output_dir, 'embedding_quality_metrics.csv'), index=False)
    
    # Create a formatted table for display
    with open(os.path.join(output_dir, 'embedding_quality_report.txt'), 'w') as f:
        f.write("Embedding Quality Evaluation\n")
        f.write("===========================\n\n")
        f.write(f"Number of cancer types: {num_clusters}\n")
        f.write(f"Total samples: {len(aligned_data)}\n\n")
        f.write("Metrics (higher is better except for Davies-Bouldin Index):\n\n")
        f.write(metrics_df.to_string(index=False))
    
    print(f"Evaluation results saved to {output_dir}")
    
    # Create visualization of metrics
    plot_embedding_metrics(metrics_df, output_dir)
    
    return metrics_df

def plot_embedding_metrics(metrics_df, output_dir):
    """
    Create visualizations of the embedding quality metrics for comparison
    """
    # Create a multi-panel figure for metrics comparison
    plt.figure(figsize=(18, 15))
    
    metrics_to_plot = [
        ('Silhouette Score', 'Higher is better'),
        ('Davies-Bouldin Index', 'Lower is better'),
        ('Calinski-Harabasz Index', 'Higher is better'),
        ('K-means AMI', 'Higher is better'),
        ('K-means ARI', 'Higher is better'),
        ('K-means V-measure', 'Higher is better'),
    ]
    
    for i, (metric, subtitle) in enumerate(metrics_to_plot):
        plt.subplot(3, 2, i+1)
        
        # Sort by metric value for better visualization
        df_sorted = metrics_df.sort_values(by=metric)
        
        # Create barplot
        ax = sns.barplot(x='Modality', y=metric, data=df_sorted, palette='viridis')
        
        # Add value labels on top of bars
        for j, p in enumerate(ax.patches):
            ax.annotate(f"{p.get_height():.3f}", 
                        (p.get_x() + p.get_width() / 2., p.get_height()),
                        ha='center', va='bottom',
                        fontsize=10, color='black',
                        xytext=(0, 5),
                        textcoords='offset points')
        
        plt.title(f"{metric}\n{subtitle}")
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
    
    plt.suptitle("Embedding Quality Metrics Comparison", fontsize=16, y=1.02)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'embedding_metrics_comparison.png'), bbox_inches='tight', dpi=300)
    plt.close()
    
    # Create a radar chart for multi-metric comparison
    create_radar_chart(metrics_df, output_dir)

def create_radar_chart(metrics_df, output_dir):
    """
    Create a radar chart to compare multiple metrics across different modalities
    """
    # Select metrics for radar chart
    metrics_for_radar = [
        'Silhouette Score', 
        'K-means AMI', 
        'K-means ARI', 
        'K-means V-measure',
        'Avg Intercluster Distance'
    ]
    
    # Filter and clean data
    radar_df = metrics_df[['Modality'] + metrics_for_radar].copy()
    
    # Replace NaN with 0 for visualization
    radar_df = radar_df.fillna(0)
    
    # Normalize metrics to 0-1 scale for comparison
    for metric in metrics_for_radar:
        if metric != 'Davies-Bouldin Index':  # For metrics where higher is better
            if radar_df[metric].max() > 0:
                radar_df[metric] = radar_df[metric] / radar_df[metric].max()
        else:  # For Davies-Bouldin where lower is better
            if radar_df[metric].max() > 0:
                radar_df[metric] = 1 - (radar_df[metric] / radar_df[metric].max())
    
    # Set up radar chart
    n_metrics = len(metrics_for_radar)
    angles = np.linspace(0, 2*np.pi, n_metrics, endpoint=False).tolist()
    angles += angles[:1]  # Close the loop
    
    fig, ax = plt.subplots(figsize=(12, 10), subplot_kw=dict(polar=True))
    
    # Add lines for each modality
    for i, modality in enumerate(radar_df['Modality']):
        values = radar_df.loc[radar_df['Modality'] == modality, metrics_for_radar].values.flatten().tolist()
        values += values[:1]  # Close the loop
        
        # Plot values
        ax.plot(angles, values, linewidth=2, label=modality)
        ax.fill(angles, values, alpha=0.1)
    
    # Set labels and title
    ax.set_theta_offset(np.pi / 2)
    ax.set_theta_direction(-1)
    ax.set_thetagrids(np.degrees(angles[:-1]), metrics_for_radar)
    
    # Add legend
    ax.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1))
    
    plt.title('Multimodal vs. Unimodal Embedding Quality Comparison', size=15, y=1.1)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'embedding_radar_comparison.png'), bbox_inches='tight', dpi=300)
    plt.close()

def compute_cluster_separability(aligned_data, umap_embeddings_dict, output_dir):
    """
    Compute separability of cancer type clusters for each modality
    """
    print("\nComputing cluster separability metrics...")
    
    # Get unique cancer types
    cancer_types = aligned_data['cancer_type'].unique()
    
    # Initialize results dataframe
    separability_df = pd.DataFrame()
    
    # Compute metrics for each modality
    for modality, embeddings in umap_embeddings_dict.items():
        print(f"Analyzing {modality} embeddings...")
        
        # Skip if embeddings are empty
        if len(embeddings) == 0:
            print(f"  Skipping {modality} - no embeddings available")
            continue
        
        # Compute cancer type centroids
        centroids = {}
        for cancer in cancer_types:
            mask = aligned_data['cancer_type'] == cancer
            if np.sum(mask) > 0:
                centroids[cancer] = np.mean(embeddings[mask], axis=0)
        
        # Compute within-cluster scatter
        within_scatter = 0
        for cancer in cancer_types:
            mask = aligned_data['cancer_type'] == cancer
            if np.sum(mask) > 0:
                dist_to_centroid = np.mean(
                    np.sqrt(np.sum((embeddings[mask] - centroids[cancer])**2, axis=1))
                )
                within_scatter += dist_to_centroid
        
        # Average within-cluster scatter
        avg_within_scatter = within_scatter / len(cancer_types)
        
        # Compute between-cluster distances
        between_distances = []
        for i, cancer1 in enumerate(cancer_types):
            if cancer1 not in centroids:
                continue
            for j, cancer2 in enumerate(cancer_types[i+1:]):
                if cancer2 not in centroids:
                    continue
                dist = np.sqrt(np.sum((centroids[cancer1] - centroids[cancer2])**2))
                between_distances.append(dist)
        
        # Average between-cluster distance
        avg_between_distance = np.mean(between_distances) if between_distances else 0
        
        # Compute separation ratio (higher is better)
        separation_ratio = avg_between_distance / avg_within_scatter if avg_within_scatter > 0 else 0
        
        # Compute silhouette coefficients for each cancer type
        cancer_silhouettes = {}
        for cancer in cancer_types:
            mask = aligned_data['cancer_type'] == cancer
            if np.sum(mask) >= 3:  # Need at least 3 samples for meaningful silhouette
                try:
                    cancer_sil = metrics.silhouette_score(embeddings[mask], np.zeros(np.sum(mask)))
                    cancer_silhouettes[cancer] = cancer_sil
                except:
                    cancer_silhouettes[cancer] = float('nan')
        
        # Add to results dataframe
        separability_df = pd.concat([separability_df, pd.DataFrame({
            'Modality': [modality],
            'Avg Within-Cluster Scatter': [avg_within_scatter],
            'Avg Between-Cluster Distance': [avg_between_distance],
            'Separation Ratio': [separation_ratio],
        })], ignore_index=True)
    
    # Save results
    separability_df.to_csv(os.path.join(output_dir, 'cluster_separability_metrics.csv'), index=False)
    
    # Create visualizations
    plt.figure(figsize=(12, 6))
    
    # Plot separation ratio
    plt.subplot(1, 2, 1)
    df_sorted = separability_df.sort_values(by='Separation Ratio', ascending=False)
    sns.barplot(x='Modality', y='Separation Ratio', data=df_sorted, palette='viridis')
    plt.title('Cluster Separation Ratio\n(Higher is Better)')
    plt.xticks(rotation=45, ha='right')
    
    # Plot scatter vs distance
    plt.subplot(1, 2, 2)
    for i, row in separability_df.iterrows():
        plt.scatter(
            row['Avg Within-Cluster Scatter'], 
            row['Avg Between-Cluster Distance'],
            s=100, 
            label=row['Modality']
        )
        plt.text(
            row['Avg Within-Cluster Scatter'] + 0.01, 
            row['Avg Between-Cluster Distance'] + 0.01,
            row['Modality']
        )
    
    # Add diagonal line where ratio = 1
    max_val = max(separability_df['Avg Within-Cluster Scatter'].max(), 
                  separability_df['Avg Between-Cluster Distance'].max()) * 1.1
    plt.plot([0, max_val], [0, max_val], 'k--', alpha=0.5)
    
    plt.xlabel('Average Within-Cluster Scatter (Lower is Better)')
    plt.ylabel('Average Between-Cluster Distance (Higher is Better)')
    plt.title('Cluster Separation Analysis')
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'cluster_separability.png'), bbox_inches='tight', dpi=300)
    plt.close()
    
    return separability_df

# This function will perform more advanced clustering analysis
def perform_clustering_analysis(aligned_data, umap_embeddings_dict, output_dir):
    """
    Perform clustering analysis to compare how well different embedding spaces 
    correlate with cancer type clusters
    """
    print("\nPerforming clustering analysis...")
    
    # Get cancer types and create color mapping
    cancer_types = aligned_data['cancer_type'].unique()
    colors = sns.color_palette("tab10", len(cancer_types))
    color_map = {cancer: color for cancer, color in zip(cancer_types, colors)}
    
    # Initialize result metrics
    cluster_metrics = []
    
    # Analyze each modality
    for modality, embeddings in umap_embeddings_dict.items():
        print(f"Analyzing {modality} embeddings...")
        
        # Skip if embeddings are empty
        if len(embeddings) == 0:
            print(f"  Skipping {modality} - no embeddings available")
            continue
        
        # Try different cluster counts
        k_values = [len(cancer_types), len(cancer_types)+1, len(cancer_types)+2]
        
        for k in k_values:
            # Apply K-means clustering
            kmeans = KMeans(n_clusters=k, random_state=42, n_init=10)
            cluster_labels = kmeans.fit_predict(embeddings)
            
            # Calculate metrics comparing clusters to cancer types
            true_labels = aligned_data['cancer_type'].values
            
            # Convert true labels to numeric
            label_map = {label: i for i, label in enumerate(cancer_types)}
            numeric_true_labels = np.array([label_map[label] for label in true_labels])
            
            # Calculate agreement metrics
            ami = metrics.adjusted_mutual_info_score(true_labels, cluster_labels)
            ari = metrics.adjusted_rand_score(true_labels, cluster_labels)
            v_measure = metrics.v_measure_score(true_labels, cluster_labels)
            
            # Store metrics
            cluster_metrics.append({
                'Modality': modality,
                'K': k,
                'AMI': ami,
                'ARI': ari,
                'V-measure': v_measure
            })
            
            # Create visualization of clusters vs cancer types
            plt.figure(figsize=(12, 10))
            
            # First subplot: clusters
            plt.subplot(2, 1, 1)
            for i in range(k):
                mask = cluster_labels == i
                if np.sum(mask) > 0:
                    plt.scatter(
                        embeddings[mask, 0],
                        embeddings[mask, 1],
                        label=f'Cluster {i}',
                        alpha=0.7,
                        s=50,
                        edgecolor='none'
                    )
            
            plt.title(f'{modality} - K-means Clusters (k={k})')
            plt.xlabel('UMAP 1')
            plt.ylabel('UMAP 2')
            plt.legend(title="Clusters", bbox_to_anchor=(1.05, 1), loc='upper left')
            
            # Second subplot: actual cancer types
            plt.subplot(2, 1, 2)
            for cancer_type in cancer_types:
                mask = aligned_data['cancer_type'] == cancer_type
                if np.sum(mask) > 0:
                    plt.scatter(
                        embeddings[mask, 0],
                        embeddings[mask, 1],
                        label=cancer_type,
                        color=color_map[cancer_type],
                        alpha=0.7,
                        s=50,
                        edgecolor='none'
                    )
            
            plt.title(f'{modality} - Actual Cancer Types')
            plt.xlabel('UMAP 1')
            plt.ylabel('UMAP 2')
            plt.legend(title="Cancer Types", bbox_to_anchor=(1.05, 1), loc='upper left')
            
            plt.tight_layout()
            plt.savefig(
                os.path.join(output_dir, f'{modality}_clusters_k{k}.png'),
                bbox_inches='tight', 
                dpi=300
            )
            plt.close()
    
    # Convert metrics to DataFrame
    cluster_metrics_df = pd.DataFrame(cluster_metrics)
    
    # Save metrics to CSV
    cluster_metrics_df.to_csv(os.path.join(output_dir, 'clustering_analysis_metrics.csv'), index=False)
    
    # Create comparison visualization
    plt.figure(figsize=(15, 8))
    
    # Plot AMI for each modality and k
    plt.subplot(1, 3, 1)
    pivot_ami = cluster_metrics_df.pivot(index='Modality', columns='K', values='AMI')
    sns.heatmap(pivot_ami, annot=True, cmap='viridis', fmt='.3f')
    plt.title('Adjusted Mutual Information (Higher is Better)')
    
    # Plot ARI for each modality and k
    plt.subplot(1, 3, 2)
    pivot_ari = cluster_metrics_df.pivot(index='Modality', columns='K', values='ARI')
    sns.heatmap(pivot_ari, annot=True, cmap='viridis', fmt='.3f')
    plt.title('Adjusted Rand Index (Higher is Better)')
    
    # Plot V-measure for each modality and k
    plt.subplot(1, 3, 3)
    pivot_v = cluster_metrics_df.pivot(index='Modality', columns='K', values='V-measure')
    sns.heatmap(pivot_v, annot=True, cmap='viridis', fmt='.3f')
    plt.title('V-measure (Higher is Better)')
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'clustering_metrics_comparison.png'), bbox_inches='tight', dpi=300)
    plt.close()
    
    return cluster_metrics_df

In [22]:
# Load data from all modalities
clinical_df, pathology_df, radiology_df, molecular_df = load_multimodal_data()

Loading clinical data...
Loaded 11428 total clinical samples
After filtering for ['TCGA-KIRC', 'TCGA-OV', 'TCGA-BRCA'], found 2243 clinical samples
After removing null embeddings, 2243 clinical samples remain
Using embedding_shape for clinical embeddings


Processing clinical embeddings:   0%|          | 0/2243 [00:00<?, ?it/s]

Found 2243 unique patients in clinical data
Loading pathology report data...
Loaded 11208 total pathology samples
After filtering by patient IDs, found 2250 pathology samples out of 11208
After removing null embeddings, 2250 pathology samples remain
Using embedding_shape for pathology embeddings


Processing pathology embeddings:   0%|          | 0/2250 [00:00<?, ?it/s]

Loading radiology data...
Loaded 11870 total radiology samples
After filtering by patient IDs, found 5532 radiology samples out of 11870
After removing null embeddings, 5257 radiology samples remain
Using embedding_shape for radiology embeddings


Processing radiology embeddings:   0%|          | 0/5257 [00:00<?, ?it/s]

Loading molecular data...
Loaded 13804 total molecular samples
Molecular data columns: ['PatientID', 'SampleID', 'Embeddings']
After filtering by patient IDs, found 2970 molecular samples out of 13804
After removing null embeddings, 2970 molecular samples remain
No embedding_shape for molecular embeddings, using raw buffers


Processing molecular embeddings:   0%|          | 0/2970 [00:00<?, ?it/s]


Final summary:
Clinical data: 2243 samples
Pathology data: 2250 samples
Radiology data: 5257 samples
Molecular data: 2970 samples


In [23]:
# Align patient data across modalities
aligned_data = align_patient_data(clinical_df, pathology_df, radiology_df, molecular_df)

Aligning patient data across modalities...
Clinical patients: 2243
Pathology patients: 2225
Radiology patients: 548
Molecular patients: 2192
Patients in clinical and pathology: 2225
Patients in clinical and radiology: 548
Patients in clinical and molecular: 2192
Found 547 patients with data in all four modalities
Using 547 patients with data in all modalities

Cancer type distribution for selected patients:
cancer_type
TCGA-KIRC    267
TCGA-OV      141
TCGA-BRCA    139
Name: count, dtype: int64


Aligning embeddings:   0%|          | 0/547 [00:00<?, ?it/s]

In [24]:
# Create multimodal embeddings
multimodal_embeddings = create_multimodal_embeddings(aligned_data)

Creating multimodal embeddings...
Clinical embeddings shape: (547, 1024)
Pathology embeddings shape: (547, 1024)
Radiology embeddings shape: (547, 1000)
Molecular embeddings shape: (547, 48)
Multimodal embeddings shape: (547, 3096)


In [26]:
def visualize_multimodal_integration(aligned_data, multimodal_embeddings):
    """
    Visualize the integration of multimodal data using UMAP and perform 
    quantitative clustering analysis
    """
    print("Visualizing multimodal integration and performing quantitative analysis...")
    
    # Handle multi-dimensional embeddings by extracting and flattening if needed
    clinical_embeddings = []
    pathology_embeddings = []
    radiology_embeddings = []
    molecular_embeddings = []
    
    for _, row in aligned_data.iterrows():
        # Process clinical embedding
        emb = row['clinical_embedding']
        if emb is None:
            clinical_embeddings.append(np.zeros(1024))
        elif len(np.array(emb).shape) > 1:
            # If multi-dimensional, flatten or take mean
            emb_array = np.array(emb)
            if len(emb_array.shape) == 2:
                # Take mean along first dimension for 2D arrays
                clinical_embeddings.append(np.mean(emb_array, axis=0))
            else:
                # For higher dimensions, flatten to 1D
                clinical_embeddings.append(np.array(emb).flatten())
        else:
            clinical_embeddings.append(np.array(emb))
        
        # Process pathology embedding
        emb = row['pathology_embedding']
        if emb is None:
            pathology_embeddings.append(np.zeros(1024))
        elif len(np.array(emb).shape) > 1:
            emb_array = np.array(emb)
            if len(emb_array.shape) == 2:
                pathology_embeddings.append(np.mean(emb_array, axis=0))
            else:
                pathology_embeddings.append(np.array(emb).flatten())
        else:
            pathology_embeddings.append(np.array(emb))
        
        # Process radiology embedding
        emb = row['radiology_embedding']
        if emb is None:
            radiology_embeddings.append(np.zeros(1000))
        elif len(np.array(emb).shape) > 1:
            emb_array = np.array(emb)
            if len(emb_array.shape) == 2:
                radiology_embeddings.append(np.mean(emb_array, axis=0))
            else:
                radiology_embeddings.append(np.array(emb).flatten())
        else:
            radiology_embeddings.append(np.array(emb))
        
        # Process molecular embedding
        emb = row['molecular_embedding']
        if emb is None:
            molecular_embeddings.append(np.zeros(48))
        elif len(np.array(emb).shape) > 1:
            emb_array = np.array(emb)
            if len(emb_array.shape) == 2:
                molecular_embeddings.append(np.mean(emb_array, axis=0))
            else:
                molecular_embeddings.append(np.array(emb).flatten())
        else:
            molecular_embeddings.append(np.array(emb))
    
    # Convert lists to numpy arrays
    clinical_embeddings = np.array(clinical_embeddings)
    pathology_embeddings = np.array(pathology_embeddings)
    radiology_embeddings = np.array(radiology_embeddings)
    molecular_embeddings = np.array(molecular_embeddings)
    
    print(f"Clinical embeddings processed shape: {clinical_embeddings.shape}")
    print(f"Pathology embeddings processed shape: {pathology_embeddings.shape}")
    print(f"Radiology embeddings processed shape: {radiology_embeddings.shape}")
    print(f"Molecular embeddings processed shape: {molecular_embeddings.shape}")
    
    # Check for and replace NaN values in embeddings
    def check_and_fix_nans(arr, name):
        nan_count = np.isnan(arr).sum()
        if nan_count > 0:
            print(f"WARNING: Found {nan_count} NaN values in {name} embeddings. Replacing with zeros.")
            arr = np.nan_to_num(arr, nan=0.0)
        return arr
    
    clinical_embeddings = check_and_fix_nans(clinical_embeddings, "clinical")
    pathology_embeddings = check_and_fix_nans(pathology_embeddings, "pathology")
    radiology_embeddings = check_and_fix_nans(radiology_embeddings, "radiology")
    molecular_embeddings = check_and_fix_nans(molecular_embeddings, "molecular")
    multimodal_embeddings = check_and_fix_nans(multimodal_embeddings, "multimodal")
    
    # Apply UMAP to each modality separately - using correct UMAP import
    reducer = umap.UMAP(random_state=42)
    
    # Apply StandardScaler and safely handle UMAP transformation
    def safe_umap_transform(data, name):
        try:
            # Apply StandardScaler, ensuring no NaNs
            scaled_data = StandardScaler().fit_transform(data)
            scaled_data = np.nan_to_num(scaled_data, nan=0.0)
            
            # Apply UMAP
            transformed = reducer.fit_transform(scaled_data)
            return transformed
        except Exception as e:
            print(f"Error applying UMAP to {name} embeddings: {e}")
            print(f"Using random 2D coordinates for {name} embeddings instead")
            return np.random.rand(data.shape[0], 2) * 10
    
    clinical_umap = safe_umap_transform(clinical_embeddings, "clinical")
    pathology_umap = safe_umap_transform(pathology_embeddings, "pathology")
    radiology_umap = safe_umap_transform(radiology_embeddings, "radiology")
    molecular_umap = safe_umap_transform(molecular_embeddings, "molecular")
    multimodal_umap = safe_umap_transform(multimodal_embeddings, "multimodal")
    
    # Get unique cancer types
    cancer_types = aligned_data['cancer_type'].unique()
    print(f"Cancer types in visualization: {cancer_types}")
    
    # Limit to top 8 cancer types by frequency if needed
    if len(cancer_types) > 8:
        top_cancer_types = aligned_data['cancer_type'].value_counts().nlargest(8).index.tolist()
        print(f"Limiting visualization to top 8 cancer types: {top_cancer_types}")
        
        # Create plot_cancer_type for visualization
        aligned_data['plot_cancer_type'] = aligned_data['cancer_type'].apply(
            lambda x: x if x in top_cancer_types else 'Other'
        )
        plot_cancer_types = aligned_data['plot_cancer_type'].unique()
    else:
        aligned_data['plot_cancer_type'] = aligned_data['cancer_type']
        plot_cancer_types = cancer_types
    
    # Define color mapping
    colors = sns.color_palette("tab10", len(plot_cancer_types))
    color_map = {cancer: color for cancer, color in zip(plot_cancer_types, colors)}
    
    # Create a 3x2 grid with special layout
    fig = plt.figure(figsize=(24, 16), dpi=300)
    
    # Fix: Create grid spec for 2 rows and 3 columns
    gs = plt.GridSpec(2, 3, figure=fig)
    
    # Create axes for each subplot
    ax_clinical = fig.add_subplot(gs[0, 0])       # Row 1, Col 1
    ax_pathology = fig.add_subplot(gs[0, 1])      # Row 1, Col 2
    ax_molecular = fig.add_subplot(gs[1, 0])      # Row 2, Col 1
    ax_radiology = fig.add_subplot(gs[1, 1])      # Row 2, Col 2
    ax_multimodal = fig.add_subplot(gs[:, 2])     # Row 1-2, Col 3
    
    # Plot each modality with its respective subplot
    # Clinical
    for cancer_type in plot_cancer_types:
        mask = aligned_data['plot_cancer_type'] == cancer_type
        if sum(mask) > 0:
            ax_clinical.scatter(
                clinical_umap[mask, 0],
                clinical_umap[mask, 1],
                c=[color_map[cancer_type]],
                alpha=0.7,
                s=200,
                edgecolor='black'
            )
    
    ax_clinical.set_title('Clinical', fontsize=30, fontweight='bold')
    ax_clinical.set_xlabel('UMAP 1', fontsize=20)
    ax_clinical.set_ylabel('UMAP 2', fontsize=20)
    
    # Pathology Report
    for cancer_type in plot_cancer_types:
        mask = aligned_data['plot_cancer_type'] == cancer_type
        if sum(mask) > 0:
            ax_pathology.scatter(
                pathology_umap[mask, 0],
                pathology_umap[mask, 1],
                c=[color_map[cancer_type]],
                alpha=0.7,
                s=200,
                edgecolor='black'
            )
    
    ax_pathology.set_title('Pathology Report', fontsize=30, fontweight='bold')
    ax_pathology.set_xlabel('UMAP 1', fontsize=20)
    ax_pathology.set_ylabel('UMAP 2', fontsize=20)
    
    # Molecular
    for cancer_type in plot_cancer_types:
        mask = aligned_data['plot_cancer_type'] == cancer_type
        if sum(mask) > 0:
            ax_molecular.scatter(
                molecular_umap[mask, 0],
                molecular_umap[mask, 1],
                c=[color_map[cancer_type]],
                alpha=0.7,
                s=200,
                edgecolor='black'
            )
    
    ax_molecular.set_title('Molecular', fontsize=30, fontweight='bold')
    ax_molecular.set_xlabel('UMAP 1', fontsize=20)
    ax_molecular.set_ylabel('UMAP 2', fontsize=20)
    
    # Radiology
    for cancer_type in plot_cancer_types:
        mask = aligned_data['plot_cancer_type'] == cancer_type
        if sum(mask) > 0:
            ax_radiology.scatter(
                radiology_umap[mask, 0],
                radiology_umap[mask, 1],
                c=[color_map[cancer_type]],
                alpha=0.7,
                s=200,
                edgecolor='black'
            )
    
    ax_radiology.set_title('Radiology', fontsize=30, fontweight='bold')
    ax_radiology.set_xlabel('UMAP 1', fontsize=20)
    ax_radiology.set_ylabel('UMAP 2', fontsize=20)
    
    # Multimodal Integration (larger plot at the bottom)
    for cancer_type in plot_cancer_types:
        mask = aligned_data['plot_cancer_type'] == cancer_type
        if sum(mask) > 0:
            ax_multimodal.scatter(
                multimodal_umap[mask, 0],
                multimodal_umap[mask, 1],
                c=[color_map[cancer_type]],
                alpha=0.7,
                s=200,
                edgecolor='black',
                label=cancer_type
            )
    
    ax_multimodal.set_title('Multimodal Integration', fontsize=30, fontweight='bold')
    ax_multimodal.set_xlabel('UMAP 1', fontsize=18)
    ax_multimodal.set_ylabel('UMAP 2', fontsize=18)
    
    # Add legend to the multimodal plot
    ax_multimodal.legend(fontsize=16, title="Cancer Types", title_fontsize=18, loc='best')
    
    plt.tight_layout()
    plt.savefig(f"{OUTPUT_DIR}/multimodal_umap_visualization.pdf", bbox_inches='tight')
    plt.close()
    
    # Also save a text file with dataset information
    info_text = (
        f"Total patients: {len(aligned_data)}\n\n"
        f"Cancer type counts:\n"
        f"{aligned_data['cancer_type'].value_counts().to_string()}\n\n"
        f"Embedding dimensions:\n"
        f"Clinical: {clinical_embeddings.shape[1]}\n"
        f"Pathology: {pathology_embeddings.shape[1]}\n"
        f"Radiology: {radiology_embeddings.shape[1]}\n"
        f"Molecular: {molecular_embeddings.shape[1]}\n"
        f"Multimodal: {multimodal_embeddings.shape[1]}"
    )
    
    with open(f"{OUTPUT_DIR}/dataset_information.txt", "w") as f:
        f.write(info_text)
    
    print(f"Visualizations saved to {OUTPUT_DIR}")
    
    # ============ NEW CODE FOR QUANTITATIVE ANALYSIS =============
    # Collect UMAP embeddings for analysis
    umap_embeddings_dict = {
        'Clinical': clinical_umap,
        'Pathology': pathology_umap,
        'Radiology': radiology_umap,
        'Molecular': molecular_umap,
        'Multimodal': multimodal_umap
    }
    
    # Perform quantitative embedding quality evaluation
    metrics_df = evaluate_embedding_quality(aligned_data, umap_embeddings_dict, OUTPUT_DIR)
    
    # Compute cluster separability metrics
    separability_df = compute_cluster_separability(aligned_data, umap_embeddings_dict, OUTPUT_DIR)
    
    # Perform detailed clustering analysis
    clustering_df = perform_clustering_analysis(aligned_data, umap_embeddings_dict, OUTPUT_DIR)
    
    # Create a summary visualization
    plt.figure(figsize=(15, 12))
    
    # Compare best metrics across modalities
    metrics_to_plot = ['Silhouette Score', 'K-means AMI', 'Separation Ratio']
    
    for i, metric in enumerate(metrics_to_plot):
        plt.subplot(3, 1, i+1)
        
        if metric == 'Separation Ratio':
            # Use separability_df for this metric
            df_sorted = separability_df.sort_values(by=metric, ascending=False)
            ax = sns.barplot(x='Modality', y=metric, data=df_sorted, palette='viridis')
            
            # Add value labels on top of bars
            for j, p in enumerate(ax.patches):
                ax.annotate(f"{p.get_height():.3f}", 
                            (p.get_x() + p.get_width() / 2., p.get_height()),
                            ha='center', va='bottom',
                            fontsize=10, color='black',
                            xytext=(0, 5),
                            textcoords='offset points')
        else:
            # Use metrics_df for other metrics
            df_sorted = metrics_df.sort_values(by=metric, ascending=False)
            ax = sns.barplot(x='Modality', y=metric, data=df_sorted, palette='viridis')
            
            # Add value labels on top of bars
            for j, p in enumerate(ax.patches):
                ax.annotate(f"{p.get_height():.3f}", 
                            (p.get_x() + p.get_width() / 2., p.get_height()),
                            ha='center', va='bottom',
                            fontsize=10, color='black',
                            xytext=(0, 5),
                            textcoords='offset points')
        
        plt.title(f"{metric} Comparison", fontsize=16)
        plt.xticks(rotation=45, ha='right')
        plt.xlabel('Modality', fontsize=14)
        plt.ylabel(metric, fontsize=14)
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, 'summary_metrics_comparison.png'), bbox_inches='tight', dpi=300)
    plt.close()
    
    # Create an integrated metric - optional
    if all(metric in metrics_df.columns for metric in ['Silhouette Score', 'K-means AMI', 'K-means ARI']):
        # Normalize metrics to 0-1 scale
        metrics_to_normalize = ['Silhouette Score', 'K-means AMI', 'K-means ARI']
        normalized_df = metrics_df.copy()
        
        for col in metrics_to_normalize:
            if normalized_df[col].max() > 0:
                normalized_df[col] = normalized_df[col] / normalized_df[col].max()
        
        # Compute integrated score (mean of normalized metrics)
        normalized_df['Integrated Score'] = normalized_df[metrics_to_normalize].mean(axis=1)
        
        # Plot integrated score
        plt.figure(figsize=(10, 6))
        ax = sns.barplot(x='Modality', y='Integrated Score', 
                     data=normalized_df.sort_values('Integrated Score', ascending=False),
                     palette='viridis')
        
        # Add value labels
        for j, p in enumerate(ax.patches):
            ax.annotate(f"{p.get_height():.3f}", 
                        (p.get_x() + p.get_width() / 2., p.get_height()),
                        ha='center', va='bottom',
                        fontsize=12, color='black',
                        xytext=(0, 5),
                        textcoords='offset points')
        
        plt.title('Integrated Clustering Performance Score', fontsize=18)
        plt.xlabel('Modality', fontsize=14)
        plt.ylabel('Integrated Score (0-1)', fontsize=14)
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        plt.savefig(os.path.join(OUTPUT_DIR, 'integrated_score_comparison.png'), bbox_inches='tight', dpi=300)
        plt.close()
        
        # Save integrated metrics to CSV
        normalized_df.to_csv(os.path.join(OUTPUT_DIR, 'integrated_metrics.csv'), index=False)
    
    print(f"Quantitative analysis complete. Results saved to {OUTPUT_DIR}")
    
    return metrics_df, separability_df, clustering_df

# Visualize multimodal integration
visualize_multimodal_integration(aligned_data, multimodal_embeddings)

print(f"Multimodal integration analysis complete. Results saved to {OUTPUT_DIR}")

Visualizing multimodal integration and performing quantitative analysis...
Clinical embeddings processed shape: (547, 1024)
Pathology embeddings processed shape: (547, 1024)
Radiology embeddings processed shape: (547, 1000)
Molecular embeddings processed shape: (547, 48)


  warn(f"n_jobs value {self.n_jobs} overridden to 1 by setting random_state. Use no seed for parallelism.")


Cancer types in visualization: ['TCGA-OV' 'TCGA-BRCA' 'TCGA-KIRC']
Visualizations saved to multimodal_analysis_results

Evaluating embedding quality for each modality...
Evaluating Clinical embeddings...
Evaluating Pathology embeddings...
  DBSCAN did not find valid clusters for Pathology
Evaluating Radiology embeddings...
  DBSCAN did not find valid clusters for Radiology
Evaluating Molecular embeddings...
Evaluating Multimodal embeddings...
Evaluation results saved to multimodal_analysis_results



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  ax = sns.barplot(x='Modality', y=metric, data=df_sorted, palette='viridis')

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  ax = sns.barplot(x='Modality', y=metric, data=df_sorted, palette='viridis')

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  ax = sns.barplot(x='Modality', y=metric, data=df_sorted, palette='viridis')

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  ax = sns.barplot(x='Modality', y=metric, data=df_sorted, palette='viridis')

Passing `palette` w


Computing cluster separability metrics...
Analyzing Clinical embeddings...
Analyzing Pathology embeddings...
Analyzing Radiology embeddings...
Analyzing Molecular embeddings...
Analyzing Multimodal embeddings...



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(x='Modality', y='Separation Ratio', data=df_sorted, palette='viridis')



Performing clustering analysis...
Analyzing Clinical embeddings...
Analyzing Pathology embeddings...
Analyzing Radiology embeddings...
Analyzing Molecular embeddings...
Analyzing Multimodal embeddings...



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  ax = sns.barplot(x='Modality', y=metric, data=df_sorted, palette='viridis')

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  ax = sns.barplot(x='Modality', y=metric, data=df_sorted, palette='viridis')

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  ax = sns.barplot(x='Modality', y=metric, data=df_sorted, palette='viridis')

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  ax = sns.barplot(x='Modality', y='Integrated Score',


Quantitative analysis complete. Results saved to multimodal_analysis_results
Multimodal integration analysis complete. Results saved to multimodal_analysis_results
