In [None]:
# Survival analysis in Melanoma SKCM 
# Estef Vazquez


# Import required libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from statsmodels.stats.multitest import multipletests
import os
import re
from lifelines import KaplanMeierFitter, CoxPHFitter
from lifelines.statistics import logrank_test
from sklearn.preprocessing import StandardScaler
import warnings
warnings.filterwarnings('ignore')

# Config options
CONFIG = {
    "genes_of_interest": ["CD44", "CD68"],
    "expression_cutoff": 0.5,  # Median cutoff
    "data_dir": "data",
    "results_dir": "results",
    "figures_dir": "figures"
}

# Set style for plots
plt.style.use('seaborn-whitegrid')
sns.set_context("paper", font_scale=1.5)
sns.set_palette("colorblind")

# Create directories
os.makedirs("data/processed", exist_ok=True)
os.makedirs("results", exist_ok=True)
os.makedirs("figures", exist_ok=True)

print("Python session info:")
import sys
print(f"Python version: {sys.version}")
print("\nKey packages:")
import lifelines
print(f"lifelines: {lifelines.__version__}")
import sklearn
print(f"scikit-learn: {sklearn.__version__}")
import pandas
print(f"pandas: {pandas.__version__}")

In [None]:
# 2. Data Loading

def load_expression_data(file_path="SKCM_counts.csv"):
    """
    Load expression data from a file.
    
    Parameters:
    -----------
    file_path : str
        Path to the RNAseq count data file (CSV format)
        
    Returns:
    --------
    pandas.DataFrame
        Matrix of gene expression counts with genes as rows and samples as columns
    """
    print(f"Loading expression data from {file_path}...")
    
    try:
        # Load expression data - assumes genes are in rows and samples in columns
        # with the first column containing gene IDs
        expression_data = pd.read_csv(file_path, index_col=0)
        
        print(f"Expression data loaded with shape: {expression_data.shape}")
        return expression_data
        
    except FileNotFoundError:
        print(f"ERROR: Could not find file: {file_path}")
        print("Please ensure the expression data file exists in the specified location")
        print("Required format: CSV with genes in rows, samples in columns")
        raise

def get_clinical_data(file_path="clinical_data_complete_cbioportal.csv"):
    """
    Load clinical data from a file.
    
    Parameters:
    -----------
    file_path : str
        Path to the clinical data file (CSV format)
        
    Returns:
    --------
    pandas.DataFrame
        Clinical information for each patient/sample
    """
    print(f"Loading clinical data from {file_path}...")
    
    try:
        # Load clinical data
        clinical_data = pd.read_csv(file_path)
        
        # Verify required columns exist
        required_cols = ['PATIENT_ID', 'SAMPLE_ID', 'SAMPLE_TYPE', 'OS_STATUS', 'OS_MONTHS']
        missing_cols = [col for col in required_cols if col not in clinical_data.columns]
        
        if missing_cols:
            print(f"ERROR: Missing required columns in clinical data: {', '.join(missing_cols)}")
            print("Required columns: PATIENT_ID, SAMPLE_ID, SAMPLE_TYPE, OS_STATUS, OS_MONTHS")
            raise ValueError("Missing required columns in clinical data")
        
        print(f"Clinical data loaded with shape: {clinical_data.shape}")
        return clinical_data
        
    except FileNotFoundError:
        print(f"ERROR: Could not find file: {file_path}")
        print("Please ensure the clinical data file exists in the specified location")
        raise

def get_braf_mutations(file_path="data/BRAF_mutations_cbioportal.txt"):
    """
    Load BRAF mutation data from a file.
    
    Parameters:
    -----------
    file_path : str
        Path to the BRAF mutation data file (tab-delimited text)
        
    Returns:
    --------
    pandas.DataFrame
        BRAF mutation information for each sample
    """
    print(f"Loading BRAF mutation data from {file_path}...")
    
    try:
        # Load BRAF mutation data (tab-delimited file)
        braf_mutation = pd.read_csv(file_path, sep='\t')
        
        # Verify required columns exist
        required_cols = ['SAMPLE_ID', 'BRAF']
        missing_cols = [col for col in required_cols if col not in braf_mutation.columns]
        
        if missing_cols:
            print(f"ERROR: Missing required columns in BRAF data: {', '.join(missing_cols)}")
            print("Required columns: SAMPLE_ID, BRAF")
            raise ValueError("Missing required columns in BRAF data")
        
        print(f"BRAF mutation data loaded with shape: {braf_mutation.shape}")
        return braf_mutation
        
    except FileNotFoundError:
        print(f"ERROR: Could not find file: {file_path}")
        print("Please ensure the BRAF mutation data file exists in the specified location")
        raise


In [None]:
# 3. Clinical Data Processing

def process_clinical_data(data):
    """
    Process clinical data for analysis.
    
    Parameters:
    -----------
    data : pandas.DataFrame
        Raw clinical data
        
    Returns:
    --------
    pandas.DataFrame
        Processed clinical data with standardized columns and data types
    
    Notes:
    ------
    This function:
    1. Converts column names to lowercase
    2. Recodes survival status to 0/1 (0=alive, 1=deceased)
    3. Cleans tumor staging information
    4. Converts numeric fields to appropriate data types
    5. Converts categorical variables to proper categories
    """
    print("Processing clinical data...")
    
    # Create a copy 
    processed_data = data.copy()
    
    # Convert  to lowercase
    processed_data.columns = processed_data.columns.str.lower()
    
    # Convert survival status to 0/1
    processed_data['os_status_recode'] = processed_data['os_status'].str.lower().apply(
        lambda x: 1 if 'deceased' in x else 0 if 'living' in x or 'alive' in x else np.nan
    )
    
    # Clean stage info
    def clean_stage(stage_str):
        if pd.isna(stage_str):
            return np.nan
        
        stage_str = str(stage_str).lower()
        if 'stage 0' in stage_str:
            return '0'
        elif re.match(r'stage i[abc]?$', stage_str):
            return '1'
        elif re.match(r'stage ii[abc]?$', stage_str):
            return '2'
        elif re.match(r'stage iii[abc]?$', stage_str):
            return '3'
        elif 'stage iv' in stage_str:
            return '4'
        elif 'i/ii nos' in stage_str:
            return '2'
        else:
            return np.nan
    
    processed_data['stage'] = processed_data['ajcc_pathologic_tumor_stage'].apply(clean_stage)
    
    # Convert numeric 
    numeric_cols = ['os_months', 'age', 'breslow_depth']
    for col in numeric_cols:
        processed_data[col] = pd.to_numeric(processed_data[col], errors='coerce')
    
    # Convert to categorical
    cat_cols = ['sample_type', 'sex', 'stage']
    for col in cat_cols:
        processed_data[col] = processed_data[col].astype('category')
    
    print("Clinical data processing complete")
    return processed_data

In [None]:
# 4. Braf mutation processing

def classify_braf_mutations(braf_mutation):
    """
    Classify BRAF mutations into meaningful categories.
    
    Parameters:
    -----------
    braf_mutation : pandas.DataFrame
        Raw BRAF mutation data
        
    Returns:
    --------
    pandas.DataFrame
        BRAF mutation data with additional classification columns
    
    Notes:
    ------
    Creates two classification schemes:
    1. mutation_status_class1: V600E specific (1=V600E, 0=other)
    2. mutation_status_class2: Any mutation vs Wild Type (1=any mutation, 0=wild type)
    """
    # Classification 1 - BRAF V600E only
    braf_class1 = braf_mutation.copy()
    braf_class1 = braf_class1.rename(columns={'SAMPLE_ID': 'sample_id', 'BRAF': 'BRAF_mutation'})
    
    # Create mutation class 1 (V600E specific)
    braf_class1['mutation_status_class1'] = braf_class1['BRAF_mutation'].apply(
        lambda x: 1 if 'V600E' in x else 0
    )
    
    # More detailed classification
    def classify_mutation_type(mutation):
        if mutation == 'WT':
            return 'Wild_Type'
        elif 'V600E' in mutation and ' ' not in mutation:
            return 'V600E_Single'
        elif ('V600E' in mutation and 'V600' in mutation) or ('V600' in mutation and 'V600E' in mutation):
            return 'V600E_Double'
        elif 'V600' in mutation:
            return 'Other_V600'
        elif any(m in mutation for m in ['G469', 'K601', 'L597']):
            return 'Class_2'
        elif any(m in mutation for m in ['G466', 'N581']):
            return 'Other_Activating'
        else:
            return 'Other'
    
    braf_class1['mutation_type'] = braf_class1['BRAF_mutation'].apply(classify_mutation_type)
    
    # Classify multiple mutations
    def count_mutations(mutation):
        if mutation == 'WT':
            return 'None'
        
        spaces = mutation.count(' ')
        if spaces == 0:
            return 'Single'
        elif spaces == 1:
            return 'Double'
        elif spaces == 2:
            return 'Triple'
        else:
            return 'Multiple'
    
    braf_class1['multiple_mutations'] = braf_class1['BRAF_mutation'].apply(count_mutations)
    braf_class1['mutation_status_class1'] = braf_class1['mutation_status_class1'].astype('category')
    
    # Classification 2 - all mutations vs wild type
    braf_class2 = braf_mutation.copy()
    braf_class2 = braf_class2.rename(columns={'SAMPLE_ID': 'sample_id', 'BRAF': 'BRAF_mutation'})
    
    braf_class2['mutation_status_class2'] = braf_class2['BRAF_mutation'].apply(
        lambda x: 0 if x == 'WT' else 1
    )
    braf_class2['mutation_status_class2'] = braf_class2['mutation_status_class2'].astype('category')
    
    # Combine classifications
    combined_braf = braf_class1[['sample_id', 'BRAF_mutation', 'mutation_status_class1']]
    combined_braf = combined_braf.merge(
        braf_class2[['sample_id', 'mutation_status_class2']], 
        on='sample_id'
    )
    
    print("BRAF classification complete")
    return combined_braf

def integrate_braf_with_clinical(clinical_data, combined_braf):
    """
    Merge clinical and BRAF data into a unified dataset.
    
    Parameters:
    -----------
    clinical_data : pandas.DataFrame
        Processed clinical data
    combined_braf : pandas.DataFrame
        Processed BRAF mutation data
        
    Returns:
    --------
    pandas.DataFrame
        Integrated dataset with clinical and mutation information
    """
    # Merge clinical and BRAF data
    clinical_data_with_braf = clinical_data.merge(combined_braf, on='sample_id', how='left')
    
    # Convert variables 
    clinical_data_with_braf['os_status_recode'] = clinical_data_with_braf['os_status_recode'].astype(float)
    clinical_data_with_braf['age'] = clinical_data_with_braf['age'].astype(float)
    clinical_data_with_braf['breslow_depth'] = clinical_data_with_braf['breslow_depth'].astype(float)
    
    # Convert stage to numeric
    clinical_data_with_braf['stage'] = pd.to_numeric(clinical_data_with_braf['stage'], errors='coerce')
    
    print("BRAF and clinical data integrated")
    return clinical_data_with_braf 


In [None]:
# 5: Processing Expression Data

def process_expression_data(expression_data):
    """
    Process RNA-seq counts to normalized expression values.
    
    Parameters:
    -----------
    expression_data : pandas.DataFrame
        Raw RNA-seq count data
        
    Returns:
    --------
    pandas.DataFrame
        Normalized expression values (log2 transformed with size factors)
    
    Notes:
    ------
    This simulates a DESeq2-like VST transformation using log2 transformation
    with a pseudocount after size factor normalization.
    """
    print("Processing expression data...")
    
    # Filter low counts (at least 10 counts in at least 10% of samples)
    min_samples = max(10, int(0.1 * expression_data.shape[1]))
    keep = (expression_data >= 10).sum(axis=1) >= min_samples
    filtered_data = expression_data.loc[keep]
    
    # Normalization: size factors calculation (column sums / geometric mean)
    size_factors = filtered_data.sum() / np.exp(np.mean(np.log(filtered_data.sum())))
    
    # Apply size factors and log2 transform (simplified VST)
    normalized = filtered_data.divide(size_factors, axis=1)
    vst_matrix = np.log2(normalized + 1)  # Add pseudocount of 1
    
    print(f"Expression data processed: {vst_matrix.shape[0]} genes retained")
    return vst_matrix

In [None]:
# 6. Data matching and integration

def extract_patient_info(barcode):
    """
    Extract patient ID and sample type from TCGA barcode.
    
    Parameters:
    -----------
    barcode : str
        TCGA barcode (e.g., 'TCGA-01-1234-01')
        
    Returns:
    --------
    dict
        Dictionary with patient_id and sample_type
    """
    parts = barcode.split('-')
    patient_id = '-'.join(parts[:3])
    sample_type = parts[3][:2]
    return {'patient_id': patient_id, 'sample_type': sample_type}

def match_clinical_expression(expression_data, clinical_data):
    """
    Match expression and clinical data by sample IDs.
    
    Parameters:
    -----------
    expression_data : pandas.DataFrame
        Processed expression data
    clinical_data : pandas.DataFrame
        Processed clinical data
        
    Returns:
    --------
    dict
        Dictionary with matched expression data, clinical data, and count
    """
    print("Matching clinical and expression data...")
    
    # Process barcodes
    sample_info = [extract_patient_info(x) for x in expression_data.columns]
    new_colnames = [f"{x['patient_id']}-{x['sample_type']}" for x in sample_info]
    
    # Create new expression matrix with updated column names
    expression_new = expression_data.copy()
    expression_new.columns = new_colnames
    
    # Find matching samples
    matching_ids = list(set(new_colnames).intersection(set(clinical_data['sample_id'])))
    
    # Create matched datasets
    expression_matched = expression_new[matching_ids]
    clinical_matched = clinical_data[clinical_data['sample_id'].isin(matching_ids)].copy()
    
    print(f"Matched data: {len(matching_ids)} samples")
    return {
        'expression': expression_matched,
        'clinical': clinical_matched,
        'n_matched': len(matching_ids)
    }

In [None]:
# 7. Gene Specific Processing
def process_gene_expression(expression_data, genes_of_interest):
    """
    Process expression data to extract genes of interest.
    
    Parameters:
    -----------
    expression_data : pandas.DataFrame
        Processed expression data
    genes_of_interest : list
        List of gene symbols to extract
        
    Returns:
    --------
    pandas.DataFrame
        Expression data for genes of interest only
    """
    print(f"Processing expression for genes of interest: {', '.join(genes_of_interest)}")
    
    try:
        #  1: If expression data has gene symbols as index
        if expression_data.index.name == 'hgnc_symbol' or any(gene in expression_data.index for gene in genes_of_interest):
            gene_subset = expression_data.loc[expression_data.index.isin(genes_of_interest)]
            if len(gene_subset) > 0:
                print(f"Found {len(gene_subset)} out of {len(genes_of_interest)} genes of interest")
                return gene_subset
        
        #  2: If expression data has Ensembl IDs and needs mapping
        # use biomaRt or gene ID mapping file
        import pandas as pd
        
        # gene mapping file
        mapping_file = os.path.join(CONFIG['data_dir'], "gene_id_mapping.csv")
        
        try:
            #  load mapping file
            gene_mapping = pd.read_csv(mapping_file)
            print(f"Loaded gene mapping with {len(gene_mapping)} entries")
        except FileNotFoundError:
            print(f"WARNING: Gene mapping file {mapping_file} not found")
            print("Will attempt to find genes directly in the expression data")
            
            # Check if genes are in column names (transposed)
            if any(gene in expression_data.columns for gene in genes_of_interest):
                print("Found genes in column names, assuming transposed expression data")
                return expression_data[expression_data.columns.intersection(genes_of_interest)]
            
            # look for genes in the index
            found_genes = []
            for gene in genes_of_interest:
                # Look for  matches or 
                matches = [idx for idx in expression_data.index if gene in str(idx)]
                if matches:
                    found_genes.extend(matches)
            
            if found_genes:
                print(f"Found potential matches for {len(found_genes)} genes")
                return expression_data.loc[found_genes]
            
            print("ERROR: Could not find genes of interest in expression data")
            print("Please provide a gene ID mapping file or ensure gene symbols are in the data")
            raise ValueError("Genes not found and no mapping available")
        
        # gene ID mapping
        # mapping file has 'ensembl_gene_id' and 'hgnc_symbol' columns
        ensembl_ids = []
        for gene in genes_of_interest:
            matches = gene_mapping[gene_mapping['hgnc_symbol'] == gene]['ensembl_gene_id'].tolist()
            ensembl_ids.extend(matches)
        
        if not ensembl_ids:
            print(f"ERROR: Could not find Ensembl IDs for any genes in {genes_of_interest}")
            raise ValueError("No matching Ensembl IDs found")
        
        # Extract genes of interest 
        gene_expression = expression_data.loc[expression_data.index.isin(ensembl_ids)]
        
        # Add gene symbols 
        if 'hgnc_symbol' not in gene_expression.columns:
            # Create mapping dictionary
            id_to_symbol = dict(zip(
                gene_mapping['ensembl_gene_id'],
                gene_mapping['hgnc_symbol']
            ))
            
            # Create new dataframe with gene symbols as index
            gene_expression_with_symbols = gene_expression.copy()
            gene_expression_with_symbols['hgnc_symbol'] = [
                id_to_symbol.get(idx, str(idx)) for idx in gene_expression.index
            ]
            gene_expression = gene_expression_with_symbols.set_index('hgnc_symbol')
        
        print(f"Found {len(gene_expression)} out of {len(genes_of_interest)} genes of interest")
        return gene_expression
        
    except Exception as e:
        print(f"ERROR in gene processing: {str(e)}")
        print("Ensure expression data is properly formatted and gene IDs are compatible")
        raise

In [None]:
# 8. Separating Datasets

def prepare_analysis_data(expression_data, clinical_data, sample_type):
    """
    Prepare data for a specific sample type (Primary or Metastasis).
    
    Parameters:
    -----------
    expression_data : pandas.DataFrame
        Processed expression data
    clinical_data : pandas.DataFrame
        Processed clinical data
    sample_type : str
        Sample type to filter ('Primary' or 'Metastasis')
        
    Returns:
    --------
    pandas.DataFrame
        Merged data for specific sample type in long format
    """
    # Filter clinical data by sample type
    clinical_filtered = clinical_data[clinical_data['sample_type'] == sample_type].copy()
    
    # Prepare expression data - convert to long format
    expression_filtered = expression_data.copy()
    
    # Convert wide to long format
    expression_long = expression_filtered.reset_index().melt(
        id_vars='hgnc_symbol', 
        var_name='sample_id', 
        value_name='expression'
    )
    
    # Merge with clinical 
    merged_data = expression_long.merge(clinical_filtered, on='sample_id', how='inner')
    
    print(f"Prepared data for {sample_type}: {len(merged_data['sample_id'].unique())} samples")
    return merged_data

In [None]:
# 9. Classificayion

def classify_expression(data, cutoff_percentile=0.5, expression_col='expression', gene='CD44'):
    """
    Classify samples into high and low expression groups based on percentile cutoff.
    
    Parameters:
    -----------
    data : pandas.DataFrame
        Processed data in long format
    cutoff_percentile : float
        Percentile cutoff for classifying high vs low (0.5 = median)
    expression_col : str
        Column name containing expression values
    gene : str
        Gene to classify 
        
    Returns:
    --------
    dict
        Dictionary with classified data and summary statistics
    """
    # Filter for specific gene
    gene_data = data[data['hgnc_symbol'] == gene].copy()
    
    # Calculate cutoff
    cutoff_value = gene_data[expression_col].quantile(cutoff_percentile)
    
    # Apply classification
    gene_data['expression_class'] = gene_data[expression_col].apply(
        lambda x: 'High' if x >= cutoff_value else 'Low'
    )
    
    # Summary statistics
    summary_stats = {
        'cutoff_value': cutoff_value,
        'n_high': (gene_data['expression_class'] == 'High').sum(),
        'n_low': (gene_data['expression_class'] == 'Low').sum(),
        'total_samples': len(gene_data),
        'na_samples': gene_data['expression_class'].isna().sum()
    }
    
    print(f"Classified {gene} expression (cutoff: {cutoff_percentile*100}%)")
    print(f"High: {summary_stats['n_high']}, Low: {summary_stats['n_low']}")
    
    return {
        'classified_data': gene_data,
        'summary': summary_stats
    }

def create_expression_histogram(data, gene='CD44', title=None, filename=None):
    """
    Create histogram visualization of gene expression distribution.
    
    Parameters:
    -----------
    data : pandas.DataFrame
        Processed data in long format
    gene : str
        Gene to visualize
    title : str
        Plot title (optional)
    filename : str
        Path to save the figure (optional)
        
    Returns:
    --------
    matplotlib.pyplot
        Plot object
    """
    gene_data = data[data['hgnc_symbol'] == gene].copy()
    
    mean_val = gene_data['expression'].mean()
    median_val = gene_data['expression'].median()
    
    plt.figure(figsize=(10, 6))
    
    # Histogram 
    sns.histplot(data=gene_data, x='expression', stat='density', 
                 kde=True, color='#2E86C1', edgecolor='black', alpha=0.7)
    
    # Add vertical lines 
    plt.axvline(mean_val, color='darkgreen', linestyle='--', linewidth=1.5, 
                label=f'Mean = {mean_val:.2f}')
    plt.axvline(median_val, color='red', linestyle='--', linewidth=1.5,
                label=f'Median = {median_val:.2f}')
    
    plt.title(title or f"{gene} Expression Distribution", fontsize=14, fontweight='bold')
    plt.xlabel(f"{gene} expression (normalized)", fontsize=12)
    plt.ylabel("Density", fontsize=12)
    plt.legend()
    
    if filename:
        plt.savefig(filename, dpi=300, bbox_inches='tight')
    
    plt.show()
    return plt

In [None]:
# 10. Final Data Prep

def prepare_final_dataset(classification_data, expression_data, clinical_processed, gene='CD44'):
    """
    Create final dataset combining classification results with expression and clinical data.
    
    Parameters:
    -----------
    classification_data : dict
        Output from classify_expression function
    expression_data : pandas.DataFrame
        Processed expression data in long format
    clinical_processed : pandas.DataFrame
        Processed clinical data
    gene : str
        Main gene for classification (used for column naming)
        
    Returns:
    --------
    pandas.DataFrame
        Final dataset for survival analysis
    """
    # Get wide format 
    gene_expressions = expression_data.pivot_table(
        index='sample_id', 
        columns='hgnc_symbol', 
        values='expression'
    ).reset_index()
    
    # Join classification, gene expression and clinical data
    classified_samples = classification_data['classified_data']['sample_id'].unique()
    clinical_filtered = clinical_processed[clinical_processed['sample_id'].isin(classified_samples)].copy()
    
    # Get classification data in right format
    class_data = classification_data['classified_data'][['sample_id', 'expression_class']].copy()
    class_data = class_data.rename(columns={'expression_class': f'{gene}_expression_class'})
    
    # Join all 
    final_data = class_data.merge(
        gene_expressions, 
        on='sample_id'
    ).merge(
        clinical_filtered,
        on='sample_id'
    )
    
    # Convert to cat
    final_data[f'{gene}_expression_class'] = final_data[f'{gene}_expression_class'].astype('category')
    
    print(f"Final dataset prepared with {len(final_data)} samples")
    return final_data

In [None]:
# 11. Correlation analysis

def analyze_gene_correlations(data, genes=['CD44', 'CD68']):
    """
    Analyze correlation between genes and create visualizations.
    
    Parameters:
    -----------
    data : pandas.DataFrame
        Final dataset with gene expression values
    genes : list
        List of genes to include in correlation analysis
        
    Returns:
    --------
    dict
        Dictionary with Pearson and Spearman correlation matrices
    """
    # Calculate Pearson and Spearman correlations
    pearson_matrix = data[genes].corr(method='pearson')
    spearman_matrix = data[genes].corr(method='spearman')
    
    # Create visualizations
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    # Pearson correlation
    sns.heatmap(pearson_matrix, annot=True, cmap='coolwarm', vmin=-1, vmax=1, 
                square=True, ax=axes[0], cbar=True)
    axes[0].set_title('Pearson Correlation', fontsize=14)
    
    # Spearman correlation
    sns.heatmap(spearman_matrix, annot=True, cmap='coolwarm', vmin=-1, vmax=1, 
                square=True, ax=axes[1], cbar=True)
    axes[1].set_title('Spearman Correlation', fontsize=14)
    
    plt.tight_layout()
    plt.savefig('figures/gene_correlation_matrix.pdf', dpi=300, bbox_inches='tight')
    plt.show()
    
    return {
        'pearson': pearson_matrix,
        'spearman': spearman_matrix
    }

In [None]:
# 12. Survival Analysis

def validate_and_prepare_data(data):
    """
    Prepare data for survival analysis, validating key fields.
    
    Parameters:
    -----------
    data : pandas.DataFrame
        Final dataset
        
    Returns:
    --------
    pandas.DataFrame
        Dataset ready for survival analysis
    """
    data = data.copy()
    
    # Convert to correct types
    data['os_status_recode'] = data['os_status_recode'].astype(float)
    numeric_cols = ['age', 'breslow_depth', 'os_months']
    for col in numeric_cols:
        data[col] = pd.to_numeric(data[col], errors='coerce')
    
    # Drop rows with missing survival 
    data = data.dropna(subset=['os_months', 'os_status_recode'])
    
    print(f"Data prepared for survival analysis: {len(data)} valid samples")
    return data


def create_expression_boxplot(data, gene='CD44', class_column='CD44_expression_class', title=None):
    """
    Create boxplot to visualize expression differences between groups.
    
    Parameters:
    -----------
    data : pandas.DataFrame
        Final dataset
    gene : str
        Gene to visualize
    class_column : str
        Column with classification (High/Low)
    title : str
        Plot title (optional)
        
    Returns:
    --------
    matplotlib.pyplot
        Plot object
    """
    plt.figure(figsize=(8, 6))
    
    # Create boxplot with points
    ax = sns.boxplot(x=class_column, y=gene, data=data, 
                    palette={'High': 'blue', 'Low': 'red'})
    
    # Add individual points 
    sns.stripplot(x=class_column, y=gene, data=data, 
                 size=4, color='black', alpha=0.4, jitter=True)
    
    # Add statistical test
    high_values = data[data[class_column] == 'High'][gene]
    low_values = data[data[class_column] == 'Low'][gene]
    t_stat, p_value = stats.ttest_ind(high_values, low_values, equal_var=False)
    
    # Add p-value 
    if p_value < 0.001:
        p_text = 'p < 0.001'
    else:
        p_text = f'p = {p_value:.3f}'
    
    y_max = data[gene].max()
    plt.text(0.5, y_max * 1.05, p_text, horizontalalignment='center', size=12, weight='bold')
    
    #  plot
    plt.title(title or f"{gene} Expression by Group", fontsize=14, fontweight='bold')
    plt.xlabel(f"{gene} Group", fontsize=12)
    plt.ylabel(f"Gene Expression (normalized)", fontsize=12)
    plt.tight_layout()
    
    return plt

def create_survival_plot(data, gene='CD44', class_column='CD44_expression_class', title=None):
    """
    Create Kaplan-Meier survival plot comparing high vs low expression groups.
    
    Parameters:
    -----------
    data : pandas.DataFrame
        Final dataset
    gene : str
        Gene to visualize
    class_column : str
        Column with classification (High/Low)
    title : str
        Plot title (optional)
        
    Returns:
    --------
    tuple
        Plot object and log-rank test results
    """
    # Count samples in each group
    n_high = sum(data[class_column] == 'High')
    n_low = sum(data[class_column] == 'Low')
    
    # Create KM 
    kmf = KaplanMeierFitter()
    
    # Set up plot
    plt.figure(figsize=(10, 6))
    
    # Fit and plot for high 
    high_data = data[data[class_column] == 'High']
    kmf.fit(high_data['os_months'], high_data['os_status_recode'], label=f"{gene} High (n={n_high})")
    ax = kmf.plot(ci_show=False, color='#E41A1C', linewidth=2)
    
    # Fit and plot for low 
    low_data = data[data[class_column] == 'Low']
    kmf.fit(low_data['os_months'], low_data['os_status_recode'], label=f"{gene} Low (n={n_low})")
    kmf.plot(ax=ax, ci_show=False, color='blue', linewidth=2)
    
    # Log-rank test
    results = logrank_test(high_data['os_months'], low_data['os_months'],
                         high_data['os_status_recode'], low_data['os_status_recode'])
    
    # Format p-value 
    if results.p_value < 0.001:
        p_text = "Log-rank p < 0.001"
    else:
        p_text = f"Log-rank p = {results.p_value:.3f}"
    
    # Add p-value 
    plt.text(0.5, 0.1, p_text, transform=plt.gca().transAxes, 
             horizontalalignment='center', size=12, weight='bold')
    
    #  plot
    plt.title(title or f"{gene} Survival Analysis", fontsize=14, fontweight='bold')
    plt.xlabel("Time (months)", fontsize=12)
    plt.ylabel("Overall survival probability", fontsize=12)
    plt.ylim(0, 1.05)
    plt.grid(alpha=0.3)
    plt.legend()
    plt.tight_layout()
    
    # Save
    plt.savefig(f"figures/{gene}_survival_plot.pdf", dpi=300, bbox_inches='tight')
    
    return plt, results


In [None]:

# 13. Cox Regression Models 

def perform_cox_analysis(data, gene='CD44'):
    """
    Perform Cox proportional hazards analysis.
    
    Parameters:
    -----------
    data : pandas.DataFrame
        Final dataset
    gene : str
        Gene to analyze
        
    Returns:
    --------
    dict
        Dictionary with Cox model results
    """
    print(f"\n=== Cox Proportional Hazards Analysis for {gene} ===\n")
    
    # Prepare data
    cox_data = data.copy()
    
    # Filter out missing values
    cox_data = cox_data.dropna(subset=['os_months', 'os_status_recode'])
    for gene expression
    cph_uni = CoxPHFitter()
    cph_uni.fit(cox_data[[gene, 'os_months', 'os_status_recode']], 
               duration_col='os_months', 
               event_col='os_status_recode')
    
    print(f"\n--- Univariate Cox model for {gene} ---")
    print(cph_uni.summary)
    
    # Multivariate model
    covariates = [gene, 'mutation_status_class1', 'age', 'breslow_depth', 'stage']
    
    # Filter only available covariates
    available_covariates = [cov for cov in covariates if cov in cox_data.columns]
    available_covariates.extend(['os_months', 'os_status_recode'])
    
    # Remove rows with missing values in key variables
    cox_data_multi = cox_data[available_covariates].dropna()
    
    # Fit multivariate 
    cph_multi = CoxPHFitter()
    try:
        cph_multi.fit(cox_data_multi, 
                     duration_col='os_months', 
                     event_col='os_status_recode')
        
        print(f"\n--- Multivariate Cox model with {gene} ---")
        print(cph_multi.summary)
    except Exception as e:
        print(f"Error fitting multivariate model: {e}")
        cph_multi = None
    
    # Create forest plot for univariate 
    plt.figure(figsize=(10, 6))
    uni_summary = cph_uni.summary.reset_index()
    
    # For univariate, we only have one row
    plt.errorbar(x=uni_summary['exp(coef)'].values[0], 
                y=[0], 
                xerr=[[uni_summary['exp(coef)'].values[0] - uni_summary['exp(coef) lower 95%'].values[0]], 
                      [uni_summary['exp(coef) upper 95%'].values[0] - uni_summary['exp(coef)'].values[0]]],
                fmt='o',
                color='red' if uni_summary['p'].values[0] < 0.05 else 'black',
                capsize=5, 
                markersize=10)
    
    # Add vertical line at HR=1
    plt.axvline(x=1, color='black', linestyle='--')
    
    # Format plot
    plt.xlabel('Hazard Ratio (95% CI)', fontsize=12)
    plt.title(f'Univariate Cox Regression: {gene}', fontsize=14)
    plt.grid(alpha=0.3)
    plt.tight_layout()
    
    # Add HR and p-value as text
    hr_text = f"HR = {uni_summary['exp(coef)'].values[0]:.2f} "
    hr_text += f"({uni_summary['exp(coef) lower 95%'].values[0]:.2f}-{uni_summary['exp(coef) upper 95%'].values[0]:.2f})"
    
    p_text = f"p = {uni_summary['p'].values[0]:.4f}"
    if uni_summary['p'].values[0] < 0.001:
        p_text = "p < 0.001"
        
    plt.text(1.5, 0, f"{hr_text}\n{p_text}", verticalalignment='center', fontsize=12)
    
    # Save 
    plt.savefig(f"figures/{gene}_cox_forest_plot.pdf", dpi=300, bbox_inches='tight')
    
    return {
        'univariate': cph_uni,
        'multivariate': cph_multi
    }



In [None]:
# 14. Run pipeline

def run_complete_analysis(genes_of_interest=['CD44', 'CD68']):
    """
    Run the complete analysis pipeline
    
    Parameters:
    -----------
    genes_of_interest : list
        List of genes to analyze
    """
    print("Starting CD44/CD68 Survival Analysis Pipeline\n")
    print("-" * 50)
    
    # 1. Load data
    expression_data = load_expression_data()
    clinical_data = get_clinical_data()
    braf_mutation = get_braf_mutations()
    
    # 2. Process data
    clinical_processed = process_clinical_data(clinical_data)
    braf_classified = classify_braf_mutations(braf_mutation)
    clinical_with_braf = integrate_braf_with_clinical(clinical_processed, braf_classified)
    
    # 3. Process expression data
    expression_processed = process_expression_data(expression_data)
    
    # 4. Match clinical and expression data
    matched_data = match_clinical_expression(expression_processed, clinical_with_braf)
    
    # 5. Extract genes of interest
    gene_expression = process_gene_expression(matched_data['expression'], genes_of_interest)
    
    # 6. Prepare data by sample type
    primary_data = prepare_analysis_data(gene_expression, matched_data['clinical'], 'Primary')
    metastatic_data = prepare_analysis_data(gene_expression, matched_data['clinical'], 'Metastasis')
    
    # 7. Analyze each gene
    results = {}
    
    for gene in genes_of_interest:
        print(f"\n{'=' * 30}")
        print(f"Analysis for {gene}")
        print(f"{'=' * 30}\n")
        
        # Create histograms
        create_expression_histogram(primary_data[primary_data['hgnc_symbol'] == gene], 
                                  gene=gene, 
                                  title=f"{gene} Expression in Primary Tumors",
                                  filename=f"figures/{gene}_primary_histogram.pdf")
        
        create_expression_histogram(metastatic_data[metastatic_data['hgnc_symbol'] == gene], 
                                  gene=gene, 
                                  title=f"{gene} Expression in Metastatic Tumors",
                                  filename=f"figures/{gene}_metastatic_histogram.pdf")
        
        # Classify expression (median cutoff)
        primary_classified = classify_expression(primary_data, 0.5, 'expression', gene)
        metastatic_classified = classify_expression(metastatic_data, 0.5, 'expression', gene)
        
        # Prepare final datasets
        primary_final = prepare_final_dataset(primary_classified, primary_data, clinical_with_braf, gene)
        metastatic_final = prepare_final_dataset(metastatic_classified, metastatic_data, clinical_with_braf, gene)
        
        # Validate data for survival analysis
        primary_final_valid = validate_and_prepare_data(primary_final)
        metastatic_final_valid = validate_and_prepare_data(metastatic_final)
        
        # Create expression boxplots
        create_expression_boxplot(primary_final_valid, gene, f'{gene}_expression_class',
                                f"{gene} Expression in Primary Tumors by Group")
        
        create_expression_boxplot(metastatic_final_valid, gene, f'{gene}_expression_class',
                                f"{gene} Expression in Metastatic Tumors by Group")
        
        # Survival analysis
        primary_km, primary_logrank = create_survival_plot(
            primary_final_valid, gene, f'{gene}_expression_class',
            f"{gene} Survival Analysis in Primary Tumors"
        )
        
        metastatic_km, metastatic_logrank = create_survival_plot(
            metastatic_final_valid, gene, f'{gene}_expression_class',
            f"{gene} Survival Analysis in Metastatic Tumors"
        )
        
        # Cox regression
        primary_cox = perform_cox_analysis(primary_final_valid, gene)
        metastatic_cox = perform_cox_analysis(metastatic_final_valid, gene)
        
        # Store results
        results[gene] = {
            'primary': {
                'classified_data': primary_classified,
                'final_data': primary_final_valid,
                'km_results': primary_logrank,
                'cox_results': primary_cox
            },
            'metastatic': {
                'classified_data': metastatic_classified,
                'final_data': metastatic_final_valid,
                'km_results': metastatic_logrank,
                'cox_results': metastatic_cox
            }
        }
    
    # 8. Compare genes (correlation)
    if len(genes_of_interest) > 1:
        analyze_gene_correlations(primary_final_valid, genes_of_interest)
        analyze_gene_correlations(metastatic_final_valid, genes_of_interest)
    
    print("\nAnalysis complete! Results saved to 'figures' and 'results' directories.")
    return results


# Complete analysis
if __name__ == "__main__":
    results = run_complete_analysis(['CD44', 'CD68'])