In [None]:
# !export PYTHONPATH=/Users/ilariasartori/syntheseus:/Users/ilariasartori/syntheseus/tutorials/search

In [None]:
# !echo $PYTHONPATH

In [None]:
from datetime import datetime
from uuid import uuid4
import os

eventid = datetime.now().strftime('%Y%m-%d%H-%M%S-') + str(uuid4())
print(eventid)

output_folder = f"Results/{eventid}"

if not os.path.exists(output_folder):
    os.makedirs(output_folder)



In [None]:
"""Basic code for nearest-neighbour value functions."""
from __future__ import annotations

from enum import Enum

import numpy as np
from rdkit.Chem import DataStructs, AllChem

from syntheseus.search.graph.and_or import OrNode
from syntheseus.search.node_evaluation.base import NoCacheNodeEvaluator
from syntheseus.search.mol_inventory import ExplicitMolInventory

# from Users.ilariasartori.syntheseus.search.graph.and_or import OrNode


class DistanceToCost(Enum):
    NOTHING = 0
    EXP = 1
    SQRT = 2
    TIMES10 = 3
    TIMES100 = 4
    NUM_NEIGHBORS_TO_06 = 5


class TanimotoNNCostEstimator(NoCacheNodeEvaluator):
    """Estimates cost of a node using Tanimoto distance to purchasable molecules."""

    def __init__(
        self,
        inventory: ExplicitMolInventory,
        distance_to_cost: DistanceToCost,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.distance_to_cost = distance_to_cost
        self._set_fingerprints([mol.smiles for mol in inventory.purchasable_mols()])

    def get_fingerprint(self, mol: AllChem.Mol):
        return AllChem.GetMorganFingerprint(mol, radius=3)

    def _set_fingerprints(self, smiles_list: list[str]) -> None:
        """Initialize fingerprint cache."""
        mols = list(map(AllChem.MolFromSmiles, smiles_list))
        assert None not in mols, "Invalid SMILES encountered."
        self._fps = list(map(self.get_fingerprint, mols))
        
    def find_min_num_elem_summing_to_threshold(array, threshold):
        # Sort the array in ascending order
        sorted_array = np.sort(array)

        # Calculate the cumulative sum of the sorted array
        cum_sum = np.cumsum(sorted_array)

        # Find the index where the cumulative sum exceeds threshold 
        index = np.searchsorted(cum_sum, threshold)

        # Check if a subset of elements sums up to more than threshold
        if index < len(array):
            return index + 1  # Add 1 to account for 0-based indexing

        # If no subset of elements sums up to more than threshold
        return 2.0*len(array) #-1

    def _get_nearest_neighbour_dist(self, smiles: str) -> float:
        fp_query = self.get_fingerprint(AllChem.MolFromSmiles(smiles))
        tanimoto_sims = DataStructs.BulkTanimotoSimilarity(fp_query, self._fps)
        if self.distance_to_cost == DistanceToCost.NUM_NEIGHBORS_TO_06:
            return find_min_num_elem_summing_to_threshold(tanimoto_sims, 0.6)/ len(tanimoto_sims)
        else:
            return 1 - max(tanimoto_sims)

    def _evaluate_nodes(self, nodes: list[OrNode], graph=None) -> list[float]:
        if len(nodes) == 0:
            return []

        # Get distances to nearest neighbours
        nn_dists = np.asarray(
            [self._get_nearest_neighbour_dist(node.mol.smiles) for node in nodes]
        )
        assert np.min(nn_dists) >= 0

        # Turn into costs
        if self.distance_to_cost == DistanceToCost.NOTHING:
            values = nn_dists
        elif self.distance_to_cost == DistanceToCost.EXP:
            values = np.exp(nn_dists) - 1
        elif self.distance_to_cost == DistanceToCost.SQRT:
            values = np.sqrt(nn_dists) 
        elif self.distance_to_cost == DistanceToCost.TIMES10:
            values = 10.0*nn_dists
        elif self.distance_to_cost == DistanceToCost.TIMES100:
            values = 100.0*nn_dists
        elif self.distance_to_cost == DistanceToCost.NUM_NEIGHBORS_TO_06:
            values = nn_dists
        else:
            raise NotImplementedError(self.distance_to_cost)

        return list(values)


In [None]:
route_div = True # Count number of diverse routes found

In [None]:
"""
Demo script comparing nearest neighbour cost function with constant value function on PaRoutes.
"""
from __future__ import annotations

import argparse
import logging
import sys
import numpy as np

from tqdm.auto import tqdm

from syntheseus.search.chem import Molecule
from syntheseus.search.graph.and_or import AndNode
from syntheseus.search.algorithms.best_first.retro_star import RetroStarSearch, MolIsPurchasableCost
from syntheseus.search.analysis.solution_time import get_first_solution_time
from syntheseus.search.analysis.route_extraction import min_cost_routes
from syntheseus.search.reaction_models.base import BackwardReactionModel
from syntheseus.search.mol_inventory import BaseMolInventory
from syntheseus.search.node_evaluation.base import (
    BaseNodeEvaluator,
    NoCacheNodeEvaluator,
)
from syntheseus.search.node_evaluation.common import ConstantNodeEvaluator

from paroutes import PaRoutesInventory, PaRoutesModel, get_target_smiles
# from neighbour_value_functions import TanimotoNNCostEstimator, DistanceToCost

from syntheseus.search.analysis import diversity
# from syntheseus.search.algorithms.best_first.retro_star import MolIsPurchasableCost

class SearchResult:
    def __init__(self, name, soln_time_dict, num_different_routes_dict, 
                 final_num_rxn_model_calls_dict, output_graph_dict, routes_dict):
        self.name = name
        self.soln_time_dict = soln_time_dict
        self.num_different_routes_dict = num_different_routes_dict
        self.final_num_rxn_model_calls_dict = final_num_rxn_model_calls_dict
        self.output_graph_dict = output_graph_dict
        self.routes_dict = routes_dict


class PaRoutesRxnCost(NoCacheNodeEvaluator[AndNode]):
    """Cost of reaction is negative log softmax, floored at -3."""

    def _evaluate_nodes(self, nodes: list[AndNode], graph=None) -> list[float]:
        softmaxes = np.asarray([node.reaction.metadata["softmax"] for node in nodes])
        costs = np.clip(-np.log(softmaxes), 1e-1, 10.0)
        return costs.tolist()


def run_algorithm(
    name: str,
    smiles_list: list[str],
    value_function: BaseNodeEvaluator,
    rxn_model: BackwardReactionModel,
    inventory: BaseMolInventory,
    and_node_cost_fn: BaseNodeEvaluator[AndNode],
    or_node_cost_fn: BaseNodeEvaluator[OrNode],
    max_expansion_depth: int = 15,
    prevent_repeat_mol_in_trees: bool= True,
    use_tqdm: bool = False,
    limit_rxn_model_calls: int = 100,
    limit_iterations: int = 1_000_000,
    logger: logging.RootLogger = logging.getLogger(),
) -> SearchResult:
    """
    Do search on a list of SMILES strings and report the time of first solution.
    """

    # Initialize algorithm.
    common_kwargs = dict(
        reaction_model=rxn_model,
        mol_inventory=inventory,
        limit_reaction_model_calls=limit_rxn_model_calls,
        limit_iterations=limit_iterations,
        max_expansion_depth=max_expansion_depth,  # prevent overly-deep solutions
        prevent_repeat_mol_in_trees=prevent_repeat_mol_in_trees,  # original paper did this
    )
    alg = RetroStarSearch(
            and_node_cost_fn=PaRoutesRxnCost(), value_function=value_function, **common_kwargs
        )

    # Do search
    logger.info(f"Start search with {name}")
    min_soln_times: list[tuple[float, ...]] = []
    if use_tqdm:
        smiles_iter = tqdm(smiles_list)
    else:
        smiles_iter = smiles_list
        
    output_graph_dict = {}
    soln_time_dict = {}
    routes_dict = {}
    final_num_rxn_model_calls_dict = {}
    num_different_routes_dict = {}
    
    for i, smiles in enumerate(smiles_iter):
        logger.debug(f"Start search {i}/{len(smiles_list)}. SMILES: {smiles}")
        this_soln_times = list()
        alg.reset()
        output_graph, _ = alg.run_from_mol(Molecule(smiles))

        # Analyze solution time
        for node in output_graph.nodes():
            node.data["analysis_time"] = node.data["num_calls_rxn_model"]
        soln_time = get_first_solution_time(output_graph)
        this_soln_times.append(soln_time)

        # Analyze number of routes
        MAX_ROUTES = 10000
        routes = list(min_cost_routes(output_graph, MAX_ROUTES))

        if alg.reaction_model.num_calls() < limit_rxn_model_calls:
            note = " (NOTE: this was less than the maximum budget)"
        else:
            note = ""
        logger.debug(
            f"Done {name}: nodes={len(output_graph)}, solution time = {soln_time}, "
            f"num routes = {len(routes)} (capped at {MAX_ROUTES}), "
            f"final num rxn model calls = {alg.reaction_model.num_calls()}{note}."
        )

        # Analyze route diversity 
        if (len(routes)>0) & route_div:
            route_objects = [output_graph.to_synthesis_graph(nodes) for nodes in routes]
            packing_set = diversity.estimate_packing_number(
                routes=route_objects,
                distance_metric=diversity.reaction_jaccard_distance,
                radius=0.999  # because comparison is > not >=
            )
            logger.debug((f"number of distinct routes = {len(packing_set)}"))
        else:
            packing_set = []

        # Save results
        soln_time_dict.update({smiles: soln_time})
        final_num_rxn_model_calls_dict.update({smiles: alg.reaction_model.num_calls()})
        num_different_routes_dict.update({smiles: len(packing_set)})
        output_graph_dict.update({smiles: output_graph})
        routes_dict.update({smiles: routes})
            
    return SearchResult(name=name,
                        soln_time_dict=soln_time_dict, 
                        num_different_routes_dict=num_different_routes_dict, 
                        final_num_rxn_model_calls_dict=final_num_rxn_model_calls_dict, 
                        output_graph_dict=output_graph_dict, 
                        routes_dict=routes_dict)


    



In [None]:
# Arguments

# COMMAND LINE
# if __name__ == "__main__":
#     parser = argparse.ArgumentParser()
#     parser.add_argument(
#         "--limit_num_smiles",
#         type=int,
#         default=None,
#         help="Maximum number of SMILES to run.",
#     )
#     parser.add_argument(
#         "--limit_iterations",
#         type=int,
#         default=500,
#         help="Maximum number of algorithm iterations.",
#     )
#     parser.add_argument(
#         "--limit_rxn_model_calls",
#         type=int,
#         default=25,
#         help="Allowed number of calls to reaction model.",
#     )
#     parser.add_argument(
#         "--paroutes_n",
#         type=int,
#         default=5,
#         help="Which PaRoutes benchmark to use.",
#     )
#     args = parser.parse_args()

# NOTEBOOK
class Args:
    limit_num_smiles = None
    limit_iterations = 500
    limit_rxn_model_calls = 100
    paroutes_n = 5
    max_expansion_depth = 20
    max_num_templates = 10  # Default 50
    prevent_repeat_mol_in_trees = True
    rxn_model = 'PAROUTES'
    inventory = 'PAROUTES'
    and_node_cost_fn='PAROUTES'
    or_node_cost_fn = 'MOL_PURCHASABLE' 


args=Args()


In [None]:
alg_names = [x[0] for x in value_fns]
alg_names

## Load from pickle

In [None]:
# # # Load pickle
# import pickle
# import os

# # eventid= "202305-1412-3438-d0f3baee-c6ce-4444-830e-e38d536c9bfa"
# # output_folder = f"Results/{eventid}"

# result = {}
# for file_name in [file for file in os.listdir(output_folder) if 'pickle' in file]:
#     name = file_name.replace('.pickle','').replace('result_','')
#     with open(f'{output_folder}/{file_name}', 'rb') as handle:
#         result[name] = pickle.load(handle)



## Create dataframe for time to solution and number of routes found

In [None]:
import json

# Logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
# formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s')
formatter = logging.Formatter('%(asctime)s %(name)s %(levelname)s %(message)s')

stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setLevel(logging.INFO)
stdout_handler.setFormatter(formatter)

file_handler = logging.FileHandler(f'{output_folder}/logs.txt', mode='w')
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)

logger.addHandler(file_handler)
logger.addHandler(stdout_handler)


# Load all SMILES to test
test_smiles = get_target_smiles(args.paroutes_n)

## Test on smaller dataset
test_smiles_all = test_smiles.copy()
dim_test = 100
test_smiles = test_smiles_all[:dim_test]
##

if args.limit_num_smiles is not None:
    test_smiles = test_smiles[: args.limit_num_smiles]

# Make reaction model, inventory, cost functions and value functions
if args.and_node_cost_fn == 'PAROUTES':
    and_node_cost_fn=PaRoutesRxnCost()
else:
    raise NotImplementedError(f'and_node_cost_fn: {args.and_node_cost_fn}')

if args.or_node_cost_fn == 'MOL_PURCHASABLE':
    or_node_cost_fn=MolIsPurchasableCost()
else:
    raise NotImplementedError(f'or_node_cost_fn: {args.or_node_cost_fn}')

if args.inventory == 'PAROUTES':
    inventory=PaRoutesInventory(n=args.paroutes_n)
else:
    raise NotImplementedError(f'inventory: {args.inventory}')

if args.rxn_model == 'PAROUTES':
    rxn_model=PaRoutesModel(max_num_templates=args.max_num_templates)
else:
    raise NotImplementedError(f'rxn_model: {args.rxn_model}')


value_fns = [
#     ("constant-0", ConstantNodeEvaluator(0.0)),
#     (
#         "Tanimoto-distance",
#         TanimotoNNCostEstimator(
#             inventory=inventory, distance_to_cost=DistanceToCost.NOTHING
#         ),
#     ),
#     (
#         "Tanimoto-distance-TIMES10",
#         TanimotoNNCostEstimator(
#             inventory=inventory, distance_to_cost=DistanceToCost.TIMES10
#         ),
#     ),
# #     (
# #         "Tanimoto-distance-TIMES100",
# #         TanimotoNNCostEstimator(
# #             inventory=inventory, distance_to_cost=DistanceToCost.TIMES100
# #         ),
# #     ),
#     (
#         "Tanimoto-distance-EXP",
#         TanimotoNNCostEstimator(
#             inventory=inventory, distance_to_cost=DistanceToCost.EXP
#         ),
#     ),
#     (
#         "Tanimoto-distance-SQRT",
#         TanimotoNNCostEstimator(
#             inventory=inventory, distance_to_cost=DistanceToCost.SQRT
#         ),
#     ),
    (
        "Tanimoto-distance-NUM_NEIGHBORS_TO_06",
        TanimotoNNCostEstimator(
            inventory=inventory, distance_to_cost=DistanceToCost.NUM_NEIGHBORS_TO_06
        ),
    ),
]

labelalias = {
    'constant-0': 'constant-0',
    'Tanimoto-distance': 'Tanimoto',
    'Tanimoto-distance-TIMES10': 'Tanimoto_times10',
    'Tanimoto-distance-TIMES100': 'Tanimoto_times100',
    'Tanimoto-distance-EXP': 'Tanimoto_exp',
    'Tanimoto-distance-SQRT': 'Tanimoto_sqrt',
    "Tanimoto-distance-NUM_NEIGHBORS_TO_06": "Tanimoto_nn_to_06",
}



# Run
logger.info(f"Start experiment {eventid}")
args_string = ""
for attr in dir(args):
    if not callable(getattr(args, attr)) and not attr.startswith("__"):
        args_string = args_string + "\n" + (f"{attr}: {getattr(args, attr)}") 
logger.info(f"Args: {args_string}")
logger.info(f"dim_test: {dim_test}")


import pickle

result={}
for name, fn in value_fns:
    alg_result = run_algorithm(
        name=name,
        smiles_list=test_smiles, 
        value_function=fn, 
        rxn_model=rxn_model,
        inventory=inventory,
        and_node_cost_fn=and_node_cost_fn,
        or_node_cost_fn=or_node_cost_fn, 
        max_expansion_depth=args.max_expansion_depth, 
        prevent_repeat_mol_in_trees=args.prevent_repeat_mol_in_trees, 
        use_tqdm=True,
        limit_rxn_model_calls=args.limit_rxn_model_calls, 
        limit_iterations=args.limit_iterations,
        logger=logger,
    )
    result[name] = alg_result
    
    # Save pickle
    with open(f'{output_folder}/result_{name}.pickle', 'wb') as handle:
        pickle.dump(alg_result, handle, protocol=pickle.HIGHEST_PROTOCOL)




In [None]:
import pandas as pd

def create_result_df(result, name):
    assert name == result[name].name, f"name: {name} is different from result[name].name: {result[name].name}"
    
    soln_time_dict = result[name].soln_time_dict
    num_different_routes_dict = result[name].num_different_routes_dict
    final_num_rxn_model_calls_dict = result[name].final_num_rxn_model_calls_dict
    output_graph_dict = result[name].output_graph_dict
    routes_dict = result[name].routes_dict

    # df_results = pd.DataFrame()
    df_soln_time = pd.DataFrame({'algorithm': [], 'similes': [], 'property':[], 'value': []})
    df_different_routes = pd.DataFrame({'algorithm': [], 'similes': [], 'property':[], 'value': []})

    #     for name_alg, value_dict  in soln_time_dict.items():
    for smiles, value  in soln_time_dict.items():
        row_soln_time = {'algorithm': name, 'similes': smiles, 'property':'sol_time', 'value': value}

        df_soln_time = pd.concat([df_soln_time, pd.DataFrame([row_soln_time])], ignore_index=True)

    #     for name_alg, value_dict  in num_different_routes_dict.items():
    for smiles, value  in num_different_routes_dict.items():
        row_different_routes = {'algorithm': name, 'similes': smiles, 'property':'diff_routes', 'value': value}

        df_different_routes = pd.concat([df_different_routes, pd.DataFrame([row_different_routes])], ignore_index=True)

    df_results_tot = pd.concat([df_soln_time, df_different_routes], axis=0)
    return df_results_tot



df_results_tot = pd.DataFrame({'algorithm': [], 'similes': [], 'property':[], 'value': []})
for name in result.keys():
    df_results_alg = create_result_df(result, name)
    df_results_tot = pd.concat([df_results_tot, df_results_alg], axis=0)
    
    
    
    

In [None]:
df_results_tot

In [None]:
# Save to csv
# df_results_tot.to_csv(f'Results/Compare/compare_times_{dim_test}.csv', index=False)
df_results_tot.to_csv(f'{output_folder}/results_all.csv', index=False)


In [None]:
# Load csv
import pandas as pd
import numpy as np

# eventid= "202305-1310-3717-7e7e984c-8c3e-4a18-ad67-5c4b29743282"
# output_folder = f"Results/{eventid}"

df_results_tot = pd.read_csv(f'{output_folder}/results_all.csv')

### 1. Solution times

In [None]:
results_solution_times = df_results_tot.loc[df_results_tot['property']=='sol_time']

In [None]:
df_result = results_solution_times.copy()

In [None]:
df_result["value_is_inf"] = (df_result['value'] == np.inf) * 1


In [None]:
df_results_grouped = df_result.groupby(["algorithm", "property"], as_index=False).agg(nr_mol_not_solved=pd.NamedAgg(column="value_is_inf", aggfunc="sum"))
df_results_grouped


In [None]:
df_results_grouped.to_csv(f'{output_folder}/num_mol_not_solved.csv', index=False)

In [None]:
import plotly.express as px
fig = px.box(df_result, x="algorithm", y="value", width=1000, height=600,
             labels={
#                      "algorithm": None,
                     "value": "Time to first solution",
#                      "species": "Species of Iris"
                 },
#              title="Time to first solution"
            )
fig.update_layout(xaxis_title=None)
fig.update_xaxes(labelalias=labelalias, categoryorder='array', categoryarray=list(labelalias.keys()))
fig.write_image(f'{output_folder}/Boxplot_time_first_solution.png') 
fig.show() 

### 2. Solution diversity

In [None]:
results_diff_routes = df_results_tot.loc[df_results_tot['property']=='diff_routes']

In [None]:
df_result = results_diff_routes.copy()

In [None]:
df_result["value_is_zero"] = (df_result['value'] == 0) * 1


In [None]:
df_results_grouped = df_result.groupby(["algorithm", "property"], as_index=False).agg(nr_mol_not_solved=pd.NamedAgg(column="value_is_zero", aggfunc="sum"))
df_results_grouped


In [None]:
import plotly.express as px
fig = px.box(df_result, x="algorithm", y="value", width=1000, height=600,
             labels={
                     "value": "Number of different routes",
                 },
            )
fig.update_layout(xaxis_title=None)
fig.update_xaxes(labelalias=labelalias, categoryorder='array', categoryarray=list(labelalias.keys()))
fig.write_image(f'{output_folder}/Boxplot_num_different_routes.png')
fig.show() 

In [None]:
fig = px.box(df_result.loc[df_result['value']!=0], x="algorithm", y="value", 
             width=1000, height=600,
             labels={
                     "value": "Number of different routes (removing zeros)",
                 },
            )

fig.update_layout(xaxis_title=None)
fig.update_xaxes(labelalias=labelalias, categoryorder='array', categoryarray=list(labelalias.keys()))
fig.write_image(f'{output_folder}/Boxplot_num_different_routes_no_zero.png') 
fig.show() 


## Correlation: value function - actual cost

In [None]:
# algs_to_consider = list(result.keys())
algs_to_consider = ['Tanimoto-distance-TIMES10']

algs_string = '_'.join(algs_to_consider)
algs_string

### 1. Assign costs

In [None]:
import pandas as pd

cost_type = "cost_1_react"
# cost_type = "cost_react_from_data"
# cost_type = "cost_react_from_data_pow01"

for name in algs_to_consider:    
    if cost_type == "cost_1_react": 
        for target_smiles, graph in result[name].output_graph_dict.items():
            for node in graph._graph.nodes():
                    if isinstance(node, (AndNode,)):
                        node.data["route_cost"] = 1.0
                    else:
                        node.data["route_cost"] = 0.0
    elif cost_type == "cost_react_from_data": 
        for target_smiles, graph in result[name].output_graph_dict.items():   
#             # 1. Set reaction costs (should be already done by the algorithm)
#             and_nodes=[
#                     node
#                     for node in graph._graph.nodes()
#                     if isinstance(node, AndNode) and "retro_star_rxn_cost" not in node.data
#                 ]
#             costs = and_node_cost_fn(and_nodes, graph=graph)
#             assert len(costs) == len(and_nodes)
#             for node, cost in zip(and_nodes, costs):
#                 node.data["retro_star_rxn_cost"] = cost
            # 2. Set route costs equal to reaction costs
            for node in graph._graph.nodes():
                    if isinstance(node, (AndNode,)):
                        node.data["route_cost"] = node.data["retro_star_rxn_cost"]
                    else:
                        node.data["route_cost"] = 0.0
    elif cost_type == "cost_react_from_data_pow01": 
        for target_smiles, graph in result[name].output_graph_dict.items():   
            for node in graph._graph.nodes():
                    if isinstance(node, (AndNode,)):
                        node.data["route_cost"] = np.power(node.data["retro_star_rxn_cost"], 0.1)
                    else:
                        node.data["route_cost"] = 0.0


    else:
        raise NotImplementedError(f'Cost type {cost_type}')

### 2. Create dataframe with values and costs

In [None]:
from syntheseus.search.analysis import route_extraction
from syntheseus.search import visualization
from syntheseus.search.analysis.route_extraction import _min_route_cost, _min_route_partial_cost
import networkx as nx

import heapq
import math
from collections.abc import Collection, Iterator
from typing import Callable, Optional, TypeVar

# def get_descendants(graph, node):
#     descendants_set = set(graph.successors(node))
#     for graph.successors(node)


def custom_cost_min_route(
    graph: RetrosynthesisSearchGraph,
    start_node,
    cost_fn: Callable[[Collection[NodeType], RetrosynthesisSearchGraph[NodeType]], float],
    cost_lower_bound: Callable[[Collection[NodeType], RetrosynthesisSearchGraph[NodeType]], float],
    max_routes: int,
    yield_partial_routes: bool = False,
) -> Iterator[tuple[float, Collection[NodeType]]]:
    """
    Iterator over the minimal trees (routes) with lowest cost.
    This can be done efficiently given a lower bound on the cost.

    NOTE: it is not clear whether this function is the best way to extract routes,
    and if in general it is guaranteed to not return the same route twice. We think
    this is the case but are not sure in general.

    Args:
        graph: graph to iterate over. Could be tree, but does not need to be.
        cost_fn: Gives the cost of a route (specified by the set of nodes).
            A cost of inf means the route will not be returned.
        cost_lower_bound: A lower bound of the cost. The lower bound means that
            if the function is evaluated on a set A, the cost of a set B >= A
            will always exceed this lower bound.
            This function will always be evaluated on partial routes.
        max_routes: Maximum number of routes to return.
        yield_partial_routes: if True, will yield routes whose leaves
            have children in the full graph. This could be useful if, for example,
            there are purchasable molecules which have children.
            Typically this will be undesirable though.

    Yields:
        Tuples of cost, route nodes.
    """

    # Initialize priority queue
    # items are: cost, whether the cost is the true cost or a lower bound,
    # tie-breaking integer (since sets cannot be ordered),
    # set of nodes in partial route, list of nodes on the route's frontier
    
    
    ### CHANGE START ###
#     queue: list[tuple[float, bool, int, set[NodeType], list[NodeType]]] = [
#         (-math.inf, False, 0, {graph.root_node}, [graph.root_node])
#     ]
    queue: list[tuple[float, bool, int, set[NodeType], list[NodeType]]] = [
        (-math.inf, False, 0, {start_node}, [start_node])
    ]
    ### END CHANGE ###
    tie_breaker = 1

    # Do best-first search
    num_routes_yielded = 0
    while len(queue) > 0 and num_routes_yielded < max_routes:
        # Pop route
        cost, is_true_cost, _, partial_route, route_frontier = heapq.heappop(queue)
        assert cost < math.inf, "Infinite cost routes should not be in the queue."

        # Scenario 1: if it is a full route, then yield it,
        # because its cost must be lower than the partial cost of all other routes.
        if is_true_cost:
            assert len(route_frontier) == 0
            return (cost, partial_route)
            num_routes_yielded += 1
        else:
            # Choose the first node in the frontier to be "expanded"
            # and re-add to the queue
            assert len(route_frontier) > 0
            node_to_expand = route_frontier[0]
            remaining_frontier = route_frontier[1:]
            possible_new_routes: list[tuple[set[NodeType], list[NodeType]]] = []

            # Potentially add this node without any of its children
            if len(list(graph.successors(node_to_expand))) == 0 or yield_partial_routes:
                possible_new_routes.append((partial_route, remaining_frontier))

            # Add all children routes, 1 at a time
            if isinstance(node_to_expand, OrNode):
                # For AND/OR trees, add each And Child and all of its children
                for and_child in graph.successors(node_to_expand):
                    and_child_children = list(graph.successors(and_child))
                    new_partial_route = partial_route | {and_child} | set(and_child_children)
                    # New frontier excludes nodes already in partial route which would either already be expanded
                    # or be in the frontier already
                    new_frontier = remaining_frontier + [
                        n for n in and_child_children if n not in partial_route
                    ]
                    possible_new_routes.append((new_partial_route, new_frontier))
#             elif isinstance(node_to_expand, MolSetNode):
#                 # For MolSet graphs, add each child individually
#                 for child in graph.successors(node_to_expand):
#                     new_partial_route = partial_route | {child}
#                     new_frontier = list(remaining_frontier)
#                     if child not in partial_route:
#                         new_frontier.append(child)
#                     possible_new_routes.append((new_partial_route, new_frontier))
            else:
                raise TypeError(f"Unknown node type {type(node_to_expand)}.")

            # Add all possible routes onto the queue
            for new_partial_route, new_frontier in possible_new_routes:
                if len(new_frontier) == 0:
                    new_cost = cost_fn(new_partial_route, graph)
                    assert new_cost >= cost, "lower bound not satisfied"
                    new_cost_is_full = True
                else:
                    new_cost = cost_lower_bound(new_partial_route, graph)
                    new_cost_is_full = False

                if new_cost < math.inf:
                    heapq.heappush(
                        queue,
                        (new_cost, new_cost_is_full, tie_breaker, new_partial_route, new_frontier),
                    )
                    tie_breaker += 1


# def reachable_nodes(G, n):
#     """
#     Returns the set of nodes that can be reached starting from node n in graph G.
#     """
#     visited = set()  # Set to keep track of visited nodes
#     stack = [n]  # Stack to keep track of nodes to explore
    
#     while stack:
#         node = stack.pop()
#         visited.add(node)
#         successors = G.successors(node)
#         for s in successors:
#             if s not in visited:
#                 stack.append(s)
    
#     return visited

rows = []
for name in algs_to_consider:    
    output_graph_dict = result[name].output_graph_dict
    for target_smiles, graph in output_graph_dict.items():
        for node in graph._graph.nodes:
            if isinstance(node, OrNode): # Molecule 
                row_data = {'name': name,
                            'smiles': node.mol.smiles,
                            'is_purchasable': node.mol.metadata["is_purchasable"],
                            'node_is_expanded': node.is_expanded,
                            'node_depth': node.depth
                           }
                row_data.update(node.data) 
                # Compute minimal cost
#                 # 1. Create subgraph from current node
#                 # Get the set of descendants of the start node
# #                 descendants = set(nx.descendants(graph, node))
#                 descendants = reachable_nodes(graph, node)
#                 # Add the start node itself to the set of descendants
#                 descendants.add(node)
#                 # Create the subgraph from the descendants
#                 subgraph = graph._graph.subgraph(descendants)

#                 # 2. Compute min cost routes
#                 min_cost_route = route_extraction.min_cost_routes(subgraph, max_routes=1)
#                 min_cost = _min_route_cost(min_cost_route, subgraph)
#                 
                min_route_result = custom_cost_min_route(
                    graph=graph,
                    start_node=node,
                    cost_fn=_min_route_cost,
                    cost_lower_bound=_min_route_partial_cost,
                    max_routes=1,
                    yield_partial_routes= False,
                )
                if min_route_result is not None:
                    min_cost, min_cost_route = min_route_result
                else:
                    min_cost = np.inf
                row_data.update({'minimal_cost_forward': min_cost})
                rows = rows + [row_data]
df_nodes = pd.DataFrame(rows)                
        
df_nodes

In [None]:
df_nodes['is_solved'] = (df_nodes['first_solution_time'] != np.inf)*1.0


In [None]:
df_nodes.to_csv(f'{output_folder}/{algs_string}_df_nodes.csv', index=False)



In [None]:
# select_algorithm = 'Tanimoto-distance-TIMES10'

df_nodes_red = df_nodes.loc[df_nodes['name'].isin(algs_to_consider)]

solved_mask = df_nodes_red['is_solved']==1.0
df_nodes_red_solved = df_nodes_red.loc[solved_mask]
df_nodes_red_not_solved = df_nodes_red.loc[~solved_mask]

x_axis_var = 'minimal_cost_forward'
# y_axis_var = 'retro_star_value'
y_axis_var = 'reaction_number'


df_nodes_red['x_var_inf'] = (df_nodes_red[x_axis_var] == np.inf)*1.0
df_nodes_red['y_var_inf'] = (df_nodes_red[y_axis_var] == np.inf)*1.0


In [None]:
df_nodes_red.groupby(['is_solved', 'x_var_inf']).agg(
    count=pd.NamedAgg(column="smiles", aggfunc="count"))


In [None]:
df_nodes_red.groupby(['is_solved', 'y_var_inf']).agg(
    count=pd.NamedAgg(column="smiles", aggfunc="count"))



In [None]:
import plotly.express as px


fig = px.box(df_nodes_red_solved, y=x_axis_var, 
             width=1000, height=600,
             labels={
#                      "value": "Number of different routes (removing zeros)",
                 },
            )

# fig.update_layout(xaxis_title=None)
# fig.update_xaxes(labelalias=labelalias, categoryorder='array', categoryarray=list(labelalias.keys()))
# fig.write_image(f'{output_folder}/Correlation_{x_axis_var}_{y_axis_var}.png') 
fig.show()

In [None]:
fig = px.box(df_nodes_red, y=y_axis_var, x= 'is_solved',
             width=1000, height=600,
             labels={
#                      "value": "Number of different routes (removing zeros)",
                 },
            )

# fig.update_layout(xaxis_title=None)
# fig.update_xaxes(labelalias=labelalias, categoryorder='array', categoryarray=list(labelalias.keys()))
fig.write_image(f'{output_folder}/{y_axis_var}_{algs_string}_solved_not_solved.png') 
fig.show() 



In [None]:
# fig = px.box(df_nodes_red_not_solved, y=y_axis_var, 
#              width=1000, height=600,
#              labels={
# #                      "value": "Number of different routes (removing zeros)",
#                  },
#             )

# # fig.update_layout(xaxis_title=None)
# # fig.update_xaxes(labelalias=labelalias, categoryorder='array', categoryarray=list(labelalias.keys()))
# # fig.write_image(f'{output_folder}/Correlation_{x_axis_var}_{y_axis_var}.png') 
# fig.show() 






In [None]:
import plotly.express as px

fig = px.scatter(df_nodes_red_solved, x=x_axis_var, y=y_axis_var, 
             width=1000, height=600,
             labels={
#                      "value": "Number of different routes (removing zeros)",
                 },
            )

# fig.update_layout(xaxis_title=None)
# fig.update_xaxes(labelalias=labelalias, categoryorder='array', categoryarray=list(labelalias.keys()))
fig.write_image(f'{output_folder}/Correlation_{x_axis_var}_{y_axis_var}_{algs_string}.png') 
fig.show() 



In [None]:
import plotly.express as px

fig = px.scatter(df_nodes_red_solved, x=x_axis_var, y=y_axis_var, color="node_depth", 
             width=1000, height=600,
             labels={
#                      "value": "Number of different routes (removing zeros)",
                 },
            )

# fig.update_layout(xaxis_title=None)
# fig.update_xaxes(labelalias=labelalias, categoryorder='array', categoryarray=list(labelalias.keys()))
fig.write_image(f'{output_folder}/Correlation_{x_axis_var}_{y_axis_var}_{algs_string}_by_node_depth.png') 
fig.show() 




In [None]:
fig = px.scatter(df_nodes_red_solved.loc[df_nodes_red_solved['node_depth']<5], x=x_axis_var, y=y_axis_var, 
             width=1000, height=600,
             labels={
#                      "value": "Number of different routes (removing zeros)",
                 },
            )

# fig.update_layout(xaxis_title=None)
# fig.update_xaxes(labelalias=labelalias, categoryorder='array', categoryarray=list(labelalias.keys()))
fig.write_image(f'{output_folder}/Correlation_{x_axis_var}_{y_axis_var}_{algs_string}_depth_below_5.png') 
fig.show() 



In [None]:
fig = px.box(df_nodes_red_solved, x=x_axis_var, y=y_axis_var, 
             width=1000, height=600,
             labels={
#                      "value": "Number of different routes (removing zeros)",
                 },
            )
# fig = px.box(df_nodes_red_solved.loc[df_nodes_red_solved['node_depth']<5], x=x_axis_var, y=y_axis_var, 
#              width=1000, height=600,
#              labels={
# #                      "value": "Number of different routes (removing zeros)",
#                  },
#             )

# fig.update_layout(xaxis_title=None)
fig.update_xaxes(categoryorder='category ascending')
fig.write_image(f'{output_folder}/Correlation_{x_axis_var}_{y_axis_var}_{algs_string}_boxplot.png') 
fig.show() 



In [None]:

fig = px.box(df_nodes_red_solved.loc[df_nodes_red_solved['node_depth']<5], x=x_axis_var, y=y_axis_var, 
             width=1000, height=600,
             labels={
#                      "value": "Number of different routes (removing zeros)",
                 },
            )

# fig.update_layout(xaxis_title=None)
fig.update_xaxes(categoryorder='category ascending')
fig.write_image(f'{output_folder}/Correlation_{x_axis_var}_{y_axis_var}_{algs_string}_boxplot_depth_below_5.png') 
fig.show() 

In [None]:

fig = px.box(df_nodes_red_solved.loc[df_nodes_red_solved['node_depth']>=5], x=x_axis_var, y=y_axis_var, 
             width=1000, height=600,
             labels={
#                      "value": "Number of different routes (removing zeros)",
                 },
            )

# fig.update_layout(xaxis_title=None)
fig.update_xaxes(categoryorder='category ascending')
fig.write_image(f'{output_folder}/Correlation_{x_axis_var}_{y_axis_var}_{algs_string}_boxplot_depth_above_5.png') 
fig.show() 