In [None]:
import pandas as pd
import random
from scipy.stats import pearsonr



def get_results_and_drug_vis(algorithm, sample, test_mode, normalized):

    t_vs_p_path = None
    if test_mode == "LPO":
        t_vs_p_path = 'to_the_universe_new/true_vs_pred.csv'
    elif test_mode == "LCO":
        t_vs_p_path = 'to_the_stars/true_vs_pred.csv'
    elif test_mode == "LDO":
        t_vs_p_path = 'ldo_results/true_vs_pred.csv'
    else:
        raise ValueError("Invalid test mode")
    results = pd.read_csv(t_vs_p_path, dtype={"drug": str, "cell_line": str, "CV_split": int})
    results = results[results['rand_setting'] =="predictions"]
    results = results[results['algorithm'] ==algorithm]
    # Load drug name mapping
    pubchem_id_to_drugname = pd.read_csv('../data/CTRPv2/drug_names.csv')


    if sample == "picked":
        viz_drugs = ["bafilomycin A1", "chlorambucil", "NSC 74859","hyperforin", "Docetaxel", "C6-ceramide", "obatoclax", "Entinostat", "etomoxir"]
        pubchem_ids = []
        for drug in viz_drugs:
            pubchem_ids.append(pubchem_id_to_drugname[pubchem_id_to_drugname['drug_name'] == drug]['pubchem_id'].values[0])
        viz_drugs = pubchem_ids
    elif sample == "top":
        viz_drugs = results['drug'].value_counts().index[:10]
    elif sample == "random":
        viz_drugs = results['drug'].unique()
        #choose  random drugs
        random.seed(42)
        viz_drugs = random.sample(list(viz_drugs), 12)
    elif sample =="lowest_pearson":

        # Compute Pearson correlation per drug
        pearsons = results.groupby('drug').apply(lambda g: pearsonr(g['y_true'], g['y_pred'])[0])

        # Drop NaNs (in case a drug has constant y_true or y_pred)
        pearsons = pearsons.dropna()

        # Select 10 drugs with lowest Pearson correlation
        viz_drugs = pearsons.nsmallest(50).index.tolist()
    else:
        raise ValueError("Invalid sample option")

    # Filter mapping to only include  viz drugs
    top_drug_names = pubchem_id_to_drugname[pubchem_id_to_drugname['pubchem_id'].isin(viz_drugs)]

    # Create mapping dict
    id_to_name = dict(zip(top_drug_names['pubchem_id'], top_drug_names['drug_name']))
    results['drug'] = results['drug'].astype(str)
    pubchem_id_to_drugname['pubchem_id'] = pubchem_id_to_drugname['pubchem_id'].astype(str)

    # Map drug IDs to names, fill with "Other" if not in 
    results["drug_name_viz"] = results['drug'].map(id_to_name).fillna("Other")
    if not normalized:
        return results
    else:

        # Load and filter
        results2 = pd.read_csv(t_vs_p_path, dtype={"drug": str, "cell_line": str, "CV_split": int}, index_col=0)
        results2 = results2[results2["rand_setting"] == "predictions"]

        # Ensure consistent types
        results2["drug"] = results2["drug"].astype(str)
        results2["cell_line"] = results2["cell_line"].astype(str)
        results2["CV_split"] = results2["CV_split"].astype(int)

        # Prepare naive predictions
        naive = results2[results2["algorithm"] == "NaiveMeanEffectsPredictor"].copy()
        naive["drug"] = naive["drug"].astype(str)
        naive["cell_line"] = naive["cell_line"].astype(str)
        naive["CV_split"] = naive["CV_split"].astype(int)

        naive = naive[["drug", "cell_line", "CV_split", "y_pred"]].rename(columns={"y_pred": "naive_y_pred"})

        # Merge
        merged = results2.merge(naive, on=["drug", "cell_line"], how="left")

        # Drop any rows where naive_y_pred is missing (just in case)
        merged = merged.dropna(subset=["naive_y_pred"])

        # Subtract naive prediction from y_pred and y_true
        merged["y_pred"] = merged["y_pred"] - merged["naive_y_pred"]
        merged["y_true"] = merged["y_true"] - merged["naive_y_pred"]

        merged = merged[merged['algorithm'] ==algorithm]


        # Map drug IDs to names, fill with "Other" if not in 
        merged["drug_name_viz"] = merged['drug'].map(id_to_name).fillna("Other")
        merged.rename(columns={"CV_split_x": "CV_split"})
        return merged
#results_lpo = get_results_and_drug_vis(algorithm=ALGORITHM, sample=SAMPLE, LCO=False, normalized=False)
#results_lpo_norm = get_results_and_drug_vis(algorithm=ALGORITHM, sample=SAMPLE, LCO=False, normalized=True)
#results_lco = get_results_and_drug_vis(algorithm=ALGORITHM, sample=SAMPLE, LCO=True, normalized=False)
#results_lco_norm = get_results_and_drug_vis(algorithm=ALGORITHM, sample=SAMPLE, LCO=True, normalized=True)



In [None]:
# cleaned up version of the code


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


def compute_variance_decomposition(y_true, y_pred_rf, y_pred_nm):
    tss = np.var(y_true, ddof=0)
    rss_rf = np.mean((y_true - y_pred_rf) ** 2)
    rss_nm = np.mean((y_true - y_pred_nm) ** 2)
    ess_rf = tss - rss_rf
    ess_nm = tss - rss_nm
    r2_rf = ess_rf / tss
    r2_nm = ess_nm / tss
    extra_ess = ess_rf - ess_nm
    relative_gain = extra_ess / (tss - ess_nm) if (tss - ess_nm) > 0 else 0
    unexplained = tss - ess_rf

    return {
        "tss": tss,
        "ess_nm": ess_nm,
        "ess_rf": ess_rf,
        "extra_ess": extra_ess,
        "unexplained": unexplained,
        "r2_nm": r2_nm,
        "r2_rf": r2_rf,
        "relative_gain": relative_gain
    }

def plot_variance_decomposition(ax, tss, ess_nm, extra_ess, unexplained, max_tss,
                                xlabel=None):
    font_adder = 10
    components = [ess_nm, extra_ess, unexplained]
    colors = ['#999999', '#00BFC4', '#F8766D']
    labels = ['Naive', 'RF gain', 'Unexplained']

    ax.bar(0, tss, color='white', edgecolor='black', width=0.6)
    bottom = 0


    for val, label, color in zip(components, labels, colors):
        height = val
        percent = (val / tss) * 100 if tss > 0 else 0
        ax.bar(0, height, bottom=bottom, label=label, color=color, width=0.6)

        text_color = 'black'

        label_text = f"{percent:.1f}%"
        label_pos = bottom + height / 2
        if height <= 0:
            continue
        if height > 0.03 * tss:

            # Normal in-bar label
            ax.text(0, label_pos, label_text,
                    ha='center', va='center',
                    fontsize=5 + font_adder, fontweight='bold', color=text_color)
        else:
            # Fallback: draw pointer and write just outside the right edge of the segment
            y_outside = bottom + height / 2
            ax.text(0.35, y_outside, label_text,
                    ha='left', va='center', fontsize=5 + font_adder, fontweight='bold', color='black')


        bottom += height
    


    ax.set_xlim(-0.7, 0.7)
    ax.set_ylim(0, max_tss * 1.1)

    # Draw short TSS tick in raw variance units
    ax.plot([-0.35, -0.25], [tss, tss], color='black', linewidth=1)
    ax.text(-0.4, tss, f"{tss:.2f}", ha='right', va='center',
            fontsize=8 + font_adder, color='black')

    ax.set_xticks([0])
    ax.set_xticklabels([xlabel] if xlabel else [''])
    ax.set_ylabel("Variance", fontsize=10+font_adder)
    ax.tick_params(axis='both', labelsize=9+font_adder)

def get_model_preds_and_merge(rf_results, nm_results, merge_cols, pred_col_rf='y_pred_rf', pred_col_nm='y_pred_nm'):
    try:
        rf = rf_results[rf_results.CV_split == 0].copy()
        nm = nm_results[nm_results.CV_split == 0].copy()
    except AttributeError:
        rf = rf_results[rf_results.CV_split_x == 0].copy()
        nm = nm_results[nm_results.CV_split_x == 0].copy()

    rf = rf.rename(columns={'y_pred': pred_col_rf})
    nm = nm.rename(columns={'y_pred': pred_col_nm})

    merged = rf.merge(nm[merge_cols + [pred_col_nm]], on=merge_cols, how='inner')
    return merged
def compute_splitwise_mean_decomposition(rf_df, nm_df):
    if 'CV_split' in rf_df.columns and 'CV_split' in nm_df.columns:
        split_col = 'CV_split'
    elif 'CV_split_x' in rf_df.columns and 'CV_split_x' in nm_df.columns:
        split_col = 'CV_split_x'
    else:
        raise ValueError("No CV_split column found.")

    split_ids = sorted(set(rf_df[split_col]) & set(nm_df[split_col]))
    metrics_list = []

    for split in split_ids:
        rf = rf_df[rf_df[split_col] == split].copy()
        nm = nm_df[nm_df[split_col] == split].copy()

        merged = rf.merge(
            nm[['drug', 'cell_line', 'y_pred']],
            on=['drug', 'cell_line'],
            suffixes=('_rf', '_nm')
        )

        y_true = merged['y_true'].values
        y_pred_rf = merged['y_pred_rf'].values
        y_pred_nm = merged['y_pred_nm'].values

        metrics = compute_variance_decomposition(y_true, y_pred_rf, y_pred_nm)
        metrics_list.append(metrics)

    # Average each metric
    keys = metrics_list[0].keys()
    return {k: np.mean([m[k] for m in metrics_list]) for k in keys}



In [None]:
test_mode = "LCO"
SAMPLE = "picked"
ALGORITHM = "RandomForest"



# --- Load raw predictions ---
rf_raw = get_results_and_drug_vis(algorithm="RandomForest", sample=SAMPLE, test_mode=test_mode, normalized=False)
nm_raw = get_results_and_drug_vis(algorithm="NaiveMeanEffectsPredictor", sample=SAMPLE, test_mode=test_mode, normalized=False)
merged_raw = get_model_preds_and_merge(rf_raw, nm_raw, merge_cols=['drug', 'cell_line'])

# --- Load normalized predictions ---
rf_norm = get_results_and_drug_vis(algorithm="RandomForest", sample=SAMPLE, test_mode=test_mode, normalized=True)
nm_norm = get_results_and_drug_vis(algorithm="NaiveMeanEffectsPredictor", sample=SAMPLE, test_mode=test_mode, normalized=True)
merged_norm = get_model_preds_and_merge(rf_norm, nm_norm, merge_cols=['drug', 'cell_line'])


In [None]:

# --- Compute decompositions ---
raw_metrics = compute_splitwise_mean_decomposition(rf_raw, nm_raw)
norm_metrics = compute_splitwise_mean_decomposition(rf_norm, nm_norm)

def safe_components(metrics):
    tss = metrics['tss']
    ess_nm = max(metrics['ess_nm'], 0)
    extra_ess = max(metrics['extra_ess'], 0)
    unexplained = max(tss - ess_nm - extra_ess, 0)
    return tss, ess_nm, extra_ess, unexplained
# --- Print metrics ---
def print_metrics(label, metrics):
    print(f"\n[{label}]")
    print(f"Total variance (TSS): {metrics['tss']:.4f}")
    print(f"Naive Mean R²: {metrics['r2_nm']:.4f}")
    print(f"Random Forest R²: {metrics['r2_rf']:.4f}")
    print(f"Extra variance explained by RF (ESS): {metrics['extra_ess']:.4f}")
    print(f"Relative gain over unexplained variance: {metrics['relative_gain']*100:.2f}%")

print_metrics("Raw", raw_metrics)
print_metrics("Normalized", norm_metrics)

fig, axes = plt.subplots(1, 2, figsize=(10, 5), sharey=True)
max_tss = max(raw_metrics['tss'], norm_metrics['tss'])

plot_variance_decomposition(
    ax=axes[0],
    tss=safe_components(raw_metrics)[0],
    ess_nm=safe_components(raw_metrics)[1],
    extra_ess=safe_components(raw_metrics)[2],
    unexplained=safe_components(raw_metrics)[3],
    max_tss=max_tss,
    xlabel='Raw',
)

plot_variance_decomposition(
    ax=axes[1],
    tss=safe_components(norm_metrics)[0],
    ess_nm=safe_components(norm_metrics)[1],
    extra_ess=safe_components(norm_metrics)[2],
    unexplained=safe_components(norm_metrics)[3],
    max_tss=max_tss,
    xlabel='Normalized',
)



axes[1].legend(loc='upper right', frameon=True, fontsize=18)
plt.tight_layout()
plt.savefig(f"figures/variance_decomposition_{test_mode}.pdf", bbox_inches='tight')
plt.show()

