In [None]:
# !export PYTHONPATH=/Users/ilariasartori/syntheseus:/Users/ilariasartori/syntheseus/tutorials/search

In [None]:
# !echo $PYTHONPATH

In [None]:
from datetime import datetime
from uuid import uuid4
import os

eventid = datetime.now().strftime('%Y%m-%d%H-%M%S-') + str(uuid4())
print(eventid)

output_folder = f"Results/{eventid}"

if not os.path.exists(output_folder):
    os.makedirs(output_folder)



In [None]:
"""Basic code for nearest-neighbour value functions."""
from __future__ import annotations

from enum import Enum

import numpy as np
from rdkit.Chem import DataStructs, AllChem

from syntheseus.search.graph.and_or import OrNode
from syntheseus.search.node_evaluation.base import NoCacheNodeEvaluator
from syntheseus.search.mol_inventory import ExplicitMolInventory

# from Users.ilariasartori.syntheseus.search.graph.and_or import OrNode


class DistanceToCost(Enum):
    NOTHING = 0
    EXP = 1
    SQRT = 2
    TIMES10 = 3
    TIMES100 = 4
    NUM_NEIGHBORS_TO_06 = 5


class TanimotoNNCostEstimator(NoCacheNodeEvaluator):
    """Estimates cost of a node using Tanimoto distance to purchasable molecules."""

    def __init__(
        self,
        inventory: ExplicitMolInventory,
        distance_to_cost: DistanceToCost,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.distance_to_cost = distance_to_cost
        self._set_fingerprints([mol.smiles for mol in inventory.purchasable_mols()])

    def get_fingerprint(self, mol: AllChem.Mol):
        return AllChem.GetMorganFingerprint(mol, radius=3)

    def _set_fingerprints(self, smiles_list: list[str]) -> None:
        """Initialize fingerprint cache."""
        mols = list(map(AllChem.MolFromSmiles, smiles_list))
        assert None not in mols, "Invalid SMILES encountered."
        self._fps = list(map(self.get_fingerprint, mols))
        
    def find_min_num_elem_summing_to_threshold(self, array, threshold):
        # Sort the array in ascending order
        sorted_array = np.sort(array)[::-1]

        # Calculate the cumulative sum of the sorted array
        cum_sum = np.cumsum(sorted_array)

        # Find the index where the cumulative sum exceeds threshold 
        index = np.searchsorted(cum_sum, threshold)

        # Check if a subset of elements sums up to more than threshold
        if index < len(array):
            return index + 1  # Add 1 to account for 0-based indexing

        # If no subset of elements sums up to more than threshold
        return len(array) #-1

    def _get_nearest_neighbour_dist(self, smiles: str) -> float:
        fp_query = self.get_fingerprint(AllChem.MolFromSmiles(smiles))
        tanimoto_sims = DataStructs.BulkTanimotoSimilarity(fp_query, self._fps)
        if self.distance_to_cost == DistanceToCost.NUM_NEIGHBORS_TO_06:
            return 1 - self.find_min_num_elem_summing_to_threshold(array=tanimoto_sims,threshold=0.6)/ len(tanimoto_sims)
        else:
            return 1 - max(tanimoto_sims)

    def _evaluate_nodes(self, nodes: list[OrNode], graph=None) -> list[float]:
        if len(nodes) == 0:
            return []

        # Get distances to nearest neighbours
        nn_dists = np.asarray(
            [self._get_nearest_neighbour_dist(node.mol.smiles) for node in nodes]
        )
        assert np.min(nn_dists) >= 0

        # Turn into costs
        if self.distance_to_cost == DistanceToCost.NOTHING:
            values = nn_dists
        elif self.distance_to_cost == DistanceToCost.EXP:
            values = np.exp(nn_dists) - 1
        elif self.distance_to_cost == DistanceToCost.SQRT:
            values = np.sqrt(nn_dists) 
        elif self.distance_to_cost == DistanceToCost.TIMES10:
            values = 10.0*nn_dists
        elif self.distance_to_cost == DistanceToCost.TIMES100:
            values = 100.0*nn_dists
        elif self.distance_to_cost == DistanceToCost.NUM_NEIGHBORS_TO_06:
            values = nn_dists
        else:
            raise NotImplementedError(self.distance_to_cost)

        return list(values)


In [None]:
route_div = True # Count number of diverse routes found

In [None]:
## Faster implementation


from __future__ import annotations

from collections.abc import Sequence

from syntheseus.search.algorithms.best_first.retro_star import RetroStarSearch
from syntheseus.search.graph.and_or import ANDOR_NODE, AndOrGraph, OrNode


class ReduceValueFunctionCallsRetroStar(RetroStarSearch):
    """
    More efficient version of Retro* which saves value function calls.
    The difference is that retro* calls the value function (i.e. reaction number estimator)
    for every leaf node whereas this algorithm assigns a placeholder value of 0 to every leaf node
    and only calls the value function if it visits that node a second time.
    This essentially leaves the behaviour of retro* unchanged, but saves value function calls.

    The reason this works is that retro* greedily expands nodes on the current lowest-cost route,
    using the value function (reaction number) estimate as the cost of the node.
    If a node is not visited with a value function estimate of 0,
    then it would definitely not be visited with a non-zero value function estimate.
    Therefore if a node is not visited with a placeholder value of 0,
    it doesn't really matter what the value function estimate is.
    """

    def setup(self, graph: AndOrGraph) -> None:
        # If there is only 1 node, "visit" it by setting its reaction number estimate to 0
        # and incrementing its visit count
        if len(graph) == 1:
            graph.root_node.num_visit += 1
            graph.root_node.data.setdefault("reaction_number_estimate", 0.0)

        return super().setup(graph)

    def visit_node(self, node: OrNode, graph: AndOrGraph) -> Sequence[ANDOR_NODE]:
        """
        If node.num_visit == 0 then evaluate the value function and return.
        Otherwise expand.
        """
        assert node.num_visit >= 0  # should not be negative
        node.num_visit += 1
        if node.num_visit == 1:
            # Evaluate value function and return.
            node.data["reaction_number_estimate"] = self.reaction_number_estimator(
                [node]
            )[0]
            return []
        else:
            return super().visit_node(node, graph)

    def _set_reaction_number_estimate(
        self, or_nodes: Sequence[OrNode], graph: AndOrGraph
    ) -> None:
        for node in or_nodes:
            node.data.setdefault("reaction_number_estimate", 0.0)


In [None]:
"""
Demo script comparing nearest neighbour cost function with constant value function on PaRoutes.
"""
from __future__ import annotations

import argparse
import logging
import sys
import numpy as np

from tqdm.auto import tqdm

from syntheseus.search.chem import Molecule
from syntheseus.search.graph.and_or import AndNode
from syntheseus.search.algorithms.best_first.retro_star import RetroStarSearch, MolIsPurchasableCost
from syntheseus.search.analysis.solution_time import get_first_solution_time
from syntheseus.search.analysis.route_extraction import min_cost_routes
from syntheseus.search.reaction_models.base import BackwardReactionModel
from syntheseus.search.mol_inventory import BaseMolInventory
from syntheseus.search.node_evaluation.base import (
    BaseNodeEvaluator,
    NoCacheNodeEvaluator,
)
from syntheseus.search.node_evaluation.common import ConstantNodeEvaluator

from paroutes import PaRoutesInventory, PaRoutesModel, get_target_smiles
# from neighbour_value_functions import TanimotoNNCostEstimator, DistanceToCost

from syntheseus.search.analysis import diversity
# from syntheseus.search.algorithms.best_first.retro_star import MolIsPurchasableCost

class SearchResult:
    def __init__(self, name, soln_time_dict, num_different_routes_dict, 
                 final_num_rxn_model_calls_dict, final_num_value_function_calls_dict,
                 output_graph_dict, routes_dict):
        self.name = name
        self.soln_time_dict = soln_time_dict
        self.num_different_routes_dict = num_different_routes_dict
        self.final_num_rxn_model_calls_dict = final_num_rxn_model_calls_dict
        self.output_graph_dict = output_graph_dict
        self.routes_dict = routes_dict
        self.final_num_value_function_calls_dict = final_num_value_function_calls_dict


class PaRoutesRxnCost(NoCacheNodeEvaluator[AndNode]):
    """Cost of reaction is negative log softmax, floored at -3."""

    def _evaluate_nodes(self, nodes: list[AndNode], graph=None) -> list[float]:
        softmaxes = np.asarray([node.reaction.metadata["softmax"] for node in nodes])
        costs = np.clip(-np.log(softmaxes), 1e-1, 10.0)
        return costs.tolist()


def run_algorithm(
    name: str,
    smiles_list: list[str],
    value_function: BaseNodeEvaluator,
    rxn_model: BackwardReactionModel,
    inventory: BaseMolInventory,
    and_node_cost_fn: BaseNodeEvaluator[AndNode],
    or_node_cost_fn: BaseNodeEvaluator[OrNode],
    max_expansion_depth: int = 15,
    prevent_repeat_mol_in_trees: bool= True,
    use_tqdm: bool = False,
    limit_rxn_model_calls: int = 100,
    limit_iterations: int = 1_000_000,
    logger: logging.RootLogger = logging.getLogger(),
) -> SearchResult:
    """
    Do search on a list of SMILES strings and report the time of first solution.
    """

    # Initialize algorithm.
    common_kwargs = dict(
        reaction_model=rxn_model,
        mol_inventory=inventory,
        limit_reaction_model_calls=limit_rxn_model_calls,
        limit_iterations=limit_iterations,
        max_expansion_depth=max_expansion_depth,  # prevent overly-deep solutions
        prevent_repeat_mol_in_trees=prevent_repeat_mol_in_trees,  # original paper did this
    )
    alg = ReduceValueFunctionCallsRetroStar(
            and_node_cost_fn=PaRoutesRxnCost(), value_function=value_function, **common_kwargs
        )

    # Do search
    logger.info(f"Start search with {name}")
    min_soln_times: list[tuple[float, ...]] = []
    if use_tqdm:
        smiles_iter = tqdm(smiles_list)
    else:
        smiles_iter = smiles_list
        
    output_graph_dict = {}
    soln_time_dict = {}
    routes_dict = {}
    final_num_rxn_model_calls_dict = {}
    final_num_value_function_calls_dict = {}
    num_different_routes_dict = {}
    
    for i, smiles in enumerate(smiles_iter):
        logger.debug(f"Start search {i}/{len(smiles_list)}. SMILES: {smiles}")
        this_soln_times = list()
        alg.reset()
        output_graph, _ = alg.run_from_mol(Molecule(smiles))

        # Analyze solution time
        for node in output_graph.nodes():
            node.data["analysis_time"] = node.data["num_calls_rxn_model"]
        soln_time = get_first_solution_time(output_graph)
        this_soln_times.append(soln_time)

        # Analyze number of routes
        MAX_ROUTES = 10000
        routes = list(min_cost_routes(output_graph, MAX_ROUTES))

        if alg.reaction_model.num_calls() < limit_rxn_model_calls:
            note = " (NOTE: this was less than the maximum budget)"
        else:
            note = ""
        logger.debug(
            f"Done {name}: nodes={len(output_graph)}, solution time = {soln_time}, "
            f"num routes = {len(routes)} (capped at {MAX_ROUTES}), "
            f"final num rxn model calls = {alg.reaction_model.num_calls()}{note}, "
            f"final num value model calls = {alg.value_function.num_calls}."
        )

        # Analyze route diversity 
        if (len(routes)>0) & route_div:
            route_objects = [output_graph.to_synthesis_graph(nodes) for nodes in routes]
            packing_set = diversity.estimate_packing_number(
                routes=route_objects,
                distance_metric=diversity.reaction_jaccard_distance,
                radius=0.999  # because comparison is > not >=
            )
            logger.debug((f"number of distinct routes = {len(packing_set)}"))
        else:
            packing_set = []

        # Save results
        soln_time_dict.update({smiles: soln_time})
        final_num_rxn_model_calls_dict.update({smiles: alg.reaction_model.num_calls()})
        final_num_value_function_calls_dict.update({smiles: alg.value_function.num_calls})
        num_different_routes_dict.update({smiles: len(packing_set)})
        output_graph_dict.update({smiles: output_graph})
        routes_dict.update({smiles: routes})
            
    return SearchResult(name=name,
                        soln_time_dict=soln_time_dict, 
                        num_different_routes_dict=num_different_routes_dict, 
                        final_num_rxn_model_calls_dict=final_num_rxn_model_calls_dict, 
                        final_num_value_function_calls_dict=final_num_value_function_calls_dict,
                        output_graph_dict=output_graph_dict, 
                        routes_dict=routes_dict)
    



In [None]:
# Arguments

# COMMAND LINE
# if __name__ == "__main__":
#     parser = argparse.ArgumentParser()
#     parser.add_argument(
#         "--limit_num_smiles",
#         type=int,
#         default=None,
#         help="Maximum number of SMILES to run.",
#     )
#     parser.add_argument(
#         "--limit_iterations",
#         type=int,
#         default=500,
#         help="Maximum number of algorithm iterations.",
#     )
#     parser.add_argument(
#         "--limit_rxn_model_calls",
#         type=int,
#         default=25,
#         help="Allowed number of calls to reaction model.",
#     )
#     parser.add_argument(
#         "--paroutes_n",
#         type=int,
#         default=5,
#         help="Which PaRoutes benchmark to use.",
#     )
#     args = parser.parse_args()

# NOTEBOOK
class Args:
    limit_num_smiles = None
    limit_iterations = 500 # 100000
    limit_rxn_model_calls = 100 # 500
    paroutes_n = 5
    max_expansion_depth = 20
    max_num_templates = 10  # Default 50
    prevent_repeat_mol_in_trees = True
    rxn_model = 'PAROUTES'
    inventory = 'PAROUTES'
    and_node_cost_fn='PAROUTES'
    or_node_cost_fn = 'MOL_PURCHASABLE' 


args=Args()


## Create dataframe for time to solution and number of routes found

In [None]:
import json

# Logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
# formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s')
formatter = logging.Formatter('%(asctime)s %(name)s %(levelname)s %(message)s')

stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setLevel(logging.INFO)
stdout_handler.setFormatter(formatter)

file_handler = logging.FileHandler(f'{output_folder}/logs.txt', mode='w')
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)

logger.addHandler(file_handler)
logger.addHandler(stdout_handler)


# Load all SMILES to test
test_smiles = get_target_smiles(args.paroutes_n)

## Test on smaller dataset
test_smiles_all = test_smiles.copy()
dim_test = 100 # 500
test_smiles = test_smiles_all[:dim_test]
##

if args.limit_num_smiles is not None:
    test_smiles = test_smiles[: args.limit_num_smiles]

# Make reaction model, inventory, cost functions and value functions
if args.and_node_cost_fn == 'PAROUTES':
    and_node_cost_fn=PaRoutesRxnCost()
else:
    raise NotImplementedError(f'and_node_cost_fn: {args.and_node_cost_fn}')

if args.or_node_cost_fn == 'MOL_PURCHASABLE':
    or_node_cost_fn=MolIsPurchasableCost()
else:
    raise NotImplementedError(f'or_node_cost_fn: {args.or_node_cost_fn}')

if args.inventory == 'PAROUTES':
    inventory=PaRoutesInventory(n=args.paroutes_n)
else:
    raise NotImplementedError(f'inventory: {args.inventory}')

if args.rxn_model == 'PAROUTES':
    rxn_model=PaRoutesModel(max_num_templates=args.max_num_templates)
else:
    raise NotImplementedError(f'rxn_model: {args.rxn_model}')


value_fns = [
#     ("constant-0", ConstantNodeEvaluator(0.0)),
#     (
#         "Tanimoto-distance",
#         TanimotoNNCostEstimator(
#             inventory=inventory, distance_to_cost=DistanceToCost.NOTHING
#         ),
#     ),
#     (
#         "Tanimoto-distance-TIMES10",
#         TanimotoNNCostEstimator(
#             inventory=inventory, distance_to_cost=DistanceToCost.TIMES10
#         ),
#     ),
# #     (
# #         "Tanimoto-distance-TIMES100",
# #         TanimotoNNCostEstimator(
# #             inventory=inventory, distance_to_cost=DistanceToCost.TIMES100
# #         ),
# #     ),
#     (
#         "Tanimoto-distance-EXP",
#         TanimotoNNCostEstimator(
#             inventory=inventory, distance_to_cost=DistanceToCost.EXP
#         ),
#     ),
#     (
#         "Tanimoto-distance-SQRT",
#         TanimotoNNCostEstimator(
#             inventory=inventory, distance_to_cost=DistanceToCost.SQRT
#         ),
#     ),
    (
        "Tanimoto-distance-NUM_NEIGHBORS_TO_06",
        TanimotoNNCostEstimator(
            inventory=inventory, distance_to_cost=DistanceToCost.NUM_NEIGHBORS_TO_06
        ),
    ),
]

labelalias = {
    'constant-0': 'constant-0',
    'Tanimoto-distance': 'Tanimoto',
    'Tanimoto-distance-TIMES10': 'Tanimoto_times10',
    'Tanimoto-distance-TIMES100': 'Tanimoto_times100',
    'Tanimoto-distance-EXP': 'Tanimoto_exp',
    'Tanimoto-distance-SQRT': 'Tanimoto_sqrt',
    "Tanimoto-distance-NUM_NEIGHBORS_TO_06": "Tanimoto_nn_to_06",
}



# Run
logger.info(f"Start experiment {eventid}")
args_string = ""
for attr in dir(args):
    if not callable(getattr(args, attr)) and not attr.startswith("__"):
        args_string = args_string + "\n" + (f"{attr}: {getattr(args, attr)}") 
logger.info(f"Args: {args_string}")
logger.info(f"dim_test: {dim_test}")


import pickle

result={}
for name, fn in value_fns:
    alg_result = run_algorithm(
        name=name,
        smiles_list=test_smiles, 
        value_function=fn, 
        rxn_model=rxn_model,
        inventory=inventory,
        and_node_cost_fn=and_node_cost_fn,
        or_node_cost_fn=or_node_cost_fn, 
        max_expansion_depth=args.max_expansion_depth, 
        prevent_repeat_mol_in_trees=args.prevent_repeat_mol_in_trees, 
        use_tqdm=True,
        limit_rxn_model_calls=args.limit_rxn_model_calls, 
        limit_iterations=args.limit_iterations,
        logger=logger,
    )
    result[name] = alg_result
    
    # Save pickle
    with open(f'{output_folder}/result_{name}.pickle', 'wb') as handle:
        pickle.dump(alg_result, handle, protocol=pickle.HIGHEST_PROTOCOL)




In [None]:
import pandas as pd

def create_result_df(result, name):
    assert name == result[name].name, f"name: {name} is different from result[name].name: {result[name].name}"
    
    soln_time_dict = result[name].soln_time_dict
    num_different_routes_dict = result[name].num_different_routes_dict
    final_num_rxn_model_calls_dict = result[name].final_num_rxn_model_calls_dict
    output_graph_dict = result[name].output_graph_dict
    routes_dict = result[name].routes_dict

    # df_results = pd.DataFrame()
    df_soln_time = pd.DataFrame({'algorithm': [], 'similes': [], 'property':[], 'value': []})
    df_different_routes = pd.DataFrame({'algorithm': [], 'similes': [], 'property':[], 'value': []})

    #     for name_alg, value_dict  in soln_time_dict.items():
    for smiles, value  in soln_time_dict.items():
        row_soln_time = {'algorithm': name, 'similes': smiles, 'property':'sol_time', 'value': value}

        df_soln_time = pd.concat([df_soln_time, pd.DataFrame([row_soln_time])], ignore_index=True)

    #     for name_alg, value_dict  in num_different_routes_dict.items():
    for smiles, value  in num_different_routes_dict.items():
        row_different_routes = {'algorithm': name, 'similes': smiles, 'property':'diff_routes', 'value': value}

        df_different_routes = pd.concat([df_different_routes, pd.DataFrame([row_different_routes])], ignore_index=True)

    df_results_tot = pd.concat([df_soln_time, df_different_routes], axis=0)
    return df_results_tot



df_results_tot = pd.DataFrame({'algorithm': [], 'similes': [], 'property':[], 'value': []})
for name in result.keys():
    df_results_alg = create_result_df(result, name)
    df_results_tot = pd.concat([df_results_tot, df_results_alg], axis=0)
    
    
    
    

In [None]:
df_results_tot

In [None]:
# Save to csv
# df_results_tot.to_csv(f'Results/Compare/compare_times_{dim_test}.csv', index=False)
df_results_tot.to_csv(f'{output_folder}/results_all.csv', index=False)
