In [22]:
import os
import numpy as np
import pandas as pd
import xgboost as xgb
import joblib
import sys
sys.path.append('../codes')
import utils
import importlib
import process_features
importlib.reload(utils)
import matplotlib

import shap
import gradio as gr
import io
from collections import defaultdict

In [23]:
# load model
xgb_onconpc = xgb.Booster()
xgb_onconpc.load_model('../model/xgboost_model.json')

# Specify cancer types to consider
cancer_types_to_consider = ['Acute Myeloid Leukemia', 'Bladder Urothelial Carcinoma', 'Cholangiocarcinoma',
                            'Colorectal Adenocarcinoma', 'Diffuse Glioma', 'Endometrial Carcinoma',
                            'Esophagogastric Adenocarcinoma', 'Gastrointestinal Neuroendocrine Tumors', 'Gastrointestinal Stromal Tumor',
                            'Head and Neck Squamous Cell Carcinoma', 'Invasive Breast Carcinoma', 'Melanoma', 'Meningothelial Tumor',
                            'Non-Hodgkin Lymphoma', 'Non-Small Cell Lung Cancer', 'Ovarian Epithelial Tumor', 'Pancreatic Adenocarcinoma',
                            'Pancreatic Neuroendocrine Tumor', 'Pleural Mesothelioma', 'Prostate Adenocarcinoma', 'Renal Cell Carcinoma',
                            'Well-Differentiated Thyroid Cancer']


onconpc_processed_cups_df = pd.read_csv('../data/onconpc_processed_cups_data.csv')
all_features = onconpc_processed_cups_df.columns.to_list()

In [24]:
def parse_inputs(age, gender, CNA_events, mutations):
    age = age #TODO check how this is normalized 
    
    gender = 1 if gender == 'male' else -1 

    # string parsing for cna input 
    if len(CNA_events) > 0:
        CNA_events = CNA_events.split('|')
        for i in range(len(CNA_events)):
            CNA_events[i] = CNA_events[i].split()
            CNA_events[i][1] = int(CNA_events[i][1]) # cast val to integer, CNA_events can take on vals -2, -1, 0, 1, 2
    else:
        CNA_events = []

    if len(mutations) > 0:
        mutations = mutations.split('|')
        for i in range(len(mutations)):
            mutations[i] = mutations[i].split(', ')
    else:
        mutations = []

    return age, gender, CNA_events, mutations

In [25]:
import deconstruct_sigs_from_user_input as deconstructSigs

def get_mut_signatures(mutations): 
    mutation_columns = ["UNIQUE_SAMPLE_ID", "CHROMOSOME", "POSITION", "REF_ALLELE", "ALT_ALLELE"]
    if mutations:
        mutation_df = pd.DataFrame(mutations, columns=mutation_columns)
    else:
        mutation_df = pd.DataFrame(columns=mutation_columns)
    file = mutation_df.to_csv('./mutation_input.csv', index=False)

    base_sub_file = deconstructSigs.get_base_substitutions() 
    df_trinuc_feats = pd.read_csv(base_sub_file) 
    mutation_signatures = process_features.obtain_mutation_signatures(df_trinuc_feats)
    return mutation_signatures

In [26]:
def get_top3_min_info(predictions):
    
    result = predictions.iloc[0]
    transposed_row = result.transpose()
    transposed_row = transposed_row.drop('cancer_type')
    transposed_row = transposed_row.drop('max_posterior')
    transposed_row = transposed_row.to_frame()
    transposed_row.columns = ['probability']    
    transposed_row['probability'] = pd.to_numeric(transposed_row['probability'], errors='coerce')   
    top3df = transposed_row.nlargest(3, columns=['probability']) # <---- top 3 cancer predictions and their probabilities 
    top3 = transposed_row.nlargest(3, columns=['probability']).index.tolist() # <---- top 3 cancer predictions and their probabilities
    top3probs = transposed_row.nlargest(3, columns=['probability'])['probability'].tolist()
    
    build = []
    for cancer,prob in zip(top3, top3probs):
        build.append([cancer, prob])
    return build

In [27]:
def top3_str(top3_list):
    build = ''
    for cancer, prob in top3_list:
        build += cancer + ': ' + str(prob) + '\n'
    return build

In [28]:
def get_top3_plots(data, top3_list): 
    plots = []
    for cancer, prob in top3_list:
        plot = get_plot(data, cancer, prob)
        plots.append(plot)
    return plots 

In [29]:
def construct_input_data(age, gender, CNA_events, mutation_signatures):
    data = {}
    data['AGE'] = age #TODO check how this is normalized 
    
    for column in mutation_signatures.columns:
        data[column] = mutation_signatures.iloc[0][column]
    
    for CNA_event in CNA_events:
        CNA, val = CNA_event[0], CNA_event[1]
        data[CNA] = val
        
    for column in all_features:
        if column not in data.keys():
            data[column] = 0

    return data

In [30]:
def get_plot(features, target, probability):
    '''
    features: df containing features (X inputs)
    predictions: xgboost predictions 
    '''

    shaps_cup = utils.obtain_shap_values_with_latest_xgboost(xgb_onconpc, features) # get shap values 
    

    target_idx = cancer_types_to_consider.index(target) # index of cancer type prediction 
    
    # Get SHAP-based explanation for the prediction
    feature_sample_df = features.iloc[0] # find the exact tumor sample we're predicting for 
    shap_pred_cancer_df = pd.DataFrame(shaps_cup[target_idx],
                                       index=features.index,
                                       columns=features.columns)
    shap_pred_sample_df = shap_pred_cancer_df.iloc[0]
    
    # Generate explanation plot
    sample_info = f'Prediction: {target}\nPrediction probability: {probability:.3f}'
    feature_group_to_features_dict = utils.partiton_feature_names_by_group(features.columns)
    fig = utils.get_individual_pred_interpretation(shap_pred_sample_df, feature_sample_df, feature_group_to_features_dict, sample_info=sample_info, filename=f'{target}_plot.png')
    return fig

In [None]:
def get_preds_min_info(age, gender, CNA_events, mutations, output):
    
    age, gender, CNA_events, mutations = parse_inputs(age, gender, CNA_events, mutations)
    global images
    global current_image_index
    global data
    global predictions
    current_image_index = 0

    mutation_signatures = get_mut_signatures(mutations)

    data = pd.DataFrame([construct_input_data(age, gender, CNA_events, mutation_signatures)])    
    predictions = pd.DataFrame(utils.get_xgboost_latest_cancer_type_preds(xgb_onconpc, data, cancer_types_to_consider)) # make predictions 

    shaps_cup = utils.obtain_shap_values_with_latest_xgboost(xgb_onconpc, data) # get shap values 
    top3_pred_list = get_top3_min_info(predictions)

    if output == 'Top Prediction':
        pred_prob = predictions.iloc[0]['max_posterior']
        pred_cancer = predictions.iloc[0]['cancer_type']
        images = [get_plot(data, pred_cancer, pred_prob)]
        selected_image = images[current_image_index]
        return top3_str(top3_pred_list), selected_image

    elif output == 'Top 3 Predictions':
        images = get_top3_plots(data, top3_pred_list)
        selected_image = images[current_image_index]
        return top3_str(top3_pred_list), selected_image

    else:
        images = [get_plot(data, output, predictions.iloc[0][output])]
        selected_image = images[current_image_index]
        return top3_str(top3_pred_list), selected_image

In [None]:
import gradio as gr

output_graph_options = cancer_types_to_consider + ['Top 3 Predictions', 'Top Prediction'] 

def change_image(direction):
    global current_image_index, images
    if direction == "Next":
        current_image_index = (current_image_index + 1) % len(images)
    elif direction == "Previous":
        current_image_index = (current_image_index - 1) % len(images)
    return gr.Image(images[current_image_index])


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            age = gr.Number(label="Age")
            gender = gr.Radio(choices=["Male", "Female"], label="Gender")
            cna_events = gr.Textbox(lines=2, placeholder="Enter CNA events...", label="Genes with CNA Events (comma-separated)")
            mutations = gr.Textbox(lines=5, placeholder="Enter mutations...", label="MUTATIONS")
            output_selector = gr.Dropdown(choices=output_graph_options, label="Output Options", filterable=True, multiselect=False)
            submit_button = gr.Button("Submit")

        with gr.Column():
            predictions_output = gr.Textbox(label="Top 3 Predicted Cancer Types")
            image = gr.Image(label="Image Display") 
            with gr.Row():
                prev_button = gr.Button("Previous")
                next_button = gr.Button("Next")

    prev_direction = gr.Textbox(visible=False, value="Previous")
    next_direction = gr.Textbox(visible=False, value="Next")
    
    submit_button.click(get_preds_min_info, inputs=[age, gender, cna_events, mutations, output_selector], outputs=[predictions_output, image])
    prev_button.click(change_image, inputs=prev_direction, outputs=image)
    next_button.click(change_image, inputs=next_direction, outputs=image)

demo.launch(debug=True, share=True)

Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://d702b9674a762cb8d5.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


In [21]:
# CNA events: KCNQ1 2 | BRAF -1 | SLX1B 1 | CBLB -2
# mutations: GENIE-MSK-P-0000015-T01-IM3, chr17, 7577539, G, A | GENIE-MSK-P-0000015-T01-IM3, chr3, 178936091, G, A | GENIE-MSK-P-0000015-T01-IM3, chr6, 152419920, T, A