In [56]:
import pickle
import numpy as np
from matplotlib import pyplot as plt
from pathlib import Path
import pandas as pd

In [57]:
def get_stat_str(v, ids):
    df = pd.DataFrame({"i": ids, "v": v})
    mean_values = df.groupby('i')['v'].mean().values
    
    mean_val = np.mean(mean_values)
    std_val = np.std(mean_values)
    med_val = np.median(mean_values)
    
    simple_mean = np.mean(v)
    simple_median = np.median(v)
    return f"{mean_val:.3f} +- {std_val:.3f} (median: {med_val:.3f})    ({simple_mean:.2f}, {simple_median:.2f})"

In [58]:
def get_metrics(d):
    length = d['length']
    if "diversity" in d:
        diversity = d['diversity']
    ring_cond = d['num_rings'] == d['num_valid_rings']
    is_tree = d['num_rings'] == 0
    bond_cond = d['num_bonds'] == d['num_valid_bonds']
    angle_cond = d['num_angles'] == d['num_valid_angles']
    clash_cond = d['num_noncov'] == d['num_valid_noncov']
    qed = d['qed']
    if 'docking_score' in d:
        docking_score = d['docking_score']
    if 'docking_score_better_than_ref' in d:
        better_docking = d['docking_score_better_than_ref']
    lipinski = d['lipinski']
    tansim = d['tansim']
    sa = d['sa']
    

    bust_pass = ring_cond & bond_cond & bond_cond & angle_cond & clash_cond
    
    result = {
        "length": length,
        "ring_cond": ring_cond,
        "is_tree": is_tree,
        "bond_cond": bond_cond,
        "angle_cond": angle_cond,
        "clash_cond": clash_cond,
        "qed": qed,
        "lipinski": lipinski,
        "tansim": tansim,
        "sa": sa,
    }
    
    if "diversity" in d:
        result.update({
            "diversity": diversity
        })
    
    result.update({
        "bust_pass": bust_pass
    })
    
    if "docking_score" in d:
        result.update({
            "docking_score": docking_score
        })
    
    if 'docking_score_better_than_ref' in d:
        result.update({
            "better_docking": better_docking
        })
    
    return result

def print_stats(d, only_core=False):
    metrics = get_metrics(d)
    
    target_ids = d["target_id"]
    unique_target_ids = np.unique(target_ids)
    print("# of targets:", len(unique_target_ids))
    print("average_n:", len(target_ids) / len(unique_target_ids))
    for key, vals in metrics.items():
        if only_core:
            if not key in ["length", "diversity", "bust_pass", "ring_cond", "bond_cond", "angle_cond", "clash_cond", "qed", "lipinski", "sa", "docking_score", "better_docking"]:
                continue
        if key in ["ring_cond", "bond_cond", "angle_cond", "clash_cond", "bust_pass"]:
            key = f"{key}_fail"
            vals = 1 - vals
        print(f"{key}:", get_stat_str(vals, target_ids))

In [59]:
info = {
    "ligan": "ligan",
    "ar": "ar",
    "graphbp": "graphbp",
    "flag:": "flag",
    "targetdiff": "targetdiff",
    "pocket2mol": "pocket2mol",
    "p2mrl": "p2mrl",
    "ref": "ref",
}

def load_metrics(file):
    if file.exists():
        with open(file, "rb") as f:
            return pickle.load(f)
    else:
        return None
    
test_metrics = {}
for key, val in info.items():
    file = Path(f"test_cache/test_{val}_outputs/metrics.pkl")
    if file is None:
        continue
    metrics = load_metrics(file)
    if metrics is not None:
        test_metrics[key] = metrics

In [60]:
print("<< test Results >>")
for title, metrics in test_metrics.items():
    print("<", title, ">")
    print_stats(metrics, only_core=True)
    print("----------------------------------------")

<< test Results >>
< ligan >
# of targets: 100
average_n: 96.9
length: 19.976 +- 8.190 (median: 20.201)    (19.93, 20.00)
ring_cond_fail: 0.140 +- 0.183 (median: 0.033)    (0.14, 0.00)
bond_cond_fail: 0.969 +- 0.089 (median: 1.000)    (0.97, 1.00)
angle_cond_fail: 0.930 +- 0.137 (median: 1.000)    (0.93, 1.00)
clash_cond_fail: 0.564 +- 0.317 (median: 0.624)    (0.56, 1.00)
qed: 0.391 +- 0.192 (median: 0.421)    (0.39, 0.39)
lipinski: 4.168 +- 1.233 (median: 4.893)    (4.16, 5.00)
sa: 0.591 +- 0.123 (median: 0.580)    (0.59, 0.57)
diversity: 0.345 +- 0.114 (median: 0.334)    (0.34, 0.33)
bust_pass_fail: 0.985 +- 0.057 (median: 1.000)    (0.98, 1.00)
docking_score: -6.359 +- 1.564 (median: -6.453)    (-6.36, -6.40)
better_docking: 0.131 +- 0.188 (median: 0.051)    (0.13, 0.00)
----------------------------------------
< ar >
# of targets: 97
average_n: 89.17525773195877
length: 17.558 +- 6.001 (median: 16.820)    (17.23, 16.00)
ring_cond_fail: 0.330 +- 0.309 (median: 0.221)    (0.32, 0.00