In [None]:
from functools import partial
from importlib import reload
from os import path as osp

import ipywidgets as widgets
import pandas as pd
from ast import literal_eval
from IPython.display import display

import src.multi_label_plotting
import src.single_label_plotting
import src.results_viz

In [None]:
RESULTS_DIR = "./results_with_annotations"
RESULTS_FILES = {
    'vep_zero_shot_causal_eqtl': 'cleaned_annotated_combined_zero_shot_scores_labels_variant_effect_causal_eqtl.csv',
    'vep_zero_shot_pathogenic_clinvar': 'cleaned_annotated_combined_zero_shot_scores_labels_variant_effect_pathogenic_clinvar.csv',
    'vep_zero_shot_pathogenic_omim': 'cleaned_annotated_combined_zero_shot_scores_labels_variant_effect_pathogenic_omim.csv',
    'vep_finetune_causal_eqtl': 'cleaned_annotated_combined_predictions_labels_variant_effect_causal_eqtl.csv',
    'vep_finetune_pathogenic_clinvar': 'cleaned_annotated_combined_predictions_labels_variant_effect_pathogenic_clinvar.csv',
    'chromatin_features_dna_accessibility': 'cleaned_annotated_combined_predictions_labels_chromatin_features_dna_accessibility.csv',
    'chromatin_features_histone_marks': 'cleaned_annotated_combined_predictions_labels_chromatin_features_histone_marks.csv',
    'regulatory_element_promoters': 'cleaned_annotated_combined_predictions_labels_regulatory_element_promoter.csv',
    'regulatory_element_enhancers': 'cleaned_annotated_combined_predictions_labels_regulatory_element_enhancer.csv',
    'bulk_rna_expression': 'cleaned_annotated_combined_predictions_labels_bulk_rna_expression.csv'
}
TASKS = list(RESULTS_FILES.keys())

### Choose task using dropdown and (re-)run cells below

In [None]:
task_dropdown = widgets.Dropdown(
    options=TASKS,
    description='Task:'
)
display(task_dropdown)

In [None]:
results_df_w_annotations = pd.read_csv(
    osp.join(RESULTS_DIR, RESULTS_FILES[task_dropdown.value]), index_col=0, low_memory=False
)

In [None]:
base_models = None
plot_fxn, plot_by_tss_dist_fxn, plot_by_annotation_fxn, plot_by_maf_fxn = None, None, None, None
reload(src.results_viz)
if 'vep' in task_dropdown.value or 'regulatory_element' in task_dropdown.value:
    reload(src.single_label_plotting)
    from src.single_label_plotting import plot_aucroc_auprc as plot_fxn    
    from src.single_label_plotting import plot_aucroc_auprc_by_annotation as plot_by_annotation_fxn
    
    if 'regulatory_element' not in task_dropdown.value and 'eqtl' in task_dropdown.value:
        from src.single_label_plotting import plot_aucroc_auprc_by_bucket as plot_by_tss_dist_fxn
    if 'zero_shot' in task_dropdown.value:
        plot_fxn = partial(plot_fxn, pred_col='Score')
        if plot_by_annotation_fxn is not None:
            plot_by_annotation_fxn = partial(plot_by_annotation_fxn, pred_col='Score',)
        if plot_by_tss_dist_fxn is not None:
            plot_by_tss_dist_fxn = partial(plot_by_tss_dist_fxn,
                                           pred_col='Score', bucket_col='distance_to_nearest_TSS', bucket_display_str='Distance to TSS')
        base_models = ('CADD', 'PhyloP')
    else:
        base_models = ('Enformer',)

        
elif 'bulk_rna_expression' == task_dropdown.value: 
    base_models = ('Enformer',)
    from src.multi_label_plotting import plot_r2 as plot_fxn
    from src.multi_label_plotting import plot_r2_by_bucket as plot_by_tss_dist_fxn
    from src.multi_label_plotting import plot_r2_by_annotation as plot_by_annotation_fxn

    
elif 'chromatin_features' in task_dropdown.value:
    reload(src.multi_label_plotting)
    base_models = ('Deep Sea',)
    from src.multi_label_plotting import plot_aucroc_auprc as plot_fxn
    from src.multi_label_plotting import plot_aucroc_auprc_by_annotation as plot_by_annotation_fxn
else:
    plot_fxn, plot_by_tss_dist_fxn, plot_by_annotation_fxn = None, None, None
    raise NotImplementedError(f'Plotting for task `{task_dropdown.value}` not implemented!')

def is_annotation(column_name):
    return (
        column_name not in ['chromosome', 'position', 'start', 'stop', 'model', 'annotations', 
                            'tissues', 'Score', 'REF', 'ALT', 'split', 'distance_to_nearest_TSS', 'SNP', 'tissue',
                            'SOURCE', 'CONSEQUENCE', 'ID', 'REVIEW_STATUS', 'GENOMIC_MUTATION_ID', 'N_SAMPLES', 'TOTAL_SAMPLES', 'FREQ',
                             'OMIM', 'GENE', 'PMID', 'AC', 'AN', 'AF', 'MAF', 'MAC', 'INT_LABEL']
        and 'probability' not in column_name
        and 'label' not in column_name
        and 'prediction' not in column_name
        and 'dist' not in column_name
    )
# Custom annotations
annotations = [
    c for c in results_df_w_annotations.columns
    if is_annotation(c)
]

from src.results_viz import ResultsViz

In [None]:
ResultsViz(df=results_df_w_annotations,
           base_models=base_models,
           annotations=sorted(annotations),
           models=sorted(results_df_w_annotations.model.unique()),
           plot_fxn=plot_fxn,
           plot_by_annotation_fxn=plot_by_annotation_fxn,
           plot_by_tss_dist_fxn=plot_by_tss_dist_fxn,
           distance_to='enhancer' if task_dropdown.value == 'bulk_rna_expression' else 'tss');