In [1]:
import warnings
import pandas as pd
import xgboost as xgb
import sys
sys.path.append('../codes')
import utils as utils

Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)


In [2]:
xgb_onconpc = xgb.Booster()
xgb_onconpc.load_model('../model/xgboost_v1.7.6_OncoNPC_full.json')
cancer_types_to_consider = ['Acute Myeloid Leukemia', 'Bladder Urothelial Carcinoma', 'Cholangiocarcinoma',
                            'Colorectal Adenocarcinoma', 'Diffuse Glioma', 'Endometrial Carcinoma',
                            'Esophagogastric Adenocarcinoma', 'Gastrointestinal Neuroendocrine Tumors', 'Gastrointestinal Stromal Tumor',
                            'Head and Neck Squamous Cell Carcinoma',
                            'Invasive Breast Carcinoma', 'Melanoma', 'Meningothelial Tumor',
                            'Non-Hodgkin Lymphoma', 'Non-Small Cell Lung Cancer', 'Ovarian Epithelial Tumor', 'Pancreatic Adenocarcinoma',
                            'Pancreatic Neuroendocrine Tumor', 'Pleural Mesothelioma', 'Prostate Adenocarcinoma', 'Renal Cell Carcinoma',
                            'Well-Differentiated Thyroid Cancer']   

In [3]:
def get_preds(patients_file, samples_file, mutations_file, cna_file, tumor_id):

    """
    Generates predictions and explanations for given tumor samples using OncoNPC model.

    This function processes patient, sample, mutation, and CNA data to predict primary sites of 
    Cancer of Unknown Primary (CUP) tumors. It also provides a bar chart of SHAP values to explain the predictions.

    Args:
        patients_file: A csv file object representing patient data.
        samples_file: A csv file object representing sample data.
        mutations_file: A csv file object representing mutation data.
        cna_file: A csv file object representing CNA (Copy Number Alterations) data.
        tumor_id: The ID of the tumor.

    Returns:
        A tuple containing:
            A string containing the top 3 most probable cancers along with their predicted probabilities. 
            The filepath to the SHAP value bar chart explaining the prediction for the given tumor ID.
    """
    
    # convert files to data frames
    patients_df = pd.read_csv(patients_file.name, sep='\t')
    samples_df = pd.read_csv(samples_file.name, sep='\t')
    mutations_df = pd.read_csv(mutations_file.name, sep='\t')
    cna_df = pd.read_csv(cna_file.name, sep='\t') 


    # declared as global variables to generate plots in update_image function
    global sample_id 
    global features
    global preds_df

    # get features and labels for OncoNPC predictive inference
    df_features_genie_final, df_labels_genie = utils.get_onconpc_features_from_raw_data(
        patients_df,
        samples_df,
        mutations_df,
        cna_df,
        features_onconpc_path='../data/features_onconpc.pkl',
        combined_cohort_age_stats_path='../data/combined_cohort_age_stats.pkl',
        mut_sig_weights_filepath='../data/mutation_signatures/sigProfiler*.csv'
    )

    sample_id = tumor_id
    features = df_features_genie_final

    # load fully trained OncoNPC model
    xgb_onconpc = xgb.Booster()
    xgb_onconpc.load_model('../model/xgboost_v1.7.6_OncoNPC_full.json')
    
    # specify cancer types to consider
    cancer_types_to_consider = ['Acute Myeloid Leukemia', 'Bladder Urothelial Carcinoma', 'Cholangiocarcinoma',
                                'Colorectal Adenocarcinoma', 'Diffuse Glioma', 'Endometrial Carcinoma',
                                'Esophagogastric Adenocarcinoma', 'Gastrointestinal Neuroendocrine Tumors', 'Gastrointestinal Stromal Tumor',
                                'Head and Neck Squamous Cell Carcinoma', 'Invasive Breast Carcinoma', 'Melanoma', 'Meningothelial Tumor',
                                'Non-Hodgkin Lymphoma', 'Non-Small Cell Lung Cancer', 'Ovarian Epithelial Tumor', 'Pancreatic Adenocarcinoma',
                                'Pancreatic Neuroendocrine Tumor', 'Pleural Mesothelioma', 'Prostate Adenocarcinoma', 'Renal Cell Carcinoma',
                                'Well-Differentiated Thyroid Cancer']
    
    # predict primary sites of CUP tumors
    preds_df = utils.get_xgboost_latest_cancer_type_preds(xgb_onconpc,
                                                          df_features_genie_final,
                                                          cancer_types_to_consider)

    # get SHAP values for CUP tumors
    warnings.filterwarnings('ignore')
    shaps = utils.obtain_shap_values_with_latest_xgboost(xgb_onconpc, df_features_genie_final)
    

    query_ids = list(samples_df.SAMPLE_ID.values)

    # results is structured such that:
    # results_dict[query_id] = {'pred_prob': pred_prob,'pred_cancer': pred_cancer,'explanation_plot': full_filename}
    results = utils.get_onconpc_prediction_explanations(query_ids, preds_df, shaps,
                                                        df_features_genie_final,
                                                        cancer_types_to_consider,
                                                        save_plot=True) 

    return get_top3(preds_df, tumor_id), results[tumor_id]['explanation_plot'] + '.png' 

In [4]:
def get_top3(predictions, tumor_sample_id):
    """
    Extracts and formats the top three cancer type predictions for a given tumor sample.

    Args:
        predictions: A DataFrame containing the cancer type predictions for various samples.
        tumor_sample_id: The ID of the tumor sample for which to extract the top three predictions.

    Returns:
        A string that lists the top three predicted cancer types and their probabilities.
    """
    # select the row corresponding to the tumor sample ID
    result = predictions.loc[tumor_sample_id]

    # transpose the row for easier processing, each row has columns cancer type, cancer probability 
    transposed_row = result.transpose()

    # remove unnecessary rows
    transposed_row = transposed_row.drop(['cancer_type', 'max_posterior'])

    # convert the series to a DataFrame and rename the column
    transposed_row = transposed_row.to_frame()
    transposed_row.columns = ['probability']

    # make sure the probability column is numeric
    transposed_row['probability'] = pd.to_numeric(transposed_row['probability'], errors='coerce')

    # get the top 3 predictions and their probabilities
    top3df = transposed_row.nlargest(3, columns=['probability'])
    top3 = top3df.index.tolist() # cancer types are indices
    top3probs = top3df['probability'].tolist()

    # build a formatted string with the top 3 predictions
    build = ''
    for cancer, prob in zip(top3, top3probs):
        build += f'{cancer}: {prob:.2f}\n'
    build = build.rstrip('\n')

    return build

In [5]:
import gradio as gr

global image # path to explanation plot, defined as global for the purposes of update 
global features


def update_image(target):
    global image
    global features
    
    shaps_cup = utils.obtain_shap_values_with_latest_xgboost(xgb_onconpc, features) # get shap values 

    target_idx = cancer_types_to_consider.index(target) # index of cancer type prediction 
    
    # Get SHAP-based explanation for the prediction
    feature_sample_df = features.loc[sample_id] # find the exact tumor sample we're predicting for 
    shap_pred_cancer_df = pd.DataFrame(shaps_cup[target_idx],
                                       index=features.index,
                                       columns=features.columns)
    shap_pred_sample_df = shap_pred_cancer_df.loc[sample_id]
    probability = preds_df.loc[sample_id][target]
    
    # Generate explanation plot
    sample_info = f'Prediction: {target}\nPrediction probability: {probability:.3f}'
    feature_group_to_features_dict, feature_to_feature_group_dict = utils.partition_feature_names_by_group(features.columns)
    print('hi im here')
    fig = utils.get_individual_pred_interpretation(shap_pred_sample_df, feature_sample_df, feature_group_to_features_dict, feature_to_feature_group_dict,sample_info=sample_info, filename=f'{target}_plot.png', filepath='../others_prediction_explanation', save_plot=True)
    return fig


with gr.Blocks() as demo:

    with gr.Row():
        with gr.Column():
            patients_file = gr.File(label="Upload clinical patients data")
            samples_file = gr.File(label="Upload clinical samples data")
            mutations_file = gr.File(label="Upload mutations data")
            cna_file = gr.File(label="Upload CNA data")
            tumor_sample_id = gr.Textbox(label="Tumor Sample ID")  # Use the value of the state
            submit_button = gr.Button("Submit")

        with gr.Column():
            predictions_output = gr.Textbox(label="Top 3 Predicted Cancer Types")
            image = gr.Image(label="Image Display")
            output_selector = gr.Dropdown(choices=[], label="Output Options", filterable=True)

    submit_button.click(get_preds, inputs=[patients_file, samples_file, mutations_file, cna_file, tumor_sample_id], outputs=[predictions_output, image])
    output_selector.change(update_image, inputs=output_selector, outputs=image)

def launch_app():
    demo.launch(debug=True, share=True)

In [6]:
launch_app()

Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://04aad362cf29ad4461.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://04aad362cf29ad4461.gradio.live



R[write to console]: In mut.to.sigs.input(mut.ref = mutationData, sample.id = sample_id,  :
R[write to console]: 
 
R[write to console]:  Some samples have fewer than 50 mutations:
  GENIE-S-001, GENIE-S-002, GENIE-S-003

