In [None]:
import pandas as pd
import altair as alt
import numpy as np
import scipy.stats as stats
from sklearn.metrics import roc_curve, auc, confusion_matrix
from typing import Optional, Tuple, Dict

In [None]:
#ClinVar data file
file = '../Data/20250912_BARD1_ClinVarSNVs_1StarPlus.txt'
clinvar_dels = '../Data/20250912_BARD1_ClinVarDels_1StarPlus.txt'

#SGE SNV and deletion data
sge = '../Data/20250825_BARD1snvscores_filtered.xlsx'
sge_dels = '../Data/20250829_BARD1delscores.tsv'

#Gets thresholds for BARD1
bard1df = pd.read_excel(sge)
# find the GMM thresholds
target_value = 0.950
# Calculate the absolute difference for the Normal (N) density
diffN = (bard1df['gmm_density_normal'] - target_value).abs()
# Find the index of the minimum difference
closest_index = diffN.idxmin()
# Retrieve the row with the closest value
closest_row_n = bard1df.loc[closest_index]

# now repeat that for the abnormal density
# Calculate the absolute difference
diffA = (bard1df['gmm_density_abnormal'] - target_value).abs()
# Find the index of the minimum difference
closest_index = diffA.idxmin()
# Retrieve the row with the closest value
closest_row_a = bard1df.loc[closest_index]

# now we get the scores that are the closest to the (n)ormal and (a)bnormal thresholds
uppr = closest_row_n['score']
lwr = closest_row_a['score']

path_max = lwr
benign_min = uppr
thresholds = [lwr,uppr]

In [None]:
def read_data(file, del_file): #Reads ClinVar data
    
    df = pd.read_csv(file, delimiter='\t') #reads ClinVar SNV tabular .txt 
    df = df[['Name','Protein change','GRCh38Chromosome','GRCh38Location','Germline classification']] #pulls useful columns
    df = df.dropna(subset = ['GRCh38Location']) #Drops variants without genomic coordinate
    df.GRCh38Location = df.GRCh38Location.astype(int) #Sets coordinates to integer data type
    df['Base Change'] = None #preps for next function

    del_df = pd.read_csv(del_file, sep = '\t') #Reads ClinVar deletions
    del_df = del_df.loc[del_df['GRCh38Location'].str.contains('-')] #Splits coordinates
    del_df['start'] = del_df['GRCh38Location'].transform(lambda x: x.split(' - ')[0]) #Gets deletion start coordinate
    del_df['end'] = del_df['GRCh38Location'].transform(lambda x: x.split(' - ')[1]) #Gets deletion end coordinate

    #Sets coordinate data types to integer
    del_df['start'] = del_df['start'].astype(int) 
    del_df['end'] = del_df['end'].astype(int)

    del_df['del_length'] = del_df['end'] - del_df['start'] #Calculates deletion length 

    del_df = del_df.loc[del_df['del_length'].isin([2])] #Pulls out 3bp deletions
    del_df['Base Change'] = del_df['start'].astype(str) + '-' + del_df['end'].astype(str) #Sets base change column to coordinate spanned by deletion
    del_df = del_df[['Base Change', 'Germline classification']] #Pulls out necessary columns


    return df, del_df

In [None]:
def get_pair(base): #ClinVar gives base changes on negative sense strand, SGE pos_id on positive sense
    if base == 'A':
        return 'T'
    elif base == 'T':
        return 'A'
    elif base == 'C':
        return 'G'
    else:
        return 'C'

In [None]:
def prep_sge(sge, sge_dels, thresholds): #reads SGE data

    #Reads SNV data and renames consequences
    df = pd.read_excel(sge)
    df = df.rename(columns = {'consequence': 'Consequence', 'score': 'snv_score'})
    df.loc[df['Consequence'].str.contains('missense'), 'Consequence'] = 'Missense'
    df.loc[df['Consequence'] == 'synonymous_variant', 'Consequence'] = 'Synonymous'
    df.loc[df['Consequence'] == 'intron_variant', 'Consequence'] = 'Intron'
    df.loc[df['Consequence'] == 'stop_gained', 'Consequence'] = 'Stop Gained'
    df.loc[df['Consequence'] == 'stop_lost', 'Consequence'] = 'Stop Lost'
    df.loc[df['Consequence'].str.contains('site'), 'Consequence'] = 'Canonical Splice'
    df.loc[df['Consequence'].str.contains('ing_var'), 'Consequence'] = 'Splice Region'
    df.loc[df['Consequence'].str.contains('UTR'), 'Consequence'] = 'UTR Variant'
    df.loc[df['Consequence'] == 'start_lost', 'Consequence'] = 'Start Lost'

    df['Function Type'] = 'Indeterminate'
    df.loc[df['functional_consequence'] == 'functionally_normal', 'Function Type'] = 'Benign'
    df.loc[df['functional_consequence'] == 'functionally_abnormal', 'Function Type'] = 'Pathogenic'
    df = df[['target', 'Consequence', 'pos_id', 'snv_score', 'Function Type']] #Pulls these columns 

    #Reads 3bp del data and renames consequences
    del_df = pd.read_csv(sge_dels, sep = '\t')
    del_df = del_df.rename(columns = {'consequence': 'Consequence', 'score': 'snv_score'})
    del_df.loc[del_df['Consequence'].str.contains('site'), 'Consequence'] = 'Canonical Splice'
    del_df.loc[del_df['Consequence'].str.contains('ing_var'), 'Consequence'] = 'Splice Region'
    del_df.loc[del_df['Consequence'].str.contains('UTR'), 'Consequence'] = 'UTR Variant'
    del_df.loc[del_df['Consequence'] == 'stop_gained', 'Consequence'] = 'Stop Gained'
    del_df.loc[del_df['Consequence'] == 'stop_lost', 'Consequence'] = 'Stop Lost'
    del_df.loc[del_df['Consequence'] == 'start_lost', 'Consequence'] = 'Start Lost'
    del_df.loc[del_df['Consequence'] == 'intron_variant', 'Consequence'] = 'Intron'
    del_df.loc[del_df['Consequence'] == 'inframe_indel', 'Consequence'] = 'Inframe Indel'
    
    del_df['pos_id'] = del_df['start'].astype(str) + '-' + del_df['end'].astype(str) #Sets position ID type to string

    #Functional classification of dels
    del_df['Function Type'] = 'Indeterminate'
    del_df.loc[del_df['snv_score'] <= thresholds[0], 'Function Type'] = 'Pathogenic'
    del_df.loc[del_df['snv_score'] >= thresholds[1], 'Function Type'] = 'Benign'

    df = pd.concat([df, del_df]) #Dataframes concatenated

    return df

In [None]:
def get_base_changes(df): #Creates pos_id column in format of SGE datafile for ClinVar data    
    k = 0
    while k < len(df):
        var = df['Name'][k]
        coord = str(df['GRCh38Location'][k])
        k += 1
        i = 0
        j = 3
        while j < (len(var) + 1):
            test_str = var[i:j]
            j += 1
            i += 1
            sense_base = get_pair(test_str[2])
            if test_str[1] == '>':
                change = coord + ":" + sense_base
                df.loc[df['Name'] == var, 'Base Change'] = change

    return df

In [None]:
def merge(clin,clin_del, sge, nf_cutoff, func_cutoff):
    #merges ClinVar dataframe and SGE dataframe based on shared pos_id
    clin_data = pd.concat([clin, clin_del])

    sge_data = sge
    df = pd.merge(clin_data, sge_data, left_on = 'Base Change', right_on = 'pos_id', how = 'inner')

    #makes dataframe nicer if output needed    
    df = df[['Name', 'Protein change', 'Germline classification', 'Base Change', 'Consequence', 'target', 'snv_score', 'Function Type']]
    reordered = ['target', 'Name', 'Protein change', 'Base Change', 'Consequence', 'snv_score', 'Germline classification', 'Function Type']
    df = df[reordered]

    df.rename(columns = {'target': 'target', 'snv_score': 'SGE Score'}, inplace = True)
    #df.to_excel('/Users/ivan/Desktop/test_excel_outputs/20250718_BARD1SGE_ClinVar_vars.xlsx', index = None)
    
    return df

In [None]:
def rename_germline(df): #Renames ClinVar germline classifications to fit on plot better

    df = df.copy() #so the original df isn't overwritten
    
    #renames germline classification categories
    df.loc[(df['Germline classification'] == 'Pathogenic') | (df['Germline classification'] == 'Likely pathogenic'), 'Germline classification'] = 'P/LP'
    df.loc[(df['Germline classification'] == 'Benign') | (df['Germline classification'] == 'Likely benign'), 'Germline classification' ] = 'B/LB'
    df.loc[df['Germline classification'] == 'Uncertain significance', 'Germline classification' ] = 'VUS'
    df.loc[df['Germline classification'] == 'Conflicting classifications of pathogenicity', 'Germline classification' ] = 'Conflicting'
    df.loc[df['Germline classification'] == 'Benign/Likely benign', 'Germline classification'] = 'B/LB'
    df.loc[df['Germline classification'] == 'Pathogenic/Likely pathogenic', 'Germline classification'] = 'P/LP'

    return df

In [None]:
def histogram(df, nf_cutoff, func_cutoff): #creates histogram and strip plot showing distribution of scores for variants in ClinVar

    df.loc[df['Base Change'].str.contains('-'), 'Consequence'] = '3bp Deletion' #Sets consequence of variants for deletion variants

    #Color palette
    palette = [
    '#006616', # dark green,
    '#81B4C7', # dusty blue
    '#ffcd3a', # yellow
    '#6AA84F', # med green
    '#93C47D', # light green
    '#888888', # med gray
    '#000000', # black
    '#1170AA', # darker blue
    '#CFCFCF', # light gray
    '#FF9A00'   #orange    
    ]
    
    
    variant_types = [
        'Synonymous',
        'Missense',  
        'Stop Gained',
        'Intron', 
        'UTR Variant',
        'Stop Lost',
        'Start Lost',
        'Canonical Splice', 
        'Splice Region', 
        '3bp Deletion'
    ]
    
    #Lines for functional and non-functional cutoffs
    nf_line = alt.Chart(pd.DataFrame({'x': [nf_cutoff]})).mark_rule(color = 'red').encode(
        x = 'x')

    func_lin = alt.Chart(pd.DataFrame({'x': [func_cutoff]})).mark_rule(color = 'blue').encode(
        x = 'x')
    
    #sets bins and domain of scale for histograms
    bins = 50
    scale = [-3,2]
    ticks = list(range(-3,3))

    
    classifications = ['Benign', 'Benign/Likely benign', 'Likely benign', 'Pathogenic', 'Likely pathogenic', 'Pathogenic/Likely pathogenic', 'Uncertain significance', 'Conflicting classifications of pathogenicity']
    p_df = df.loc[df['Germline classification'].isin(classifications)]
    p_df.loc[p_df['Germline classification'] == 'Benign/Likely benign','Germline classification'] = 'Likely benign' 
    p_df.loc[p_df['Germline classification'] == 'Pathogenic/Likely pathogenic', 'Germline classification'] = 'Likely pathogenic'
    p_df.loc[(p_df['Germline classification'] == 'Uncertain significance') | (p_df['Germline classification'] == 'Conflicting classifications of pathogenicity'), 'Germline classification'] = 'Uncertain significance'
    
    #p_df.to_excel('/Users/ivan/Desktop/test_excel_outputs/20250912_BARD1SGE_SNVsDels_inClinVar.xlsx', index = False)
    
    #Creates histogram showing distribution of variants in ClinVar vs. SGE score
    plp_hist = alt.Chart(p_df).mark_bar().encode(
        alt.X('SGE Score', axis = alt.Axis(title = 'SGE Score', labelFontSize = 18, titleFontSize = 20), bin = alt.Bin(maxbins = bins)),
        alt.Y('count()', axis = alt.Axis(title = 'Number of Variants', labelFontSize = 18, titleFontSize = 20)),
        color = alt.Color('Germline classification:N',
                          legend = alt.Legend(titleFontSize = 16, 
                                              labelFontSize = 14
                                             ),
                          scale = alt.Scale(
                              domain = ['Benign', 'Likely benign', 'Pathogenic', 'Likely pathogenic', 'Uncertain significance'],
                              range = ['#1D7AAB', '#63A1C4', '#CA7682','#E6B1B8', '#A0A0A0']
                          )
                         )
    ).properties(
        width = 800,
        height = 400,
        title = alt.TitleParams(text = 'BARD1 ClinVar Variants in SGE ' + ' (n= ' + str(len(p_df)) + ')', fontSize = 22)
    )
    plp_hist = plp_hist + nf_line + func_lin
    plp_hist = plp_hist.configure_axis(
        grid = False
    )
    plp_hist.show()

    #Creates strip plot showing distribution of variants in ClinVar vs. SGE score
    plp_strip = alt.Chart(p_df).mark_tick().encode(
        x = alt.X('SGE Score',
                  title = 'SGE Score',
                  axis = alt.Axis(labelFontSize = 16, 
                                  titleFontSize = 20
                                 )
                 ),
        y = alt.Y('Germline classification',
                 sort = ['Benign', 'Likely benign', 'Uncertain significance', 'Likely pathogenic', 'Pathogenic'],
                 axis = alt.Axis(
                     labelFontSize = 16,
                     titleFontSize = 20,
                     labelLimit = 1000
                 )
                 ),
        color = alt.Color('Consequence',
                         scale = alt.Scale(
                             range = palette,
                             domain = variant_types
                         ),
                          legend = alt.Legend(titleFontSize = 16,
                                              labelFontSize = 14
                                             )
                         ),
        tooltip = ['Base Change']
    ).properties(
        width = 800,
        height = 400,
        title = alt.TitleParams(text = 'BARD1 ClinVar Variants in SGE ' + ' (n= ' + str(len(p_df)) + ')', fontSize = 22)
    )

    #Builds summary df with variant counts for each germline classification category
    summary_df = p_df['Germline classification'].value_counts().reset_index()
    summary_df['count'] = summary_df['count'].astype(str)
    summary_df['text'] = '(n = ' + summary_df['count'] + ')'

    #Builds y-axis labels with variant counts
    counts = alt.Chart(summary_df).mark_text(
        align='right',
        dx=-10,  # Slight offset to the left of the y-axis
        dy=20,  # Offset below the category label
        fontSize=16,
        color='black'
    ).encode(
        y=alt.Y('Germline classification:N', sort = ['Benign', 'Likely benign', 'Uncertain significance', 'Likely pathogenic', 'Pathogenic']),
        text='text',
        x=alt.value(0)  # Position at the left edge
    )

    #Builds final strip plot
    plp_strip = (plp_strip + counts + nf_line + func_lin).configure_axis(
        grid = False
    ).configure_view(
        stroke = None
    )
    
    plp_strip.display()


    blb_df = p_df.loc[p_df['Germline classification'].isin(['Benign', 'Likely Benign'])]
    plp_df = p_df.loc[p_df['Germline classification'].isin(['Pathogenic', 'Likely pathogenic'])]
    blb_scores = blb_df['SGE Score']
    plp_scores = plp_df['SGE Score']

    plp_mean = plp_scores.mean()
    blb_mean = blb_scores.mean()
    t_statistic, p_value = stats.ttest_1samp(a = blb_scores, popmean = plp_mean)

    print("Mean BLB Score: ", blb_mean, '\n',
          'Mean PLP Score: ', plp_mean, '\n',
          'One Sample T-Test P-value: ', p_value)
    
    #extracts VUS and conflicting data from ClinVar
    vus = ['Uncertain significance', 'Conflicting classifications of pathogenicity']
    v_df = df.loc[df['Germline classification'].isin(vus)]
    new_path = v_df.loc[v_df['SGE Score'] <= path_max]
    new_benign = v_df.loc[v_df['SGE Score'] >= benign_min]

    v_df.loc[v_df['Germline classification'] == 'Conflicting classifications of pathogenicity', 'Germline classification'] = 'Conflicting'
    #v_df.to_excel('output.xlsx')
    
    return plp_hist, plp_strip

In [None]:
def pr_df(df): #makes dataframe needed for ROC
    #filter out non P/LP and B/LB variants
    nv_df = df[['Base Change','Germline classification', 'Function Type']]
    non_vus_list = ['Benign', 'Benign/Likely benign', 'Likely benign', 'Pathogenic', 'Likely pathogenic', 'Pathogenic/Likely pathogenic']
    nv_df = nv_df.loc[nv_df['Germline classification'].isin(non_vus_list)]

    #creates columns in dataframe needed for ROC analysis
    nv_df = nv_df.copy()
    nv_df['Germline Num'] = np.nan
    nv_df['SGE Num'] = np.nan
    nv_df['target'] = np.nan

    #reindexes new df
    new_index = []
    for i in range(len(nv_df)):
            new_index.append(i)
    nv_df = nv_df.reset_index(drop = True)
    nv_df.index = new_index

    #assigns 1s and 0s to each variant type - both for ClinVar classification and SGE classification
    nv_df.loc[(nv_df['Germline classification'] == 'Benign') | (nv_df['Germline classification'] == 'Likely benign') | (nv_df['Germline classification'] == 'Benign/Likely benign') , 'Germline Num'] = 1
    nv_df.loc[(nv_df['Germline classification'] == 'Pathogenic') | (nv_df['Germline classification'] == 'Likely pathogenic') | (nv_df['Germline classification'] == 'Pathogenic/Likely pathogenic') , 'Germline Num'] = 0

    nv_df.loc[nv_df['Function Type'] == 'Pathogenic', 'SGE Num'] = 0
    nv_df.loc[(nv_df['Function Type'] == 'Benign') | (nv_df['Function Type'] == 'Intermediate'), 'SGE Num'] = 1

    #determines if ClinVar an SGE agree (1 - yes, 0 - no)
    i = 0
    while i < len(nv_df):
        clin = nv_df['Germline Num'][i]
        sge = nv_df['SGE Num'][i]
        id = nv_df['Base Change'][i]

        if clin == sge:
            nv_df.loc[nv_df['Base Change'] == id, 'target'] = 1
        else:
            nv_df.loc[nv_df['Base Change'] == id, 'target'] = 0

        i += 1
    
    
    return nv_df
    

In [None]:
def roc_qc(df):
    i = 0
    concor = 0
    total = 0
    while i < len(df):
        test = df['target'][i]
        if test == 1:
            concor += 1
            total += 1
        else:
            total += 1
        i += 1
    print('Cocordant: ' + str(concor))
    print('Discordant: ' + str(total-concor))

In [None]:
def concor_stats(df):
    i = 0
    plp_total = 0
    plp_concor = 0
    blb_total = 0
    blb_concor = 0
    while i < len(df):
        test = df['target'][i]
        type = df['Germline classification'][i]
        if type == 'Pathogenic' or type == 'Likely pathogenic' or type == 'Pathogenic/Likely pathogenic':
            plp_total += 1
            if test == 1:
                plp_concor += 1
                i += 1
            else:
                i += 1
        elif type == 'Benign' or type == 'Likely benign' or type == 'Benign/Likely benign':
            blb_total += 1
            if test == 1:
                blb_concor += 1
                i += 1
            else: 
                i += 1
    total = plp_total + blb_total
    total_concor = plp_concor + blb_concor
    
    print(str(total_concor),' of ', str(total), ' variants concordant')
    print(str(plp_concor), ' of ', str(plp_total), ' P/LP variants concordant')
    print(str(blb_concor), ' of ', str(blb_total), ' B/LB variants concordant')

In [None]:
def generate_roc_curve(df: pd.DataFrame,
                      prediction_column: str = 'SGE Num',
                      true_label_column: str = 'Germline Num',
                      model_name: str = 'SGE Classifier') -> Tuple[pd.DataFrame, float, Dict]:
    """
    Generate ROC curve data comparing SGE predictions to Germline truth.
    
    Parameters:
    -----------
    df : DataFrame
        Input dataframe with classifications
    prediction_column : str
        Column name containing predictions (default: 'SGE Num')
    true_label_column : str
        Column name containing true labels (default: 'Germline Num')
    model_name : str
        Name for this model/classification
    
    Returns:
    --------
    Tuple of (ROC DataFrame, AUC value, metrics dictionary)
    """
    # Remove any rows with NaN values
    mask = ~(df[prediction_column].isna() | df[true_label_column].isna())
    clean_df = df[mask].copy()
    
    y_true = clean_df[true_label_column].values
    y_pred = clean_df[prediction_column].values
    
    # Calculate ROC curve
    fpr, tpr, thresholds = roc_curve(y_true, y_pred)
    roc_auc = auc(fpr, tpr)
    
    # Calculate additional metrics
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    
    metrics = {
        'AUC': roc_auc,
        'Sensitivity (TPR)': tp / (tp + fn) if (tp + fn) > 0 else 0,
        'Specificity (TNR)': tn / (tn + fp) if (tn + fp) > 0 else 0,
        'PPV (Precision)': tp / (tp + fp) if (tp + fp) > 0 else 0,
        'NPV': tn / (tn + fn) if (tn + fn) > 0 else 0,
        'Accuracy': (tp + tn) / (tp + tn + fp + fn),
        'True Positives': tp,
        'True Negatives': tn,
        'False Positives': fp,
        'False Negatives': fn,
        'Total Samples': len(y_true)
    }
    
    # Create DataFrame with ROC data
    roc_df = pd.DataFrame({
        'FPR': fpr,
        'TPR': tpr,
        'Threshold': thresholds,
        'Model': model_name,
        'AUC': roc_auc
    })
    
    return roc_df, roc_auc, metrics


In [None]:
def plot_roc_altair(df: pd.DataFrame,
                    prediction_column: str = 'SGE Num',
                    true_label_column: str = 'Germline Num',
                    model_name: str = 'SGE Classifier',
                    width: int = 500,
                    height: int = 500,
                    title: Optional[str] = None,
                    include_diagonal: bool = True,
                    show_metrics: bool = True) -> alt.Chart:
    """
    Create an Altair ROC curve visualization comparing SGE to Germline.
    
    Parameters:
    -----------
    df : DataFrame
        Input dataframe
    prediction_column : str
        Column with SGE predictions (0 or 1)
    true_label_column : str
        Column with Germline truth labels (0 or 1)
    model_name : str
        Name for the model in the visualization
    width : int
        Chart width in pixels
    height : int
        Chart height in pixels
    title : str, optional
        Chart title (auto-generated if None)
    include_diagonal : bool
        Whether to include diagonal reference line
    show_metrics : bool
        Whether to print performance metrics
    
    Returns:
    --------
    Altair Chart object
    """
    # Generate ROC data and metrics
    roc_df, auc_value, metrics = generate_roc_curve(
        df, prediction_column, true_label_column, model_name
    )
    
    # Print metrics if requested
    if show_metrics:
        print(f"\n{'='*50}")
        print(f"Performance Metrics: {model_name}")
        print(f"{'='*50}")
        print(f"AUC: {metrics['AUC']:.4f}")
        print(f"Accuracy: {metrics['Accuracy']:.4f}")
        print(f"Sensitivity (TPR): {metrics['Sensitivity (TPR)']:.4f}")
        print(f"Specificity (TNR): {metrics['Specificity (TNR)']:.4f}")
        print(f"PPV (Precision): {metrics['PPV (Precision)']:.4f}")
        print(f"NPV: {metrics['NPV']:.4f}")
        print(f"\nConfusion Matrix:")
        print(f"  True Positives: {metrics['True Positives']}")
        print(f"  True Negatives: {metrics['True Negatives']}")
        print(f"  False Positives: {metrics['False Positives']}")
        print(f"  False Negatives: {metrics['False Negatives']}")
        print(f"  Total Samples: {metrics['Total Samples']}")
        print(f"{'='*50}\n")
    
    # Add AUC to model name for legend
    roc_df['Model_Label'] = f"{model_name} (AUC = {auc_value:.3f})"
    
    # Create the ROC curve
    roc_line = alt.Chart(roc_df).mark_line(
        strokeWidth=3,
        color='darkorange'
    ).encode(
        x=alt.X('FPR:Q',
                title='False Positive Rate (1 - Specificity)',
                scale=alt.Scale(domain=[0, 1]),
                axis=alt.Axis(format='.1f', grid=True, tickCount=6, labelFontSize = 18, titleFontSize = 20, titleFontWeight = 'bold')),
        y=alt.Y('TPR:Q',
                title='True Positive Rate (Sensitivity)',
                scale=alt.Scale(domain=[0, 1]),
                axis=alt.Axis(format='.1f', grid=True, tickCount=6, labelFontSize = 18, titleFontSize = 20, titleFontWeight = 'bold')),
        tooltip=[
            alt.Tooltip('FPR:Q', format='.3f', title='False Positive Rate'),
            alt.Tooltip('TPR:Q', format='.3f', title='True Positive Rate'),
            alt.Tooltip('Threshold:Q', format='.3f', title='Threshold'),
            alt.Tooltip('AUC:Q', format='.4f', title='AUC')
        ]
    )
    
    # Add interactive points for detailed hover
    points = alt.Chart(roc_df).mark_circle(
        size=30,
        color='steelblue',
        opacity=0
    ).encode(
        x='FPR:Q',
        y='TPR:Q',
        tooltip=[
            alt.Tooltip('FPR:Q', format='.3f', title='FPR'),
            alt.Tooltip('TPR:Q', format='.3f', title='TPR (Sensitivity)'),
            alt.Tooltip('Threshold:Q', format='.3f', title='Threshold')
        ]
    ).add_params(
        alt.selection_point(on='mouseover', nearest=True, empty=False)
    )
    
    # Create optimal point (closest to top-left corner)
    optimal_idx = np.argmax(roc_df['TPR'] - roc_df['FPR'])
    optimal_point = alt.Chart(pd.DataFrame({
        'FPR': [roc_df.iloc[optimal_idx]['FPR']],
        'TPR': [roc_df.iloc[optimal_idx]['TPR']],
        'Label': [f'Optimal (FPR={roc_df.iloc[optimal_idx]["FPR"]:.3f}, TPR={roc_df.iloc[optimal_idx]["TPR"]:.3f})']
    })).mark_point(
        size=150,
        color='red',
        filled=True
    ).encode(
        x='FPR:Q',
        y='TPR:Q',
        tooltip=['Label:N']
    )
    
    # Combine line, points, and optimal point
    roc_chart = roc_line + points
    
    # Add diagonal reference line if requested
    if include_diagonal:
        diagonal_df = pd.DataFrame({'x': [0, 1], 'y': [0, 1]})
        
        diagonal = alt.Chart(diagonal_df).mark_line(
            strokeDash=[5, 5],
            color='gray',
            opacity=0.5,
            strokeWidth=1.5
        ).encode(
            x=alt.X('x:Q', scale=alt.Scale(domain=[0, 1])),
            y=alt.Y('y:Q', scale=alt.Scale(domain=[0, 1]))
        )
        
        # Add annotation for random classifier
        random_text = alt.Chart(pd.DataFrame({
            'x': [0.5],
            'y': [0.48],
            'text': ['']
        })).mark_text(
            angle=315,
            fontSize=11,
            color='gray',
            opacity=0.7
        ).encode(
            x='x:Q',
            y='y:Q',
            text='text:N'
        )
        
        final_chart = diagonal + random_text + roc_chart
    else:
        final_chart = roc_chart
    
    # Add AUC text annotation
    auc_text = alt.Chart(pd.DataFrame({
        'x': [0.95],
        'y': [0.05],
        'text': [f'AUC = {auc_value:.4f}']
    })).mark_text(
        align='right',
        baseline='top',
        fontSize=24,
        fontWeight='bold',
        color='black'
    ).encode(
        x=alt.X('x:Q', scale=alt.Scale(domain=[0, 1])),
        y=alt.Y('y:Q', scale=alt.Scale(domain=[0, 1])),
        text='text:N'
    )
    
    final_chart = final_chart + auc_text
    
    # Set title
    if title is None:
        title = f'ROC Curve: {prediction_column} vs {true_label_column} (Ground Truth)'
    
    # Configure and finalize the chart
    final_chart = final_chart.properties(
        width=width,
        height=height,
        title={
            'text': title,
            'fontSize': 16,
            'fontWeight': 'bold'
        }
    ).configure_axis(
        gridOpacity=0.25,
        labelFontSize=12,
        titleFontSize=13,
        titleFontWeight='normal'
    ).configure_view(
        strokeWidth=1,
        stroke='lightgray'
    )
    
    return final_chart

In [None]:
def create_roc_chart(df: pd.DataFrame,
                    show_metrics: bool = True) -> alt.Chart:
    """
    Simple function to create ROC curve for SGE vs Germline.
    
    Parameters:
    -----------
    df : DataFrame
        Your data with SGE Num and Germline Num columns
    show_metrics : bool
        Whether to print performance metrics
    
    Returns:
    --------
    Altair chart
    """
    return plot_roc_altair(
        df=df,
        prediction_column='SGE Num',
        true_label_column='Germline Num',
        model_name='SGE Classifier',
        show_metrics=show_metrics
    )

In [None]:
def main():
    alt.data_transformers.disable_max_rows()
    clin_data, clin_dels_data = read_data(file, clinvar_dels)
    print('ClinVar Variants: ', len(clin_data))
    with_base = get_base_changes(clin_data)
    sge_data = prep_sge(sge, sge_dels, thresholds)
    clinvar_data = merge(with_base, clin_dels_data, sge_data,path_max, benign_min)
    germ_relabeled = rename_germline(clinvar_data)
    df_pr= pr_df(clinvar_data)

    #df_pr.to_excel('/Users/ivan/Desktop/test_excel_outputs/20250912_BARD1vClinVar.xlsx', index = False)
    concor_stats(df_pr)
    final_histogram, final_stripplot = histogram(clinvar_data, path_max, benign_min)

    roc = create_roc_chart(df_pr, show_metrics = True)


    roc.display()
    #roc.save('/Users/ivan/Desktop/BARD1_draft_figs/fig_4b_ClinVar_ROC.png', ppi = 500)
    #final_stripplot.save('/Users/ivan/Desktop/BARD1_draft_figs/fig_4b_clinvar_stripplot_.png', ppi = 500)

In [None]:
main()