In [None]:
# Before opening jupyter, in the terminal:
# PYTHONPATH=/Users/ilariasartori/syntheseus:/Users/ilariasartori/syntheseus/tutorials/search

In [None]:
!echo $PYTHONPATH

In [None]:
from datetime import datetime
from uuid import uuid4

eventid = datetime.now().strftime('%Y%m-%d%H-%M%S-') + str(uuid4())
eventid

In [None]:
"""Basic code for nearest-neighbour value functions."""
from __future__ import annotations

from enum import Enum

import numpy as np
from rdkit.Chem import DataStructs, AllChem

from syntheseus.search.graph.and_or import OrNode
from syntheseus.search.node_evaluation.base import NoCacheNodeEvaluator
from syntheseus.search.mol_inventory import ExplicitMolInventory

# from Users.ilariasartori.syntheseus.search.graph.and_or import OrNode


class DistanceToCost(Enum):
    NOTHING = 0
    EXP = 1


class TanimotoNNCostEstimator(NoCacheNodeEvaluator):
    """Estimates cost of a node using Tanimoto distance to purchasable molecules."""

    def __init__(
        self,
        inventory: ExplicitMolInventory,
        distance_to_cost: DistanceToCost,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.distance_to_cost = distance_to_cost
        self._set_fingerprints([mol.smiles for mol in inventory.purchasable_mols()])

    def get_fingerprint(self, mol: AllChem.Mol):
        return AllChem.GetMorganFingerprint(mol, radius=3)

    def _set_fingerprints(self, smiles_list: list[str]) -> None:
        """Initialize fingerprint cache."""
        mols = list(map(AllChem.MolFromSmiles, smiles_list))
        assert None not in mols, "Invalid SMILES encountered."
        self._fps = list(map(self.get_fingerprint, mols))

    def _get_nearest_neighbour_dist(self, smiles: str) -> float:
        fp_query = self.get_fingerprint(AllChem.MolFromSmiles(smiles))
        tanimoto_sims = DataStructs.BulkTanimotoSimilarity(fp_query, self._fps)
        return 1 - max(tanimoto_sims)

    def _evaluate_nodes(self, nodes: list[OrNode], graph=None) -> list[float]:
        if len(nodes) == 0:
            return []

        # Get distances to nearest neighbours
        nn_dists = np.asarray(
            [self._get_nearest_neighbour_dist(node.mol.smiles) for node in nodes]
        )
        assert np.min(nn_dists) >= 0

        # Turn into costs
        if self.distance_to_cost == DistanceToCost.NOTHING:
            values = nn_dists
        elif self.distance_to_cost == DistanceToCost.EXP:
            values = np.exp(nn_dists) - 1
        else:
            raise NotImplementedError(self.distance_to_cost)

        return list(values)


In [None]:
"""
Demo script comparing nearest neighbour cost function with constant value function on PaRoutes.
"""
from __future__ import annotations

import argparse
import logging
import sys
import numpy as np

from tqdm.auto import tqdm

from syntheseus.search.chem import Molecule
from syntheseus.search.graph.and_or import AndNode
from syntheseus.search.algorithms.best_first.retro_star import RetroStarSearch
from syntheseus.search.analysis.solution_time import get_first_solution_time
from syntheseus.search.analysis.route_extraction import min_cost_routes
from syntheseus.search.reaction_models.base import BackwardReactionModel
from syntheseus.search.mol_inventory import BaseMolInventory
from syntheseus.search.node_evaluation.base import (
    BaseNodeEvaluator,
    NoCacheNodeEvaluator,
)
from syntheseus.search.node_evaluation.common import ConstantNodeEvaluator

from paroutes import PaRoutesInventory, PaRoutesModel, get_target_smiles
# from neighbour_value_functions import TanimotoNNCostEstimator, DistanceToCost


class PaRoutesRxnCost(NoCacheNodeEvaluator[AndNode]):
    """Cost of reaction is negative log softmax, floored at -3."""

    def _evaluate_nodes(self, nodes: list[AndNode], graph=None) -> list[float]:
        softmaxes = np.asarray([node.reaction.metadata["softmax"] for node in nodes])
        costs = np.clip(-np.log(softmaxes), 1e-1, 10.0)
        return costs.tolist()

    
def run_algorithm(
    smiles_list: list[str],
    value_function: tuple[str, BaseNodeEvaluator],
    rxn_model: BackwardReactionModel,
    inventory: BaseMolInventory,
    use_tqdm: bool = False,
    limit_rxn_model_calls: int = 100,
    limit_iterations: int = 1_000_000,
):
    """
    Do search on a list of SMILES strings and report the time of first solution.
    """

    # Initialize algorithm.
    common_kwargs = dict(
        reaction_model=rxn_model,
        mol_inventory=inventory,
        limit_reaction_model_calls=limit_rxn_model_calls,
        limit_iterations=limit_iterations,
        max_expansion_depth=15,  # prevent overly-deep solutions
        prevent_repeat_mol_in_trees=True,  # original paper did this
    )
#     algs = [
#         RetroStarSearch(
#             and_node_cost_fn=PaRoutesRxnCost(), value_function=fn, **common_kwargs
#         )
#         for _, fn in value_functions
#     ]
    alg = RetroStarSearch(
            and_node_cost_fn=PaRoutesRxnCost(), value_function=value_function[1], **common_kwargs
        )
    
    name = value_function[0]

    # Do search
    logger = logging.getLogger("RUN ALGORITHM")
#     min_soln_times: list[tuple[float, ...]] = []
    if use_tqdm:
        smiles_iter = tqdm(smiles_list)
    else:
        smiles_iter = smiles_list
        
    output_graph_dict = {} 
    soln_time_dict = {}
    routes_dict = {}
    final_num_rxn_model_calls_dict = {}
    for i, smiles in enumerate(smiles_iter):
        logger.debug(f"Start search {i}/{len(smiles_list)}. SMILES: {smiles}")
#         this_soln_times = list()
#         for (name, _), alg in zip(value_functions, algs):
        
        alg.reset()
        output_graph, _ = alg.run_from_mol(Molecule(smiles))

        # Analyze solution time
        for node in output_graph.nodes():
            node.data["analysis_time"] = node.data["num_calls_rxn_model"]
        soln_time = get_first_solution_time(output_graph)
#         this_soln_times.append(soln_time)

        # Analyze number of routes
        MAX_ROUTES = 10000
        routes = list(min_cost_routes(output_graph, MAX_ROUTES))

        if alg.reaction_model.num_calls() < limit_rxn_model_calls:
            note = " (NOTE: this was less than the maximum budget)"
        else:
            note = ""
        logger.debug(
            f"Done {name}: nodes={len(output_graph)}, solution time = {soln_time}, "
            f"num routes = {len(routes)} (capped at {MAX_ROUTES}), "
            f"final num rxn model calls = {alg.reaction_model.num_calls()}{note}."
        )
#         min_soln_times.append(tuple(this_soln_times))
        output_graph_dict[smiles] = output_graph
        soln_time_dict[smiles] = soln_time
        routes_dict[smiles] = routes
        final_num_rxn_model_calls_dict[smiles] = alg.reaction_model.num_calls()

#     return min_soln_times
    return output_graph_dict, soln_time_dict, routes_dict, final_num_rxn_model_calls_dict





In [None]:
# Arguments

# COMMAND LINE
# if __name__ == "__main__":
#     parser = argparse.ArgumentParser()
#     parser.add_argument(
#         "--limit_num_smiles",
#         type=int,
#         default=None,
#         help="Maximum number of SMILES to run.",
#     )
#     parser.add_argument(
#         "--limit_iterations",
#         type=int,
#         default=500,
#         help="Maximum number of algorithm iterations.",
#     )
#     parser.add_argument(
#         "--limit_rxn_model_calls",
#         type=int,
#         default=25,
#         help="Allowed number of calls to reaction model.",
#     )
#     parser.add_argument(
#         "--paroutes_n",
#         type=int,
#         default=5,
#         help="Which PaRoutes benchmark to use.",
#     )
#     args = parser.parse_args()

# NOTEBOOK
class Args:
    limit_num_smiles = None
    limit_iterations = 500
    limit_rxn_model_calls = 50
    paroutes_n = 5

args=Args()


In [None]:


# Logging
logging.basicConfig(
    stream=sys.stdout,
    level=logging.DEBUG,
    format="%(asctime)s %(name)s %(levelname)s %(message)s",
    filemode="w",
)
logging.getLogger().info(args)

# Load all SMILES to test
test_smiles = get_target_smiles(args.paroutes_n)

## Test on smaller dataset
test_smiles_all = test_smiles.copy()
test_smiles = test_smiles_all[:5]
##

if args.limit_num_smiles is not None:
    test_smiles = test_smiles[: args.limit_num_smiles]

# Make reaction model, inventory, value functions
rxn_model = PaRoutesModel()
inventory = PaRoutesInventory(n=args.paroutes_n)
# value_fns = [
#     ("constant-0", ConstantNodeEvaluator(0.0)),
#     (
#         "Tanimoto-distance",
#         TanimotoNNCostEstimator(
#             inventory=inventory, distance_to_cost=DistanceToCost.NOTHING
#         ),
#     ),
# ]
value_fn1 = ("constant-0", ConstantNodeEvaluator(0.0))
value_fn2 = (
    "Tanimoto-distance",
    TanimotoNNCostEstimator(
        inventory=inventory, distance_to_cost=DistanceToCost.NOTHING
    ),
)
value_fn3 = (
    "Tanimoto-distance",
    TanimotoNNCostEstimator(
        inventory=inventory, distance_to_cost=DistanceToCost.EXP
    ),
)



In [None]:
value_fn = value_fn2


# Run without value function (retro*-0)
output_graph_dict, soln_time_dict, routes_dict, final_num_rxn_model_calls_dict = run_algorithm(
    smiles_list=test_smiles,
    value_function=value_fn,
    limit_rxn_model_calls=args.limit_rxn_model_calls,
    limit_iterations=args.limit_iterations,
    use_tqdm=True,
    rxn_model=rxn_model,
    inventory=inventory,
)




In [None]:
single_molecule = 'O=C(O)CCNC(=O)c1nc(-c2ccncc2)c2c(cc(-c3ccccc3)c(=O)n2CC2CCCCC2)c1O'

In [None]:
output_graph = output_graph_dict[single_molecule]
output_graph

In [None]:
# # WHOLE GRAPH

# from syntheseus.search import visualization
# visualization.visualize_andor(output_graph, f"Results/{value_fn[0]}/{single_molecule}/whole-graph.pdf")

In [None]:
# def min_cost_routes(
#     graph: RetrosynthesisSearchGraph,
#     max_routes: int,
#     stop_cost: Optional[float] = None,
# ) -> Iterator[Collection[BaseGraphNode]]:
#     """
#     Return solved routes from "graph" with the lowest possible cost.
#     Graph can be AND/OR or MolSet graph.
#     The cost of each route is the sum of node.data["route_cost"] for each
#     node in the route. It is assumed that this is set beforehand.
#     It is also assumed that "node.has_solution" is set beforehand.

#     Args:
#         graph: graph whose routes to extract
#         max_routes: maximum number of routes to yield.
#         stop_cost: if provided, iterator will terminate once a route of cost
#             >= stop_cost is encountered
#     """

#     for cost, route in _iter_top_routes(
#         graph=graph,
#         cost_fn=_min_route_cost,
#         cost_lower_bound=_min_route_partial_cost,
#         max_routes=max_routes,
#         yield_partial_routes=False,
#     ):
#         if stop_cost is not None and cost >= stop_cost:
#             break
#         else:
#             yield route

# def _min_route_cost(nodes: Collection[BaseGraphNode], graph: RetrosynthesisSearchGraph) -> float:
#     if _route_has_solution(nodes, graph):
#         return sum(n.data.get("route_cost", 0.0) for n in nodes)
#     else:
#         return math.inf

In [None]:
output_graph.__dict__

In [None]:
subgraph = output_graph._graph
import pandas as pd

first = True
for node in subgraph.nodes:
    if isinstance(node, OrNode): # Molecule 
        row_data = {'smiles': node.mol.smiles,
                    'is_purchasable': node.mol.metadata["is_purchasable"]}
        row_data.update(node.data) 
        if first:
            df_nodes = pd.DataFrame([row_data])
        else:
            df_nodes = pd.concat([df_nodes, pd.DataFrame([row_data])], ignore_index=True)
        first= False
df_nodes

In [None]:
# Assign cost to routes
output_graphs = [output_graph]
for graph in output_graphs:
    for node in graph.nodes():
#         if isinstance(node, (AndNode, MolSetNode)):
        if isinstance(node, (AndNode,)):
            node.data["route_cost"] = 1.0
        else:
            node.data["route_cost"] = 0.0


In [None]:
# INDIVIDUAL ROUTES

# We can extract individual synthesis routes from the above using "route extraction"
from syntheseus.search.analysis import route_extraction
from syntheseus.search import visualization
from syntheseus.search.analysis.route_extraction import _min_route_cost

# This iterator returns sets of nodes which constitute synthesis routes
for i, route_nodes in enumerate(route_extraction.min_cost_routes(output_graph, max_routes=5)):
    visualization.visualize_andor(output_graph, filename=f"Results/{value_fn[0]}/{single_molecule}/Route_{i+1}.pdf", nodes=route_nodes)
    print('MIN ROUTE COST: ', _min_route_cost(route_nodes, output_graph))

In [None]:
for i, route_nodes in enumerate(route_extraction.min_cost_routes(output_graph, max_routes=10000)):
#     visualization.visualize_andor(output_graph, filename=f"Results/{value_fn[0]}/{single_molecule}/Route_{i+1}.pdf", nodes=route_nodes)
    print('MIN ROUTE COST: ', _min_route_cost(route_nodes, output_graph))
    
    

In [None]:
aa = route_extraction.min_cost_routes(output_graph, max_routes=10000)

In [None]:
for i in aa:
    print('hi')