In [1]:
import pandas as pd

rf_path = "rf_10_fold_pseudobulk_30_pcs_tissue_growth.csv"
en_path = "en_10_fold_pseudobulk_30_pcs_tissue_growth.csv"

rf_df = pd.read_csv(rf_path)
en_df = pd.read_csv(en_path)

print("RF shape:", rf_df.shape)
print("EN shape:", en_df.shape)

display(rf_df.head())
display(en_df.head())


RF shape: (625, 6)
EN shape: (625, 6)


Unnamed: 0,DRUG_ID,n_samples,r2_global,pearson_r,pearson_pval,rmse
0,133,132,0.07961,0.287642,0.000825,1.625511
1,134,132,0.046691,0.236615,0.006304,1.867207
2,135,131,-0.010608,0.176148,0.044162,2.623994
3,136,133,-0.021383,0.131129,0.132457,1.522899
4,140,133,0.10288,0.321402,0.000162,1.551783


Unnamed: 0,DRUG_ID,n_samples,r2_global,pearson_r,pearson_pval,rmse
0,133,132,0.100553,0.333985,9.1e-05,1.60691
1,134,132,0.092744,0.30928,0.000308,1.821548
2,135,131,0.039429,0.203729,0.019597,2.55821
3,136,133,0.023601,0.173349,0.045997,1.488985
4,140,133,0.069461,0.269579,0.001702,1.580423


In [2]:
N = 15

rf_top = rf_df.sort_values("r2_global", ascending=False).head(N).reset_index(drop=True)
en_top = en_df.sort_values("r2_global", ascending=False).head(N).reset_index(drop=True)

print("Top RF drugs:")
display(rf_top)

print("Top EN drugs:")
display(en_top)


Top RF drugs:


Unnamed: 0,DRUG_ID,n_samples,r2_global,pearson_r,pearson_pval,rmse
0,2544,138,0.390003,0.624615,2.687842e-16,1.682686
1,1073,138,0.380763,0.628753,1.49438e-16,1.281825
2,1526,126,0.371317,0.60949,3.627422e-14,1.312988
3,2564,138,0.365279,0.604825,3.964475e-15,1.799092
4,1086,138,0.359762,0.605919,3.433342e-15,1.821091
5,1564,138,0.342724,0.58599,4.343442e-14,1.425836
6,1498,131,0.342151,0.58532,2.100281e-13,1.387545
7,1386,135,0.339534,0.583112,1.164711e-13,1.067452
8,2500,114,0.33444,0.584769,8.407215e-12,1.664698
9,1037,135,0.329568,0.574111,3.349384e-13,0.92933


Top EN drugs:


Unnamed: 0,DRUG_ID,n_samples,r2_global,pearson_r,pearson_pval,rmse
0,2564,138,0.382163,0.618804,6.041637e-16,1.775002
1,2544,138,0.377509,0.615307,9.758079e-16,1.69983
2,1089,138,0.374964,0.61376,1.204036e-15,1.201692
3,1073,138,0.363047,0.602664,5.260025e-15,1.300032
4,1079,138,0.355912,0.59693,1.102233e-14,1.672416
5,1086,138,0.353557,0.595367,1.344989e-14,1.829893
6,1036,135,0.350063,0.591672,4.137324e-14,1.175512
7,2156,112,0.349474,0.596815,3.796828e-12,1.139667
8,2545,137,0.347747,0.589846,3.348751e-14,1.106923
9,1526,126,0.34691,0.592482,2.711585e-13,1.338232


In [3]:
rf_set = set(rf_top["DRUG_ID"].astype(str))
en_set = set(en_top["DRUG_ID"].astype(str))

print("Overlap of top", N, "RF vs EN drug IDs:")
print(rf_set & en_set)

print("\nRF-only top drugs:")
print(rf_set - en_set)

print("\nEN-only top drugs:")
print(en_set - rf_set)


Overlap of top 15 RF vs EN drug IDs:
{'1086', '2544', '1073', '2156', '1526', '2564'}

RF-only top drugs:
{'2145', '2500', '2543', '1037', '1498', '1564', '1372', '1386', '1373'}

EN-only top drugs:
{'1036', '1089', '1061', '1708', '1392', '1079', '1378', '2545', '2562'}


In [4]:
# Merge RF and EN results by DRUG_ID
merged = rf_df[["DRUG_ID","r2_global"]].rename(columns={"r2_global":"r2_RF"}).merge(
    en_df[["DRUG_ID","r2_global"]].rename(columns={"r2_global":"r2_EN"}),
    on="DRUG_ID", how="inner"
)

# Add ranks
merged["rank_RF"] = merged["r2_RF"].rank(ascending=False, method="min")
merged["rank_EN"] = merged["r2_EN"].rank(ascending=False, method="min")

# Sort by average rank
merged["avg_rank"] = merged[["rank_RF","rank_EN"]].mean(axis=1)
merged_top = merged.sort_values("avg_rank").head(N)

print("Consensus top drugs (by average rank across RF and EN):")
display(merged_top)


Consensus top drugs (by average rank across RF and EN):


Unnamed: 0,DRUG_ID,r2_RF,r2_EN,rank_RF,rank_EN,avg_rank
601,2544,0.390003,0.377509,1.0,2.0,1.5
614,2564,0.365279,0.382163,4.0,1.0,2.5
238,1073,0.380763,0.363047,2.0,4.0,3.0
244,1086,0.359762,0.353557,5.0,6.0,5.5
371,1526,0.371317,0.34691,3.0,10.0,6.5
544,2156,0.321443,0.349474,13.0,8.0,10.5
384,1564,0.342724,0.319906,6.0,16.0,11.0
246,1089,0.304584,0.374964,22.0,3.0,12.5
210,1036,0.305665,0.350063,21.0,7.0,14.0
602,2545,0.306325,0.347747,19.0,9.0,14.0


In [5]:
# Load your drug ID -> name mapping
drug_map = pd.read_csv("drug_id_name_map.csv")

# Attach names to the consensus table
merged_top_named = merged_top.merge(drug_map, on="DRUG_ID", how="left")

print("Consensus top drugs with names:")
display(merged_top_named)


Consensus top drugs with names:


Unnamed: 0,DRUG_ID,r2_RF,r2_EN,rank_RF,rank_EN,avg_rank,DRUG_NAME
0,2544,0.390003,0.377509,1.0,2.0,1.5,
1,2564,0.365279,0.382163,4.0,1.0,2.5,
2,1073,0.380763,0.363047,2.0,4.0,3.0,5-Fluorouracil
3,1086,0.359762,0.353557,5.0,6.0,5.5,BI-2536
4,1526,0.371317,0.34691,3.0,10.0,6.5,Refametinib
5,2156,0.321443,0.349474,13.0,8.0,10.5,5-azacytidine
6,1564,0.342724,0.319906,6.0,16.0,11.0,SCH772984
7,1089,0.304584,0.374964,22.0,3.0,12.5,Oxaliplatin
8,1036,0.305665,0.350063,21.0,7.0,14.0,PLX-4720
9,2545,0.306325,0.347747,19.0,9.0,14.0,


In [6]:
from datasets import load_dataset

# Load Tahoe drug metadata (379 drugs)
tahoe_drug_md = load_dataset("tahoebio/Tahoe-100M", name="drug_metadata", split="train").to_pandas()

# Normalize both sides for string matching
def normalize_name(x):
    return str(x).strip().lower().replace("-", "").replace(" ", "")

merged_top_named["DRUG_NAME_norm"] = merged_top_named["DRUG_NAME"].map(normalize_name)
tahoe_drug_md["drug_norm"] = tahoe_drug_md["drug"].map(normalize_name)

# Mark direct overlap
merged_top_named["In_Tahoe"] = merged_top_named["DRUG_NAME_norm"].isin(set(tahoe_drug_md["drug_norm"]))

print("Top drugs with Tahoe availability:")
display(merged_top_named[["DRUG_ID","DRUG_NAME","r2_RF","r2_EN","rank_RF","rank_EN","avg_rank","In_Tahoe"]])


  from .autonotebook import tqdm as notebook_tqdm


Top drugs with Tahoe availability:


Unnamed: 0,DRUG_ID,DRUG_NAME,r2_RF,r2_EN,rank_RF,rank_EN,avg_rank,In_Tahoe
0,2544,,0.390003,0.377509,1.0,2.0,1.5,False
1,2564,,0.365279,0.382163,4.0,1.0,2.5,False
2,1073,5-Fluorouracil,0.380763,0.363047,2.0,4.0,3.0,True
3,1086,BI-2536,0.359762,0.353557,5.0,6.0,5.5,False
4,1526,Refametinib,0.371317,0.34691,3.0,10.0,6.5,False
5,2156,5-azacytidine,0.321443,0.349474,13.0,8.0,10.5,True
6,1564,SCH772984,0.342724,0.319906,6.0,16.0,11.0,False
7,1089,Oxaliplatin,0.304584,0.374964,22.0,3.0,12.5,True
8,1036,PLX-4720,0.305665,0.350063,21.0,7.0,14.0,False
9,2545,,0.306325,0.347747,19.0,9.0,14.0,False


In [7]:
merged_top_named.to_csv("consensus_top_drugs_named.csv", index=False)
print("Saved: consensus_top_drugs_named.csv")


Saved: consensus_top_drugs_named.csv
