# plot protein drugability

In [108]:
import os
import glob
import pandas as pd
from pathlib import Path

In [109]:
workdir = os.path.dirname(os.getcwd())
input_dir = Path(f"{workdir}/7_protein_drugability/data/")
module_path = Path("/home/bbc8731/HSV/3_module_expansion/data/categories_methods")
plot_dir = Path(f"{workdir}/7_protein_drugability/data/plot")

In [115]:
module_files = sorted(module_path.rglob("*/drugability/*.csv"))
thresholds = [4]
for p in module_files:
    base_path = p.parents[1]
    modules = pd.read_csv(p, sep=None, engine='python', comment='#')
    tract_cols = modules.columns.difference(["ensembl_id", "symbol", "uniprot_id"])

In [116]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

module_files = sorted(module_path.rglob("*/drugability/*.csv"))

for p in module_files:

    modules = pd.read_csv(p, sep=None, engine='python', comment='#')

    bool_cols = modules.columns.difference(
        ["ensembl_id", "symbol", "uniprot_id", "druggability_rank"]
    )

    desired_order = [
        "Approved Drug",
        "Advanced Clinical",
        "Phase 1 Clinical",
        "High-Quality Ligand",
        "Structure with Ligand",
        "High-Quality Pocket",
        "Med-Quality Pocket",
        "Druggable Family"
    ]


    df = (
        modules
        .drop_duplicates(subset="symbol")
        .set_index("symbol")
        .sort_values("druggability_rank", ascending=False)
    )

    df.index.name = None
    n_proteins = len(df)

    if n_proteins <= 30:

        # bool_data = df[bool_cols].astype(int).T
        bool_data = df[desired_order].astype(int).T
        rank_data = df[["druggability_rank"]].T
        rank_data.index = [""]

        min_rank, max_rank = rank_data.values.min(), rank_data.values.max()

        fig, axes = plt.subplots(
            2, 1,
            figsize=(0.5 * n_proteins, 0.6 * len(bool_cols) + 2),
            gridspec_kw={'height_ratios': [len(bool_cols), 1]}
        )


        sns.heatmap(
            bool_data,
            cmap=["white"],      # no color
            cbar=False,
            linewidths=0.5,
            linecolor="lightgray",
            ax=axes[0]
        )
        
        for j in range(bool_data.shape[1]):  # per gene
            first_marked = False
            for i in range(bool_data.shape[0]):  # top → bottom
                if bool_data.iloc[i, j] == 1:
                    if not first_marked:
                        # draw gray cell
                        axes[0].add_patch(
                            plt.Rectangle(
                                (j, i), 1, 1,
                                facecolor="lightgray",
                                edgecolor="lightgray",
                                zorder=1
                            )
                        )
                        color = "black"
                        first_marked = True
                    else:
                        color = "black"
        
                    axes[0].text(
                        j + 0.5,
                        i + 0.5,
                        "✓",
                        ha="center",
                        va="center",
                        fontsize=12,
                        fontweight="bold",
                        color=color
                    )
        

        axes[0].xaxis.tick_top()
        axes[0].tick_params(axis='x', rotation=90)

        rank_hm = sns.heatmap(
            rank_data,
            cmap="Blues_r",
            vmin=min_rank,
            vmax=max_rank,
            cbar=True,
            cbar_kws={"orientation": "horizontal", "pad": 0.4},
            ax=axes[1]
        )

        axes[1].set_xticks([])
        axes[1].set_ylabel("druggability_rank", rotation=0, labelpad=40)
        axes[1].set_xlim(axes[0].get_xlim())

        cbar = rank_hm.collections[0].colorbar
        cbar.set_ticks(range(int(min_rank), int(max_rank) + 1))
        cbar.set_label("")
        cbar.ax.set_xlabel("")
        cbar.ax.set_ylabel("")

    else:

        # bool_data = df[bool_cols].astype(int)
        bool_data = df[desired_order].astype(int)
        rank_data = df[["druggability_rank"]]
        rank_data.columns = [""]

        min_rank, max_rank = rank_data.values.min(), rank_data.values.max()

        fig, axes = plt.subplots(
            1, 2,
            figsize=(0.6 * len(bool_cols) + 2, 0.25 * n_proteins),
            gridspec_kw={'width_ratios': [len(bool_cols), 1]}
        )

        sns.heatmap(
            bool_data,
            cmap=["white"],      # no color
            cbar=False,
            linewidths=0.5,
            linecolor="lightgray",
            ax=axes[0]
        )
        
        # Add check marks

        for i in range(bool_data.shape[0]):  # per gene
            first_marked = False
            for j in range(bool_data.shape[1]):
                if bool_data.iloc[i, j] == 1:
                    if not first_marked:
                        axes[0].add_patch(
                            plt.Rectangle(
                                (j, i), 1, 1,
                                facecolor="lightgray",
                                edgecolor="lightgray",
                                zorder=1
                            )
                        )
                        color = "black"
                        first_marked = True
                    else:
                        color = "black"
        
                    axes[0].text(
                        j + 0.5,
                        i + 0.5,
                        "✓",
                        ha="center",
                        va="center",
                        fontsize=12,
                        fontweight="bold",
                        color=color
                    )



        rank_hm = sns.heatmap(
            rank_data,
            cmap="Blues_r",
            vmin=min_rank,
            vmax=max_rank,
            cbar=True,
            ax=axes[1]
        )

        axes[1].set_yticks([])
        axes[1].set_ylim(axes[0].get_ylim())

        cbar = rank_hm.collections[0].colorbar
        cbar.set_ticks(range(int(min_rank), int(max_rank) + 1))
        cbar.set_label("")
        cbar.ax.set_xlabel("")
        cbar.ax.set_ylabel("")

    title = p.parents[1].name
    plt.suptitle(title, y=0.95)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.savefig(f"{plot_dir}/{title}.pdf", format="pdf")
    plt.close()
    # plt.show()

