## DeepFrag2

DeepFrag2 is a deep learning tool designed for lead optimization, the process of modifying a small-molecule ligand to improve its binding to a protein target. Given the 3D structure of a protein-ligand complex, DeepFrag suggests a molecular fragment that can be added to the ligand to improve binding. The model outputs the "fingerprint" of the theoretical fragment to add. This fingerprint is then used to search a database of known fragments to identify the most suitable real-world candidate.

In [None]:
# @title Download Deepfrag2 and set up the required dependencies
# @markdown This cell automates the setup process by first cloning the DeepFrag2 source
# @markdown code from its GitHub repository and installing the Miniconda package manager.
# @markdown It then uses the high-speed Mamba solver to create a dedicated and
# @markdown reproducible `deepfrag2` environment with all the specific library versions
# @markdown required to run the program.

# Clone the deepfrag2 git repo
import os
if not os.path.exists("deepfrag2"):
  !git clone https://github.com/durrantlab/deepfrag2.git
  !cd deepfrag2; git checkout main; cd ..

# Install Miniconda
import os
if not os.path.exists("Miniconda3-latest-Linux-x86_64.sh"):
  !wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
  !chmod +x Miniconda3-latest-Linux-x86_64.sh
  !bash ./Miniconda3-latest-Linux-x86_64.sh -b -f -p /usr/local/miniconda

# Install Mamba for faster environment solving
print("Installing Mamba...")
!/usr/local/miniconda/bin/conda install mamba -n base -c conda-forge -y
print("Mamba installation complete.")

# Now check if conda is accessible
!/usr/local/miniconda/bin/conda --version

# Create an environment.yml file for the dependencies required to run DeepFrag2 (CPU version)
# This version is optimized for faster solving by prioritizing the conda-forge channel.
with open("environment.yml", "w") as f:
  f.write("""name: deepfrag2

channels:
  - conda-forge
  - pytorch
  - pyg
  - nodefaults

dependencies:
  - python=3.9
  - pip=24.0
  - flake8=4.0.1
  - mypy=0.812
  - pytest=7.1.2
  - numpy=1.23.5
  - numba=0.56.4
  - prody=2.4.1
  - pytorch=1.11.0
  - torchvision=0.12.0
  - torchaudio=0.11.0
  - pytorch-lightning=1.7.1
  - scipy=1.11.4
  - pandas=1.4.4
  - rdkit=2024.3.1
  - k3d=2.15.2
  - py3dmol=2.4.0
  - h5py=3.7.0
  - scikit-learn=1.5.1
  - torchinfo=1.8.0
  - wandb=0.15.8
  - torch-geometric
  - filelock=3.13.1
  - regex
  - sacremoses=0.0.53
  - sentencepiece=0.2.0
  - tokenizers=0.20.0
  - torchmetrics=0.11.0
  - matplotlib=3.8.3
  - protobuf=3.20.3
  - wget
  - fair-esm
  - mdanalysis
  - pip:
      - Fancy-aggregations
""")

# Use Mamba to create the environment.
# The --force flag ensures that any existing environment with the same name is
# automatically removed, preventing interactive prompts on subsequent runs.
print("Creating environment with Mamba...")
!/usr/local/miniconda/bin/mamba env create --force -f environment.yml
print("Environment creation complete.")

In [None]:
# @title Load PDB file

# @markdown This cell provides an interactive widget to load the protein-ligand complex
# @markdown structure that will be used as the starting point for fragment growing. You
# @markdown can choose to upload a PDB file from your computer, download a structure
# @markdown directly from the Protein Data Bank via its PDB ID, or use the pre-configured
# @markdown 2HU4 tutorial example.
# @markdown <br /><br />Choose how you want to provide the PDB file of your protein/ligand
# @markdown complex:

import ipywidgets as widgets
from IPython.display import display, clear_output
from google.colab import files
import os
import requests
import MDAnalysis as mda
from io import StringIO
import warnings

# Create radio buttons for selection
input_method = widgets.RadioButtons(
    options=[
        ('Upload PDB file from computer', 'upload'),
        ('Download from Protein Data Bank by PDB ID', 'pdb_id'),
        ('Use example (2HU4)', 'example')
    ],
    description='Input method:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='400px')
)

# Create output widgets for dynamic content
instructions_area = widgets.Output()
output_area = widgets.Output()

def on_method_change(change):
    # Clear and update instructions
    instructions_area.clear_output(wait=True)
    with instructions_area:
        if change['new'] == 'upload':
            print("Upload a PDB file of the protein/ligand complex from your computer.")
            print("   If you wish to replace a fragment (rather than just adding a fragment")
            print("   to the existing ligand), you must delete the fragment to be replaced")
            print("   before uploading the PDB.")
        elif change['new'] == 'pdb_id':
            print("Download a PDB structure directly from the Protein Data Bank.")
            print("   Enter a valid 4-character PDB ID (e.g., 2HU4, 1BVN, 3K3F).")
            print("   The structure will be downloaded and saved as system.pdb.")
        elif change['new'] == 'example':
            print("Use the provided example: 2HU4 (neuraminidase bound to oseltamivir).")
            print("   This will automatically download the structure and remove the")
            print("   oseltamivir carboxylate fragment, preparing it for fragment growing.")
        print("")

    # Completely clear and reset the output area
    output_area.clear_output(wait=True)

    with output_area:
        if change['new'] == 'upload':
            print("Click the button below to upload your PDB file:")
            upload_btn = widgets.Button(
                description="Upload PDB File",
                button_style='info',
                icon='upload'
            )

            def upload_file(b):
                output_area.clear_output(wait=True)
                with output_area:
                    print("Select your PDB file...")
                    uploaded = files.upload()

                    if uploaded:
                        original_name = next(iter(uploaded))
                        new_name = "system.pdb"
                        os.rename(original_name, new_name)
                        print(f"File uploaded and saved as {new_name}")

            upload_btn.on_click(upload_file)
            display(upload_btn)

        elif change['new'] == 'pdb_id':
            pdb_id_input = widgets.Text(
                placeholder='Enter PDB ID (e.g., 2HU4)',
                description='PDB ID:',
                style={'description_width': 'initial'}
            )

            download_btn = widgets.Button(
                description="Download PDB",
                button_style='success',
                icon='download'
            )

            def download_pdb(b):
                pdb_id = pdb_id_input.value.strip().upper()
                if not pdb_id:
                    with output_area:
                        print("Please enter a PDB ID")
                    return

                output_area.clear_output(wait=True)
                with output_area:
                    print(f"Downloading PDB ID: {pdb_id}")

                    try:
                        url = f"https://files.rcsb.org/view/{pdb_id}.pdb"
                        response = requests.get(url)

                        if response.status_code == 200:
                            with open("system.pdb", "w") as f:
                                f.write(response.text)
                            print(f"PDB {pdb_id} downloaded and saved as system.pdb")
                        else:
                            print(f"Failed to download PDB {pdb_id}. Status code: {response.status_code}")
                            print("Please check the PDB ID and try again.")

                    except Exception as e:
                        print(f"Error downloading PDB: {str(e)}")

            download_btn.on_click(download_pdb)
            display(widgets.HBox([pdb_id_input, download_btn]))

        elif change['new'] == 'example':
            example_btn = widgets.Button(
                description="Load 2HU4 Example",
                button_style='warning',
                icon='flask'
            )

            def load_example(b):
                output_area.clear_output(wait=True)
                with output_area:
                    print("Loading 2HU4 example (neuraminidase bound to oseltamivir)...")
                    print("Removing oseltamivir carboxylate fragment...")

                    try:
                        # Get the PDB text of 2HU4
                        url = "https://files.rcsb.org/view/2HU4.pdb"
                        response = requests.get(url)

                        if response.status_code == 200:
                            pdb_txt = response.text
                        else:
                            raise RuntimeError(f"Failed to download file. Status code: {response.status_code}")

                        # Load into MDAnalysis
                        u = mda.Universe(StringIO(pdb_txt), format="PDB")

                        # Remove the oseltamivir carboxylate from this pdb (O1A, O1B, and C1 from resname G39)
                        atoms_to_remove = u.select_atoms("(resname G39) and (name O1A O1B C1)")

                        # Create a new selection that excludes them
                        atoms_to_keep = u.atoms.difference(atoms_to_remove)

                        # Write the filtered structure to system.pdb
                        with mda.Writer("system.pdb", atoms_to_keep.n_atoms) as W:
                            W.write(atoms_to_keep)

                        print("2HU4 example loaded with fragment removed, saved as system.pdb")

                    except Exception as e:
                        print(f"Error loading example: {str(e)}")

            example_btn.on_click(load_example)
            display(example_btn)

# Set up the initial display
input_method.observe(on_method_change, names='value')

# Display the interface
display(input_method)
display(instructions_area)
display(output_area)

# Trigger initial display
on_method_change({'new': input_method.value})

In [None]:
# @title Select the ligand
# @markdown This cell analyzes the loaded PDB file to automatically identify potential
# @markdown ligand molecules, presenting them in a dropdown menu for your selection. A 3D
# @markdown viewer is displayed to help visualize your choice within the protein. Upon
# @markdown confirmation, the selected ligand and the protein are saved as separate files
# @markdown for the next step.

import ipywidgets as widgets
from IPython.display import display, clear_output
import MDAnalysis as mda
from io import StringIO
import py3Dmol
from rdkit import Chem

# Load system.pdb
with open("system.pdb") as f:
    system_pdb = f.read()

# Load into MDAnalysis
prot_lig = mda.Universe(StringIO(system_pdb), format="PDB")

# Select only protein atoms
protein = prot_lig.select_atoms("protein")

# Get all protein residues (for exclusion)
protein_resids = set(protein.residues.resids)
protein_resnames = set(protein.residues.resnames)

# Find all residues that are:
# - Not named HOH
# - Not among protein residue names
# - Have at least 6 atoms
ligand_residues = [
    res for res in prot_lig.residues
    if res.atoms.n_atoms >= 6 and res.resname != "HOH" and res.resname not in protein_resnames
]

# Save protein output
protein.write("protein_only.pdb")

# Check if any ligands were found
if not ligand_residues:
    raise ValueError("No suitable ligands (≥6 atoms) found.")

# Create dropdown options with descriptive labels
ligand_options = [
    f"{res.resname} (resid {res.resid}, chain {res.segid.strip() or 'A'}, {res.atoms.n_atoms} atoms)"
    for res in ligand_residues
]

# Create widgets
dropdown = widgets.Dropdown(
    options=ligand_options,
    value=ligand_options[0],
    description='Ligand:',
    style={'description_width': 'initial'},
    layout={'width': '500px'}
)
out = widgets.Output()

# Initialize global variables
ligand = None
ligand_resname = None
ligand_resid = None
ligand_chain = None
atom_options = None
atom_coords = None

def update_ligand_display(ligand_label):
    global ligand, ligand_resname, ligand_resid, ligand_chain, atom_options, atom_coords

    # Find the index of the selected ligand
    idx = ligand_options.index(ligand_label)
    selected_ligand = ligand_residues[idx]

    # Update global variables
    ligand = selected_ligand
    ligand_resname = selected_ligand.resname
    ligand_resid = selected_ligand.resid
    ligand_chain = selected_ligand.segid.strip() or "A"

    # Update atom options and coordinates
    atom_options = [atom.name for atom in ligand.atoms]
    atom_coords = {atom.name: atom.position for atom in ligand.atoms}

    # Save the selected ligand
    ligand.atoms.write("selected_ligand.pdb")

    # You must also save it as an sdf file for use with deepfrag. Use rdkit
    # to convert from PDB block and save to SDF file.
    mol = Chem.MolFromPDBFile("selected_ligand.pdb")

    # Save to an SDF file
    writer = Chem.SDWriter('selected_ligand.sdf')
    writer.write(mol)
    writer.close()

    # Create 3D visualization
    view = py3Dmol.view(width=600, height=500)
    view.addModel(system_pdb, "pdb")

    # Protein as ribbon
    view.setStyle({}, {'cartoon': {'color': 'spectrum'}})
    view.addStyle({}, {'stick': {'colorscheme': 'lightgrayCarbon', 'radius': 0.2}})

    # Selected ligand as highlighted sticks
    view.addStyle(
        {'resn': ligand_resname, 'resi': str(ligand_resid)},
        {'stick': {'colorscheme': 'greenCarbon', 'radius': 0.35}}
    )

    # Focus camera on the ligand
    view.zoomTo({'resn': ligand_resname, 'resi': str(ligand_resid)})

    # Print ligand information
    print(f"\nSelected ligand: {ligand_resname} (resid {ligand_resid}, chain {ligand_chain})\n")
    # print(f"Number of atoms: {ligand.atoms.n_atoms}")
    # print(f"Atom names: {', '.join(atom_options)}")

    return view

def on_change(change):
    if change['type'] == 'change' and change['name'] == 'value':
        with out:
            clear_output()
            updated_view = update_ligand_display(change['new'])
            updated_view.show()

# Initialize with first ligand
with out:
    initial_view = update_ligand_display(dropdown.value)
    initial_view.show()

# Link callback
dropdown.observe(on_change, names='value')

# Display widgets
print(f"\nFound {len(ligand_residues)} potential ligands. Please select one:")
display(dropdown, out)

In [None]:
# @title Select the ligand atom to serve as the branching point
# @markdown This cell creates a dropdown menu listing all atoms of the previously selected
# @markdown ligand, allowing you to choose the branching point for fragment generation. As
# @markdown you select an atom, the 3D visualization updates to highlight your choice with
# @markdown a yellow sphere, marking the precise location where the new fragment will be
# @markdown grown.

import ipywidgets as widgets
from IPython.display import display, clear_output
import MDAnalysis as mda
from io import StringIO
import py3Dmol

# Widgets
dropdown = widgets.Dropdown(
    options=atom_options,
    value=atom_options[0],
    description='Atom:',
)
out = widgets.Output()

branch_atm_loc_xyz = None
def update_view(atom_name):
    global branch_atm_loc_xyz
    coords = atom_coords[atom_name]
    x, y, z = coords
    branch_atm_loc_xyz = coords

    view = py3Dmol.view(width=600, height=500)
    view.addModel(system_pdb, "pdb")

    # Protein as ribbon
    view.setStyle({}, {'cartoon': {'color': 'spectrum'}})
    view.addStyle({}, {'stick': {'colorscheme': 'lightgrayCarbon', 'radius': 0.2}})

    # Ligand as green sticks
    view.addStyle(
        {'resn': ligand_resname, 'resi': str(ligand_resid)},
        {'stick': {'colorscheme': 'greenCarbon', 'radius': 0.35}}
    )

    # Labels for all ligand atoms
    for atom in ligand.atoms:
        pos = atom.position
        view.addLabel(atom.name, {
            'position': {'x': float(pos[0]), 'y': float(pos[1]), 'z': float(pos[2])},
            'backgroundColor': 'black',
            'fontColor': 'white',
            'fontSize': 14
        })

    # Add sphere at selected atom
    view.addSphere({
        'center': {'x': float(x), 'y': float(y), 'z': float(z)},
        'radius': 1.0,
        'color': 'yellow',
        'alpha': 0.9
    })

    # Focus camera on the ligand
    view.zoomTo({'resn': ligand_resname, 'resi': str(ligand_resid)})

    # Inform user of atom position
    print(f"\nPosition of atom {atom_name}: {x:.2f}, {y:.2f}, {z:.2f}\n")

    return view

def on_change(change):
    if change['type'] == 'change' and change['name'] == 'value':
        with out:
            clear_output()
            updated_view = update_view(change['new'])
            updated_view.show()

# Initial display
with out:
    initial_view = update_view(dropdown.value)
    initial_view.show()

# Link callback
dropdown.observe(on_change)

# Display widgets
display(dropdown, out)

In [None]:
# @title Select model type and label set
# @markdown This cell provides a dropdown menu to select the desired DeepFrag2 model
# @markdown (specialized for different chemical properties like aromaticity or size). It
# @markdown also allows the user to select the corresponding set of molecular fragments to
# @markdown use. Based on your selections, the script automatically downloads the required
# @markdown pre-trained model and SMILES files or provides an interface for you to upload
# @markdown a custom fragment library.

import ipywidgets as widgets
from IPython.display import display, HTML, Markdown, clear_output
import os
import urllib.request
import time
from google.colab import files
import re

model_local_path = None
smi_local_path = None
custom_smi_path = None

# Model options with custom display labels and descriptions
model_options = [
    ('All-fragment model', 'all'),
    ('Large-fragment (>=4) model', 'gte4'),
    ('Small-fragment (<=3) model', 'lte3'),
    ('Aromatic-fragment model', 'aromatic'),
    ('Aliphatic-fragment model', 'aliphatic'),
    ('Acidic-fragment model', 'acid'),
    ('Basic-fragment model', 'base')
]

model_descriptions = {
    'all': 'Trained on all molecular fragments regardless of size or chemical properties. This is the most general model and works well for diverse molecular structures.',
    'gte4': 'Specialized for large molecular fragments containing 4 or more heavy atoms. Best for generating substantial structural additions to molecules.',
    'lte3': 'Optimized for small molecular fragments with 3 or fewer heavy atoms. Ideal for fine-tuning molecules with small chemical modifications.',
    'aromatic': 'Trained specifically on aromatic ring-containing fragments. Excellent for generating aromatic modifications and ring systems.',
    'aliphatic': 'Focused on non-aromatic (aliphatic) fragments including linear and branched carbon chains. Best for generating alkyl modifications.',
    'acid': 'Specialized for acidic functional groups and fragments (carboxylic acids, sulfonic acids, etc.). Useful for introducing acidic moieties.',
    'base': 'Trained on basic fragments containing nitrogen-based functional groups (amines, amides, etc.). Ideal for introducing basic moieties.'
}

# Label set options with custom display labels and descriptions
label_options = [
    ('Train/Val/Test-set fragments', 'all'),
    ('Test-set fragments', 'test'),
    ('Custom label set', 'custom')
]

label_descriptions = {
    'all': 'Fragments derived from the training, validation, and testing-set examples associated with this model. Fragment physicochemical properties match those the model was trained on.',
    'test': 'Fragments derived only from the testing set associated with this model. Fragment physicochemical properties match those the model was trained on.',
    'custom': 'Using your uploaded custom SMILES file. Allows targeted generation based on your specific chemical space once the file is provided.'
}

# Base URL
base_url = 'https://durrantlab.pitt.edu/apps/deepfrag2/models/'

# Local directory to save files
local_dir = 'deepfrag2_models/'

# Create dropdown widgets
model_dropdown = widgets.Dropdown(
    options=model_options,
    value='all',
    description='Model:',
    layout=widgets.Layout(width='350px')
)

label_dropdown = widgets.Dropdown(
    options=label_options,
    value='all',
    description='Label set:',
    layout=widgets.Layout(width='350px')
)

# Button for custom SMILES upload
custom_smi_upload_button = widgets.Button(
    description="Upload Custom SMILES",
    button_style='info',
    icon='upload',
    layout=widgets.Layout(width='300px', display='none'),
    tooltip='Click to upload a custom SMILES file (.smi or .txt)'
)

# Output widgets
description_output = widgets.Output()
custom_instructions_output = widgets.Output()
output = widgets.Output() # For general messages about downloads, file paths etc.
upload_status_output = widgets.Output() # For messages specifically from the custom upload button action

# Create directory
if not os.path.exists(local_dir):
    os.makedirs(local_dir)

def show_custom_instructions():
    with custom_instructions_output:
        custom_instructions_output.clear_output()
        instructions = """
Your custom SMILES file should be a text file (`.smi` or `.txt`) containing the
molecular fragments, one per line. Use an asterisk to mark the atom where each
fragment should connect to the parent molecule (i.e., the asterisk represents
the R-group that is the parent molecule). Here is an example:

<br />

```
*C1CC(CO)C(=O)C1O
*c1cnc(C)cn1
*O
```
<br />"""
        display(Markdown(instructions))

def hide_custom_instructions():
    with custom_instructions_output:
        custom_instructions_output.clear_output()

def construct_paths(model_type, label_set):
    if model_type == 'gte4': model_url_part = 'gte_4'
    elif model_type == 'lte3': model_url_part = 'lte_3'
    else: model_url_part = model_type

    if model_type == 'all':
        ckpt_url = f"{base_url}all_best.ckpt"
        ckpt_local = f"{local_dir}all_best.ckpt"
    elif model_type in ['aromatic', 'aliphatic', 'acid', 'base']:
        ckpt_url = f"{base_url}gte_4_{model_type}_best.ckpt"
        ckpt_local = f"{local_dir}gte_4_{model_type}_best.ckpt"
    else:
        ckpt_url = f"{base_url}{model_url_part}_best.ckpt"
        ckpt_local = f"{local_dir}{model_url_part}_best.ckpt"

    smi_url, smi_local = None, None
    if label_set != 'custom':
        actual_model_prefix_for_smi = model_url_part
        if model_type == 'all': actual_model_prefix_for_smi = 'all'
        elif model_type in ['aromatic', 'aliphatic', 'acid', 'base']: actual_model_prefix_for_smi = f'gte_4_{model_type}'
        smi_set_part = 'all_sets_frags.smi' if label_set == 'all' else 'test_set_frags.smi'
        smi_filename = f"{actual_model_prefix_for_smi}_{smi_set_part}"
        smi_url = f"{base_url}{smi_filename}"
        smi_local = f"{local_dir}{smi_filename}"
    return ckpt_url, ckpt_local, smi_url, smi_local

class DownloadProgressBar:
    def __init__(self, description="Downloading"):
        self.description = description
        self.start_time = None
        self.last_update = None

    def __call__(self, block_num, block_size, total_size):
        if self.start_time is None:
            self.start_time = time.time()
            self.last_update = self.start_time
        current_time = time.time()
        downloaded = block_num * block_size
        if current_time - self.last_update > 0.5 or downloaded >= total_size:
            self.last_update = current_time
            if total_size > 0:
                percent = min(downloaded / total_size * 100, 100)
                elapsed = current_time - self.start_time
                speed = downloaded / elapsed / 1024 / 1024 if elapsed > 0 else 0
                if downloaded >= total_size:
                    print(f"{self.description}: 100% | {downloaded/1024/1024:.1f}/{total_size/1024/1024:.1f} MB | {speed:.1f} MB/s | Complete")
                else:
                    print(f"{self.description}: {percent:.1f}% | {downloaded/1024/1024:.1f}/{total_size/1024/1024:.1f} MB | {speed:.1f} MB/s", end='\r')
            elif total_size == 0:
                 print(f"{self.description}: {downloaded/1024/1024:.1f} MB downloaded", end='\r')

def download_if_needed(url, local_path, description):
    if not os.path.exists(local_path):
        progress_bar = DownloadProgressBar(description=f"Downloading {os.path.basename(local_path)}")
        try:
            urllib.request.urlretrieve(url, local_path, reporthook=progress_bar)
            display(HTML(f'&nbsp;&nbsp;&nbsp;<a href="{url}" target="_blank">{url}</a>'))
            return True
        except Exception as e:
            print(f"\nError downloading {description} from {url}: {e}")
            if os.path.exists(local_path): os.remove(local_path)
            return False
    else:
        return True

def on_custom_smi_upload_button_clicked(b):
    global custom_smi_path, model_local_path, smi_local_path

    upload_status_output.clear_output(wait=True)
    with output:
        output.clear_output(wait=True) # Clears main output before new upload process

    with upload_status_output:
        print("Please select your custom SMILES file (.smi or .txt)...")
        print("NOTE: Because of Colab limitations, you must re-run this cell if you wish to switch back to non-custom label sets.")

    try:
        uploaded = files.upload()

        if not uploaded:
            with upload_status_output:
                upload_status_output.clear_output(wait=True)
                print("Upload cancelled or no file selected.")
            return

        colab_root_filename = next(iter(uploaded))
        file_content = uploaded[colab_root_filename]

        name_part, ext_part = os.path.splitext(colab_root_filename)
        stripped_name_part = re.sub(r'\s*\(\d+\)$', '', name_part)
        original_filename_for_custom_prefix = stripped_name_part + ext_part

        target_filename_in_local_dir = f"custom_{original_filename_for_custom_prefix}"
        new_custom_smi_path = os.path.join(local_dir, target_filename_in_local_dir)

        with open(new_custom_smi_path, 'wb') as f:
            f.write(file_content)
        custom_smi_path = new_custom_smi_path

        try:
            if os.path.exists(colab_root_filename): # Path in /content/
                os.remove(colab_root_filename)
        except Exception as e_remove:
            print(f"NOTE: Could not remove temporary uploaded file '{colab_root_filename}' from Colab root: {e_remove}")

        with output:
            print(f"\nCustom SMILES file '{original_filename_for_custom_prefix}' uploaded.")
            print(f"  Saved to: {custom_smi_path}")

            model_type = model_dropdown.value
            ckpt_url, ckpt_local, _, _ = construct_paths(model_type, 'custom')
            model_local_path = ckpt_local

            print(f"\nDownloading model file (if needed for '{model_type}' model)...")
            ckpt_downloaded = download_if_needed(ckpt_url, ckpt_local, "Checkpoint file")

            if ckpt_downloaded:
                print(f"\n📁 Files ready for use:")
                print(f"   Model: {model_local_path}")
                print(f"   SMILES: {custom_smi_path}\n")
                smi_local_path = custom_smi_path
                update_descriptions()
            else:
                print(f"\nModel download failed. Descriptions may not reflect the latest custom file.")

        upload_status_output.clear_output(wait=True)

    except Exception as e:
        with output:
            print(f"\nAn error occurred during file upload or processing: {e}")
        upload_status_output.clear_output(wait=True)

custom_smi_upload_button.on_click(on_custom_smi_upload_button_clicked)

def update_descriptions():
    with description_output:
        description_output.clear_output()
        model_type = model_dropdown.value
        label_set = label_dropdown.value
        model_display = [label for label, val in model_options if val == model_type][0]
        label_display = [label for label, val in label_options if val == label_set][0]
        print(f"\nSELECTED MODEL: {model_display}")
        print(f"DESCRIPTION: {model_descriptions.get(model_type, 'N/A')}")
        print()
        print(f"SELECTED LABEL SET: {label_display}")
        print(f"DESCRIPTION: {label_descriptions.get(label_set, 'N/A')}")

def clear_descriptions():
    with description_output:
        description_output.clear_output()

def on_dropdown_change(change):
    global model_local_path, smi_local_path, custom_smi_path
    if change['type'] == 'change' and change['name'] == 'value':
        # Clear all relevant output areas at the beginning of any dropdown change
        clear_descriptions()
        output.clear_output(wait=True) # Explicitly clear main output
        upload_status_output.clear_output(wait=True) # Clear upload-specific messages

        model_type = model_dropdown.value
        label_set = label_dropdown.value

        if label_set == 'custom':
            custom_smi_upload_button.layout.display = 'block'
            show_custom_instructions()

            if custom_smi_path and os.path.exists(custom_smi_path):
                with output:
                    print(f"\nNOTE: Using previously uploaded custom SMILES file: {os.path.basename(custom_smi_path)}")
                    print(f"  Path: {custom_smi_path}")
                    ckpt_url, ckpt_local, _, _ = construct_paths(model_type, 'custom')
                    model_local_path = ckpt_local
                    print(f"\nDownloading model file (if needed for '{model_type}' model)...")
                    ckpt_downloaded = download_if_needed(ckpt_url, ckpt_local, "Checkpoint file")
                    if ckpt_downloaded:
                        print(f"\nFiles ready for use:")
                        print(f"   Model: {model_local_path}")
                        print(f"   SMILES: {custom_smi_path}\n")
                        update_descriptions() # Update descriptions for custom + model
                    else:
                         print(f"\nModel download failed for existing custom SMILES.")
            else: # No valid custom_smi_path exists
                if custom_smi_path and not os.path.exists(custom_smi_path): # Path existed but file is gone
                     with output: print(f"\nPreviously uploaded custom file not found at {custom_smi_path}.")
                custom_smi_path = None # Ensure path is cleared if file is not there
                with output:
                    # Only print upload prompt if no valid custom file is being processed
                    if not (custom_smi_path and os.path.exists(custom_smi_path)):
                         print("\n  Please upload a custom SMILES file using the button above.")
        else: # Not a 'custom' label set
            custom_smi_upload_button.layout.display = 'none'
            hide_custom_instructions()
            # custom_smi_path = None # Decide if you want to forget custom_smi_path when switching away
            update_descriptions() # Update descriptions for non-custom model/label
            with output:
                ckpt_url, ckpt_local, smi_url, smi_local_val = construct_paths(model_type, label_set)
                model_local_path = ckpt_local
                smi_local_path = smi_local_val
                print("\nDownloading model and SMILES files (if needed)...")
                ckpt_downloaded = download_if_needed(ckpt_url, ckpt_local, "Checkpoint file")
                smi_downloaded = False
                if smi_url: smi_downloaded = download_if_needed(smi_url, smi_local_path, "SMILES file")

                if ckpt_downloaded and (smi_downloaded or not smi_url):
                    print(f"\nFiles ready for use:")
                    print(f"   Model: {model_local_path}")
                    if smi_local_path: print(f"   SMILES: {smi_local_path}\n")
                elif not ckpt_downloaded: print(f"\nModel download failed.")
                elif not smi_downloaded and smi_url: print(f"\nSMILES file download failed.")

model_dropdown.observe(on_dropdown_change, names='value')
label_dropdown.observe(on_dropdown_change, names='value')

control_container = widgets.VBox([
    widgets.HBox([model_dropdown, label_dropdown]),
    custom_instructions_output,
    custom_smi_upload_button,
    upload_status_output
])

display(control_container)
display(description_output)
display(output)

on_dropdown_change({'type': 'change', 'name': 'value', 'old': None, 'new': model_dropdown.value})

In [None]:
# @title Specify key DeepFrag2 parameters
# @markdown Specify the number of fragment predictions DeepFrag2 should predict
num_inference_predictions = 10 # @param {"type":"number","placeholder":"10"}

In [None]:
# @title Run DeepFrag2
# @markdown This cell executes the core DeepFrag2 prediction script, assembling all the
# @markdown previously defined inputs (e.g., the protein, ligand, branching atom, and
# @markdown selected model) into a final command. The script then runs to generate the new
# @markdown molecular fragments.

import subprocess
import os

# Clean up previous results
!rm -rf predictions_Single_Complex

# Debug information if you need it
# print(model_local_path)
# print(smi_local_path)
# print(branch_atm_loc_xyz)

coord_str = f"{branch_atm_loc_xyz[0]},{branch_atm_loc_xyz[1]},{branch_atm_loc_xyz[2]}"

# Here we must use the conda-installed python, not the default colab python.
cmd = [
    "/usr/local/miniconda/envs/deepfrag2/bin/python3.9",
    "deepfrag2/MainDF2.py",
    "--mode=inference_single_complex",
    f"--load_checkpoint={model_local_path}",
    f"--default_root_dir={os.getcwd()}/",
    f"--inference_label_sets={smi_local_path}",
    "--rotations=1",
    "--receptor=./protein_only.pdb",
    "--ligand=./selected_ligand.sdf",
    f"--branch_atm_loc_xyz={coord_str}",
    f"--num_inference_predictions={num_inference_predictions}",
    "--cpu"
]

# Let the user know re. the command to be run
print("Running command:\n\n" + " /\n    ".join(cmd))

# Attempt to run DeepFrag2
try:
    result = subprocess.run(cmd, check=True, capture_output=True, text=True)
    print(result.stdout)
except subprocess.CalledProcessError as e:
    print("STDOUT:\n", e.stdout)
    print("STDERR:\n", e.stderr)

In [None]:
# @title Visualize results (structures and table)
# @markdown This cell parses the output file from the DeepFrag2 run to extract the
# @markdown generated molecular fragments and their corresponding scores. The results are
# @markdown then presented in two complementary formats: a table that ranks each fragment
# @markdown by its score, and a visual grid that displays the 2D structure of each
# @markdown molecule for easy comparison.

import glob
import numpy as np
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import AllChem
import pandas as pd
from PIL import Image
import io
import matplotlib.pyplot as plt
from IPython.display import display, HTML

# Get the data from the output
out_file = glob.glob("predictions_Single_Complex/*.results/*.tsv")[0]
print(f"Loading data from: {out_file}")

with open(out_file) as f:
    lines = f.readlines()
smis_infos = [l.strip().split() for l in lines if "*" in l]
print(f"\nFound {len(smis_infos)} fragments\n")

# Build DataFrame
df = pd.DataFrame(smis_infos, columns=["SMILES", "Score"])
df["Score"] = df["Score"].astype(float)

# Sort DataFrame by Score in descending order
df = df.sort_values(by="Score", ascending=False).reset_index(drop=True)

# Style the DataFrame for better display
styled_df = df.style.format({
    "Score": "{:.3f}"
}).set_properties(**{
    'text-align': 'left',
    'font-size': '12pt',
}).set_table_styles([
    {'selector': 'th', 'props': [('font-size', '14pt'), ('text-align', 'left'), ('background-color', '#f0f0f0')]},
    {'selector': 'caption', 'props': [('font-size', '16pt'), ('font-weight', 'bold')]}
]).set_caption("SMILES Structures Ranked by Score")

# Display the styled DataFrame
display(styled_df)

# Convert SMILES to molecules with improved parsing
mols = []
smiles = []
scores = []

for i, row in df.iterrows():
    smi = row["SMILES"]
    # Desalt smi if necessary
    if "." in smi:
        smi = [s for s in smi.split(".") if "*" in s][0]

    score = row["Score"]

    # Try multiple parsing approaches
    mol = None

    # Approach 1: Try with default settings
    mol = Chem.MolFromSmiles(smi)

    # Approach 2: Try without sanitization
    if mol is None:
        try:
            print(f"Trying to parse without sanitization: {smi}")
            mol = Chem.MolFromSmiles(smi, sanitize=False)

            # Perform partial sanitization
            Chem.SanitizeMol(mol, Chem.SanitizeFlags.SANITIZE_FINDRADICALS |
                              Chem.SanitizeFlags.SANITIZE_KEKULIZE |
                              Chem.SanitizeFlags.SANITIZE_SETAROMATICITY |
                              Chem.SanitizeFlags.SANITIZE_SETCONJUGATION |
                              Chem.SanitizeFlags.SANITIZE_SETHYBRIDIZATION |
                              Chem.SanitizeFlags.SANITIZE_SYMMRINGS,
                              catchErrors=True)
        except Exception as e:
            print(f"Partial sanitization failed: {e}")

    # Approach 3: Try with modified SMILES for phosphates
    if mol is None and 'P' in smi:
        try:
            # Replace problematic phosphate patterns
            fixed_smi = smi.replace('PH', 'P')  # Remove explicit H on P
            print(f"Trying modified phosphate SMILES: {fixed_smi}")
            mol = Chem.MolFromSmiles(fixed_smi, sanitize=False)
        except Exception as e:
            print(f"Modified P SMILES failed: {e}")

    if mol is not None:
        # Generate atom coordinates if needed
        try:
            if mol.GetNumConformers() == 0:
                AllChem.Compute2DCoords(mol)
        except Exception as e:
            print(f"Could not compute coordinates: {e}")
            continue

        mols.append(mol)
        smiles.append(smi)
        scores.append(score)
    else:
        print(f"Warning: Could not parse SMILES after all attempts: {smi}")

# print(f"Successfully converted {len(mols)} molecules")

# Create a proper matplotlib figure and axes array
n_rows = (len(mols) + 3) // 4  # Calculate number of rows (ceiling division)

# Create subplots
fig, axes = plt.subplots(n_rows, 4, figsize=(16, 4 * n_rows))

# Handle case with 4 or fewer molecules
if n_rows == 1:
    axes = np.array([axes]).reshape(1, -1)  # Make it 2D array with one row

# Draw each molecule with highlighting on the atom with asterisk (*)
for i, (mol, smi, score) in enumerate(zip(mols, smiles, scores)):
    row_idx = i // 4
    col_idx = i % 4

    # Access the current axis correctly
    current_ax = axes[row_idx, col_idx]

    # Find asterisk atoms to highlight
    highlight_atoms = [atom.GetIdx() for atom in mol.GetAtoms() if atom.GetSymbol() == "*"]

    # Draw the molecule
    try:
        # Try with highlights using MolToImage
        if highlight_atoms:
            mol_img = Draw.MolToImage(
                mol,
                size=(300, 250),
                highlightAtoms=highlight_atoms,
                highlightColor=(0, 1, 0)  # Green
            )
        else:
            mol_img = Draw.MolToImage(mol, size=(300, 250))
    except Exception as e:
        print(f"Warning: Could not draw molecule {i}: {e}")
        # Create a blank white image as fallback
        mol_img = Image.new('RGB', (300, 250), color='white')

    # Display in the grid
    current_ax.imshow(mol_img)
    current_ax.set_xlabel(f"{smi}\nScore: {score:.3f}", fontsize=18)
    current_ax.set_xticks([])
    current_ax.set_yticks([])

    # Remove border
    for spine in current_ax.spines.values():
        spine.set_visible(False)

# Hide unused subplots
for i in range(len(mols), 4 * n_rows):
    row_idx = i // 4
    col_idx = i % 4
    axes[row_idx, col_idx].axis('off')

plt.tight_layout()
plt.show()

# Display summary after the plot
print("\n" + "="*80)
print("SUMMARY")
print("="*80)
print(f"Total molecules processed: {len(df)}")
print(f"Successfully visualized: {len(mols)}")
print(f"Top scoring molecule: {df.iloc[0]['SMILES']} (Score: {df.iloc[0]['Score']:.3f})")

# Optionally, save the DataFrame to a CSV file
# df.to_csv("molecule_scores.csv", index=False)