In [None]:
import pandas as pd
import os

from syntheseus.search.chem import Molecule
from syntheseus.search.graph.and_or import AndNode
from syntheseus.search.algorithms.best_first.retro_star import RetroStarSearch, MolIsPurchasableCost
from syntheseus.search.analysis.solution_time import get_first_solution_time
from syntheseus.search.analysis.route_extraction import min_cost_routes
from syntheseus.search.reaction_models.base import BackwardReactionModel
from syntheseus.search.mol_inventory import BaseMolInventory
from syntheseus.search.node_evaluation.base import (
    BaseNodeEvaluator,
    NoCacheNodeEvaluator,
)
from syntheseus.search.node_evaluation.common import ConstantNodeEvaluator

from paroutes import PaRoutesInventory, PaRoutesModel, get_target_smiles



In [None]:
input_folder = '20230524_1010'
output_folder = f'Results_from_input/{input_folder}'

if not os.path.exists(output_folder):
    os.makedirs(output_folder)


## Load data

In [None]:
data_dict = {}
for test_db in [file for file in os.listdir(f'Input/{input_folder}') if '.csv' in file]:
    test_db_name = test_db.replace('.csv','')
    data_dict[test_db_name] = pd.read_csv(f'Input/{input_folder}/{test_db}')


## Compute value functions

#### Auxiliary functions

In [None]:
"""Basic code for nearest-neighbour value functions."""
from __future__ import annotations

from enum import Enum

import numpy as np
from rdkit.Chem import DataStructs, AllChem

from syntheseus.search.graph.and_or import OrNode
from syntheseus.search.node_evaluation.base import NoCacheNodeEvaluator
from syntheseus.search.mol_inventory import ExplicitMolInventory

# from Users.ilariasartori.syntheseus.search.graph.and_or import OrNode


class DistanceToCost(Enum):
    NOTHING = 0
    EXP = 1
    SQRT = 2
    TIMES10 = 3
    TIMES100 = 4
    NUM_NEIGHBORS_TO_1 = 5
    NUM_NEIGHBORS_TO_1_TIMES1000 = 6


class TanimotoNNCostEstimator():
    """Estimates cost of a molecule using Tanimoto distance to purchasable molecules."""

    def __init__(
        self,
        inventory: ExplicitMolInventory,
        distance_to_cost: DistanceToCost,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.distance_to_cost = distance_to_cost
        self._set_fingerprints([mol.smiles for mol in inventory.purchasable_mols()])

    def get_fingerprint(self, mol: AllChem.Mol):
        return AllChem.GetMorganFingerprint(mol, radius=3)

    def _set_fingerprints(self, smiles_list: list[str]) -> None:
        """Initialize fingerprint cache."""
        mols = list(map(AllChem.MolFromSmiles, smiles_list))
        assert None not in mols, "Invalid SMILES encountered."
        self._fps = list(map(self.get_fingerprint, mols))
        
    def find_min_num_elem_summing_to_threshold(self, array, threshold):
        # Sort the array in ascending order
        sorted_array = np.sort(array)[::-1]

        # Calculate the cumulative sum of the sorted array
        cum_sum = np.cumsum(sorted_array)

        # Find the index where the cumulative sum exceeds threshold 
        index = np.searchsorted(cum_sum, threshold)

        # Check if a subset of elements sums up to more than threshold
        if index < len(array):
            return index + 1  # Add 1 to account for 0-based indexing

        # If no subset of elements sums up to more than threshold
        return len(array) #-1

    def _get_nearest_neighbour_dist(self, smiles: str) -> float:
        fp_query = self.get_fingerprint(AllChem.MolFromSmiles(smiles))
        tanimoto_sims = DataStructs.BulkTanimotoSimilarity(fp_query, self._fps)
        if self.distance_to_cost in [DistanceToCost.NUM_NEIGHBORS_TO_1, DistanceToCost.NUM_NEIGHBORS_TO_1_TIMES1000]:
            return (self.find_min_num_elem_summing_to_threshold(array=tanimoto_sims,threshold=1)-1)/ len(tanimoto_sims)
        else:
            return 1 - max(tanimoto_sims)

    def evaluate_molecules(self, molecules_smiles: list[str]) -> dict:
        """Returns a dictionary of {molecule_smiles:molecule_value}."""
        if len(molecules_smiles) == 0:
            return {}

        # Get distances to nearest neighbours
        nn_dists = np.asarray(
            [self._get_nearest_neighbour_dist(mol_smiles) for mol_smiles in molecules_smiles]
        )
        assert np.min(nn_dists) >= 0

        # Turn into costs
        if self.distance_to_cost == DistanceToCost.NOTHING:
            values = nn_dists
        elif self.distance_to_cost == DistanceToCost.EXP:
            values = np.exp(nn_dists) - 1
        elif self.distance_to_cost == DistanceToCost.SQRT:
            values = np.sqrt(nn_dists) 
        elif self.distance_to_cost == DistanceToCost.TIMES10:
            values = 10.0*nn_dists
        elif self.distance_to_cost == DistanceToCost.TIMES100:
            values = 100.0*nn_dists
        elif self.distance_to_cost == DistanceToCost.NUM_NEIGHBORS_TO_1:
            values = nn_dists
        elif self.distance_to_cost == DistanceToCost.NUM_NEIGHBORS_TO_1_TIMES1000:
            values = 1000.0*nn_dists
        else:
            raise NotImplementedError(self.distance_to_cost)

        return {k: v for k, v in zip(molecules_smiles, values)}
    
    
class ConstantMolEvaluator():  
    def __init__(
        self,
        constant_value: float,
        **kwargs,
    ):
        self.constant_value=constant_value
    
    def evaluate_molecules(self, molecules_smiles: list[str]) -> dict:
        return {k: self.constant_value for k in molecules_smiles}
    

#### Which value functions to consider

In [None]:
inventory=PaRoutesInventory(n=5)

In [None]:

value_fns = [
    ("constant-0", ConstantMolEvaluator(0.0)),
    (
        "Tanimoto-distance",
        TanimotoNNCostEstimator(
            inventory=inventory, distance_to_cost=DistanceToCost.NOTHING
        ),
    ),
    (
        "Tanimoto-distance-TIMES10",
        TanimotoNNCostEstimator(
            inventory=inventory, distance_to_cost=DistanceToCost.TIMES10
        ),
    ),
#     (
#         "Tanimoto-distance-TIMES100",
#         TanimotoNNCostEstimator(
#             inventory=inventory, distance_to_cost=DistanceToCost.TIMES100
#         ),
#     ),
    (
        "Tanimoto-distance-EXP",
        TanimotoNNCostEstimator(
            inventory=inventory, distance_to_cost=DistanceToCost.EXP
        ),
    ),
    (
        "Tanimoto-distance-SQRT",
        TanimotoNNCostEstimator(
            inventory=inventory, distance_to_cost=DistanceToCost.SQRT
        ),
    ),
    (
        "Tanimoto-distance-NUM_NEIGHBORS_TO_1",
        TanimotoNNCostEstimator(
            inventory=inventory, distance_to_cost=DistanceToCost.NUM_NEIGHBORS_TO_1
        ),
    ),
    (
        "Tanimoto-distance-NUM_NEIGHBORS_TO_1_TIMES1000",
        TanimotoNNCostEstimator(
            inventory=inventory, distance_to_cost=DistanceToCost.NUM_NEIGHBORS_TO_1_TIMES1000
        ),
    ),
]

labelalias = {
    'constant-0': 'constant-0',
    'Tanimoto-distance': 'Tanimoto',
    'Tanimoto-distance-TIMES10': 'Tanimoto_times10',
    'Tanimoto-distance-TIMES100': 'Tanimoto_times100',
    'Tanimoto-distance-EXP': 'Tanimoto_exp',
    'Tanimoto-distance-SQRT': 'Tanimoto_sqrt',
    "Tanimoto-distance-NUM_NEIGHBORS_TO_1": "Tanimoto_nn_to_1",
    "Tanimoto-distance-NUM_NEIGHBORS_TO_1_TIMES1000": "Tanimoto_nn_to_1_times1000",
}



#### Remove infs

In [None]:
for name in data_dict.keys(): 
    data_dict[name] = data_dict[name].replace(np.inf, -1)

In [None]:
# data_dict_copy = data_dict.copy()

In [None]:

[file for file in os.listdir(f'{output_folder}') if '.csv' in file]

#### Add some features

In [None]:
#### CAREFUL If we need to edit results
data_dict = {}
for test_db in [file for file in os.listdir(f'{output_folder}') if '.csv' in file]:
    test_db_name = test_db.replace('_result.csv','')
    data_dict[test_db_name] = pd.read_csv(f'{output_folder}/{test_db}')



In [None]:
# Binned cost (lowest_cost_route_found)
cost_variable = 'lowest_cost_route_found'
binned_var_name = 'lowest_cost_route_(binned)'
num_bins = 20
lc_bin_ranges_dict = {}
lc_bin_labels_dict = {}
for test_db_name, test_data in data_dict.items():
    min_value = 0 # int(test_data['lowest_cost_route_found'].min())
    max_value = int(test_data[cost_variable].max())
    bin_range = np.array([-1, -0.5])
    bin_range = np.append(bin_range, np.linspace(min_value, max_value, num_bins+1, dtype=int))
#     bin_range = np.linspace(min_value, max_value, num_bins+1, dtype=int)
    bin_labels = [f'{str(int(round(lower,0))).zfill(3)}-{str(int(round(upper,0))).zfill(3)}' for lower, upper in zip(bin_range[:-1], bin_range[1:])]
    bin_labels[0] = 'NotSolved'
    bin_labels[1] = '000'
    lc_bin_ranges_dict[test_db_name] = bin_range
    lc_bin_labels_dict[test_db_name] = bin_labels
    

# Is purchasable
purchasable_mols_smiles = [mol.smiles for mol in inventory.purchasable_mols()]

for name in data_dict.keys(): 
    # Is purchasable
    data_dict[name]['is_purchasable'] = (data_dict[name]['smiles'].isin(purchasable_mols_smiles)) * 1.0
    
    # Binned cost (lowest_cost_route_found)
    data_dict[name][binned_var_name] = pd.cut(data_dict[name][cost_variable], bins=lc_bin_ranges_dict[test_db_name], labels=lc_bin_labels_dict[test_db_name], include_lowest=True)
#     pd.Series.cat.add_categories(data_dict[name][binned_var_name], ['NotSolved', '000'])
    data_dict[name].loc[data_dict[name][cost_variable] == -1, binned_var_name] = 'NotSolved'
    data_dict[name].loc[data_dict[name][cost_variable] == 0, binned_var_name] = '000'



In [None]:
bin_labels

In [None]:
from paroutes import PaRoutesInventory 
from tqdm.auto import tqdm

for name in data_dict.keys():     
    smiles_list = data_dict[name]['smiles'].unique()
    for value_function_name, value_function in tqdm(value_fns):
        smiles_value_fn_dict = value_function.evaluate_molecules(smiles_list)
        data_dict[name][value_function_name] = data_dict[name]['smiles'].map(smiles_value_fn_dict)
    
    

In [None]:
# test_data_copy = test_data.copy()

In [None]:
for test_db_name, test_data in data_dict.items():
    print(test_data.shape)

In [None]:
column_order = [
    'smiles', 'n_iter', 'first_soln_time', 
    'lowest_cost_route_found', 'lowest_cost_route_(binned)',
    'best_route_cost_lower_bound', 'lowest_depth_route_found',
    'best_route_depth_lower_bound', 'num_calls_rxn_model',
    'num_nodes_in_tree', 
    'is_purchasable',
    'constant-0', 'Tanimoto-distance',
    'Tanimoto-distance-TIMES10', 
    'Tanimoto-distance-EXP',
    'Tanimoto-distance-SQRT', 
    'Tanimoto-distance-NUM_NEIGHBORS_TO_1',
    'Tanimoto-distance-NUM_NEIGHBORS_TO_1_TIMES1000',
]

for test_db_name, test_data in data_dict.items(): 
    test_data = test_data[column_order]
    test_data.to_csv(f'{output_folder}/{test_db_name}_result.csv', index=False)

In [None]:
# aa = pd.read_csv('Results_from_input/20230524_1010/paroutes_n5_result.csv')
# aa = aa.rename(columns={'cost_binned': 'lowest_cost_route_(binned)'})
# aa.to_csv('Results_from_input/20230524_1010/paroutes_n5_result.csv', index=False)