In [None]:
import pandas as pd
import numpy as np
import re
from itertools import combinations
import json
from bs4 import BeautifulSoup
import seaborn as sns
import plotly.graph_objects as go
import statsmodels.formula.api as smf
from plotly.subplots import make_subplots
import plotly.express as px
from matplotlib import pyplot as plt

from src.scripts import smiles
from src.scripts import targets
from src.scripts.save_plotly import save_plotly


from src.utils.utils import group_categories
from src.utils.utils import count_classified_rows
from src.utils.utils import collinear_cols

from src.scripts.data_description import (
    plot_distributions_plotly,
    plot_overlaps_plotly,
    categorical_countplot_plotly,
    plot_availability_plotly,
)

import warnings

warnings.filterwarnings("ignore")

In [None]:
usecols = [
    "BindingDB Reactant_set_id",
    "Ligand SMILES",
    "Ligand InChI",
    "Ligand InChI Key",
    "BindingDB MonomerID",
    "BindingDB Ligand Name",
    "Target Name",
    "Target Source Organism According to Curator or DataSource",
    "Ki (nM)",
    "IC50 (nM)",
    "Kd (nM)",
    "EC50 (nM)",
    "kon (M-1-s-1)",
    "koff (s-1)",
    "pH",
    "Temp (C)",
    "Curation/DataSource",
    "Article DOI",
    "BindingDB Entry DOI",
    "PMID",
    "PubChem AID",
    "Patent Number",
    "Authors",
    "Institution",
    "Ligand HET ID in PDB",
    "PDB ID(s) for Ligand-Target Complex",
    "PubChem CID",
    "PubChem SID",
    "ChEBI ID of Ligand",
    "ChEMBL ID of Ligand",
    "DrugBank ID of Ligand",
    "IUPHAR_GRAC ID of Ligand",
    "KEGG ID of Ligand",
    "ZINC ID of Ligand",
    "Number of Protein Chains in Target (>1 implies a multichain complex)",
    "BindingDB Target Chain Sequence",
    "PDB ID(s) of Target Chain",
    "UniProt (SwissProt) Recommended Name of Target Chain",
    "UniProt (SwissProt) Entry Name of Target Chain",
    "UniProt (SwissProt) Primary ID of Target Chain",
    "UniProt (TrEMBL) Primary ID of Target Chain",
]

In [None]:
df = pd.read_csv(
    r"/Users/poseidon/Documents/ADA/ada-2024-project-standarddeviants/data/BindingDB_All.tsv",
    sep="\t",
    usecols=usecols,
)

In [None]:
# Defining categories
binding_kinetics = [
    "Ki (nM)",
    "IC50 (nM)",
    "Kd (nM)",
    "EC50 (nM)",
    "kon (M-1-s-1)",
    "koff (s-1)",
    "pH",
    "Temp (C)",
]

target_related = [x for x in usecols if "Target" in x]
ligand_related = [x for x in usecols if "Ligand" in x]
id_columns = [x for x in usecols if ("ID" in x) or ("id" in x)]
id_columns.append("Ligand InChI Key")
names = [x for x in usecols if "Name" in x]
metadata = [
    x
    for x in usecols
    if (x not in target_related)
    and (x not in ligand_related)
    and (x not in id_columns)
    and (x not in names)
    and (x not in binding_kinetics)
]

## Binding Kinetics

In [None]:
fig_overlaps_bk = plot_overlaps_plotly(df=df, group=binding_kinetics, annot=True)
fig_overlaps_bk

In [None]:
fig_overlaps_bk = plot_overlaps_plotly(df=df, group=binding_kinetics, annot=True)
fig_overlaps_bk

In [None]:
# clean binding kinetics data
df.replace(" NV,", np.nan, inplace=True)
for col in binding_kinetics:
    df[col] = df[col].astype(str).str.replace(" C", "")
    df[col] = (
        df[col].astype(str).str.replace(">", "").str.replace("<", "").astype(float)
    )

In [None]:
rows = 3
cols = 3
fig_distributions = make_subplots(rows=rows, cols=cols, subplot_titles=binding_kinetics)

for idx, col in enumerate(binding_kinetics):
    r = idx // cols + 1
    c = idx % cols + 1
    # For pH and Temp, no log transform; otherwise, log transform
    log_transform = not (col == "pH" or col == "Temp (C)")
    dist_fig = plot_distributions_plotly(
        df, col, log_transform=log_transform, show_legend=(idx == 0)
    )
    for trace in dist_fig.data:
        fig_distributions.add_trace(trace, row=r, col=c)
    fig_distributions.update_xaxes(dist_fig.layout.xaxis, row=r, col=c)
    fig_distributions.update_yaxes(dist_fig.layout.yaxis, row=r, col=c)


fig_distributions.update_layout(
    # height=1000, width=1000, No figsize for better website results
    title_text="Binding Kinetics Parameters Distributions",
    title_x=0.5,
    plot_bgcolor="rgb(34, 37, 41)",
    paper_bgcolor="rgb(34, 37, 41)",
    font=dict(color="white"),
    showlegend=False,
)

fig_distributions.show()

## Target and Ligang Related Data

In [None]:
ligands_and_targets = list(set(ligand_related + target_related))

combined_fig_availability_ligand_target = make_subplots(
    rows=2, cols=1, subplot_titles=("Ligand Related", "Target Related")
)

fig_ligand = plot_availability_plotly(df=df, group=ligand_related, step=100)
for trace in fig_ligand.data:
    combined_fig_availability_ligand_target.add_trace(trace, row=1, col=1)
fig_target = plot_availability_plotly(df=df, group=target_related, step=100)
for trace in fig_target.data:
    combined_fig_availability_ligand_target.add_trace(trace, row=2, col=1)

combined_fig_availability_ligand_target.update_yaxes(
    autorange="reversed",
    showgrid=False,
    tickfont=dict(color="white"),
    showticklabels=False,
)

combined_fig_availability_ligand_target.update_layout(
    title="Availability Matrix",
    xaxis_title="",
    yaxis_title="Observations",
    margin=dict(l=60, r=20, t=60, b=100),
    plot_bgcolor="rgb(34, 37, 41)",
    paper_bgcolor="rgb(34, 37, 41)",
    # width=800,
    # height=800,
    title_x=0.5,
    font=dict(color="white"),
)

combined_fig_availability_ligand_target.show()

In [None]:
combined_fig_o = make_subplots(
    rows=2,
    cols=1,
    subplot_titles=("Ligand Related", "Target Related"),
    horizontal_spacing=0.3,
)

fig_ligand_o = plot_overlaps_plotly(df, ligand_related)
for trace in fig_ligand_o.data:
    combined_fig_o.add_trace(trace, row=1, col=1)
fig_target_o = plot_overlaps_plotly(df, target_related)
for trace in fig_target_o.data:
    combined_fig_o.add_trace(trace, row=2, col=1)


combined_fig_o.update_layout(
    title_text="Proportion of Overlap within Data directly Relating to Ligands and Targets",
    xaxis_title="",
    yaxis_title="Observations",
    # margin=dict(l=60, r=20, t=60, b=100),
    plot_bgcolor="rgb(34, 37, 41)",
    paper_bgcolor="rgb(34, 37, 41)",
    title_x=0.5,
    # height=800,
    # width= 2200,
    font=dict(color="white"),
)

combined_fig_o.show()

## Target Organism

In [None]:
# cleaning target organism data
in_mapping = {
    "Human immunodeficiency virus": ["immunodeficiency virus", "hiv"],
    "Sars coronavirus": ["severe acute respiratory", "sars"],
    "Human herpes virus": ["herpes"],
    "Homo sapiens": ["h. sapiens"],
    "Mus musculus": ["mouse", "m. musculus"],
    "Rattus norvegicus": ["rattus", "r. norvegicus"],
    "Cavia porcellus": ["cavia"],
    "Hepatitis C": [],
    "Escherichia coli": ["coli"],
    "Caenorhabditis elegans": ["elegans"],
    "Influenza virus": ["influenza"],
    "Oryctolagus cuniculus": ["cuniculus"],
    "Streptococcus pyogenes": ["pyogenes"],
    "Plasmodium falciparum": [],
    "Saccharomyces cerevisiae": ["cervisiae"],
    "Streptococcus pneumoniae": [],
    "Mycobacterium tuberculosis": [],
}

f = lambda x: group_categories(
    str(x),
    in_mapping=in_mapping,
    check_key_for_in_mapping=True,
)

df["Target Source Organism According to Curator or DataSource"] = (
    df["Target Source Organism According to Curator or DataSource"]
    .apply(f)
    .apply(str.capitalize)
    .replace("Nan", np.nan)
)

In [None]:
f = lambda x: group_categories(
    str(x),
    in_mapping=in_mapping,
    check_key_for_in_mapping=True,
)

df["Target Source Organism According to Curator or DataSource"] = (
    df["Target Source Organism According to Curator or DataSource"]
    .apply(f)
    .apply(str.capitalize)
    .replace("Nan", np.nan)
)

# Call the updated function
fig_target_source_org_distrib = categorical_countplot_plotly(
    df,
    "Target Source Organism According to Curator or DataSource",
    N=50,
    percentile=0.8,
    x_scale="log",
)
fig_target_source_org_distrib.show()

## Metadata

In [None]:
fig_combined_metadata = make_subplots(rows=1, cols=2, horizontal_spacing=0.2)

fig_metadata_overlap = plot_overlaps_plotly(df, metadata)
for trace in fig_metadata_overlap.data:
    fig_combined_metadata.add_trace(trace, row=1, col=2)
fig_metadata_availability = plot_availability_plotly(df, metadata, 100)
for trace in fig_metadata_availability.data:
    fig_combined_metadata.add_trace(trace, row=1, col=1)

availability_annotations = fig_metadata_availability.layout.annotations
for annotation in availability_annotations:
    annotation["xref"] = "x1"  # Align to subplot in col=1
    annotation["yref"] = "paper"  # Use paper coordinates for global alignment
    fig_combined_metadata.add_annotation(annotation)

fig_combined_metadata.update_yaxes(
    autorange="reversed",
    showgrid=False,
    tickfont=dict(color="white"),
    showticklabels=False,
    row=1,
    col=1,
)

fig_combined_metadata.update_layout(
    title_text="Availability and Overlap of Metadata",
    xaxis_title="",
    yaxis_title="Observations",
    # margin=dict(l=60, r=20, t=60, b=100),
    plot_bgcolor="rgb(34, 37, 41)",
    paper_bgcolor="rgb(34, 37, 41)",
    title_x=0.5,
    # height=600,
    # width= 1200,
    font=dict(color="white"),
)

fig_combined_metadata.show()

In [None]:
fig_combined_distrib_patent_doi = make_subplots(
    rows=2,
    cols=2,
    shared_xaxes=True,
    subplot_titles=[
        "Top 10 Articles by Compounds",
        "Top 10 Patents by Compounds",
        "Top 100 Articles by Compounds",
        "Top 100 Patents by Compounds",
    ],
    # horizontal_spacing=0.2,
    # vertical_spacing=0.3,
)

# Top N values for each plot
N1 = 100
N2 = 10

fig_article_10 = categorical_countplot_plotly(
    df,
    "Article DOI",
    N=N2,
    percentile=None,
    x_scale="log",
)
for trace in fig_article_10.data:
    fig_combined_distrib_patent_doi.add_trace(trace, row=1, col=1)

fig_article_100 = categorical_countplot_plotly(
    df,
    "Article DOI",
    N=N1,
    percentile=None,
    x_scale="log",
)
for trace in fig_article_100.data:
    fig_combined_distrib_patent_doi.add_trace(trace, row=2, col=1)

fig_patent_10 = categorical_countplot_plotly(
    df,
    "Patent Number",
    N=N2,
    percentile=None,
    x_scale="log",
)
for trace in fig_patent_10.data:
    fig_combined_distrib_patent_doi.add_trace(trace, row=1, col=2)

fig_patent_100 = categorical_countplot_plotly(
    df,
    "Patent Number",
    N=N1,
    percentile=None,
    x_scale="log",
)
for trace in fig_patent_100.data:
    fig_combined_distrib_patent_doi.add_trace(trace, row=2, col=2)


fig_combined_distrib_patent_doi.update_layout(
    title="Distributions of Compounds Studied per Article and Covered per Patent (Top 10 and 100)",
    # height=800,
    # width=1200,
    showlegend=True,  # Disable individual legends for subplots
    plot_bgcolor="rgb(34, 37, 41)",
    paper_bgcolor="rgb(34, 37, 41)",
    font=dict(color="white"),
    title_x=0.5,
)

fig_combined_distrib_patent_doi.update_yaxes(
    showticklabels=False, title_text="Article DOI", row=2, col=1
)
fig_combined_distrib_patent_doi.update_yaxes(
    showticklabels=False, title_text="Patent Number ", row=2, col=2
)
fig_combined_distrib_patent_doi.update_yaxes(autorange="reversed")

fig_combined_distrib_patent_doi.show()

In [None]:
def add_to_categorical_plot(df, column, N, row, col, fig_combined):
    fig_temp = categorical_countplot_plotly(
        df, column, N=N, percentile=None, x_scale="log"
    )
    for trace in fig_temp.data:
        fig_combined.add_trace(trace, row=row, col=col)


fig_combined_distrib_patent_doi = make_subplots(
    rows=2,
    cols=2,
    shared_xaxes=True,
    subplot_titles=[
        "Top 10 Articles by Compounds",
        "Top 10 Patents by Compounds",
        "Top 100 Articles by Compounds",
        "Top 100 Patents by Compounds",
    ],
)

N_values = [(10, 1, 1), (100, 2, 1), (10, 1, 2), (100, 2, 2)]  # N, row, col
columns = ["Article DOI", "Article DOI", "Patent Number", "Patent Number"]


for (N, row, col), column in zip(N_values, columns):
    add_to_categorical_plot(df, column, N, row, col, fig_combined_distrib_patent_doi)

fig_combined_distrib_patent_doi.update_layout(
    title="Distributions of Compounds Studied per Article and Covered per Patent (Top 10 and 100)",
    showlegend=True,
    plot_bgcolor="rgb(34, 37, 41)",
    paper_bgcolor="rgb(34, 37, 41)",
    font=dict(color="white"),
    title_x=0.5,
    yaxis2=dict(  # Y-axis for row 2, col 1
        title="Article DOI", showticklabels=False, autorange="reversed"
    ),
    yaxis4=dict(  # Y-axis for row 2, col 2
        title="Patent Number", showticklabels=False, autorange="reversed"
    ),
)

fig_combined_distrib_patent_doi.show()

In [None]:
in_mapping = {
    "Pfizer": [],
    "MSD": ["Dohme"],
    "Bristol-Myers Squibb": [],
    "Amgen": [],
    "Novartis": [],
    "Janssen": [],
    "Eli Lilly": ["lilly"],
    "Roche": [],
    "Incyte": [],
    "Gilead": [],
    "Bayer": [],
    "Abbott": [],
    "Scripps Research Institute": ["scripps"],
    "The Burnham Institute": ["burnham"],
    "Genentech": [],
    "GlaxoSmithKline": ["gsk"],
    "Astrazeneca": [],
    "Abbvie": [],
    "Merck": [],
    "Boehring": [],
}

f = lambda x: group_categories(
    str(x),
    in_mapping=in_mapping,
    check_key_for_in_mapping=True,
)

df["Institution"] = (
    df["Institution"].apply(f).replace("TBA", np.nan).replace("nan", np.nan)
)  # tba = to be attributed

In [None]:
fig_institution = categorical_countplot_plotly(df, "Institution", N=20)


fig_institution.update_layout(
    title={
        "text": "Distribution of Top 20 Institutions",
        "x": 0.5,
        "xanchor": "center",
        "yanchor": "top",
    },
)

fig_institution.show()

## Molecular and Chemical Features 

In [None]:
mol_df = (
    pd.read_csv(
        "/Users/poseidon/Documents/ADA/ada-2024-project-standarddeviants/data/BindingDB_All.tsv",
        sep="\t",
        usecols=[
            "BindingDB Reactant_set_id",
            "Ligand SMILES",
            "BindingDB Ligand Name",
        ],
    )
    .dropna()
    .sample(10000, random_state=0)
)
mol_df.head()

In [None]:
mol_df["H-Bond Donors"] = mol_df["Ligand SMILES"].apply(smiles.get_Hdonors)
mol_df["H-Bond Acceptors"] = mol_df["Ligand SMILES"].apply(smiles.get_Hacceptors)
mol_df["Molecular Weight"] = mol_df["Ligand SMILES"].apply(smiles.get_MW)
mol_df["C LogP"] = mol_df["Ligand SMILES"].apply(smiles.get_LogP)

fig_pairplot = px.scatter_matrix(
    mol_df,
    dimensions=["H-Bond Donors", "H-Bond Acceptors", "Molecular Weight", "C LogP"],
    title="Pairwise Relationships of Molecular Properties",
    labels={
        "H-Bond Donors": "H-Bond Donors",
        "H-Bond Acceptors": "H-Bond Acceptors",
        "Molecular Weight": "Molecular Weight",
        "C LogP": "C LogP",
    },
)

fig_pairplot.update_layout(
    # width=1000,
    # height=1000,
    plot_bgcolor="rgb(34, 37, 41)",
    paper_bgcolor="rgb(34, 37, 41)",
    font=dict(color="white"),
    title_x=0.5,
)

fig_pairplot.show()

## Lipinski Rules

In [None]:
mol_df["Lipinski"] = mol_df["Ligand SMILES"].apply(smiles.lipinski)
mol_df["Lipinski"] = mol_df["Lipinski"].map({True: 1, False: 0})
mol_df["Lipinski Label"] = mol_df["Lipinski"].map({0: "No", 1: "Yes"})

# Calculate percentages
lipinski_counts = (
    mol_df["Lipinski Label"].value_counts(normalize=True).mul(100).reset_index()
)

# Rename columns explicitly
lipinski_counts.columns = ["Lipinski", "Percent"]

# Debugging: Print the corrected lipinski_counts DataFrame
print(lipinski_counts)

# Create the bar chart using go.Figure
fig = go.Figure()

# Add "No" bar
fig.add_trace(
    go.Bar(
        x=["No"],
        y=[
            lipinski_counts.loc[lipinski_counts["Lipinski"] == "No", "Percent"].values[
                0
            ]
        ],
        name="No",
        marker_color="#ff7f7e",
        texttemplate="%{y:.1f}%",
        textposition="outside",
    )
)

# Add "Yes" bar
fig.add_trace(
    go.Bar(
        x=["Yes"],
        y=[
            lipinski_counts.loc[lipinski_counts["Lipinski"] == "Yes", "Percent"].values[
                0
            ]
        ],
        name="Yes",
        marker_color="#7fc080",
        texttemplate="%{y:.1f}%",
        textposition="outside",
    )
)

# Update layout
fig.update_layout(
    title="Does the drug respect Lipinski's Rule?",
    title_x=0.5,
    xaxis_title="",
    yaxis_title="%",
    plot_bgcolor="rgb(34, 37, 41)",
    paper_bgcolor="rgb(34, 37, 41)",
    font=dict(color="white"),
    showlegend=False,
    # eight=600,
    # width=400
)

# Show the plot
fig.show()

## Molecular Features

In [None]:
from src.scripts import smiles
from src.scripts import targets
import src.utils

import statsmodels.api as sm
from statsmodels.discrete.discrete_model import MNLogit
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import Pipeline
from sklearn.model_selection import cross_val_score, KFold
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.decomposition import PCA


from rdkit import Chem
from rdkit.Chem import AllChem, DataStructs, Descriptors, Draw

import json
from tqdm import tqdm

tqdm.pandas()
from src.utils.utils import collinear_cols
from src.scripts.data_description import (
    plot_distributions_plotly,
    plot_overlaps_plotly,
    categorical_countplot_plotly,
    plot_availability_plotly,
)

In [None]:
ligand_df = pd.read_csv(
    "../data/BindingDB_All.tsv",
    sep="\t",
    usecols=[
        "Ligand SMILES",
        "IC50 (nM)",
        "UniProt (SwissProt) Primary ID of Target Chain",
        "Article DOI",
    ],
)

In [None]:
descriptors = ligand_df.dropna(
    subset=["Ligand SMILES", "UniProt (SwissProt) Primary ID of Target Chain"]
)[ligand_df["UniProt (SwissProt) Primary ID of Target Chain"] == "P07949"][
    "Ligand SMILES"
].progress_apply(
    smiles.get_MolDescriptors
)
descriptors_df = pd.DataFrame(
    descriptors[descriptors.apply(lambda x: isinstance(x, dict))].to_list(),
    index=descriptors[descriptors.apply(lambda x: isinstance(x, dict))].index,
)

In [None]:
X = (
    StandardScaler()
    .set_output(transform="pandas")
    .fit_transform(descriptors_df.dropna())
)
y = smiles.get_IC50(ligand_df.iloc[X.index]).dropna()
y = pd.cut(
    y, bins=np.logspace(0, np.log(y.max()), num=10) - 1, labels=np.arange(9)
).dropna()
X = X.loc[y.index]

In [None]:
# Train the Random Forest Regressor
rf = RandomForestClassifier(n_estimators=1000, max_depth=5, random_state=42)
to_drop = src.utils.utils.collinear_cols(X, threshold=0.95)
rf.fit(X.drop(to_drop, axis=1), y)

In [None]:
mdi_df = pd.DataFrame(
    data=[tree.feature_importances_ for tree in rf.estimators_],
    columns=X.drop(to_drop, axis=1).columns,
)


def bootstrap_ci(data, n_bootstrap=1000, ci=95):
    boot_means = []
    for _ in range(n_bootstrap):
        sample = np.random.choice(data, size=len(data), replace=True)
        boot_means.append(sample.mean())
    lower = np.percentile(boot_means, (100 - ci) / 2)
    upper = np.percentile(boot_means, 100 - (100 - ci) / 2)
    return (data.mean(), (data.mean() - lower, upper - data.mean()))


# Calculate mean and confidence intervals for each feature
mean_importances = mdi_df.mean()
ci_importances = [bootstrap_ci(mdi_df[col]) for col in mdi_df.columns]

error_array = np.array(
    [ci[1][1] for ci in ci_importances]
)  # Upper confidence interval range

fig = go.Figure(
    data=[
        go.Bar(
            x=mean_importances.index,
            y=mean_importances,
            error_y=dict(
                type="data", array=error_array, visible=True, thickness=1.5, width=5
            ),
        )
    ]
)

fig.update_layout(
    title="Feature Importances with Bootstrapped Confidence Intervals",
    title_x=0.5,
    xaxis=dict(title="Features", tickangle=90, tickfont=dict(size=10)),
    yaxis_title="MDI (Mean Decrease in Impurity)",
    bargap=0.1,
    template="plotly_dark",
    height=600,
    margin=dict(l=50, r=20, t=50, b=120),
)

fig.show()

In [None]:
X = (
    StandardScaler()
    .set_output(transform="pandas")
    .fit_transform(descriptors_df.dropna())
)
y = smiles.get_IC50(ligand_df.iloc[X.index]).dropna()
y = y[y < 2e5]
y = pd.cut(y, bins=np.logspace(0, 5, num=5, base=10) - 1, labels=np.arange(4)).dropna()
X = X.loc[y.index]

# Train the Random Forest Regressor
lr = LogisticRegression(max_iter=1000, random_state=42)
to_drop = collinear_cols(X, threshold=0.95)
lr.fit(X.drop(to_drop, axis=1), y)

In [None]:
X = (
    StandardScaler()
    .set_output(transform="pandas")
    .fit_transform(descriptors_df.dropna())
)
y = smiles.get_IC50(ligand_df.iloc[X.index]).dropna()
y = y[y < 2e5]
y = pd.cut(y, bins=np.logspace(0, 5, num=5, base=10) - 1, labels=np.arange(4)).dropna()
X = X.loc[y.index]

lr = LogisticRegression(max_iter=1000, random_state=42)
to_drop = collinear_cols(X, threshold=0.9)
log_reg = MNLogit(y, X.drop(to_drop + ["NumRadicalElectrons", "Ipc"], axis=1)).fit(
    maxiter=100
)

print(log_reg.summary())

In [None]:
p_values = log_reg.pvalues
weights = log_reg.params


# Define significance levels
def significance_stars(p):
    if p <= 0.001:
        return "***"
    elif p <= 0.01:
        return "**"
    elif p <= 0.05:
        return "*"
    else:
        return ""


# Apply significance levels to p-values
significance = p_values.applymap(significance_stars)

# Combine weights and p-values into a single DataFrame for plotting
weights_melt = weights.reset_index().melt(
    id_vars="index", var_name="Class", value_name="Weight"
)
p_values_melt = p_values.reset_index().melt(
    id_vars="index", var_name="Class", value_name="P-Value"
)
significance_melt = significance.reset_index().melt(
    id_vars="index", var_name="Class", value_name="Significance"
)

# Merge data for plotting
combined = weights_melt.merge(p_values_melt, on=["index", "Class"]).merge(
    significance_melt, on=["index", "Class"]
)

In [None]:
# Custom legend mapping
legend_mapping = {
    0: "low IC50 [0,17[",
    1: "medium IC50 [17,315[",
    2: "high IC50 [315,5600[",
}


combined = weights_melt.merge(p_values_melt, on=["index", "Class"]).merge(
    significance_melt, on=["index", "Class"]
)

fig = go.Figure()

# Iterate over the classes to create bars for each class with new legend names
for cls in combined["Class"].unique():
    df_class = combined[combined["Class"] == cls]
    fig.add_trace(
        go.Bar(
            x=df_class["index"],
            y=df_class["Weight"],
            name=legend_mapping[int(cls)],  # Use custom legend names
            text=df_class["Significance"],  # Significance annotations
            textposition="outside",
            marker=dict(line=dict(width=0.5, color="black")),
        )
    )

# Add a horizontal line at y=0
fig.add_hline(y=0, line_dash="dash", line_color="black", line_width=1)

# Update layout
fig.update_layout(
    title="Predictor Weights with Significance Levels (Multinomial Logistic Regression)",
    xaxis_title="Predictor",
    yaxis_title="Weight (Coefficient)",
    barmode="group",
    legend_title="IC50 Range",
    template="plotly_dark",
)

# Customize x-axis tick labels
fig.update_xaxes(tickangle=45, tickmode="array", tickvals=combined["index"].unique())

# Show the figure
fig.show()

### Cheng Prusoff

In [None]:
from src.scripts import cheng_prusoff_classification as cp

# select for cheng prusoff
usecols = [
    "Ki (nM)",
    "IC50 (nM)",
]
# load data
df_cp = pd.read_csv(
    r"/Users/poseidon/Documents/ADA/ada-2024-project-standarddeviants/data/BindingDB_All.tsv",
    sep="\t",
    usecols=usecols,
)

# clean data
df_cp.replace(" NV,", np.nan, inplace=True)
for col in df_cp.columns:
    df_cp[col] = df_cp[col].astype(str).str.replace(" C", "")
    df_cp[col] = (
        df_cp[col].astype(str).str.replace(">", "").str.replace("<", "").astype(float)
    )

# prepare data for cheng-prusoff
both_log = cp.cheng_prusoff_data(df_cp)

b_max = cp.cheng_prusoff_classifier(
    df=both_log, min_seperator_interecpt=0, show_evaluation=True
)

# assign labels with optimal seperator intercept; cluster 1 = below line
alpha = 0.71  # as found above
b = b_max[alpha]
classified_action = both_log.copy().drop(columns=["label"])
classified_action["cluster"] = (both_log["log_IC50"] < both_log["log_Ki"] + b).astype(
    int
)
classified_action["cluster"] = classified_action["cluster"].replace({1: "A", 0: "B"})
fam0 = classified_action.query("cluster == 'A'")
fam1 = classified_action.query("cluster == 'B'")

res0 = cp.cheng_prusoff_model(fam0)
res1 = cp.cheng_prusoff_model(fam1)

In [None]:
colors = ["lightcoral", "lightblue"]
scatter_traces = []
clusters = [fam0, fam1]
models = [res0, res1]

annotations = []

for idx, (cluster, model) in enumerate(zip(clusters, models)):
    # Scatter points
    scatter_traces.append(
        go.Scatter(
            x=cluster["log_IC50"],
            y=cluster["log_Ki"],
            mode="markers",
            name=f"Cluster {chr(65+idx)}",
            marker=dict(color=colors[idx], size=5),
        )
    )

    # Regression line (droite de régression)
    x_vals = np.linspace(cluster["log_IC50"].min(), cluster["log_IC50"].max(), num=50)
    intercept, slope = model.params
    y_vals = slope * x_vals + intercept

    scatter_traces.append(
        go.Scatter(
            x=x_vals,
            y=y_vals,
            mode="lines",
            name=f"Regression {chr(65+idx)}",
            line=dict(color=colors[idx], dash="dash"),
        )
    )

    # Add regression equation as annotation
    equation = f"y = {slope:.2f}x + {intercept:.2f}"
    annotations.append(
        dict(
            x=cluster["log_IC50"].max(),
            y=cluster["log_Ki"].max(),
            text=f"Cluster {chr(65+idx)}: {equation}",
            showarrow=False,
            font=dict(size=12, color=colors[idx]),
        )
    )

    for line in kde.collections:
        path = line.get_paths()[0]
        verts = path.vertices
        scatter_traces.append(
            go.Scatter(
                x=verts[:, 0],
                y=verts[:, 1],
                mode="lines",
                line=dict(color=colors[idx], width=1),
                name=f"Density Cluster {chr(65+idx)}",
                showlegend=False,  # Avoid redundancy
            )
        )

# Combine scatter plot, regression lines, and density contours
fig = go.Figure(data=scatter_traces)
fig.update_layout(
    title=dict(
        text="Log Ki vs Log IC50 Clusters with Regression Lines", x=0.5
    ),  # Center title
    xaxis_title="Log IC50",
    yaxis_title="Log Ki",
    legend_title="Cluster Details",
    template="plotly_dark",
    annotations=annotations,  # Add the annotations for regression equations
)


# Show plots
fig.show()