In [1]:
# !export PYTHONPATH=/Users/ilariasartori/syntheseus:/Users/ilariasartori/syntheseus/tutorials/search

In [2]:
# !echo $PYTHONPATH

In [3]:
from datetime import datetime
from uuid import uuid4
import os

# eventid = datetime.now().strftime('%Y%m-%d%H-%M%S-') + str(uuid4())
eventid = '202306-2611-1616-0f701fe7-bbee-4d2c-83f7-bba18beb858a'
print(eventid)

output_folder = f"CompareTanimotoLearnt/{eventid}"

if not os.path.exists(output_folder):
    os.makedirs(output_folder)



202306-2611-1616-0f701fe7-bbee-4d2c-83f7-bba18beb858a


In [4]:
import json
import pickle

In [5]:
"""Basic code for nearest-neighbour value functions."""
from __future__ import annotations

from enum import Enum

import numpy as np
from rdkit.Chem import DataStructs, AllChem

from syntheseus.search.graph.and_or import OrNode
from syntheseus.search.node_evaluation.base import NoCacheNodeEvaluator
from syntheseus.search.mol_inventory import ExplicitMolInventory

import torch
import torch.nn.functional as F

# from Users.ilariasartori.syntheseus.search.graph.and_or import OrNode


class DistanceToCost(Enum):
    NOTHING = 0
    EXP = 1
    SQRT = 2
    TIMES10 = 3
    TIMES100 = 4
    NUM_NEIGHBORS_TO_1 = 5


class TanimotoNNCostEstimator(NoCacheNodeEvaluator):
    """Estimates cost of a node using Tanimoto distance to purchasable molecules."""

    def __init__(
        self,
        inventory: ExplicitMolInventory,
        distance_to_cost: DistanceToCost,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.distance_to_cost = distance_to_cost
        self._set_fingerprints([mol.smiles for mol in inventory.purchasable_mols()])

    def get_fingerprint(self, mol: AllChem.Mol):
        return AllChem.GetMorganFingerprint(mol, radius=3)

    def _set_fingerprints(self, smiles_list: list[str]) -> None:
        """Initialize fingerprint cache."""
        mols = list(map(AllChem.MolFromSmiles, smiles_list))
        assert None not in mols, "Invalid SMILES encountered."
        self._fps = list(map(self.get_fingerprint, mols))
        
    def find_min_num_elem_summing_to_threshold(self, array, threshold):
        # Sort the array in ascending order
        sorted_array = np.sort(array)[::-1]

        # Calculate the cumulative sum of the sorted array
        cum_sum = np.cumsum(sorted_array)

        # Find the index where the cumulative sum exceeds threshold 
        index = np.searchsorted(cum_sum, threshold)

        # Check if a subset of elements sums up to more than threshold
        if index < len(array):
            return index + 1  # Add 1 to account for 0-based indexing

        # If no subset of elements sums up to more than threshold
        return len(array) #-1

    def _get_nearest_neighbour_dist(self, smiles: str) -> float:
        fp_query = self.get_fingerprint(AllChem.MolFromSmiles(smiles))
        tanimoto_sims = DataStructs.BulkTanimotoSimilarity(fp_query, self._fps)
        if self.distance_to_cost == DistanceToCost.NUM_NEIGHBORS_TO_1:
            return 1 - self.find_min_num_elem_summing_to_threshold(array=tanimoto_sims,threshold=1)/ len(tanimoto_sims)
        else:
            return 1 - max(tanimoto_sims)

    def _evaluate_nodes(self, nodes: list[OrNode], graph=None) -> list[float]:
        if len(nodes) == 0:
            return []

        # Get distances to nearest neighbours
        nn_dists = np.asarray(
            [self._get_nearest_neighbour_dist(node.mol.smiles) for node in nodes]
        )
        assert np.min(nn_dists) >= 0, f'Negative distance: {np.min(nn_dists)} '

        # Turn into costs
        if self.distance_to_cost == DistanceToCost.NOTHING:
            values = nn_dists
        elif self.distance_to_cost == DistanceToCost.EXP:
            values = np.exp(nn_dists) - 1
        elif self.distance_to_cost == DistanceToCost.SQRT:
            values = np.sqrt(nn_dists) 
        elif self.distance_to_cost == DistanceToCost.TIMES10:
            values = 10.0*nn_dists
        elif self.distance_to_cost == DistanceToCost.TIMES100:
            values = 100.0*nn_dists
        elif self.distance_to_cost == DistanceToCost.NUM_NEIGHBORS_TO_1:
            values = nn_dists
        else:
            raise NotImplementedError(self.distance_to_cost)

        return list(values)


class Emb_from_fingerprints_NNCostEstimator(NoCacheNodeEvaluator):
    """Estimates cost of a node using Tanimoto distance to purchasable molecules."""

    def __init__(
        self,
        inventory: ExplicitMolInventory,
        distance_to_cost: DistanceToCost_emb_fnps,
        model,
        distance_type,
        **kwargs,
    ):
#         print('Stat initialization Emb')
        super().__init__(**kwargs)
        self.model = model
        self.model.eval()
        
        self.distance_to_cost = distance_to_cost
        self.distance_type = distance_type
        self._set_fingerprints_vect([mol.smiles for mol in inventory.purchasable_mols()])
        with torch.no_grad():
            self.emb_purch_molecules = torch.stack([self.model(torch.tensor(fingerprint, dtype=torch.double)) for fingerprint in self._fps], dim=0)
#         print('End initialization Emb')

    def get_fingerprint_vect(self, mol: AllChem.Mol):
        return AllChem.GetMorganFingerprintAsBitVect(mol, radius=3)

    def _set_fingerprints_vect(self, smiles_list: list[str]) -> None:
        """Initialize fingerprint cache."""
        mols = list(map(AllChem.MolFromSmiles, smiles_list))
        assert None not in mols, "Invalid SMILES encountered."
        self._fps = list(map(self.get_fingerprint_vect, mols))
        

    def compute_embedding_from_fingerprint(self, mol_fingerprints):
#         self.model.eval()

        with torch.no_grad():
            output = self.model(torch.tensor(mol_fingerprints, dtype=torch.double))
#             if isinstance(mol_fingerprints, list):
#                 output = 
#             else:
                
        return output

    def embedding_distance(self, emb_1, emb_2):
        if self.distance_type == 'Euclidean':
            # Compute Euclidean distance
            euclidean_distance = torch.norm(emb_1 - emb_2, dim=1)
            return euclidean_distance
        elif self.distance_type == 'cosine':
            # Compute cosine similarity
            cosine_similarity = F.cosine_similarity(emb_1, emb_2, dim=1)
#             print(cosine_similarity)
#             cosine_distance = 1 - cosine_similarity
            cosine_distance = torch.clamp(1 - cosine_similarity, min=0, max=1)
#             print(cosine_distance)
            
#             if np.min(cosine_distance)< 0 or np.max(cosine_distance) > 1:
#                 if abs(cosine_distance - 0) < 1e-10:
#                     return 0
#                 elif abs(cosine_distance - 1) < 1e-10:
#                     return 1
#                 else:
#                     raise ValueError(f"Cosine distance not between 0 and 1: Min: {np.min(cosine_distance)}, Max:{np.max(cosine_distance)}")            
            return cosine_distance
        else:
            # Raise error for unsupported distance type
            raise NotImplementedError(f"Distance type '{self.distance_type}' is not implemented.")

    def _get_nearest_neighbour_dist(self, smiles: str) -> float:
        fp_target = self.get_fingerprint_vect(AllChem.MolFromSmiles(smiles))  # Target fingerprint
        emb_target = self.compute_embedding_from_fingerprint(fp_target)  # Target embedding
        
#         emb_purch_molecules = self.compute_embedding_from_fingerprint(self._fps)  # Purchasable molecules embeddings
#         print('Embedded purchasable molecules')

        # Euclidean (or cosine) distance between embeddings
        distances = self.embedding_distance(emb_target, self.emb_purch_molecules)
        return torch.min(distances).item()

    def _evaluate_nodes(self, nodes: list[OrNode], graph=None) -> list[float]:
        if len(nodes) == 0:
            return []

        # Get distances to nearest neighbours
        nn_dists = np.asarray(
            [self._get_nearest_neighbour_dist(node.mol.smiles) for node in nodes]
        )
        assert np.min(nn_dists) >= 0, f'Negative distance: {np.min(nn_dists)} '

        # Turn into costs
        if self.distance_to_cost == DistanceToCost.NOTHING:
            values = nn_dists
        elif self.distance_to_cost == DistanceToCost.EXP:
            values = np.exp(nn_dists) - 1
        elif self.distance_to_cost == DistanceToCost.SQRT:
            values = np.sqrt(nn_dists) 
        elif self.distance_to_cost == DistanceToCost.TIMES10:
            values = 10.0*nn_dists
        elif self.distance_to_cost == DistanceToCost.TIMES100:
            values = 100.0*nn_dists
        else:
            raise NotImplementedError(self.distance_to_cost)

        return list(values)

In [6]:
route_div = True # Count number of diverse routes found

In [7]:
## Faster implementation


from __future__ import annotations

from collections.abc import Sequence

from syntheseus.search.algorithms.best_first.retro_star import RetroStarSearch
from syntheseus.search.graph.and_or import ANDOR_NODE, AndOrGraph, OrNode


class ReduceValueFunctionCallsRetroStar(RetroStarSearch):
    """
    More efficient version of Retro* which saves value function calls.
    The difference is that retro* calls the value function (i.e. reaction number estimator)
    for every leaf node whereas this algorithm assigns a placeholder value of 0 to every leaf node
    and only calls the value function if it visits that node a second time.
    This essentially leaves the behaviour of retro* unchanged, but saves value function calls.

    The reason this works is that retro* greedily expands nodes on the current lowest-cost route,
    using the value function (reaction number) estimate as the cost of the node.
    If a node is not visited with a value function estimate of 0,
    then it would definitely not be visited with a non-zero value function estimate.
    Therefore if a node is not visited with a placeholder value of 0,
    it doesn't really matter what the value function estimate is.
    """

    def setup(self, graph: AndOrGraph) -> None:
        # If there is only 1 node, "visit" it by setting its reaction number estimate to 0
        # and incrementing its visit count
        if len(graph) == 1:
            graph.root_node.num_visit += 1
            graph.root_node.data.setdefault("reaction_number_estimate", 0.0)

        return super().setup(graph)

    def visit_node(self, node: OrNode, graph: AndOrGraph) -> Sequence[ANDOR_NODE]:
        """
        If node.num_visit == 0 then evaluate the value function and return.
        Otherwise expand.
        """
        assert node.num_visit >= 0  # should not be negative
        node.num_visit += 1
        if node.num_visit == 1:
            # Evaluate value function and return.
            node.data["reaction_number_estimate"] = self.reaction_number_estimator(
                [node]
            )[0]
            return []
        else:
            return super().visit_node(node, graph)

    def _set_reaction_number_estimate(
        self, or_nodes: Sequence[OrNode], graph: AndOrGraph
    ) -> None:
        for node in or_nodes:
            node.data.setdefault("reaction_number_estimate", 0.0)


In [8]:
"""
Demo script comparing nearest neighbour cost function with constant value function on PaRoutes.
"""
from __future__ import annotations

import argparse
import logging
import sys
import numpy as np

from tqdm.auto import tqdm

from syntheseus.search.chem import Molecule
from syntheseus.search.graph.and_or import AndNode
from syntheseus.search.algorithms.best_first.retro_star import RetroStarSearch, MolIsPurchasableCost
from syntheseus.search.analysis.solution_time import get_first_solution_time
from syntheseus.search.analysis.route_extraction import min_cost_routes
from syntheseus.search.reaction_models.base import BackwardReactionModel
from syntheseus.search.mol_inventory import BaseMolInventory
from syntheseus.search.node_evaluation.base import (
    BaseNodeEvaluator,
    NoCacheNodeEvaluator,
)
from syntheseus.search.node_evaluation.common import ConstantNodeEvaluator

from paroutes import PaRoutesInventory, PaRoutesModel, get_target_smiles
# from neighbour_value_functions import TanimotoNNCostEstimator, DistanceToCost

from syntheseus.search.analysis import diversity
# from syntheseus.search.algorithms.best_first.retro_star import MolIsPurchasableCost

class SearchResult:
    def __init__(self, name, soln_time_dict, num_different_routes_dict, 
                 final_num_rxn_model_calls_dict, final_num_value_function_calls_dict,
                 output_graph_dict, routes_dict):
        self.name = name
        self.soln_time_dict = soln_time_dict
        self.num_different_routes_dict = num_different_routes_dict
        self.final_num_rxn_model_calls_dict = final_num_rxn_model_calls_dict
        self.output_graph_dict = output_graph_dict
        self.routes_dict = routes_dict
        self.final_num_value_function_calls_dict = final_num_value_function_calls_dict


class PaRoutesRxnCost(NoCacheNodeEvaluator[AndNode]):
    """Cost of reaction is negative log softmax, floored at -3."""

    def _evaluate_nodes(self, nodes: list[AndNode], graph=None) -> list[float]:
        softmaxes = np.asarray([node.reaction.metadata["softmax"] for node in nodes])
        costs = np.clip(-np.log(softmaxes), 1e-1, 10.0)
        return costs.tolist()


def run_algorithm(
    name: str,
    smiles_list: list[str],
    value_function: BaseNodeEvaluator,
    rxn_model: BackwardReactionModel,
    inventory: BaseMolInventory,
    and_node_cost_fn: BaseNodeEvaluator[AndNode],
    or_node_cost_fn: BaseNodeEvaluator[OrNode],
    max_expansion_depth: int = 15,
    prevent_repeat_mol_in_trees: bool= True,
    use_tqdm: bool = False,
    limit_rxn_model_calls: int = 100,
    limit_iterations: int = 1_000_000,
    logger: logging.RootLogger = logging.getLogger(),
) -> SearchResult:
    """
    Do search on a list of SMILES strings and report the time of first solution.
    """

    # Initialize algorithm.
    common_kwargs = dict(
        reaction_model=rxn_model,
        mol_inventory=inventory,
        limit_reaction_model_calls=limit_rxn_model_calls,
        limit_iterations=limit_iterations,
        max_expansion_depth=max_expansion_depth,  # prevent overly-deep solutions
        prevent_repeat_mol_in_trees=prevent_repeat_mol_in_trees,  # original paper did this
    )
    alg = ReduceValueFunctionCallsRetroStar(
            and_node_cost_fn=PaRoutesRxnCost(), value_function=value_function, **common_kwargs
        )

    # Do search
    logger.info(f"Start search with {name}")
    min_soln_times: list[tuple[float, ...]] = []
    if use_tqdm:
        smiles_iter = tqdm(smiles_list)
    else:
        smiles_iter = smiles_list
        
    output_graph_dict = {}
    soln_time_dict = {}
    routes_dict = {}
    final_num_rxn_model_calls_dict = {}
    final_num_value_function_calls_dict = {}
    num_different_routes_dict = {}
    
    for i, smiles in enumerate(smiles_iter):
        logger.debug(f"Start search {i}/{len(smiles_list)}. SMILES: {smiles}")
        this_soln_times = list()
        alg.reset()
        output_graph, _ = alg.run_from_mol(Molecule(smiles))

        # Analyze solution time
        for node in output_graph.nodes():
            node.data["analysis_time"] = node.data["num_calls_rxn_model"]
        soln_time = get_first_solution_time(output_graph)
        this_soln_times.append(soln_time)

        # Analyze number of routes
        MAX_ROUTES = 10000
        routes = list(min_cost_routes(output_graph, MAX_ROUTES))

        if alg.reaction_model.num_calls() < limit_rxn_model_calls:
            note = " (NOTE: this was less than the maximum budget)"
        else:
            note = ""
        logger.debug(
            f"Done {name}: nodes={len(output_graph)}, solution time = {soln_time}, "
            f"num routes = {len(routes)} (capped at {MAX_ROUTES}), "
            f"final num rxn model calls = {alg.reaction_model.num_calls()}{note}, "
            f"final num value model calls = {alg.value_function.num_calls}."
        )

        # Analyze route diversity 
        if (len(routes)>0) & route_div:
            route_objects = [output_graph.to_synthesis_graph(nodes) for nodes in routes]
            packing_set = diversity.estimate_packing_number(
                routes=route_objects,
                distance_metric=diversity.reaction_jaccard_distance,
                radius=0.999  # because comparison is > not >=
            )
            logger.debug((f"number of distinct routes = {len(packing_set)}"))
        else:
            packing_set = []

        # Save results
        soln_time_dict.update({smiles: soln_time})
        final_num_rxn_model_calls_dict.update({smiles: alg.reaction_model.num_calls()})
        final_num_value_function_calls_dict.update({smiles: alg.value_function.num_calls})
        num_different_routes_dict.update({smiles: len(packing_set)})
        output_graph_dict.update({smiles: output_graph})
        routes_dict.update({smiles: routes})
            
    return SearchResult(name=name,
                        soln_time_dict=soln_time_dict, 
                        num_different_routes_dict=num_different_routes_dict, 
                        final_num_rxn_model_calls_dict=final_num_rxn_model_calls_dict, 
                        final_num_value_function_calls_dict=final_num_value_function_calls_dict,
                        output_graph_dict=output_graph_dict, 
                        routes_dict=routes_dict)
    



In [9]:
# Arguments

# COMMAND LINE
# if __name__ == "__main__":
#     parser = argparse.ArgumentParser()
#     parser.add_argument(
#         "--limit_num_smiles",
#         type=int,
#         default=None,
#         help="Maximum number of SMILES to run.",
#     )
#     parser.add_argument(
#         "--limit_iterations",
#         type=int,
#         default=500,
#         help="Maximum number of algorithm iterations.",
#     )
#     parser.add_argument(
#         "--limit_rxn_model_calls",
#         type=int,
#         default=25,
#         help="Allowed number of calls to reaction model.",
#     )
#     parser.add_argument(
#         "--paroutes_n",
#         type=int,
#         default=5,
#         help="Which PaRoutes benchmark to use.",
#     )
#     args = parser.parse_args()

# NOTEBOOK
class Args:
    limit_num_smiles = 100
    limit_iterations = 500 # 100000
    limit_rxn_model_calls = 100 # 500
    paroutes_n = 5
    max_expansion_depth = 20
    max_num_templates = 10  # Default 50
    prevent_repeat_mol_in_trees = True
    rxn_model = 'PAROUTES'
    inventory = 'PAROUTES'
    and_node_cost_fn='PAROUTES'
    or_node_cost_fn = 'MOL_PURCHASABLE' 


args=Args()


### Load embedding models

In [10]:
experiment_name = 'fingerprints_v1'
emb_model_input_folder = f'GraphRuns/{experiment_name}'

In [11]:
# 1. Read json.config
with open(f'{emb_model_input_folder}/config.json', 'r') as f:
    config = json.load(f)
    
# 2. Load model
import torch.nn as nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class GNNModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GNNModel, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.fc(x)
        return x
    
class FingerprintModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(FingerprintModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim, dtype=torch.double)
        self.fc2 = nn.Linear(hidden_dim, output_dim, dtype=torch.double)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

# # 2-OPTION 1 - Load from checkpoint (last for now)
# if config["model_type"] == 'gnn':
#     with open(f'{emb_model_input_folder}/input_dim.pickle', 'rb') as f:
#         input_dim_dict = pickle.load(f)
#         gnn_input_dim = input_dim_dict['input_dim']
# #     gnn_input_dim = preprocessed_targets[0].node_features.shape[1]
#     gnn_hidden_dim = config["hidden_dim"]
#     gnn_output_dim = config["output_dim"]
# elif config["model_type"] == 'fingerprints':
#     with open(f'{emb_model_input_folder}/input_dim.pickle', 'rb') as f:
#         input_dim_dict = pickle.load(f)
#         fingerprint_input_dim = input_dim_dict['input_dim']
# #     fingerprint_input_dim = (preprocessed_targets[0].size()[0])
#     fingerprint_hidden_dim = config["hidden_dim"]
#     fingerprint_output_dim = config["output_dim"]
# else:
#     raise NotImplementedError(f'Model type {config["model_type"]}')

# if config["model_type"] == 'gnn':
#     model = GNNModel(
#         input_dim=gnn_input_dim, 
#         hidden_dim=gnn_hidden_dim, 
#         output_dim=gnn_output_dim).to(device)
    
# elif config["model_type"] == 'fingerprints':
#     model = FingerprintModel(
#         input_dim=fingerprint_input_dim, 
#         hidden_dim=fingerprint_hidden_dim, 
#         output_dim=fingerprint_output_dim).to(device)
# else:
#     raise NotImplementedError(f'Model type {config["model_type"]}')
    
    
# checkpoint_path = f'{emb_model_input_folder}/epoch_100_checkpoint.pth'

# checkpoint = torch.load(checkpoint_path)

# # Load the model state dict from the checkpoint
# model.load_state_dict(checkpoint['model_state_dict'])
# model_fingerprints_v1 = model

# 2-OPTION 2 - Load from pickle (best)
model_fingerprints_v1_path = f'{emb_model_input_folder}/model_min_val.pkl'
with open(model_fingerprints_v1_path, 'rb') as f:
    model_fingerprints_v1 = pickle.load(f)



## Create dataframe for time to solution and number of routes found

In [12]:
distance_type_fingerprints_v1 = 'cosine'

In [13]:
import json

# Logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
# formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s')
formatter = logging.Formatter('%(asctime)s %(name)s %(levelname)s %(message)s')

stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setLevel(logging.INFO)
stdout_handler.setFormatter(formatter)

file_handler = logging.FileHandler(f'{output_folder}/logs.txt', mode='w')
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)

logger.addHandler(file_handler)
logger.addHandler(stdout_handler)


# Load all SMILES to test
# test_smiles = get_target_smiles(args.paroutes_n) # Paroutes

# guacamol_file_path = '/scratch/is541/retrosynthesis_dev/Data/Guacamol/guacamol_v1_test_10ksample.txt'
guacamol_file_path = 'guacamol_v1_test_10ksample.txt'
with open(guacamol_file_path) as f:
        lines = f.readlines()
        test_smiles = [line.strip() for line in lines]  # NOTE: no header


if args.limit_num_smiles is not None:
    test_smiles = test_smiles[: args.limit_num_smiles]

# Make reaction model, inventory, cost functions and value functions
if args.and_node_cost_fn == 'PAROUTES':
    and_node_cost_fn=PaRoutesRxnCost()
else:
    raise NotImplementedError(f'and_node_cost_fn: {args.and_node_cost_fn}')

if args.or_node_cost_fn == 'MOL_PURCHASABLE':
    or_node_cost_fn=MolIsPurchasableCost()
else:
    raise NotImplementedError(f'or_node_cost_fn: {args.or_node_cost_fn}')

if args.inventory == 'PAROUTES':
    inventory=PaRoutesInventory(n=args.paroutes_n)
else:
    raise NotImplementedError(f'inventory: {args.inventory}')

if args.rxn_model == 'PAROUTES':
    rxn_model=PaRoutesModel(max_num_templates=args.max_num_templates)
else:
    raise NotImplementedError(f'rxn_model: {args.rxn_model}')


value_fns = [
#     ("constant-0", ConstantNodeEvaluator(0.0)),
#     (
#         "Tanimoto-distance",
#         TanimotoNNCostEstimator(
#             inventory=inventory, distance_to_cost=DistanceToCost.NOTHING
#         ),
#     ),
    (
        "Tanimoto-distance-TIMES10",
        TanimotoNNCostEstimator(
            inventory=inventory, distance_to_cost=DistanceToCost.TIMES10
        ),
    ),
    (
        "Tanimoto-distance-TIMES100",
        TanimotoNNCostEstimator(
            inventory=inventory, distance_to_cost=DistanceToCost.TIMES100
        ),
    ),
    (
        "Tanimoto-distance-EXP",
        TanimotoNNCostEstimator(
            inventory=inventory, distance_to_cost=DistanceToCost.EXP
        ),
    ),
    (
        "Tanimoto-distance-SQRT",
        TanimotoNNCostEstimator(
            inventory=inventory, distance_to_cost=DistanceToCost.SQRT
        ),
    ),
    (
        "Tanimoto-distance-NUM_NEIGHBORS_TO_1",
        TanimotoNNCostEstimator(
            inventory=inventory, distance_to_cost=DistanceToCost.NUM_NEIGHBORS_TO_1
        ),
    ),
#     (
#         "Embedding-from-fingerprints",
#         Emb_from_fingerprints_NNCostEstimator(
#             inventory=inventory, distance_to_cost=DistanceToCost.NOTHING,
#             model=model_fingerprints_v1, distance_type=distance_type_fingerprints_v1,

#         ),
#     ),
#     (
#         "Embedding-from-fingerprints-TIMES10",
#         Emb_from_fingerprints_NNCostEstimator(
#             inventory=inventory, distance_to_cost=DistanceToCost.TIMES10,
#             model=model_fingerprints_v1, distance_type=distance_type_fingerprints_v1,

#         ),
#     ),
    (
        "Embedding-from-fingerprints-TIMES100",
        Emb_from_fingerprints_NNCostEstimator(
            inventory=inventory, distance_to_cost=DistanceToCost.TIMES100,
            model=model_fingerprints_v1, distance_type=distance_type_fingerprints_v1,

        ),
    ),
]

labelalias = {
    'constant-0': 'constant-0',
    'Tanimoto-distance': 'Tanimoto',
    'Tanimoto-distance-TIMES10': 'Tanimoto_times10',
    'Tanimoto-distance-TIMES100': 'Tanimoto_times100',
    'Tanimoto-distance-EXP': 'Tanimoto_exp',
    'Tanimoto-distance-SQRT': 'Tanimoto_sqrt',
    "Tanimoto-distance-NUM_NEIGHBORS_TO_1": "Tanimoto_nn_to_1",
    "Embedding-from-fingerprints": "Emb_fnps",
    "Embedding-from-fingerprints-TIMES10": "Emb_fnps_times10",
    "Embedding-from-fingerprints-TIMES100": "Emb_fnps_times10",
}



# Run
logger.info(f"Start experiment {eventid}")
args_string = ""
for attr in dir(args):
    if not callable(getattr(args, attr)) and not attr.startswith("__"):
        args_string = args_string + "\n" + (f"{attr}: {getattr(args, attr)}") 
logger.info(f"Args: {args_string}")
logger.info(f"dim_test: {len(test_smiles)}")


import pickle

### RUN
result={}
for name, fn in value_fns:
    alg_result = run_algorithm(
        name=name,
        smiles_list=test_smiles, 
        value_function=fn, 
        rxn_model=rxn_model,
        inventory=inventory,
        and_node_cost_fn=and_node_cost_fn,
        or_node_cost_fn=or_node_cost_fn, 
        max_expansion_depth=args.max_expansion_depth, 
        prevent_repeat_mol_in_trees=args.prevent_repeat_mol_in_trees, 
        use_tqdm=True,
        limit_rxn_model_calls=args.limit_rxn_model_calls, 
        limit_iterations=args.limit_iterations,
        logger=logger,
    )
    result[name] = alg_result
    
    # Save pickle
    with open(f'{output_folder}/result_{name}.pickle', 'wb') as handle:
        pickle.dump(alg_result, handle, protocol=pickle.HIGHEST_PROTOCOL)


# ### LOAD result dict from pickles
# result = {}
# for name, fn in value_fns:
#     pickle_name = f'{output_folder}/result_{name}.pickle'
#     print(pickle_name)
#     with open(pickle_name, 'rb') as handle:
#         result[name] = pickle.load(handle)

Metal device set to: Apple M1 Pro
Stat initialization Emb
End initialization Emb
2023-06-27 09:28:35,661 root INFO Start experiment 202306-2611-1616-0f701fe7-bbee-4d2c-83f7-bba18beb858a
2023-06-27 09:28:35,662 root INFO Args: 
and_node_cost_fn: PAROUTES
inventory: PAROUTES
limit_iterations: 500
limit_num_smiles: 100
limit_rxn_model_calls: 100
max_expansion_depth: 20
max_num_templates: 10
or_node_cost_fn: MOL_PURCHASABLE
paroutes_n: 5
prevent_repeat_mol_in_trees: True
rxn_model: PAROUTES
2023-06-27 09:28:35,662 root INFO dim_test: 100
2023-06-27 09:28:35,663 root INFO Start search with Tanimoto-distance-TIMES10


  0%|          | 0/100 [00:00<?, ?it/s]

  parent.data["retro_star_value"]


2023-06-27 09:39:39,337 root INFO Start search with Tanimoto-distance-TIMES100


  0%|          | 0/100 [00:00<?, ?it/s]

2023-06-27 09:50:20,526 root INFO Start search with Tanimoto-distance-EXP


  0%|          | 0/100 [00:00<?, ?it/s]

2023-06-27 10:05:49,158 root INFO Start search with Tanimoto-distance-SQRT


  0%|          | 0/100 [00:00<?, ?it/s]

2023-06-27 10:21:09,178 root INFO Start search with Tanimoto-distance-NUM_NEIGHBORS_TO_1


  0%|          | 0/100 [00:00<?, ?it/s]

2023-06-27 10:36:39,584 root INFO Start search with Embedding-from-fingerprints-TIMES100


  0%|          | 0/100 [00:00<?, ?it/s]

In [14]:
# import pandas as pd

# def create_result_df(result, name):
#     assert name == result[name].name, f"name: {name} is different from result[name].name: {result[name].name}"
    
#     soln_time_dict = result[name].soln_time_dict
#     num_different_routes_dict = result[name].num_different_routes_dict
#     final_num_rxn_model_calls_dict = result[name].final_num_rxn_model_calls_dict
#     output_graph_dict = result[name].output_graph_dict
#     routes_dict = result[name].routes_dict

#     # df_results = pd.DataFrame()
#     df_soln_time = pd.DataFrame({'algorithm': [], 'similes': [], 'property':[], 'value': []})
#     df_different_routes = pd.DataFrame({'algorithm': [], 'similes': [], 'property':[], 'value': []})

#     #     for name_alg, value_dict  in soln_time_dict.items():
#     for smiles, value  in soln_time_dict.items():
#         row_soln_time = {'algorithm': name, 'similes': smiles, 'property':'sol_time', 'value': value}

#         df_soln_time = pd.concat([df_soln_time, pd.DataFrame([row_soln_time])], ignore_index=True)

#     #     for name_alg, value_dict  in num_different_routes_dict.items():
#     for smiles, value  in num_different_routes_dict.items():
#         row_different_routes = {'algorithm': name, 'similes': smiles, 'property':'diff_routes', 'value': value}

#         df_different_routes = pd.concat([df_different_routes, pd.DataFrame([row_different_routes])], ignore_index=True)

#     df_results_tot = pd.concat([df_soln_time, df_different_routes], axis=0)
#     return df_results_tot



# df_results_tot = pd.DataFrame({'algorithm': [], 'similes': [], 'property':[], 'value': []})
# for name in tqdm(result.keys()):
#     df_results_alg = create_result_df(result, name)
#     df_results_tot = pd.concat([df_results_tot, df_results_alg], axis=0)
    
    
    
    

In [15]:
# df_results_tot

In [16]:
# # Save to csv
# # df_results_tot.to_csv(f'Results/Compare/compare_times_{dim_test}.csv', index=False)
# df_results_tot.to_csv(f'{output_folder}/results_all.csv', index=False)
