In [None]:
# !export PYTHONPATH=/Users/ilariasartori/syntheseus:/Users/ilariasartori/syntheseus/tutorials/search

In [None]:
# !echo $PYTHONPATH

In [None]:
import numpy as np
import pandas as pd

class SearchResult:
    def __init__(self, name, soln_time_dict, num_different_routes_dict, 
                 final_num_rxn_model_calls_dict, output_graph_dict, routes_dict):
        self.name = name
        self.soln_time_dict = soln_time_dict
        self.num_different_routes_dict = num_different_routes_dict
        self.final_num_rxn_model_calls_dict = final_num_rxn_model_calls_dict
        self.output_graph_dict = output_graph_dict
        self.routes_dict = routes_dict

labelalias = {
    'constant-0': 'constant-0',
    'Tanimoto-distance': 'Tanimoto',
    'Tanimoto-distance-TIMES10': 'Tanimoto_times10',
    'Tanimoto-distance-TIMES100': 'Tanimoto_times100',
    'Tanimoto-distance-EXP': 'Tanimoto_exp',
    'Tanimoto-distance-SQRT': 'Tanimoto_sqrt',
    "Tanimoto-distance-NUM_NEIGHBORS_TO_06": "Tanimoto_nn_to_06",
}




In [None]:
# alg_names = [x[0] for x in value_fns]
# alg_names

## Load from pickle

In [None]:
# # Load pickle
import pickle
import os

eventid= "202305-1709-5350-ZZ-and_202305-1412-3438"
output_folder = f"Results/{eventid}"

result = {}
for file_name in [file for file in os.listdir(output_folder) if 'pickle' in file]:
    name = file_name.replace('.pickle','').replace('result_','')
    with open(f'{output_folder}/{file_name}', 'rb') as handle:
        result[name] = pickle.load(handle)



In [None]:
import pandas as pd

def create_result_df(result, name):
    assert name == result[name].name, f"name: {name} is different from result[name].name: {result[name].name}"
    
    soln_time_dict = result[name].soln_time_dict
    num_different_routes_dict = result[name].num_different_routes_dict
    final_num_rxn_model_calls_dict = result[name].final_num_rxn_model_calls_dict
    output_graph_dict = result[name].output_graph_dict
    routes_dict = result[name].routes_dict

    # df_results = pd.DataFrame()
    df_soln_time = pd.DataFrame({'algorithm': [], 'similes': [], 'property':[], 'value': []})
    df_different_routes = pd.DataFrame({'algorithm': [], 'similes': [], 'property':[], 'value': []})

    #     for name_alg, value_dict  in soln_time_dict.items():
    for smiles, value  in soln_time_dict.items():
        row_soln_time = {'algorithm': name, 'similes': smiles, 'property':'sol_time', 'value': value}

        df_soln_time = pd.concat([df_soln_time, pd.DataFrame([row_soln_time])], ignore_index=True)

    #     for name_alg, value_dict  in num_different_routes_dict.items():
    for smiles, value  in num_different_routes_dict.items():
        row_different_routes = {'algorithm': name, 'similes': smiles, 'property':'diff_routes', 'value': value}

        df_different_routes = pd.concat([df_different_routes, pd.DataFrame([row_different_routes])], ignore_index=True)

    df_results_tot = pd.concat([df_soln_time, df_different_routes], axis=0)
    return df_results_tot



df_results_tot = pd.DataFrame({'algorithm': [], 'similes': [], 'property':[], 'value': []})
for name in result.keys():
    df_results_alg = create_result_df(result, name)
    df_results_tot = pd.concat([df_results_tot, df_results_alg], axis=0)
    
    
    
    

## Load from df_results

In [None]:
# # Load csv
# import pandas as pd
# import numpy as np

# # eventid= "202305-1310-3717-7e7e984c-8c3e-4a18-ad67-5c4b29743282"
# # output_folder = f"Results/{eventid}"

# df_results_tot = pd.read_csv(f'{output_folder}/results_all.csv')

## 1. SOLUTIONS

### 1a. Solution times

In [None]:
results_solution_times = df_results_tot.loc[df_results_tot['property']=='sol_time']

In [None]:
df_result = results_solution_times.copy()

In [None]:
df_result["value_is_inf"] = (df_result['value'] == np.inf) * 1


In [None]:
df_results_grouped = df_result.groupby(["algorithm", "property"], as_index=False).agg(nr_mol_not_solved=pd.NamedAgg(column="value_is_inf", aggfunc="sum"))
df_results_grouped


In [None]:
df_results_grouped.to_csv(f'{output_folder}/num_mol_not_solved.csv', index=False)

In [None]:
import plotly.express as px
fig = px.box(df_result, x="algorithm", y="value", width=1000, height=600,
             labels={
#                      "algorithm": None,
                     "value": "Time to first solution",
#                      "species": "Species of Iris"
                 },
#              title="Time to first solution"
            )
fig.update_layout(xaxis_title=None)
fig.update_xaxes(labelalias=labelalias, categoryorder='array', categoryarray=list(labelalias.keys()))
fig.write_image(f'{output_folder}/Boxplot_time_first_solution.png') 
fig.show() 

In [None]:
result.keys()

### 1b. Solution diversity

In [None]:
results_diff_routes = df_results_tot.loc[df_results_tot['property']=='diff_routes']

In [None]:
df_result = results_diff_routes.copy()

In [None]:
df_result["value_is_zero"] = (df_result['value'] == 0) * 1


In [None]:
df_results_grouped = df_result.groupby(["algorithm", "property"], as_index=False).agg(nr_mol_not_solved=pd.NamedAgg(column="value_is_zero", aggfunc="sum"))
df_results_grouped


In [None]:
import plotly.express as px
fig = px.box(df_result, x="algorithm", y="value", width=1000, height=600,
             labels={
                     "value": "Number of different routes",
                 },
            )
fig.update_layout(xaxis_title=None)
fig.update_xaxes(labelalias=labelalias, categoryorder='array', categoryarray=list(labelalias.keys()))
fig.write_image(f'{output_folder}/Boxplot_num_different_routes.png')
fig.show() 

In [None]:
fig = px.box(df_result.loc[df_result['value']!=0], x="algorithm", y="value", 
             width=1000, height=600,
             labels={
                     "value": "Number of different routes (removing zeros)",
                 },
            )

fig.update_layout(xaxis_title=None)
fig.update_xaxes(labelalias=labelalias, categoryorder='array', categoryarray=list(labelalias.keys()))
fig.write_image(f'{output_folder}/Boxplot_num_different_routes_no_zero.png') 
fig.show() 


## 2. CORRELATION: value function - actual cost

In [None]:
# algs_to_consider = list(result.keys())
algs_to_consider = ['Tanimoto-distance-TIMES10']

algs_string = '_'.join(algs_to_consider)
algs_string

### 2a. Assign costs

In [None]:
from syntheseus.search.graph.and_or import AndNode


cost_type = "cost_1_react"
# cost_type = "cost_react_from_data"
# cost_type = "cost_react_from_data_pow01"

for name in algs_to_consider:    
    if cost_type == "cost_1_react": 
        for target_smiles, graph in result[name].output_graph_dict.items():
            for node in graph._graph.nodes():
                    if isinstance(node, (AndNode,)):
                        node.data["route_cost"] = 1.0
                    else:
                        node.data["route_cost"] = 0.0
    elif cost_type == "cost_react_from_data": 
        for target_smiles, graph in result[name].output_graph_dict.items():   
#             # 1. Set reaction costs (should be already done by the algorithm)
#             and_nodes=[
#                     node
#                     for node in graph._graph.nodes()
#                     if isinstance(node, AndNode) and "retro_star_rxn_cost" not in node.data
#                 ]
#             costs = and_node_cost_fn(and_nodes, graph=graph)
#             assert len(costs) == len(and_nodes)
#             for node, cost in zip(and_nodes, costs):
#                 node.data["retro_star_rxn_cost"] = cost
            # 2. Set route costs equal to reaction costs
            for node in graph._graph.nodes():
                    if isinstance(node, (AndNode,)):
                        node.data["route_cost"] = node.data["retro_star_rxn_cost"]
                    else:
                        node.data["route_cost"] = 0.0
    elif cost_type == "cost_react_from_data_pow01": 
        for target_smiles, graph in result[name].output_graph_dict.items():   
            for node in graph._graph.nodes():
                    if isinstance(node, (AndNode,)):
                        node.data["route_cost"] = np.power(node.data["retro_star_rxn_cost"], 0.1)
                    else:
                        node.data["route_cost"] = 0.0


    else:
        raise NotImplementedError(f'Cost type {cost_type}')

### 2b. Create dataframe with values and costs

In [None]:
from syntheseus.search.analysis import route_extraction
from syntheseus.search import visualization
from syntheseus.search.analysis.route_extraction import _min_route_cost, _min_route_partial_cost
from syntheseus.search.graph.base_graph import RetrosynthesisSearchGraph
import networkx as nx

import heapq
import math
from collections.abc import Collection, Iterator
from syntheseus.search.graph.and_or import AndOrGraph, OrNode
from typing import Callable, Optional, TypeVar

from syntheseus.search.graph.node import BaseGraphNode

NodeType = TypeVar("NodeType", bound=BaseGraphNode)

# def get_descendants(graph, node):
#     descendants_set = set(graph.successors(node))
#     for graph.successors(node)


def custom_cost_min_route(
    graph: RetrosynthesisSearchGraph,
    start_node,
    cost_fn: Callable[[Collection[NodeType], RetrosynthesisSearchGraph[NodeType]], float],
    cost_lower_bound: Callable[[Collection[NodeType], RetrosynthesisSearchGraph[NodeType]], float],
    max_routes: int,
    yield_partial_routes: bool = False,
) -> Iterator[tuple[float, Collection[NodeType]]]:
    """
    Iterator over the minimal trees (routes) with lowest cost.
    This can be done efficiently given a lower bound on the cost.

    NOTE: it is not clear whether this function is the best way to extract routes,
    and if in general it is guaranteed to not return the same route twice. We think
    this is the case but are not sure in general.

    Args:
        graph: graph to iterate over. Could be tree, but does not need to be.
        cost_fn: Gives the cost of a route (specified by the set of nodes).
            A cost of inf means the route will not be returned.
        cost_lower_bound: A lower bound of the cost. The lower bound means that
            if the function is evaluated on a set A, the cost of a set B >= A
            will always exceed this lower bound.
            This function will always be evaluated on partial routes.
        max_routes: Maximum number of routes to return.
        yield_partial_routes: if True, will yield routes whose leaves
            have children in the full graph. This could be useful if, for example,
            there are purchasable molecules which have children.
            Typically this will be undesirable though.

    Yields:
        Tuples of cost, route nodes.
    """

    # Initialize priority queue
    # items are: cost, whether the cost is the true cost or a lower bound,
    # tie-breaking integer (since sets cannot be ordered),
    # set of nodes in partial route, list of nodes on the route's frontier
    
    
    ### CHANGE START ###
#     queue: list[tuple[float, bool, int, set[NodeType], list[NodeType]]] = [
#         (-math.inf, False, 0, {graph.root_node}, [graph.root_node])
#     ]
    queue: list[tuple[float, bool, int, set[NodeType], list[NodeType]]] = [
        (-math.inf, False, 0, {start_node}, [start_node])
    ]
    ### END CHANGE ###
    tie_breaker = 1

    # Do best-first search
    num_routes_yielded = 0
    while len(queue) > 0 and num_routes_yielded < max_routes:
        # Pop route
        cost, is_true_cost, _, partial_route, route_frontier = heapq.heappop(queue)
        assert cost < math.inf, "Infinite cost routes should not be in the queue."

        # Scenario 1: if it is a full route, then yield it,
        # because its cost must be lower than the partial cost of all other routes.
        if is_true_cost:
            assert len(route_frontier) == 0
            return (cost, partial_route)
            num_routes_yielded += 1
        else:
            # Choose the first node in the frontier to be "expanded"
            # and re-add to the queue
            assert len(route_frontier) > 0
            node_to_expand = route_frontier[0]
            remaining_frontier = route_frontier[1:]
            possible_new_routes: list[tuple[set[NodeType], list[NodeType]]] = []

            # Potentially add this node without any of its children
            if len(list(graph.successors(node_to_expand))) == 0 or yield_partial_routes:
                possible_new_routes.append((partial_route, remaining_frontier))

            # Add all children routes, 1 at a time
            if isinstance(node_to_expand, OrNode):
                # For AND/OR trees, add each And Child and all of its children
                for and_child in graph.successors(node_to_expand):
                    and_child_children = list(graph.successors(and_child))
                    new_partial_route = partial_route | {and_child} | set(and_child_children)
                    # New frontier excludes nodes already in partial route which would either already be expanded
                    # or be in the frontier already
                    new_frontier = remaining_frontier + [
                        n for n in and_child_children if n not in partial_route
                    ]
                    possible_new_routes.append((new_partial_route, new_frontier))
#             elif isinstance(node_to_expand, MolSetNode):
#                 # For MolSet graphs, add each child individually
#                 for child in graph.successors(node_to_expand):
#                     new_partial_route = partial_route | {child}
#                     new_frontier = list(remaining_frontier)
#                     if child not in partial_route:
#                         new_frontier.append(child)
#                     possible_new_routes.append((new_partial_route, new_frontier))
            else:
                raise TypeError(f"Unknown node type {type(node_to_expand)}.")

            # Add all possible routes onto the queue
            for new_partial_route, new_frontier in possible_new_routes:
                if len(new_frontier) == 0:
                    new_cost = cost_fn(new_partial_route, graph)
                    assert new_cost >= cost, "lower bound not satisfied"
                    new_cost_is_full = True
                else:
                    new_cost = cost_lower_bound(new_partial_route, graph)
                    new_cost_is_full = False

                if new_cost < math.inf:
                    heapq.heappush(
                        queue,
                        (new_cost, new_cost_is_full, tie_breaker, new_partial_route, new_frontier),
                    )
                    tie_breaker += 1


# def reachable_nodes(G, n):
#     """
#     Returns the set of nodes that can be reached starting from node n in graph G.
#     """
#     visited = set()  # Set to keep track of visited nodes
#     stack = [n]  # Stack to keep track of nodes to explore
    
#     while stack:
#         node = stack.pop()
#         visited.add(node)
#         successors = G.successors(node)
#         for s in successors:
#             if s not in visited:
#                 stack.append(s)
    
#     return visited

rows = []
for name in algs_to_consider:    
    output_graph_dict = result[name].output_graph_dict
    for target_smiles, graph in output_graph_dict.items():
        for node in graph._graph.nodes:
            if isinstance(node, OrNode): # Molecule 
                row_data = {'name': name,
                            'smiles': node.mol.smiles,
                            'is_purchasable': node.mol.metadata["is_purchasable"],
                            'node_is_expanded': node.is_expanded,
                            'node_depth': node.depth
                           }
                row_data.update(node.data) 
                # Compute minimal cost
#                 # 1. Create subgraph from current node
#                 # Get the set of descendants of the start node
# #                 descendants = set(nx.descendants(graph, node))
#                 descendants = reachable_nodes(graph, node)
#                 # Add the start node itself to the set of descendants
#                 descendants.add(node)
#                 # Create the subgraph from the descendants
#                 subgraph = graph._graph.subgraph(descendants)

#                 # 2. Compute min cost routes
#                 min_cost_route = route_extraction.min_cost_routes(subgraph, max_routes=1)
#                 min_cost = _min_route_cost(min_cost_route, subgraph)
#                 
                min_route_result = custom_cost_min_route(
                    graph=graph,
                    start_node=node,
                    cost_fn=_min_route_cost,
                    cost_lower_bound=_min_route_partial_cost,
                    max_routes=1,
                    yield_partial_routes= False,
                )
                if min_route_result is not None:
                    min_cost, min_cost_route = min_route_result
                else:
                    min_cost = np.inf
                row_data.update({'minimal_cost_forward': min_cost})
                rows = rows + [row_data]
df_nodes = pd.DataFrame(rows)                
        
df_nodes

In [None]:
df_nodes['is_solved'] = (df_nodes['first_solution_time'] != np.inf)*1.0
df_nodes['is_purchasable'] = (df_nodes['is_purchasable'])*1.0

df_nodes['node_is_expanded'] = (df_nodes['node_is_expanded'])*1.0



In [None]:
df_nodes['is_solved'] = (df_nodes['first_solution_time'] != np.inf)*1.0
df_nodes['is_purchasable'] = (df_nodes['is_purchasable'])*1.0
df_nodes['node_is_expanded'] = (df_nodes['node_is_expanded'])*1.0



##### Save result

In [None]:
df_nodes.to_csv(f'{output_folder}/{algs_string}_df_nodes.csv', index=False)



In [None]:
# select_algorithm = 'Tanimoto-distance-TIMES10'

df_nodes_red = df_nodes.loc[df_nodes['name'].isin(algs_to_consider)]

solved_mask = df_nodes_red['is_solved']==1.0
df_nodes_red_solved = df_nodes_red.loc[solved_mask]
df_nodes_red_not_solved = df_nodes_red.loc[~solved_mask]

x_axis_var = 'minimal_cost_forward'
# y_axis_var = 'retro_star_value'
# y_axis_var = 'reaction_number'
y_axis_var = 'reaction_number_estimate'


df_nodes_red['x_var_inf'] = (df_nodes_red[x_axis_var] == np.inf)*1.0
df_nodes_red['y_var_inf'] = (df_nodes_red[y_axis_var] == np.inf)*1.0


#### Molecule solved - Molecule cost

In [None]:
df_nodes_red.groupby(['is_solved', 'x_var_inf']).agg(
    count=pd.NamedAgg(column="smiles", aggfunc="count"))


In [None]:
fig = px.box(df_nodes_red, y=y_axis_var, x= 'is_purchasable', color='node_is_expanded', #points="all", 
             width=1000, height=600,
             labels={
#                      "value": "Number of different routes (removing zeros)",
                 },
            )

# fig.update_layout(xaxis_title=None)
# fig.update_xaxes(labelalias=labelalias, categoryorder='array', categoryarray=list(labelalias.keys()))
fig.write_image(f'{output_folder}/{x_axis_var}_{algs_string}_solved_not_solved.png') 
fig.show() 



#### Molecule solved - Molecule value

In [None]:
df_nodes_red.groupby(['is_solved', 'y_var_inf']).agg(
    count=pd.NamedAgg(column="smiles", aggfunc="count"))



In [None]:
# import plotly.express as px


# fig = px.box(df_nodes_red_solved, y=x_axis_var, 
#              width=1000, height=600,
#              labels={
# #                      "value": "Number of different routes (removing zeros)",
#                  },
#             )

# # fig.update_layout(xaxis_title=None)
# # fig.update_xaxes(labelalias=labelalias, categoryorder='array', categoryarray=list(labelalias.keys()))
# # fig.write_image(f'{output_folder}/Correlation_{x_axis_var}_{y_axis_var}.png') 
# fig.show()

In [None]:
fig = px.box(df_nodes_red, y=y_axis_var, x= 'is_solved', color='is_purchasable',
             width=1000, height=600,
             labels={
#                      "value": "Number of different routes (removing zeros)",
                 },
            )

# fig.update_layout(xaxis_title=None)
# fig.update_xaxes(labelalias=labelalias, categoryorder='array', categoryarray=list(labelalias.keys()))
fig.write_image(f'{output_folder}/{y_axis_var}_{algs_string}_solved_not_solved.png') 
fig.show() 



In [None]:
# fig = px.box(df_nodes_red_not_solved, y=y_axis_var, 
#              width=1000, height=600,
#              labels={
# #                      "value": "Number of different routes (removing zeros)",
#                  },
#             )

# # fig.update_layout(xaxis_title=None)
# # fig.update_xaxes(labelalias=labelalias, categoryorder='array', categoryarray=list(labelalias.keys()))
# # fig.write_image(f'{output_folder}/Correlation_{x_axis_var}_{y_axis_var}.png') 
# fig.show() 






#### Correlation molecule cost and value

In [None]:
import dash
from dash import dcc, html
from dash.dependencies import Input, Output
import plotly.express as px

# # Sample DataFrame for demonstration
# df_nodes_red = pd.DataFrame({'minimal_cost_forward': [1, 2, 3],
#                              'reaction_number': [4, 5, 6],
#                              'reaction_number_estimate': [7, 8, 9],
#                              'is_solved': ['X', 'Y', 'X'],
#                              'is_purchasable': ['M', 'N', 'M'],
#                              'retro_star_value': [10, 11, 12]})

app = dash.Dash(__name__)

app.layout = html.Div([
    dcc.Dropdown(
        id='x-axis-dropdown',
        options=[{'label': col, 'value': col} for col in df_nodes_red.columns],
        value='minimal_cost_forward',
        placeholder='Select X-axis variable'
    ),
    dcc.Dropdown(
        id='y-axis-dropdown',
        options=[{'label': 'Reaction Number', 'value': 'reaction_number'},
                 {'label': 'Reaction Number Estimate', 'value': 'reaction_number_estimate'},
                 {'label': 'Retro Star Value', 'value': 'retro_star_value'}],
        value='reaction_number',
        placeholder='Select Y-axis variable'
    ),
    html.Label('Filter by value in column "is_solved":'),
    dcc.RadioItems(
        id='filter-radio-is-solved',
        options=[{'label': val, 'value': val} for val in df_nodes_red['is_solved'].unique()],
        value=None,
        labelStyle={'display': 'block'}
    ),
    html.Label('Filter by value in column "is_purchasable":'),
    dcc.RadioItems(
        id='filter-radio-is-purchasable',
        options=[{'label': val, 'value': val} for val in df_nodes_red['is_purchasable'].unique()],
        value=None,
        labelStyle={'display': 'block'}
    ),
    dcc.RadioItems(
        id='color-column-radio',
        options=[
            {'label': 'Same color for all points', 'value': 'same_color'},
            {'label': 'Color by "is_solved" column', 'value': 'is_solved'},
            {'label': 'Color by "is_purchasable" column', 'value': 'is_purchasable'},
            {'label': 'Color by "node_depth" column', 'value': 'node_depth'}
        ],
        value='same_color',
        labelStyle={'display': 'block'}
    ),
    dcc.Graph(id='scatter-plot')
])

@app.callback(
    Output('scatter-plot', 'figure'),
    Input('x-axis-dropdown', 'value'),
    Input('y-axis-dropdown', 'value'),
    Input('filter-radio-is-solved', 'value'),
    Input('filter-radio-is-purchasable', 'value'),
    Input('color-column-radio', 'value')
)

def update_scatter_plot(x_axis_var, y_axis_var, filter_value_solved, filter_value_purchasable, color_option):
    filtered_df = df_nodes_red.copy()

    if filter_value_solved:
        filtered_df = filtered_df[filtered_df['is_solved'] == filter_value_solved]

    if filter_value_purchasable:
        filtered_df = filtered_df[filtered_df['is_purchasable'] == filter_value_purchasable]

    hover_data = {x: True for x in df_nodes_red.columns}

#     color_discrete_map = None
    if color_option == 'same_color':
        color = None
    elif color_option in ['is_solved', 'is_purchasable']:
        filtered_df[color_option] = filtered_df[color_option].astype(str)
        color = color_option
#         color_discrete_map = {'0': 'red', '1': 'blue'}  # Define desired colors for each unique value of 'is_solved'
#     elif color_option == 'is_purchasable':
#         color_discrete_map = {'0': 'red', '1': 'blue'}  # Define desired colors for each unique value of 'is_purchasable'
    fig = px.scatter(filtered_df, x=x_axis_var, y=y_axis_var, hover_data=hover_data, 
                     color=color,
                     width=1000, height=600)
#     fig = px.scatter(filtered_df, x=x_axis_var, y=y_axis_var, hover_data=hover_data,
#                      color=color_option, color_discrete_map=color_discrete_map,
#                      width=1000, height=600)

    return fig

if __name__ == '__main__':
    app.run_server(port=8049, use_reloader=False, debug=True)