In [66]:
import os
import glob
import json
import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns


In [69]:
def action_term_preprocessing(x):
    if pd.isna(x):
        return "NA"
    
    x = str(x).strip()
    
    # If multiple terms, take second
    if "," in x:
        parts = [p.strip() for p in x.split(",")]
        if len(parts) > 1:
            x = parts[1]
    
    # Abbreviation replacement
    x_lower = x.lower()
    
    if x_lower == "positive allosteric modulator":
        return "PAM"
    if x_lower == "negative allosteric modulator":
        return "NAM"
    
    return x


In [70]:
workdir = os.path.dirname(os.getcwd())
input_dir = Path(f"{workdir}/9_drug_protein_moa/data/")
module_path = Path("/home/bbc8731/HSV/3_module_expansion/data/categories_methods")
plot_dir = Path(f"{workdir}/9_drug_protein_moa/data/plot")

drug_target = pd.read_csv(os.path.join(input_dir, "drug_target_nedrex.csv"))
drug_target["drugbank_id"] = drug_target["sourceDomainId"].str.replace("drugbank.", "", regex=False)
drug_target["uniprot_id"] = drug_target["targetDomainId"].str.replace("uniprot.", "", regex=False)
drug_target["drug|protein"] = (drug_target["drugbank_id"].astype(str) + "|" + drug_target["uniprot_id"].astype(str))
drug_target = drug_target[["drug|protein", "interaction_label", "actions"]]
drug_target = drug_target[drug_target["actions"] != "[]"]
drug_target["actions"] = (drug_target["actions"].str.strip("[]").str.replace("'", "", regex=False))
drug_target["interaction_label"].unique()

array(['target_negativeEffect', 'target', 'target_positiveEffect'],
      dtype=object)

In [72]:
files = sorted(p for p in module_path.glob("*/") if p.name.startswith(("BP_", "CC_")))

files = [('/home/bbc8731/HSV/3_module_expansion/data/categories_methods/BP_Egress_and_Envelopment')]

for p in files:
    candidate_drugs = pd.read_csv(f"{p}/drug_repurposing/trustrank/uniprot_ppi/validation/approved_drugs/candidate_drugs_scored_atc_code.csv")
    candidate_drugs = candidate_drugs.rename(columns={"score": "drug_score", 
                                                      "neg_log_score": "drug_neg_log_score", 
                                                      "cmpdname": "drug_name", 
                                                      "atc_code":"drug_atc_code",
                                                     "atc_label": "drug_atc_label"})

    protein_modules = pd.read_csv(f"{p}/drugability/protein_drugability_cutoff_4.csv")
    # extract proteins with approved drugs: ["druggability_rank"] == 1
    protein_modules = protein_modules[protein_modules["druggability_rank"] == 1]
    protein_modules = protein_modules[["uniprot_id", "symbol", "ensembl_id"]]

    # identify all the drug-target combinations
    drug_protein_df = candidate_drugs.merge(protein_modules, how="cross")
    drug_protein_df["drug|protein"] = (drug_protein_df["drugbank_id"].astype(str) + "|" + drug_protein_df["uniprot_id"].astype(str))
    drug_protein = drug_protein_df.merge(drug_target, on = "drug|protein", how = "left")
    
    df_plot = drug_protein[["drug_name", "symbol", "trustrank", "drug_atc_label", "actions"]]
    df_plot = df_plot.dropna(subset = ["drug_name", "symbol", "drug_atc_label", "actions"])
    

    if df_plot.empty:
        print(f"Skipping {p.name} — no valid drug-target actions.")
        continue

    ############### Plot 
    df_plot["action_clean"] = df_plot["actions"].apply(action_term_preprocessing)
    
    # -----------------------------
    # 1️⃣ Keep action as-is, only convert true NaN
    # -----------------------------
    df_plot["action_clean"] = df_plot["action_clean"].where(
        df_plot["action_clean"].notna(),
        "NA"
    )
    
    # -----------------------------
    # 2️⃣ Create drug label with rank
    # -----------------------------
    df_plot["drug_label"] = (
        df_plot["drug_name"] + " (Rank " +
        df_plot["trustrank"].astype(int).astype(str) + ")"
    )
    
    df_plot = df_plot[["drug_label", "symbol", "trustrank", "action_clean"]]
    
    # Sort drugs by rank
    rank_order = (
        df_plot[["drug_label", "trustrank"]]
        .drop_duplicates()
        .sort_values("trustrank")
    )["drug_label"]
    
    # -----------------------------
    # 3️⃣ Pivot (NO filling)
    # -----------------------------
    pivot = df_plot.pivot_table(
        index="drug_label",
        columns="symbol",
        values="action_clean",
        aggfunc="first"
    ).loc[rank_order]
    
    # -----------------------------
    # 4️⃣ Convert categories to numeric just for coloring
    # -----------------------------
    unique_actions = pd.unique(pivot.values.ravel())
    unique_actions = [x for x in unique_actions if pd.notna(x)]
    
    action_map = {action: i+1 for i, action in enumerate(unique_actions)}
    pivot_num = pivot.apply(lambda col: col.map(action_map))
    
    # -----------------------------
    # 5️⃣ Plot
    # -----------------------------
    # n_rows, n_cols = pivot.shape
    # plt.figure(figsize=(n_cols * 1, n_rows * 0.2))

    plt.figure(figsize=(8, 10))
    
    cmap = sns.color_palette("tab10", n_colors=len(action_map))
    
    ax = sns.heatmap(
        pivot_num,
        cmap=cmap,
        linewidths=0.5,
        linecolor="lightgray",
        cbar=False,
        annot=pivot,
        fmt="",
        mask=pivot.isna()   # <- true missing drug–target combos remain empty
    )

    name = os.path.basename(p)
    clean_name = name.split("_", 1)[1].replace("_", " ")

    
    plt.xlabel("")
    plt.ylabel("")
    plt.title(f"candidates for '{clean_name}' category")
    plt.xticks(rotation=0)
    plt.tight_layout()
    # plt.show()

    plt.savefig(os.path.join(plot_dir, f"actions_drug_protein_candidates_{name}.pdf"), bbox_inches="tight")
    plt.close()
    