In [None]:
%load_ext autoreload
%autoreload 2

import sys

sys.path.append("../..")
from notebooks import notebook_setup

notebook_setup.setup_notebook()

In [None]:
from pathlib import Path

import ipywidgets as widgets
import nglview as nv
import numpy as np
import pandas as pd
import yaml
from IPython.display import clear_output, display
from nglview.shape import Shape
from sklearn.cluster import AgglomerativeClustering, MeanShift
from notebooks.notebook_utils import render_protein
from tqdm.auto import tqdm

from src.modules.cluster import cluster

In [None]:
log_path = Path("logs/eval/runs/2025-10-13_16-06-43-908995")
preds_path = log_path / "predictions.csv"

with open(log_path / ".hydra/config.yaml", "r") as f:
    test_dataset = yaml.safe_load(f)["dataloader"]

protein_path = Path(f"data/equipocket/{test_dataset}/raw")

df = pd.read_csv(preds_path)
protein_names = sorted(df["protein_name"].dropna().astype(str).unique().tolist())

In [None]:
out = widgets.Output()
protein_dropdown = widgets.Combobox(
    options=protein_names,
    description="Protein",
    ensure_option=True,  # only allow values from options
    placeholder='Type to search...',
    value=protein_names[0]
)


def _render_protein(protein_name):
    with out:
        clear_output(wait=True)
        view = render_protein(protein_name, df, protein_path)
        display(view)


def _on_dropdown_change(change):
    if change["name"] == "value":
        _render_protein(change["new"])


protein_dropdown.observe(_on_dropdown_change, names="value")

_render_protein(protein_dropdown.value)
display(protein_dropdown, out)

In [None]:
cluster_algorithm = MeanShift()
# cluster_algorithm = AgglomerativeClustering(n_clusters=None, distance_threshold=4.0)

In [None]:
cluster_preds = True
num_global_nodes = 8
threshold = 4

res = []
for p in tqdm(df["protein_name"].unique()):
    binding = np.load(protein_path / f"{p}/binding.npz")

    bindingsite_centers = binding["binding_site_centers"]

    df_selected_protein = df.loc[lambda x: x["protein_name"] == p]
    pred_coords = df_selected_protein[["x", "y", "z"]].values
    confs = df_selected_protein["confidence_0"].values

    lig_ids = np.unique(binding["ligand_ids"])
    num_ligs = len(lig_ids)

    n_rank = {f"n_rank_dca_{i}": 0 for i in range(num_global_nodes)}
    if cluster_preds:
        pred_coords, confs = cluster(pred_coords, confs, algorithm=cluster_algorithm)

    confs_rank = np.argsort(confs)[::-1]

    for lig_id in lig_ids:
        lig_coords = binding["ligand_coords"][binding["ligand_ids"] == lig_id]
        for i, rank in enumerate(range(num_global_nodes)):
            top_i = confs_rank[:i+num_ligs]
            top_i_coords = pred_coords[top_i]

            dist = np.linalg.norm(lig_coords[:, None] - top_i_coords, axis=-1)
            below_dist = (dist <= threshold).any()
            n_rank[f"n_rank_dca_{i}"] += 1 if below_dist else 0

    n_rank_dcc = {f"n_rank_dcc_{i}": 0 for i in range(num_global_nodes)}
    for center in bindingsite_centers:
        for i, rank in enumerate(range(num_global_nodes)):
            top_i = confs_rank[:i+num_ligs]
            top_i_coords = pred_coords[top_i]

            dist = np.linalg.norm(center - top_i_coords, axis=-1)
            below_dist = (dist <= threshold).any()
            n_rank_dcc[f"n_rank_dcc_{i}"] += 1 if below_dist else 0

    res.append({"protein_name": p, "num_ligs": num_ligs, **n_rank, **n_rank_dcc})

In [None]:
df_res = pd.DataFrame(res)

dca_rank_cols = [c for c in df_res.columns if "dca" in c]
for c in dca_rank_cols:
    df_res[f"dca_ratio_{c}"] = round(df_res[c] / df_res["num_ligs"], 2)

dca_ratio_cols = [c for c in df_res.columns if "dca_ratio_" in c]
df_rank_dca = df_res[["protein_name", "num_ligs"] + dca_ratio_cols]

In [None]:
df_rank_dca[dca_ratio_cols].mean(axis=0)

In [None]:
dcc_rank_cols = [c for c in df_res.columns if "dcc" in c]
for c in dcc_rank_cols:
    df_res[f"dcc_ratio_{c}"] = round(df_res[c] / df_res["num_ligs"], 2)

dcc_ratio_cols = [c for c in df_res.columns if "dcc_ratio_" in c]
df_rank_dcc = df_res[["protein_name", "num_ligs"] + dcc_ratio_cols]

In [None]:
df_rank_dcc[dcc_ratio_cols].mean(axis=0)