In [1]:
import os, glob
import base64
import graphviz
import pandas as pd
import numpy as np

import json
from dash import Dash, dcc, html
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio
import matplotlib.pyplot as plt
from matplotlib.colors import ColorConverter
from rdkit import Chem
from rdkit.Chem import Draw

# Display the updated DataFrame in the notebook
from IPython.display import display

random_seed = 42
data_folder = 'data'

In [2]:
# load the dataframe from pickle files
select_properties_df = pd.read_pickle(os.path.join(data_folder, 'Select_properties.pkl'))
yields_df = pd.read_pickle(os.path.join(data_folder, 'Yields.pkl'))
yield_data_df = pd.read_pickle(os.path.join(data_folder, 'yield_data_df.pkl'))
select_properties_data_df = pd.read_pickle(os.path.join(data_folder, 'select_properties_data_df.pkl'))
select_properties_data_removed_highlycorr_df = pd.read_pickle(os.path.join(data_folder, 'select_properties_data_removed_highlycorr_df.pkl'))

In [3]:
# load the dict from the json file
with open(os.path.join(data_folder, "mol_image_paths.json"), "r") as f:
    mol_image_paths = json.load(f)
with open(os.path.join(data_folder, "mol_image_data.json"), "r") as f:
    mol_image_data = json.load(f)
with open(os.path.join(data_folder, "mol_image_paths_captioned.json"), "r") as f:
    mol_image_paths_captioned = json.load(f)
with open(os.path.join(data_folder, "mol_image_data_captioned.json"), "r") as f:
    mol_image_data_captioned = json.load(f)

In [None]:
mol_image_paths

In [None]:
yield_columns = yield_data_df.columns

# Build the heatmap
fig = go.Figure(
    data=go.Heatmap(
        z=yields_df[yield_columns].T.to_numpy(),
        x=yields_df["id"],
        y=yield_columns,
        colorscale="Reds",
        colorbar=dict(
            title="Yield",
            dtick=20,
            tickvals=[0, 20, 40, 60, 80, 100],
            ticktext=["0%", "20%", "40%", "60%", "80%", "100%"],
        ),
        text=yields_df[yield_columns].T.to_numpy(),
        texttemplate="%{text:.2f}",
        hovertemplate="""
        Compound ID: %{x}<br>
        Method: %{y}<br>
        Yield: %{text:.2f}%<extra></extra>
        """,
        showscale=True,
    )
)

# Add molecule images to the x-axis
for i, compound_id in enumerate(yields_df["id"]):
    img_path = mol_image_paths[compound_id]  # Path to molecule image
    img = base64.b64encode(open(img_path, "rb").read()).decode()
    fig.add_layout_image(
        dict(
            source=f"data:image/png;base64,{img}",
            xref="x",
            yref="paper",
            x=i,  # X position (aligned to x-axis tick)
            y=1,  # Y position slightly above the plot
            xanchor="center",
            yanchor="bottom",
            sizex=1,  # Adjust size
            sizey=1,  # Adjust size
            sizing="contain",  # Maintain aspect ratio
            layer="above",  # Place above heatmap
        )
    )

# Update layout
fig.update_layout(
    xaxis_title="Compound ID",
    yaxis_title="Method",
    height=len(yield_columns) * 50,
    width=len(yields_df["id"]) * 50,
    xaxis=dict(tickangle=30),
    template="plotly",
)


fig.show()

In [None]:
app = Dash()

app.layout = html.Div([
    dcc.Graph(
        id='testing',
        figure=fig
    )
])

app.run(jupyter_mode="tab")

In [None]:
from dash import Dash, dcc, html, Input, Output, no_update, callback
import plotly.graph_objects as go
import pandas as pd

# Small molecule drugbank dataset
# Source: https://raw.githubusercontent.com/plotly/dash-sample-apps/main/apps/dash-drug-discovery/data/small_molecule_drugbank.csv'
data_path = 'data/small_molecule_drugbank.csv'

df = pd.read_csv(data_path, header=0, index_col=0)

fig = go.Figure(data=[
    go.Scatter(
        x=df["LOGP"],
        y=df["PKA"],
        mode="markers",
        marker=dict(
            colorscale='viridis',
            color=df["MW"],
            size=df["MW"],
            colorbar={"title": "Molecular<br>Weight"},
            line={"color": "#444"},
            reversescale=True,
            sizeref=45,
            sizemode="diameter",
            opacity=0.8,
        )
    )
])

# turn off native plotly.js hover effects - make sure to use
# hoverinfo="none" rather than "skip" which also halts events.
fig.update_traces(hoverinfo="none", hovertemplate=None)

fig.update_layout(
    xaxis=dict(title='Log P'),
    yaxis=dict(title='pkA'),
    plot_bgcolor='rgba(255,255,255,0.1)'
)

app = Dash()

app.layout = html.Div([
    dcc.Graph(id="graph-basic-2", figure=fig, clear_on_unhover=True),
    dcc.Tooltip(id="graph-tooltip"),
])


@callback(
    Output("graph-tooltip", "show"),
    Output("graph-tooltip", "bbox"),
    Output("graph-tooltip", "children"),
    Input("graph-basic-2", "hoverData"),
)
def display_hover(hoverData):
    if hoverData is None:
        return False, no_update, no_update

    # demo only shows the first point, but other points may also be available
    pt = hoverData["points"][0]
    bbox = pt["bbox"]
    num = pt["pointNumber"]

    df_row = df.iloc[num]
    img_src = df_row['IMG_URL']
    name = df_row['NAME']
    form = df_row['FORM']
    desc = df_row['DESC']
    if len(desc) > 300:
        desc = desc[:100] + '...'

    children = [
        html.Div([
            html.Img(src=img_src, style={"width": "100%"}),
            html.H2(f"{name}", style={"color": "darkblue", "overflow-wrap": "break-word"}),
            html.P(f"{form}"),
            html.P(f"{desc}"),
        ], style={'width': '200px', 'white-space': 'normal'})
    ]

    return True, bbox, children


app.run(jupyter_mode="tab", debug=True)