# Testing Observable Plot

In [1]:
import polars as pl
from pyobsplot import Plot, js, Obsplot

In [2]:
df = pl.read_parquet('../../dgym-data/experiment_annotations.parquet')
tested_adjacency = df.filter(pl.col("Current Status") == "Tested")[['Inspiration', 'SMILES']]
tested_adjacency = tested_adjacency.fill_null('start').to_numpy()

In [3]:
import networkx as nx

def create_graph(edge_list):
    """
    Constructs a directed graph from a given edge list.

    Parameters:
    - edge_list (list of tuples): Each tuple represents an edge in the form (source, target),
      where 'target' is inspired by 'source'.

    Returns:
    - G (networkx.DiGraph): A directed graph representing the molecule relationships.
    """
    G = nx.DiGraph()
    G.add_edges_from(edge_list)
    return G

def find_paths(graph):
    """
    Finds all paths from root nodes to leaf nodes in the graph.

    Parameters:
    - graph (networkx.DiGraph): The graph to traverse.

    Returns:
    - paths (list of strings): All paths in the graph from roots to leaves.
    """
    roots = [node for node, degree in graph.in_degree() if degree == 0]
    leaves = [node for node, degree in graph.out_degree() if degree == 0]
    paths = []

    for root in roots:
        for leaf in leaves:
            for path in nx.all_simple_paths(graph, source=root, target=leaf):
                paths.append('^'.join(path))
    return paths

# Example usage:
G = create_graph(tested_adjacency.tolist())
all_paths = find_paths(G)[::]
all_paths = [path for path in all_paths if 'start' in path]

In [4]:
Plot.plot({
    'axis': None,
    'margin': 10,
    'marginLeft': 200,
    'marginRight': 200,
    'width': 1800,
    'height': 2400,
    'marks': [
        Plot.tree(all_paths[:-20], {
            'textStroke': "white",
            'delimiter': '^',
            # 'curve': 'step-before',
            'treeSort': "node:height"
        })
    ]
})

ObsplotWidget(spec={'data': [], 'code': {'axis': None, 'margin': 10, 'marginLeft': 200, 'marginRight': 200, 'w…

In [313]:
utilities = []
for path in all_paths:
    smiles = path.split('^')[-1]
    try:
        utility = df.filter(pl.col("SMILES") == smiles)['utility'].item()
        utilities.append(utility)
    except:
        print(smiles)

In [334]:
import seaborn as sns

palette = sns.palettes.color_palette('magma')
colormap = sns.palettes.get_colormap('magma')

In [349]:
import numpy as np

def normalize_array(arr):
    """
    Normalize a numpy array to have its minimum and maximum values scaled to 0 and 1, respectively.
    
    Parameters
    ----------
    arr : np.ndarray
        The input array to normalize.
    
    Returns
    -------
    np.ndarray
        The normalized array with values scaled between 0 and 1.
    
    Example
    -------
    >>> arr = np.array([10, 20, 30, 40, 50])
    >>> normalize_array(arr)
    array([0. , 0.25, 0.5 , 0.75, 1. ])
    """
    arr = np.array(arr)
    min_val = np.nanmin(arr)
    max_val = np.nanmax(arr)
    normalized_arr = (arr - min_val) / (max_val - min_val)
    return normalized_arr

normalized_utility = normalize_array(utilities)

In [392]:
path_utility = pl.DataFrame({'Path': all_paths, 'Utility': normalized_utility})

In [424]:
plot_dict = Plot.tree(path_utility, {
    'path': 'Path',
    'strokeWidth': 1,
    'curve': "step-before",
    'treeLayout': indent,
    'treeSort': "node:height",
    'delimiter': '^',
})

In [510]:
# op = Obsplot(renderer="jsdom")

indent = js("""
() => {
  return (root) => {
    root.eachBefore((node, i) => {
      node.x = i;
      node.y = node.depth;
    });
  };
}
""")

Plot.plot({
    'axis': None,
    'inset': 10,
    'insetRight': 800,
    'round': True,
    'width': 1000,
    'height': 6000,
    'marks': [
        Plot.tree(path_utility, {
            'path': 'Path',
            'strokeWidth': 1,
            'curve': "step-before",
            'treeLayout': indent,
            'treeSort': "node:height",
            'delimiter': '^',
            # "fill": "red"
        }),
        # Plot.text(path_utility, {
        #     # 'x': 800,  # Static x-offset, adjust as needed to align with your tree layout
        #     # 'y': 'node_y',  # Assume 'node_y' is calculated as part of the tree layout or is a column in the DataFrame
        #     # 'x': 10,
        #     'y': 'Utility',
        #     'text': 'Utility',
        #     'text-anchor': 'end'
        # })
    ]
})

ObsplotWidget(spec={'data': [{'pyobsplot-type': 'DataFrame', 'value': b'ARROW1\x00\x00\xff\xff\xff\xff\xa8\x00…