In [11]:
# Imports
import numpy as np
import pandas as pd
import os

from matplotlib.colors import LinearSegmentedColormap, Normalize

# Specific imports from ccf_medication modules
from ccf_medication.plotting.heatmaps import render_grid

# Constants
from ccf_medication.constants.pathing import (
    ACTIVE_PATIENT_TABLE_PATH,
    REMISSION_PATIENT_TABLE_PATH,
    MED_VS_NO_MED_SIGNIF_TABLE_PATH,
    RESULTS_DIR,
    REM_VS_ACT_SIGNIF_TABLE_PATH,
    RESPONSE_TABLE_PATH,
)

from ccf_medication.constants.thresholds import (
    MIN_PATIENTS_PER_DRUG_FAMILY,
    MIN_PATIENTS_PER_SEVERITY_GROUP,
    MIN_PATIENTS_PER_SEVERITY_GROUP,
)

from ccf_medication.constants.plotting import (
    BASE_COLOR,
    REMISSION_COLOR,
    ACTIVE_COLOR,
    PLOTTING_DRUG_FAMILIES_MAP,
)



In [12]:
paper_results_dir = os.path.join(RESULTS_DIR, 'paper_plots')
counts_path = os.path.join(paper_results_dir, 'count_heatmaps/')
os.makedirs(counts_path, exist_ok=True)

# Universal Utils

In [13]:
def table_universal_formatting(table, use_plotting_drug_families=True):

    table = table.rename(columns={'all': ''}, level=2)

    table = table.rename(columns={'small_intestine': 'Small Intestine'}, level=2)

    table = table.rename(columns={'colon': 'Colon'}, level=2)

    if use_plotting_drug_families == True:
        table.index = table.index.map(PLOTTING_DRUG_FAMILIES_MAP)

    if 'Calcineurin inhibitors' in table.index:
        table.drop(index=['Calcineurin inhibitors'], inplace=True)


    #alphabetize by drug family
    table = table.sort_index(axis=0)

    table = table.fillna(0)

    return table

def easy_render_heatmap(table, title, path, show_bottom_labels=True, top_color_hex=None, bottom_color_hex="#ffffff", max_value=None, colorbar_label="Sample count"):
    if top_color_hex is None:
        top_color_hex = BASE_COLOR

    if max_value is None:
        max_value = np.nanmax(table.values.flatten())

    cmap = LinearSegmentedColormap.from_list("white_to_remission", [bottom_color_hex, top_color_hex])
    norm = Normalize(vmin=0, vmax=max_value)

    render_grid(table, 
                title, 
                path, 
                show_bottom_labels=show_bottom_labels, 
                norm=norm, 
                cmap=cmap, 
                colorbar_label=colorbar_label)

In [14]:
def correct_fill_values(results_df, 
                        remission_df=None,
                        active_df=None,
                        rem_patient_thresh=None,
                        active_patient_thresh=None):
    """ 
    Correct the Nan's or Zeros in the results_df based on the remission_df and active_df.
    For instance, if the results_df is the the num of signif genes from the disease severity analysis,
    then any results_df should be nan if the remission_df or active_df has less than rem_patient_thresh or active_patient_thresh patients,
    but if it meets both thresholds and the value is 0 or nan, then it should be 0. """

    # fill nan with 0
    results_df = results_df.fillna(0)

    if remission_df is not None:
        if active_df is not None:
            nan_mask = (remission_df < rem_patient_thresh) | (active_df < active_patient_thresh)
        else:
            nan_mask = (remission_df < rem_patient_thresh)
        results_df[nan_mask] = pd.NA

    else:
        raise ValueError("remission_df is required")
    
    return results_df

In [15]:
def complete_heatmap_workflow(table_path, title, save_path, show_bottom_labels=True, top_color_hex=None, bottom_color_hex="#ffffff", max_value=None,  colorbar_label="Sample count", use_plotting_drug_families=True, remission_path=None, active_path=None, rem_patient_thresh=None, active_patient_thresh=None):
    table = pd.read_parquet(table_path)
    table = table_universal_formatting(table, use_plotting_drug_families)

    if remission_path is not None:
        remission_df = pd.read_parquet(remission_path)
        remission_df = table_universal_formatting(remission_df, use_plotting_drug_families)
        if active_path is not None:
            active_df = pd.read_parquet(active_path)
            active_df = table_universal_formatting(active_df, use_plotting_drug_families)
        else:
            active_df = None
        table = correct_fill_values(table, remission_df, active_df, rem_patient_thresh, active_patient_thresh)

    easy_render_heatmap(table, title, save_path, show_bottom_labels, top_color_hex, bottom_color_hex, max_value, colorbar_label)

# Patient Sample Counts

### Active Patient Table

In [16]:
complete_heatmap_workflow(table_path = ACTIVE_PATIENT_TABLE_PATH,
                          title = "Number of Active Patients", 
                          save_path = os.path.join(counts_path, 'active_counts.png'),
                          top_color_hex=ACTIVE_COLOR, 
                          colorbar_label="Number of Active Patients")

### Remission Patient Table

In [17]:
complete_heatmap_workflow(table_path = REMISSION_PATIENT_TABLE_PATH,
                          title = "Number of Remission Patients", 
                          save_path = os.path.join(counts_path, 'remission_counts.png'),
                          top_color_hex=REMISSION_COLOR,
                          colorbar_label="Number of Remission Patients")

# 

# Med vs No Med Signifcant Gene Counts

In [18]:
complete_heatmap_workflow(table_path = MED_VS_NO_MED_SIGNIF_TABLE_PATH,
                          save_path = os.path.join(counts_path, 'med_vs_no_med_signif_counts.png'),
                          title = "Significant Transcripts/Proteins per Medication\n(Remission - Med vs. No Med) ", 
                          top_color_hex=BASE_COLOR,
                          remission_path=REMISSION_PATIENT_TABLE_PATH,
                          active_path=None,
                          rem_patient_thresh=MIN_PATIENTS_PER_DRUG_FAMILY,
                          active_patient_thresh=None,
                          colorbar_label="Number of Transcripts/Proteins")

# Disease Severity Significant Gene Counts

In [19]:
complete_heatmap_workflow(table_path = REM_VS_ACT_SIGNIF_TABLE_PATH,
                          save_path = os.path.join(counts_path, 'rem_vs_act_signif_counts.png'),
                          title = "Significant Transcripts/Proteins per Disease Severity\n (Active vs. Remission)", 
                          top_color_hex=BASE_COLOR,
                          remission_path=REMISSION_PATIENT_TABLE_PATH,
                          active_path=ACTIVE_PATIENT_TABLE_PATH,
                          rem_patient_thresh=MIN_PATIENTS_PER_SEVERITY_GROUP,
                          active_patient_thresh=MIN_PATIENTS_PER_SEVERITY_GROUP,
                          colorbar_label="Number of Transcripts/Proteins")

# Response Gene Counts

In [20]:
complete_heatmap_workflow(table_path = RESPONSE_TABLE_PATH,
                          save_path = os.path.join(counts_path, 'response_signif_counts.png'),
                          title = "Number of Medication Response Biomarkers", 
                          top_color_hex=BASE_COLOR,
                          remission_path=REMISSION_PATIENT_TABLE_PATH,
                          active_path=ACTIVE_PATIENT_TABLE_PATH,
                          rem_patient_thresh=MIN_PATIENTS_PER_SEVERITY_GROUP,
                          active_patient_thresh=MIN_PATIENTS_PER_SEVERITY_GROUP,
                          colorbar_label="Number of Biomarkers")