In [12]:
from dataclasses import dataclass
from typing import Callable, Dict, Tuple
import math
import functools

from rdkit import Chem
from rdkit.Chem import QED as qed_module

import dockstring

@dataclass
class BenchmarkObjective:
    """
    General class for an objective function whose calculation requires
    1) evaluating a set of independent base functions
    2) aggregating those scores together into a single score
    """
    base_functions: Dict[str, Callable[[str], float]]
    aggregation_function: Callable[..., float]

    def _eval_base_functions(self, smiles: str) -> Dict[str, float]:
        return {name: f(smiles) for name, f in self.base_functions.items()}

    def __call__(self, smiles: str) -> Tuple[float, Dict[str, float]]:
        """Call all in 1."""
        base_fn_vals = self._eval_base_functions(smiles)
        return self.aggregation_function(**base_fn_vals), base_fn_vals


def safe_dock_function(smiles: str, target_name: str, **dock_kwargs):
    """Call dockstring and return nan if there are errors."""
    target = dockstring.load_target(target_name)
    try:
        docking_output = target.dock(smiles, **dock_kwargs)
        score = docking_output[0]
    except dockstring.DockstringError:
        score = float("nan")
    return score


def QED(smiles: str) -> float:
    """Calculates QED from a SMILES string."""
    mol = Chem.MolFromSmiles(smiles)
    return qed_module.qed(mol)

# Raw functions
def QED_penalty(qed: float) -> float:
    return 10.0 * (1.0 - qed)


def F2_score(*, F2: float, QED: float) -> float:
    return F2 + QED_penalty(QED)


def promiscuous_PPAR_score(*, PPARA: float, PPARD: float, PPARG: float, QED: float) -> float:
    # Max of a list of NaNs is not always NaN so we check this manually
    if any(math.isnan(v) for v in [PPARA, PPARD, PPARG]):
        return math.nan
    return max(PPARA, PPARD, PPARG) + QED_penalty(QED)


def selective_JAK2_score(*, JAK2: float, LCK: float, QED: float) -> float:

    # Note: there was a small error in the formula for this objective
    # in our JCIM publication. *THIS* is the correct formula,
    # which matches the numbers in all of our tables
    lck_median_score = -8.1
    return JAK2 - min(LCK - lck_median_score, 0) + QED_penalty(QED)

# Our addition

def GFR_score(*, FGFR1: float, EGFR: float, KDR: float, QED: float) -> float:
    """
    Scoring function for GFR inhibitor design.
    
    Args:
        FGFR1 (float): Docking score for the desired target FGFR1
        EGFR (float): Docking score for the off-target EGFR
        KDR (float): Docking score for the off-target KDR
        QED (float): Quantitative Estimation of Drug-likeness (QED) score
        
    Returns:
        float: Composite score balancing potent FGFR1 binding, off-target penalties, and drug-likeness
    """
    # Check for NaN values
    if any(math.isnan(v) for v in [FGFR1, EGFR, KDR]):
        return math.nan
    
    # Define a target docking score threshold
    target_threshold = -8.0  # Adjust as needed
    
    # Calculate off-target penalties
    egfr_penalty = max(EGFR - target_threshold, 0)
    kdr_penalty = max(KDR - target_threshold, 0)
    
    # Calculate the composite score
    composite_score = FGFR1 - egfr_penalty - kdr_penalty + QED_penalty(QED)
    
    return composite_score


def get_benchmark_functions(**dock_kwargs) -> Dict[str, BenchmarkObjective]:
    """
    Returns the functions for the original benckmarks.

    dock_kwargs specifies kwargs to pass to the `target.dock` function
    (e.g. pH, num_cpus)
    """
    output: Dict[str, BenchmarkObjective] = dict()

    # F2
    output["F2"] = BenchmarkObjective(
        base_functions=dict(
            F2=functools.partial(safe_dock_function, target_name="F2", **dock_kwargs),
            QED=QED,
        ),
        aggregation_function=F2_score,
    )

    # Promiscuous PPAR
    ppar_funcs = {
        target_name: functools.partial(safe_dock_function, target_name=target_name, **dock_kwargs)
        for target_name in ["PPARA", "PPARD", "PPARG"]
    }
    output["promiscuous_PPAR"] = BenchmarkObjective(
        base_functions=dict(
            QED=QED,
            **ppar_funcs,
        ),
        aggregation_function=promiscuous_PPAR_score,
    )

    # Selective JAK2
    jak2_funcs = {
        target_name: functools.partial(safe_dock_function, target_name=target_name, **dock_kwargs)
        for target_name in ["JAK2", "LCK"]
    }
    output["selective_JAK2"] = BenchmarkObjective(
        base_functions=dict(
            QED=QED,
            **jak2_funcs,
        ),
        aggregation_function=selective_JAK2_score,
    )

    # GFR
    gfr_funcs = {
        target_name: functools.partial(safe_dock_function, target_name=target_name, **dock_kwargs)
        for target_name in ["FGFR1", "EGFR", "KDR"]
    }
    output["GFR"] = BenchmarkObjective(
        base_functions=dict(
            QED=QED,
            **gfr_funcs,
        ),
        aggregation_function=GFR_score,
    )

    return output

In [13]:
benchmark_function_dict = get_benchmark_functions()  # TODO: adjust if needed
print(f"Keys: {benchmark_function_dict.keys()}")
benchmark_function_dict

NameError: name 'QED_function' is not defined

In [7]:
from rdkit import Chem
with Chem.SDMolSupplier("./data/ABL1.sdf") as w:
    mols = [mol for mol in w]