In [1]:
from glob import glob
from os.path import exists, join, basename
from tqdm import tqdm
from json import load, dump
from matplotlib import pyplot as plt
from collections import Counter

from umap import UMAP

import pandas as pd
import numpy as np

import importlib.util
from pathlib import Path

# import local wizmap
path = Path.cwd().parent.parent / "notebook_widget" / "wizmap" / "wizmap.py"

spec = importlib.util.spec_from_file_location("wizmap", path)
wizmap = importlib.util.module_from_spec(spec)
spec.loader.exec_module(wizmap)

# Load Data

In [2]:
# Load data
sol_df = pd.read_csv("solubility.csv")
tox_df = pd.read_csv("toxicity.csv")
df = pd.read_csv("scaffold.csv")
# only keep solubility and toxicity
sol_df = sol_df[["Structure", "Solubility"]]
tox_df = tox_df[["Structure", "Toxicity"]]

# build dict lookups
sol_map = dict(zip(sol_df["Structure"].to_numpy(), sol_df["Solubility"].to_numpy()))
tox_map = dict(zip(tox_df["Structure"].to_numpy(), tox_df["Toxicity"].to_numpy()))

# add columns via dict
df["Solubility"] = df["Structure"].map(sol_map)
df["Toxicity"]   = df["Structure"].map(tox_map)

# verify no duplication or merging errors
print(df[["Solubility", "Toxicity"]].isna().sum())

# create cleaned array and delete unused varibles
arr = df.to_numpy()
del df, sol_df,tox_df,sol_map,tox_map
print(arr.shape)

emb = arr[:,0:32]
desc = arr[:, 33:]
print(desc.shape)
del arr

Solubility    0
Toxicity      0
dtype: int64
(2000000, 37)
(2000000, 4)


# Dim Reduction

In [3]:
reducer = UMAP(metric="cosine")
embeddings_2d = reducer.fit_transform(emb) # 9-10 minutes, 17 min on low power mode

In [4]:
del emb # remove emb, since we no longer need it

# Wizmap

In [5]:
xs = embeddings_2d[:, 0].astype(float).tolist()
ys = embeddings_2d[:, 1].astype(float).tolist()

out = []
n = desc.shape[0]
chunk_size = 200000

for i in range(0, n, chunk_size):
    chunk = desc[i:i+chunk_size, :]

    texts = (
        chunk[:, 0].astype(str)
        + "; Scaffold: "   + chunk[:, 1].astype(str)
        + "; Solubility: " + chunk[:, 2].astype(str)
        + "; Toxicity: "   + chunk[:, 3].astype(str)
    )
    out.append(texts)

texts_full = np.concatenate(out)
del texts, chunk, out

In [6]:
del embeddings_2d

In [7]:
#Create Datalist
data_list = wizmap.generate_data_list(xs, ys, texts_full)
del texts_full

Start generating data list...


In [10]:
# The following is where the values are computed
instructions = """You are a computational chemist analyzing chemical structures.

You will be given a set of chemical structures represented as SMILES strings.

Your task is to analyze these structures to identify:
- Structural similarity and shared substructures
- Common functional groups present in the SMILES representations
- Relevant chemical properties (such as solubility and toxicity)
- The chemical rationale for grouping these functional groups together, based on their structural, electronic, and physicochemical characteristics

Rather than explaining each structure individually, focus on the patterns that are highly common across the set. 
Consider functional group chemistry, aromaticity, polarity, hydrogen bonding, steric effects, and how these features influence chemical behavior and properties.

Where relevant, include concrete examples of functional groups or substructures inferred from the SMILES strings to illustrate these patterns.

Summarize the common structural and chemical patterns in 50 words or fewer, then list 2-5 key descriptors that best characterize the group.

Provide your response strictly in the following JSON format:
{
  "keywords": string[], // array of key descriptors that best characterize the group
  "summary": string // 50 words or fewer summary of the common structural and chemical patterns
}
"""

structure = desc[:,0]
grid_dict = wizmap.generate_grid_dict(xs, ys, structure, instructions, "Chemical Structures", max_zoom_scale=10) # replaced by llm format
del structure

Start generating contours...
Start generating multi-level summaries...


2000000it [00:13, 148492.02it/s]


Level 1/4


100%|██████████| 13250/13250 [00:00<00:00, 14971.86it/s]


12764
Level 2/4


100%|██████████| 3748/3748 [00:00<00:00, 4815.59it/s]


3462
Level 3/4


100%|██████████| 1133/1133 [00:00<00:00, 1570.01it/s]


1003
Level 4/4


100%|██████████| 376/376 [00:00<00:00, 423.82it/s]


311
max Size =  26890


In [None]:
wizmap.save_json_files(data_list, grid_dict, output_dir="./")