<a href="https://colab.research.google.com/github/jfjoung/Mechanistic_dataset/blob/main/Mechanistic_dataset_visualization.ipynb" target="_blank">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In [None]:
!pip install rdkit ipywidgets  # Install RDKit and ipywidgets if not already installed

In [2]:
from rdkit import Chem
from rdkit.Chem import Draw
from IPython.display import display, Markdown
import ipywidgets as widgets

# Load the Reaction_templates.py content to get class_reaction_templates
import importlib.util, sys, types

repo_url = "https://raw.githubusercontent.com/jfjoung/Mechanistic_dataset/main/templates/Reaction_templates.py"
try:
    import requests
    res = requests.get(repo_url)
    res.raise_for_status()
    # Save to a local file (or directly exec the content)
    open("Reaction_templates.py", "w").write(res.text)
except Exception as e:
    print("Could not fetch Reaction_templates.py from GitHub. Using local file if available. Error:", e)

# Load the module from the file
spec = importlib.util.spec_from_file_location("Reaction_templates", "Reaction_templates.py")
rt_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(rt_module)
# The module should have class_reaction_templates defined
class_reaction_templates = rt_module.class_reaction_templates

# Count how many reaction classes are available:
all_classes = []
for class_key in class_reaction_templates:
    # class_key might be a tuple of class names or a single string
    if isinstance(class_key, tuple):
        all_classes.extend(list(class_key))
    else:
        all_classes.append(class_key)
all_classes_count = len(all_classes)
print(f"Loaded reaction templates for {all_classes_count} reaction classes.")


Loaded reaction templates for 252 reaction classes.


In [16]:
import importlib.util
import re
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import rdChemReactions
import ipywidgets as widgets
from IPython.display import display, HTML

# Load Reaction_templates.py
spec = importlib.util.spec_from_file_location("Reaction_templates", "Reaction_templates.py")
reaction_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(reaction_module)
reaction_dict = reaction_module.class_reaction_templates

# Extract atomic symbol
def extract_symbol(s):
    if s.startswith("#"):
        try:
            num = int(re.findall(r"#(\d+)", s)[0])
            return Chem.GetPeriodicTable().GetElementSymbol(num)
        except:
            return "*"
    match = re.match(r"([A-Z][a-z]?)", s)
    return match.group(1) if match else "*"

# Label atoms for both molecules and reactions
def set_query_atom_labels(mol, mode='clean'):
    for atom in mol.GetAtoms():
        if not atom.HasQuery():
            continue

        smarts = Chem.MolFragmentToSmarts(mol, atomsToUse=[atom.GetIdx()])
        if not (smarts.startswith("[") and smarts.endswith("]")):
            atom.SetProp("_displayLabel", atom.GetSymbol())
            continue

        inner = smarts[1:-1]

        if mode == 'query':
            atom.SetProp("_displayLabel", inner)
            continue

        # clean mode
        parts = re.split(r"[;&]", inner)
        symbol_raw = parts[0]
        if ":" in symbol_raw:
            symbol_raw = symbol_raw.split(":")[0]
        symbol = extract_symbol(symbol_raw)

        hcount = None
        charge = None
        for part in parts[1:]:
            part = part.strip()
            part = part.split(":")[0]
            if re.fullmatch(r"H\d*", part):
                try:
                    hcount = int(part[1:])
                except:
                    pass
            elif re.fullmatch(r"[+-]\d*", part):
                charge = part


        label = symbol
        if hcount == 1:
            label += "H"
        elif hcount and hcount > 1:
            label += f"H{hcount}"
        if charge not in (None, "+0", "-0", "0"):
            label += charge

        atom.SetProp("_displayLabel", label)

# Draw reaction SMARTS
def draw_reaction_smarts(smarts, mode='clean'):
    try:
        rxn = rdChemReactions.ReactionFromSmarts(smarts, useSmiles=False)
        for mol in rxn.GetReactants():
            set_query_atom_labels(mol, mode)
        for mol in rxn.GetProducts():
            set_query_atom_labels(mol, mode)
        return Draw.ReactionToImage(rxn, subImgSize=(300, 200))
    except Exception as e:
        return f"Error: {e}"

# Draw SMARTS molecules in grid
def display_smarts_grid(smarts_list, mode='clean', title='SMARTS'):
    if not isinstance(smarts_list, list) or not smarts_list:
        return
    print(f"{title}:")
    imgs = []
    for smarts in smarts_list:
        mol = Chem.MolFromSmarts(smarts)
        if mol:
            set_query_atom_labels(mol, mode)
            img = Draw.MolToImage(mol, size=(200, 150))
        else:
            img = f"⚠ {smarts}"
        imgs.append((smarts, img))

    html = "<table><tr>"
    for i, (smarts, img) in enumerate(imgs):
        html += f"<td style='text-align:center; padding:10px'>{smarts}<br/>"
        if isinstance(img, str):
            html += f"<div>{img}</div>"
        else:
            from io import BytesIO
            import base64
            bio = BytesIO()
            img.save(bio, format="PNG")
            data = base64.b64encode(bio.getvalue()).decode()
            html += f"<img src='data:image/png;base64,{data}'/>"
        html += "</td>"
        if (i + 1) % 4 == 0:
            html += "</tr><tr>"
    html += "</tr></table>"
    display(HTML(html))

# Widgets
class_dropdown = widgets.Dropdown(options=list(reaction_dict.keys()), description='Class:')
mech_dropdown = widgets.Dropdown(description='Mechanism:')
step_dropdown = widgets.Dropdown(description='Step:')
mode_dropdown = widgets.Dropdown(options=['clean', 'query'], value='clean', description='Mode:')
output = widgets.Output()

# Update mechanism list
def update_mechanisms(change):
    mech_dict = reaction_dict[change['new']]
    mech_dropdown.options = list(mech_dict.keys())
    if mech_dropdown.options:
        mech_dropdown.value = mech_dropdown.options[0]
    update_steps({'new': mech_dropdown.value})
    display_step()


# Update step list
def update_steps(change):
    mech_data = reaction_dict[class_dropdown.value][change['new']]
    step_dropdown.options = list(mech_data.get("Stages", {}).keys())
    if step_dropdown.options:
        step_dropdown.value = step_dropdown.options[0]
    display_step()


# Display selected step
def display_step(change=None):
    with output:
        output.clear_output()
        cls = class_dropdown.value
        mech = mech_dropdown.value
        step = step_dropdown.value
        mode = mode_dropdown.value
        data = reaction_dict[cls][mech]
        reagent = data.get("Reagent", [])
        exclude = data.get("Exclude_reagent", [])
        stage_data = data["Stages"].get(step, {})

        print(f"Reaction class: {cls}")
        print(f"Mechanism: {mech}")
        print(f"Step: {step}")

        display_smarts_grid(reagent, mode, title="Reagent")
        display_smarts_grid(exclude, mode, title="Exclude")

        if isinstance(stage_data, dict) and "Templates" in stage_data:
            templates = stage_data["Templates"]
            description = stage_data.get("Description", "")
        else:
            templates = [stage_data] if isinstance(stage_data, str) else []
            description = ""

        if description:
            print(f"Description: {description}")
        for i, tpl in enumerate(templates):
            print(f"Template {i+1}: {tpl}")
            # Optional pKa annotation
            pkas = stage_data.get("pKa", [])
            if i < len(pkas) and isinstance(pkas[i], dict):
                pkadict = pkas[i]
                if "A" in pkadict and pkadict["A"] is not None:
                    print(f"Required pKa of acid for this reaction: {pkadict['A']}")
                elif "B" in pkadict and pkadict["B"] is not None:
                    print(f"Required pKa of base for this reaction: {pkadict['B']}")
            img = draw_reaction_smarts(tpl, mode)
            if isinstance(img, str):
                print(img)
            else:
                display(img)



# Bind widget events
class_dropdown.observe(update_mechanisms, names='value')
mech_dropdown.observe(update_steps, names='value')
step_dropdown.observe(display_step, names='value')
mode_dropdown.observe(display_step, names='value')

# Trigger initial state
update_mechanisms({'new': class_dropdown.value})
update_steps({'new': mech_dropdown.value})
display_step()

# Show UI
ui = widgets.VBox([
    class_dropdown,
    mech_dropdown,
    step_dropdown,
    mode_dropdown,
    output
])
display(ui)


VBox(children=(Dropdown(description='Class:', options=(('Carboxylic acid + amine condensation', 'Carboxylic ac…