# FlOWR.root Tutorial: Pocket-Conditional Ligand Generation and Affinity Prediction

This tutorial demonstrates how to use the **FLOWR.root** model for various ligand generation tasks:

1. **De Novo Generation** - Generate ligands from scratch for a protein pocket
2. **Scaffold-Conditional Generation (Scaffold Hopping)** - Generate new scaffolds while preserving functional groups
3. **Scaffold Elaboration** - Generate new functional groups / R-groups while preserving the molecular scaffold
4. **Core Growing** - Grow around a selected ring system core (with ring system index selection)
5. **Fragment Growing** - Grow from a given fragment with optional size control and prior center
6. **Substructure Inpainting** - Regenerate specific parts of a ligand while preserving others
7. **Affinity Prediction** - Predict binding affinities for generated ligands

**Additional options** applicable to conditional generation modes:
- `--anisotropic_prior`: Use an anisotropic Gaussian prior shaped by the molecular geometry (recommended for inpainting modes)
- `--ref_ligand_com_prior`: Center the random prior at the reference ligand's center of mass

## 1. Setup and Configuration

First, let's import the required libraries and define our global configuration variables.

In [None]:
# Import required libraries
import subprocess
import glob
from pathlib import Path
from typing import Union, List, Optional

# Molecular visualization and manipulation
from rdkit import Chem
from rdkit.Chem import Draw, AllChem
from rdkit.Chem.Draw import rdMolDraw2D
import py3Dmol

# Data handling and visualization
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, SVG

# Set plotting style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

print("✓ All libraries imported successfully!")

In [None]:
# =============================================================================
#                         GLOBAL CONFIGURATION
# =============================================================================

# Target protein
TARGET = "ptp1b"

# Paths
ROOT_DIR = Path(".").resolve().parent.parent.parent  # flowr_root directory
EXAMPLES_DIR = ROOT_DIR / "examples"

# Input files
PROTEIN_PDB = EXAMPLES_DIR / f"{TARGET}_protein.pdb"
LIGANDS_SDF = EXAMPLES_DIR / f"{TARGET}_ligands.sdf"

# Model checkpoint - UPDATE THIS PATH!
CKPT_PATH = "YOUR_CKPT_PATH/flowr_root_v2.1.ckpt"

# Output base directory
SAVE_DIR = EXAMPLES_DIR / "tutorial_outputs" / TARGET
SAVE_DIR.mkdir(parents=True, exist_ok=True)

# Generation parameters
COORD_NOISE_STD = 0.1 # How much noise added to generation to increae diversity
POCKET_CUTOFF = 6 # Determines size of the pocket, in Angstrom
LIGAND_IDX = 0  # Reference ligand index for conditional generation (if your SDF contains several ligands)
N_MOLECULES = 5  # Number of molecules to generate
BATCH_COST = 20 # Modify depending on your compute to avoud out-of-memory errors (SET TO 1 or 2 IF YOU RUN THIS LOCALLY ON A MAC!)
NUM_GPUS = 1 # NVIDIA GPU with at least 40GB VRAM recommended (otherwise reduce BATCH_COST)
NUM_WORKERS = 12 # Number of CPU workers for data loading (SET TO 0 IF YOU RUN THIS LOCALLY ON A MAC!)

print(f"{'='*60}")
print("   FlOWR.root Tutorial Configuration")
print(f"{'='*60}")
print(f"  Target:           {TARGET.upper()}")
print(f"  Protein PDB:      {PROTEIN_PDB}")
print(f"  Ligands SDF:      {LIGANDS_SDF}")
print(f"  Checkpoint:       {CKPT_PATH}")
print(f"  Output directory: {SAVE_DIR}")
print(f"  Reference ligand: index {LIGAND_IDX}")
print(f"  Molecules/target: {N_MOLECULES}")
print(f"{'='*60}")

In [None]:
# =============================================================================
# HELPER FUNCTIONS
# =============================================================================

def load_molecules_from_sdf(sdf_path):
    """Load molecules from an SDF file."""
    supplier = Chem.SDMolSupplier(str(sdf_path), removeHs=False)
    mols = [mol for mol in supplier if mol is not None]
    print(f"Loaded {len(mols)} molecules from {sdf_path}")
    return mols


def visualize_molecules_2d(mols, legends=None, mols_per_row=4, img_size=(300, 300)):
    """Create a 2D grid visualization of molecules."""
    if legends is None:
        legends = [f"Mol {i+1}" for i in range(len(mols))]
    
    img = Draw.MolsToGridImage(
        mols, 
        molsPerRow=mols_per_row, 
        subImgSize=img_size,
        legends=legends
    )
    return img


def visualize_ligand_3d(mol, style='stick', width=600, height=400):
    """Visualize a single ligand in 3D using py3Dmol."""
    viewer = py3Dmol.view(width=width, height=height)
    
    # Convert mol to MolBlock
    mol_block = Chem.MolToMolBlock(mol)
    viewer.addModel(mol_block, 'mol')
    
    if style == 'stick':
        viewer.setStyle({'stick': {'colorscheme': 'greenCarbon'}})
    elif style == 'sphere':
        viewer.setStyle({'sphere': {'scale': 0.3}})
    
    viewer.zoomTo()
    return viewer


def visualize_ligand_in_pocket(ligand_mol, protein_pdb_path, width=800, height=600):
    """Visualize ligand within the protein pocket with enhanced styling."""
    viewer = py3Dmol.view(width=width, height=height)
    
    # Add protein
    with open(protein_pdb_path, 'r') as f:
        protein_data = f.read()
    viewer.addModel(protein_data, 'pdb')
    
    # Protein backbone as cartoon with spectrum coloring
    viewer.setStyle({'model': 0}, {
        'cartoon': {
            'color': 'spectrum',
            'opacity': 0.85,
            'thickness': 0.4
        }
    })
    
    # Highlight binding site residues (within 5Å of ligand)
    viewer.addStyle(
        {'model': 0, 'within': {'distance': 5, 'sel': {'model': 1}}},
        {
            'stick': {
                'colorscheme': 'whiteCarbon',
                'radius': 0.15
            }
        }
    )
    
    # Add transparent surface for binding pocket
    viewer.addSurface(
        py3Dmol.VDW,
        {
            'opacity': 0.25,
            'color': 'white',
            'wireframe': False
        },
        {'model': 0, 'within': {'distance': 6, 'sel': {'model': 1}}}
    )
    
    # Add ligand
    mol_block = Chem.MolToMolBlock(ligand_mol)
    viewer.addModel(mol_block, 'mol')
    
    # Ligand as ball-and-stick with vibrant colors
    viewer.setStyle({'model': 1}, {
        'stick': {
            'colorscheme': 'cyanCarbon',
            'radius': 0.2
        },
        'sphere': {
            'colorscheme': 'cyanCarbon',
            'scale': 0.25
        }
    })
    
    # Add labels for ligand atoms (optional - heavy atoms only)
    viewer.addLabel(
        "Ligand",
        {
            'fontSize': 12,
            'fontColor': 'white',
            'backgroundColor': 'rgba(0,150,150,0.7)',
            'backgroundOpacity': 0.7
        },
        {'model': 1, 'atom': 'C', 'serial': 1}
    )
    
    # Set background gradient
    viewer.setBackgroundColor('0xffffff', 0.9)
    
    # Zoom and center on ligand with padding
    viewer.zoomTo({'model': 1}, 200)
    
    # Add outline for depth perception
    viewer.setViewStyle({'style': 'outline', 'color': 'black', 'width': 0.02})
    
    return viewer


def visualize_multiple_ligands_in_pocket(ligand_mols, protein_pdb_path, width=800, height=600, 
                                          show_labels=True, show_surface=True):
    """Visualize multiple ligands within the protein pocket with enhanced styling."""
    viewer = py3Dmol.view(width=width, height=height)
    
    # Add protein
    with open(protein_pdb_path, 'r') as f:
        protein_data = f.read()
    viewer.addModel(protein_data, 'pdb')
    
    # Protein styling - prominent with spectrum coloring
    viewer.setStyle({'model': 0}, {
        'cartoon': {
            'color': 'spectrum',
            'opacity': 1.0,
            'thickness': 0.5,
            'arrows': True
        }
    })
    
    # Carbon colorscheme names available in py3Dmol
    carbon_schemes = [
        'greenCarbon',
        'cyanCarbon', 
        'magentaCarbon',
        'orangeCarbon',
        'purpleCarbon',
        'blueCarbon',
        'yellowCarbon',
        'whiteCarbon',
        'pinkCarbon',
        'grayCarbon',
    ]
    
    num_ligands = min(len(ligand_mols), 10)
    
    # Add ligands with element-based coloring
    for i, mol in enumerate(ligand_mols[:num_ligands]):
        mol_block = Chem.MolToMolBlock(mol)
        viewer.addModel(mol_block, 'mol')
        
        model_idx = i + 1
        scheme = carbon_schemes[i % len(carbon_schemes)]
        
        # Standard element coloring with unique carbon color per ligand
        viewer.setStyle({'model': model_idx}, {
            'stick': {
                'colorscheme': scheme,
                'radius': 0.2
            },
            'sphere': {
                'colorscheme': scheme,
                'scale': 0.25
            }
        })
        
        # Add labels for each ligand
        if show_labels:
            viewer.addLabel(
                f"Ligand {i + 1}",
                {
                    'fontSize': 11,
                    'fontColor': 'white',
                    'backgroundColor': 'gray',
                    'backgroundOpacity': 0.85,
                    'borderRadius': 4,
                    'padding': 2
                },
                {'model': model_idx, 'serial': 1}
            )
    
    # Add binding site residues
    for i in range(num_ligands):
        viewer.addStyle(
            {'model': 0, 'within': {'distance': 4.5, 'sel': {'model': i + 1}}},
            {
                'stick': {
                    'colorscheme': 'default',
                    'radius': 0.12
                }
            }
        )
    
    # Add colored transparent surface around binding pocket
    if show_surface and num_ligands > 0:
        viewer.addSurface(
            py3Dmol.SAS,
            {
                'opacity': 0.6,
                'color': '0x4682b4',
            },
            {'model': 0, 'within': {'distance': 7, 'sel': {'model': 1}}}
        )
    
    viewer.setBackgroundColor('0xf0f0f0')
    viewer.zoomTo()
    
    return viewer

def visualize_mol_with_atom_indices(mol, size=(500, 400)):
    """Draw molecule with atom indices labeled."""
    mol_copy = Chem.Mol(mol)
    AllChem.Compute2DCoords(mol_copy)
    
    drawer = rdMolDraw2D.MolDraw2DSVG(size[0], size[1])
    opts = drawer.drawOptions()
    opts.addAtomIndices = True
    
    drawer.DrawMolecule(mol_copy)
    drawer.FinishDrawing()
    svg = drawer.GetDrawingText()
    display(SVG(svg))

def run_command(cmd, description=""):
    """Run a shell command and print output."""
    print(f"\n{'='*60}")
    print(f"Running: {description}")
    print(f"Command: {' '.join(cmd)}")
    print('='*60)
    
    result = subprocess.run(cmd, capture_output=True, text=True, cwd=str(ROOT_DIR))
    
    if result.returncode != 0:
        print(f"Error:\n{result.stderr}")
        raise RuntimeError(f"Command failed with return code {result.returncode}")
    
    print(result.stdout)
    print(f"✓ {description} completed successfully!")
    return result


print("✓ Helper functions defined!")

In [None]:
# Let's first examine the reference ligands in our dataset
ref_mols = load_molecules_from_sdf(LIGANDS_SDF)

# Display 2D structures
print(f"\nReference ligands for {TARGET.upper()}:")
visualize_molecules_2d(ref_mols[:8], mols_per_row=4)

## 2. De Novo Generation

De novo generation creates new ligands from scratch, using only the protein pocket structure as input. The model learns to generate molecules that are complementary to the binding site.

In [None]:
# =============================================================================
#                         DE NOVO GENERATION
# =============================================================================

DENOVO_SAVE_DIR = SAVE_DIR / "denovo"
DENOVO_SAVE_DIR.mkdir(parents=True, exist_ok=True)

denovo_cmd = [
    "python", "-m", "flowr.gen.generate_from_pdb",
    "--pdb_file", str(PROTEIN_PDB),
    "--ligand_file", str(LIGANDS_SDF),
    "--ligand_idx", str(LIGAND_IDX),
    "--arch", "pocket",
    "--pocket_type", "holo",
    "--cut_pocket",
    "--pocket_cutoff", str(POCKET_CUTOFF),
    "--gpus", str(NUM_GPUS),
    "--batch_cost", str(BATCH_COST),
    "--num_workers", str(NUM_WORKERS),
    "--ckpt_path", str(CKPT_PATH),
    "--save_dir", str(DENOVO_SAVE_DIR),
    "--max_sample_iter", "20",
    "--coord_noise_scale", str(COORD_NOISE_STD),
    "--sample_n_molecules_per_target", str(N_MOLECULES),
    "--sample_mol_sizes",
    "--categorical_strategy", "uniform-sample",
    "--filter_valid_unique",
]

print("De Novo Generation Configuration:")
print(f"  Output: {DENOVO_SAVE_DIR}")
print(f"  N molecules: {N_MOLECULES}")
print(f"  Noise scale: {COORD_NOISE_STD}")

In [None]:
# Run de novo generation
# This may take several minutes depending on your hardware
run_command(denovo_cmd, "De Novo Ligand Generation")

## 3. Visualize De Novo Generated Ligands

After generation, we'll visualize the generated ligands both in isolation and within the protein pocket context.

In [None]:
# =============================================================================
# LOAD AND VISUALIZE DE NOVO GENERATED LIGANDS
# =============================================================================

# Find generated SDF files
denovo_sdf_pattern = DENOVO_SAVE_DIR / "*.sdf"
denovo_sdf_files = sorted(glob.glob(str(denovo_sdf_pattern)))

if denovo_sdf_files:
    # Load the generated molecules
    denovo_sdf = denovo_sdf_files[0]  # Take the first/main output file
    denovo_mols = load_molecules_from_sdf(denovo_sdf)
    
    print("\n2D Visualization of De Novo Generated Ligands:")
    display(visualize_molecules_2d(denovo_mols[:8], 
                                   legends=[f"DeNovo {i+1}" for i in range(min(8, len(denovo_mols)))]))
else:
    print("No de novo generated files found. Please run the generation step first.")
    denovo_mols = []

In [None]:
# Visualize a single de novo ligand in 3D (standalone)
idx = 1
if denovo_mols:
    print("3D Visualization of First De Novo Generated Ligand:")
    viewer = visualize_ligand_3d(denovo_mols[idx])
    viewer.show()

In [None]:
# Visualize de novo ligands within the protein pocket
if denovo_mols and PROTEIN_PDB.exists():
    print(f"\nDe Novo Ligands in {TARGET.upper()} Binding Pocket:")
    viewer = visualize_multiple_ligands_in_pocket(denovo_mols[:5], PROTEIN_PDB)
    viewer.show()
elif not PROTEIN_PDB.exists():
    print(f"Protein structure not found at {PROTEIN_PDB}")
    print("The protein file is created during generation.")

## 4. Scaffold-Conditional Generation (Scaffold Hopping)

Scaffold hopping preserves the functional groups from a reference ligand while generating a new molecular scaffold. This is useful for exploring novel chemotypes while maintaining key interactions.

**Tip:** Add `--anisotropic_prior` to use a shape-aware Gaussian prior that better matches the molecular geometry, or `--ref_ligand_com_prior` to center the prior at the reference ligand's center of mass.

In [None]:
# =============================================================================
#                    SCAFFOLD-CONDITIONAL GENERATION (SCAFFOLD HOPPING)
# =============================================================================

SCAFFOLD_SAVE_DIR = SAVE_DIR / "scaffold"
SCAFFOLD_SAVE_DIR.mkdir(parents=True, exist_ok=True)

scaffold_cmd = [
    "python", "-m", "flowr.gen.generate_from_pdb",
    "--pdb_file", str(PROTEIN_PDB),
    "--ligand_file", str(LIGANDS_SDF),
    "--ligand_idx", str(LIGAND_IDX),
    "--arch", "pocket",
    "--pocket_type", "holo",
    "--cut_pocket",
    "--pocket_cutoff", str(POCKET_CUTOFF),
    "--gpus", str(NUM_GPUS),
    "--num_workers", str(NUM_WORKERS),
    "--batch_cost", str(BATCH_COST),
    "--ckpt_path", str(CKPT_PATH),
    "--save_dir", str(SCAFFOLD_SAVE_DIR),
    "--max_sample_iter", "20",
    "--coord_noise_scale", str(COORD_NOISE_STD),
    "--sample_n_molecules_per_target", str(N_MOLECULES),
    "--sample_mol_sizes",
    "--categorical_strategy", "uniform-sample",
    "--filter_valid_unique",
    "--filter_cond_substructure",
    "--filter_diversity",
    "--diversity_threshold", "0.7",
    "--optimize_gen_ligs_hs",   # Optimize hydrogens in generated ligands
    # "--optimize_gen_ligs",    # Optimize generated ligands (only use if you encounter highly strained structures)
    "--scaffold_hopping",       # <-- Enable scaffold hopping (replaces scaffold, keeps functional groups)
    # "--anisotropic_prior",    # Optional: use shape-aware anisotropic Gaussian prior
    # "--ref_ligand_com_prior", # Optional: center prior at reference ligand COM
    # "--ref_ligand_com_noise_std", "0.0",  # Optional: noise std for COM prior (default: 1.0)
]

print("Scaffold Hopping Configuration:")
print(f"  Reference ligand index: {LIGAND_IDX}")
print(f"  Output: {SCAFFOLD_SAVE_DIR}")

# Show reference ligand
print(f"\nReference ligand (index {LIGAND_IDX}) — functional groups will be preserved:")
display(visualize_molecules_2d([ref_mols[LIGAND_IDX]], legends=["Reference"]))

In [None]:
# Run scaffold-conditional generation
run_command(scaffold_cmd, "Scaffold-Conditional Generation")

## 5. Visualize Scaffold Hopping Ligands

Let's compare the scaffold hopping generated ligands with the reference.

In [None]:
# =============================================================================
# LOAD AND VISUALIZE SCAFFOLD-CONDITIONAL LIGANDS
# =============================================================================

scaffold_sdf_pattern = SCAFFOLD_SAVE_DIR / "*.sdf"
scaffold_sdf_files = sorted(glob.glob(str(scaffold_sdf_pattern)))

if scaffold_sdf_files:
    scaffold_sdf = scaffold_sdf_files[0]
    scaffold_mols = load_molecules_from_sdf(scaffold_sdf)
    
    # Compare with reference
    comparison_mols = [ref_mols[LIGAND_IDX]] + scaffold_mols[:7]
    comparison_legends = ["Reference"] + [f"Scaffold {i+1}" for i in range(min(7, len(scaffold_mols)))]
    
    print("\nComparison: Reference vs Scaffold-Conditional Generated Ligands:")
    display(visualize_molecules_2d(comparison_mols, legends=comparison_legends))
else:
    print("No scaffold-conditional generated files found.")
    scaffold_mols = []

In [None]:
# Visualize scaffold-conditional ligands in protein pocket
if scaffold_mols and PROTEIN_PDB.exists():
    print(f"\nScaffold-Conditional Ligands in {TARGET.upper()} Binding Pocket:")
    viewer = visualize_multiple_ligands_in_pocket(scaffold_mols[:5], PROTEIN_PDB)
    viewer.show()

## 6. Scaffold Elaboration (R-Group Generation)

Scaffold elaboration preserves the core molecular scaffold from a reference ligand while generating new R-groups, decorations, and functional groups. This is useful for lead optimization where you want to keep the scaffold but explore different substituents.

**Tip:** Add `--anisotropic_prior` to use a shape-aware Gaussian prior, or `--ref_ligand_com_prior` to center the prior at the reference ligand's center of mass.

In [None]:
# =============================================================================
#                    SCAFFOLD ELABORATION (R-GROUP GENERATION)
# =============================================================================

FUNCGROUP_SAVE_DIR = SAVE_DIR / "elaboration"
FUNCGROUP_SAVE_DIR.mkdir(parents=True, exist_ok=True)

elaboration_cmd = [
    "python", "-m", "flowr.gen.generate_from_pdb",
    "--pdb_file", str(PROTEIN_PDB),
    "--ligand_file", str(LIGANDS_SDF),
    "--ligand_idx", str(LIGAND_IDX),
    "--arch", "pocket",
    "--pocket_type", "holo",
    "--cut_pocket",
    "--pocket_cutoff", str(POCKET_CUTOFF),
    "--gpus", str(NUM_GPUS),
    "--num_workers", str(NUM_WORKERS),
    "--batch_cost", str(BATCH_COST),
    "--ckpt_path", str(CKPT_PATH),
    "--save_dir", str(FUNCGROUP_SAVE_DIR),
    "--max_sample_iter", "20",
    "--coord_noise_scale", str(COORD_NOISE_STD),
    "--sample_n_molecules_per_target", str(N_MOLECULES),
    "--sample_mol_sizes",
    "--categorical_strategy", "uniform-sample",
    "--filter_valid_unique",
    "--filter_diversity",
    "--diversity_threshold", "0.7",
    "--optimize_gen_ligs_hs",   # Optimize hydrogens in generated ligands
    # "--optimize_gen_ligs",    # Optimize generated ligands (only use if you encounter highly strained structures)
    "--filter_cond_substructure",
    "--scaffold_elaboration",   # <-- Enable scaffold elaboration (keeps scaffold, replaces R-groups/functional groups)
    # "--anisotropic_prior",    # Optional: use shape-aware anisotropic Gaussian prior
    # "--ref_ligand_com_prior", # Optional: center prior at reference ligand COM
]

print("Scaffold Elaboration Configuration:")
print(f"  Reference ligand index: {LIGAND_IDX}")
print(f"  Output: {FUNCGROUP_SAVE_DIR}")

In [None]:
# Run scaffold elaboration
run_command(elaboration_cmd, "Scaffold Elaboration Generation")

## 7. Visualize Scaffold Elaboration Ligands

In [None]:
# =============================================================================
# LOAD AND VISUALIZE SCAFFOLD ELABORATION LIGANDS
# =============================================================================

elaboration_sdf_pattern = FUNCGROUP_SAVE_DIR / "*.sdf"
elaboration_sdf_files = sorted(glob.glob(str(elaboration_sdf_pattern)))

if elaboration_sdf_files:
    elaboration_sdf = elaboration_sdf_files[0]
    elaboration_mols = load_molecules_from_sdf(elaboration_sdf)
    
    # Compare with reference
    comparison_mols = [ref_mols[LIGAND_IDX]] + elaboration_mols[:7]
    comparison_legends = ["Reference"] + [f"Elaboration {i+1}" for i in range(min(7, len(elaboration_mols)))]
    
    print("\nComparison: Reference vs Scaffold Elaboration Generated Ligands:")
    display(visualize_molecules_2d(comparison_mols, legends=comparison_legends))
else:
    print("No scaffold elaboration generated files found.")
    elaboration_mols = []

In [None]:
# Visualize scaffold elaboration ligands in protein pocket
if elaboration_mols and PROTEIN_PDB.exists():
    print(f"\nScaffold Elaboration Ligands in {TARGET.upper()} Binding Pocket:")
    viewer = visualize_multiple_ligands_in_pocket(elaboration_mols[:5], PROTEIN_PDB)
    viewer.show()

## 8. Core Growing

Core growing preserves a selected ring system (core) from the reference ligand and generates new substituents around it. Use `--ring_system_indexing` to select which ring system to keep (default: 0, i.e. the first/largest ring system).

**Tip:** Add `--anisotropic_prior` for a shape-aware prior that better matches the growth direction.

In [None]:
# =============================================================================
#                          CORE GROWING
# =============================================================================

CORE_SAVE_DIR = SAVE_DIR / "core_growing"
CORE_SAVE_DIR.mkdir(parents=True, exist_ok=True)

# Select which ring system to keep (0 = first/largest ring system)
RING_SYSTEM_INDEX = 0

core_growing_cmd = [
    "python", "-m", "flowr.gen.generate_from_pdb",
    "--pdb_file", str(PROTEIN_PDB),
    "--ligand_file", str(LIGANDS_SDF),
    "--ligand_idx", str(LIGAND_IDX),
    "--arch", "pocket",
    "--pocket_type", "holo",
    "--cut_pocket",
    "--pocket_cutoff", str(POCKET_CUTOFF),
    "--gpus", str(NUM_GPUS),
    "--num_workers", str(NUM_WORKERS),
    "--batch_cost", str(BATCH_COST),
    "--ckpt_path", str(CKPT_PATH),
    "--save_dir", str(CORE_SAVE_DIR),
    "--max_sample_iter", "20",
    "--coord_noise_scale", str(COORD_NOISE_STD),
    "--sample_n_molecules_per_target", str(N_MOLECULES),
    "--sample_mol_sizes",
    "--categorical_strategy", "uniform-sample",
    "--filter_valid_unique",
    "--filter_diversity",
    "--diversity_threshold", "0.7",
    "--optimize_gen_ligs_hs",       # Optimize hydrogens in generated ligands
    "--filter_cond_substructure",
    "--core_growing",               # <-- Enable core growing
    "--ring_system_indexing", str(RING_SYSTEM_INDEX),  # <-- Select which ring system to preserve
    # "--anisotropic_prior",        # Optional: use shape-aware anisotropic Gaussian prior
    # "--ref_ligand_com_prior",     # Optional: center prior at reference ligand COM
]

print("Core Growing Configuration:")
print(f"  Reference ligand index: {LIGAND_IDX}")
print(f"  Ring system index: {RING_SYSTEM_INDEX}")
print(f"  Output: {CORE_SAVE_DIR}")

# Show reference ligand with atom indices to identify ring systems
print(f"\nReference ligand (index {LIGAND_IDX}) — ring system {RING_SYSTEM_INDEX} will be preserved:")
display(visualize_molecules_2d([ref_mols[LIGAND_IDX]], legends=["Reference"]))

In [None]:
# Run core growing
run_command(core_growing_cmd, "Core Growing Generation")

## 9. Visualize Core Growing Ligands

In [None]:
# =============================================================================
# LOAD AND VISUALIZE CORE GROWING LIGANDS
# =============================================================================

core_sdf_pattern = CORE_SAVE_DIR / "*.sdf"
core_sdf_files = sorted(glob.glob(str(core_sdf_pattern)))

if core_sdf_files:
    core_sdf = core_sdf_files[0]
    core_mols = load_molecules_from_sdf(core_sdf)
    
    # Compare with reference
    comparison_mols = [ref_mols[LIGAND_IDX]] + core_mols[:7]
    comparison_legends = ["Reference"] + [f"CoreGrow {i+1}" for i in range(min(7, len(core_mols)))]
    
    print("\nComparison: Reference vs Core Growing Generated Ligands:")
    print(f"(Ring system {RING_SYSTEM_INDEX} preserved, substituents regenerated)")
    display(visualize_molecules_2d(comparison_mols, legends=comparison_legends))
else:
    print("No core growing generated files found.")
    core_mols = []

In [None]:
# Visualize core growing ligands in protein pocket
if core_mols and PROTEIN_PDB.exists():
    print(f"\nCore Growing Ligands in {TARGET.upper()} Binding Pocket:")
    viewer = visualize_multiple_ligands_in_pocket(core_mols[:5], PROTEIN_PDB)
    viewer.show()

## 10. Fragment Growing

Fragment growing takes an input fragment and grows additional atoms around it. Key options:
- `--grow_size`: Specify the number of heavy atoms to add (if not set, molecule sizes are sampled)
- `--prior_center_file`: Provide an XYZ file with coordinates to center the random prior (controls growth direction)
- `--anisotropic_prior`: Use an anisotropic Gaussian prior shaped along the growth direction

In [None]:
# =============================================================================
#                         FRAGMENT GROWING
# =============================================================================

FRAGMENT_SAVE_DIR = SAVE_DIR / "fragment_growing"
FRAGMENT_SAVE_DIR.mkdir(parents=True, exist_ok=True)

# Fragment growing parameters
GROW_SIZE = 15  # Number of heavy atoms to add to the fragment (set to None to sample sizes)

fragment_growing_cmd = [
    "python", "-m", "flowr.gen.generate_from_pdb",
    "--pdb_file", str(PROTEIN_PDB),
    "--ligand_file", str(LIGANDS_SDF),
    "--ligand_idx", str(LIGAND_IDX),
    "--arch", "pocket",
    "--pocket_type", "holo",
    "--cut_pocket",
    "--pocket_cutoff", str(POCKET_CUTOFF),
    "--gpus", str(NUM_GPUS),
    "--num_workers", str(NUM_WORKERS),
    "--batch_cost", str(BATCH_COST),
    "--ckpt_path", str(CKPT_PATH),
    "--save_dir", str(FRAGMENT_SAVE_DIR),
    "--max_sample_iter", "20",
    "--coord_noise_scale", str(COORD_NOISE_STD),
    "--sample_n_molecules_per_target", str(N_MOLECULES),
    "--sample_mol_sizes",
    "--categorical_strategy", "uniform-sample",
    "--filter_valid_unique",
    "--filter_diversity",
    "--diversity_threshold", "0.7",
    "--optimize_gen_ligs_hs",       # Optimize hydrogens in generated ligands
    "--filter_cond_substructure",
    "--fragment_growing",           # <-- Enable fragment growing
    "--grow_size", str(GROW_SIZE),  # <-- Number of atoms to grow
    "--anisotropic_prior",          # Recommended: use anisotropic prior for directed growth
    # "--prior_center_file", "path/to/center.xyz",  # Optional: XYZ file to center prior at
    # "--ref_ligand_com_prior",     # Optional: center prior at reference ligand COM
    # "--ref_ligand_com_noise_std", "0.0",  # Optional: noise std for COM prior
]

print("Fragment Growing Configuration:")
print(f"  Reference ligand index: {LIGAND_IDX}")
print(f"  Grow size: {GROW_SIZE} atoms")
print(f"  Anisotropic prior: enabled")
print(f"  Output: {FRAGMENT_SAVE_DIR}")

In [None]:
# Run fragment growing
run_command(fragment_growing_cmd, "Fragment Growing Generation")

In [None]:
# =============================================================================
# LOAD AND VISUALIZE FRAGMENT GROWING LIGANDS
# =============================================================================

fragment_sdf_pattern = FRAGMENT_SAVE_DIR / "*.sdf"
fragment_sdf_files = sorted(glob.glob(str(fragment_sdf_pattern)))

if fragment_sdf_files:
    fragment_sdf = fragment_sdf_files[0]
    fragment_mols = load_molecules_from_sdf(fragment_sdf)
    
    # Compare with reference
    comparison_mols = [ref_mols[LIGAND_IDX]] + fragment_mols[:7]
    comparison_legends = ["Reference"] + [f"FragGrow {i+1}" for i in range(min(7, len(fragment_mols)))]
    
    print(f"\nComparison: Reference vs Fragment Growing Generated Ligands (grew {GROW_SIZE} atoms):")
    display(visualize_molecules_2d(comparison_mols, legends=comparison_legends))
else:
    print("No fragment growing generated files found.")
    fragment_mols = []

## 11. Visualize Fragment Growing Ligands in Pocket

In [None]:
# Visualize fragment growing ligands in protein pocket
if fragment_mols and PROTEIN_PDB.exists():
    print(f"\nFragment Growing Ligands in {TARGET.upper()} Binding Pocket:")
    viewer = visualize_multiple_ligands_in_pocket(fragment_mols[:5], PROTEIN_PDB)
    viewer.show()

## 12. Substructure Inpainting

Substructure inpainting allows you to selectively regenerate specific parts of a molecule while preserving others. Specify the atom indices you want to **change** via `--substructure`. This is particularly useful for:
- Optimizing specific regions of a lead compound
- Exploring modifications to particular functional groups
- Maintaining critical binding interactions while varying other parts

In [None]:
# =============================================================================
#                       SUBSTRUCTURE INPAINTING
# =============================================================================

sub_mol = Chem.RemoveHs(ref_mols[LIGAND_IDX])
# Visualize reference ligand with atom indices
print(f"Reference ligand (index {LIGAND_IDX}) with atom indices:")
print(f"   Total heavy atoms: {sub_mol.GetNumHeavyAtoms()}")
print("    Select atom indices below that you want to CHANGE during inpainting.\n")
print("    NOTE: Atom indices need to be taken from the molecule without hydrogens!")

visualize_mol_with_atom_indices(sub_mol)

# Define which atoms to CHANGE (modify based on visualization above!)
CHANGE_ATOM_INDICES = [15, 16, 17, 18, 19]  # <-- Adjust these!

INPAINT_SAVE_DIR = SAVE_DIR / "substructure"
INPAINT_SAVE_DIR.mkdir(parents=True, exist_ok=True)

inpaint_cmd = [
    "python", "-m", "flowr.gen.generate_from_pdb",
    "--pdb_file", str(PROTEIN_PDB),
    "--ligand_file", str(LIGANDS_SDF),
    "--ligand_idx", str(LIGAND_IDX),
    "--arch", "pocket",
    "--pocket_type", "holo",
    "--cut_pocket",
    "--pocket_cutoff", str(POCKET_CUTOFF),
    "--gpus", str(NUM_GPUS),
    "--num_workers", str(NUM_WORKERS),
    "--batch_cost", str(BATCH_COST),
    "--ckpt_path", str(CKPT_PATH),
    "--save_dir", str(INPAINT_SAVE_DIR),
    "--max_sample_iter", "20",
    "--coord_noise_scale", str(COORD_NOISE_STD),
    "--sample_n_molecules_per_target", str(N_MOLECULES),
    "--sample_mol_sizes",
    "--categorical_strategy", "uniform-sample",
    "--filter_valid_unique",
    "--filter_diversity",
    "--diversity_threshold", "0.7",
    "--optimize_gen_ligs_hs",       # Optimize hydrogens in generated ligands
    "--filter_cond_substructure",
    "--substructure_inpainting",    # <-- Enable substructure inpainting
    "--substructure",               # <-- Followed by atom indices to CHANGE
] + [str(idx) for idx in CHANGE_ATOM_INDICES]

print("Substructure Inpainting Configuration:")
print(f"  Reference ligand index: {LIGAND_IDX}")
print(f"  Atoms to change: {CHANGE_ATOM_INDICES}")
print(f"  Output: {INPAINT_SAVE_DIR}")

In [None]:
# Run substructure inpainting
run_command(inpaint_cmd, "Substructure Inpainting")

## 13. Visualize Inpainted Ligands

In [None]:
# =============================================================================
# LOAD AND VISUALIZE INPAINTED LIGANDS
# =============================================================================

inpaint_sdf_pattern = INPAINT_SAVE_DIR / "*.sdf"
inpaint_sdf_files = sorted(glob.glob(str(inpaint_sdf_pattern)))

if inpaint_sdf_files:
    inpaint_sdf = inpaint_sdf_files[0]
    inpaint_mols = load_molecules_from_sdf(inpaint_sdf)
    
    # Compare with reference
    comparison_mols = [ref_mols[LIGAND_IDX]] + inpaint_mols[:7]
    comparison_legends = ["Original"] + [f"Inpainted {i+1}" for i in range(min(7, len(inpaint_mols)))]
    
    print("\nComparison: Original vs Inpainted Ligands:")
    print("(Preserved regions remain similar, inpainted regions show variation)")
    display(visualize_molecules_2d(comparison_mols, legends=comparison_legends))
else:
    print("No inpainted generated files found.")
    inpaint_mols = []

In [None]:
# Visualize inpainted ligands in protein pocket
if inpaint_mols and PROTEIN_PDB.exists():
    print(f"\nInpainted Ligands in {TARGET.upper()} Binding Pocket:")
    viewer = visualize_multiple_ligands_in_pocket(inpaint_mols[:5], PROTEIN_PDB)
    viewer.show()

## 14. Affinity Prediction

Now let's predict binding affinities for ligands using the flowr.root affinity prediction model. We'll predict affinities for all ligands in the reference SDF file.

In [None]:
# =============================================================================
#                         AFFINITY PREDICTION
# =============================================================================

AFFINITY_SAVE_DIR = SAVE_DIR / "affinity"
AFFINITY_SAVE_DIR.mkdir(parents=True, exist_ok=True)

# Affinity prediction parameters
AFFINITY_COORD_NOISE = 0.05  # Lower noise for prediction stability
AFFINITY_BATCH_COST = 5
SEED = 42

# Build command based on predict.sh structure
affinity_cmd = [
    "python", "-m", "flowr.predict.predict_from_pdb",
    "--gpus", str(NUM_GPUS),
    "--num_workers", str(NUM_WORKERS),
    "--seed", str(SEED),
    "--batch_cost", str(AFFINITY_BATCH_COST),
    "--arch", "pocket",
    "--pocket_noise", "fix",
    "--pocket_type", "holo",
    "--cut_pocket",
    "--pocket_cutoff", str(POCKET_CUTOFF), 
    "--pdb_file", str(PROTEIN_PDB),
    "--ligand_file", str(LIGANDS_SDF),
    "--multiple_ligands",  # SDF contains multiple ligands; remove if you only have one
    "--ckpt_path", str(CKPT_PATH),
    "--save_dir", str(AFFINITY_SAVE_DIR),
    "--coord_noise_scale", str(AFFINITY_COORD_NOISE),
]

print("Affinity Prediction Configuration:")
print(f"  Protein: {PROTEIN_PDB.name}")
print(f"  Ligands: {LIGANDS_SDF.name}")
print(f"  Architecture: pocket")
print(f"  Coord noise scale: {AFFINITY_COORD_NOISE}")
print(f"  Batch cost: {AFFINITY_BATCH_COST}")
print(f"  Output: {AFFINITY_SAVE_DIR}")

In [None]:
# Run affinity prediction
run_command(affinity_cmd, "Affinity Prediction")

In [None]:
# =============================================================================
# LOAD AFFINITY PREDICTIONS
# =============================================================================

def load_affinity_sdf(
    sdf_path: Union[str, Path],
    affinity_props: Optional[List[str]] = None
) -> pd.DataFrame:
    """
    Load ligands from SDF file and extract affinity predictions into a DataFrame.
    """
    if affinity_props is None:
        affinity_props = ["pic50", "pkd", "pki", "pec50"]
    
    sdf_path = Path(sdf_path)
    if not sdf_path.exists():
        raise FileNotFoundError(f"SDF file not found: {sdf_path}")
    
    suppl = Chem.SDMolSupplier(str(sdf_path), removeHs=False)
    mols = [mol for mol in suppl if mol is not None]
    
    data = {
        "ligand_idx": list(range(len(mols))),
        "smiles": [Chem.MolToSmiles(mol) for mol in mols],
    }
    
    for prop in affinity_props:
        values = []
        for mol in mols:
            if mol.HasProp(prop):
                try:
                    values.append(float(mol.GetProp(prop)))
                except ValueError:
                    values.append(np.nan)
            else:
                values.append(np.nan)
        data[prop] = values
    
    return pd.DataFrame(data)

# Load
sdf_path = f"{AFFINITY_SAVE_DIR}/gen_lig_with_aff.sdf"
affinity_df = load_affinity_sdf(sdf_path)
print(affinity_df)

## 15. Visualize Affinity Predictions

In [None]:
# =============================================================================
# AFFINITY VISUALIZATION
# =============================================================================

AFFINITY_TYPE = "pic50"  # Choose from "pic50", "pkd", "pki", "pec50"

def visualize_affinity(
    affinity_df: pd.DataFrame,
    affinity_type: str = "pic50",
    save_dir: Optional[Union[str, Path]] = None,
    figsize: tuple = (16, 5),
    top_n: int = 10,
    dpi: int = 300,
    show_plot: bool = True,
) -> plt.Figure:
    """Visualize affinity predictions with publication-quality figures."""
    if affinity_type not in affinity_df.columns:
        available = [col for col in ["pic50", "pkd", "pki", "pec50"] if col in affinity_df.columns]
        raise ValueError(f"Affinity type '{affinity_type}' not found. Available: {available}")
    
    df = affinity_df.dropna(subset=[affinity_type]).copy()
    if len(df) == 0:
        raise ValueError(f"No valid data for affinity type '{affinity_type}'")
    
    palette = {"pIC50": "#2ecc71", "pKd": "#3498db", "pKi": "#9b59b6", "pEC50": "#e74c3c"}
    main_color = palette.get(affinity_type, "#3498db")
    
    fig, axes = plt.subplots(1, 3, figsize=figsize)
    fig.suptitle(f"Affinity Analysis: {affinity_type}", fontsize=14, fontweight='bold', y=1.02)
    
    mean_val = df[affinity_type].mean()
    median_val = df[affinity_type].median()
    std_val = df[affinity_type].std()
    
    # Plot 1: Ranked Ligand Affinities
    ax1 = axes[0]
    sorted_df = df.sort_values(affinity_type, ascending=True)
    max_bars = min(30, len(sorted_df))
    plot_df = sorted_df.tail(max_bars) if len(sorted_df) > max_bars else sorted_df
    norm_vals = (plot_df[affinity_type] - plot_df[affinity_type].min()) / \
                (plot_df[affinity_type].max() - plot_df[affinity_type].min() + 1e-8)
    colors = plt.cm.RdYlGn(0.2 + 0.6 * norm_vals.values)
    y_pos = range(len(plot_df))
    ax1.barh(y_pos, plot_df[affinity_type], color=colors, edgecolor='white', linewidth=0.5)
    ax1.set_yticks(y_pos)
    ax1.set_yticklabels([f"Lig {i}" for i in plot_df['ligand_idx']], fontsize=8)
    ax1.set_xlabel(f'{affinity_type}', fontsize=11)
    ax1.set_title('Ranked Ligand Affinities', fontsize=12, fontweight='bold')
    ax1.axvline(x=mean_val, color='#e74c3c', linestyle='--', linewidth=2, alpha=0.8)
    ax1.spines['top'].set_visible(False)
    ax1.spines['right'].set_visible(False)
    
    # Plot 2: Distribution with KDE
    ax2 = axes[1]
    sns.histplot(df[affinity_type], bins='auto', kde=True, color=main_color, alpha=0.6, ax=ax2)
    ax2.axvline(x=mean_val, color='#e74c3c', linestyle='--', linewidth=2.5, label=f'Mean: {mean_val:.2f}')
    ax2.axvline(x=median_val, color='#f39c12', linestyle='-.', linewidth=2.5, label=f'Median: {median_val:.2f}')
    ax2.set_xlabel(f'{affinity_type}', fontsize=11)
    ax2.set_ylabel('Count', fontsize=11)
    ax2.set_title('Affinity Distribution', fontsize=12, fontweight='bold')
    ax2.legend(loc='upper right', fontsize=9)
    ax2.spines['top'].set_visible(False)
    ax2.spines['right'].set_visible(False)
    
    # Plot 3: Top N Ligands
    ax3 = axes[2]
    top_n_actual = min(top_n, len(df))
    top_ligands = df.nlargest(top_n_actual, affinity_type)
    top_colors = plt.cm.Greens(np.linspace(0.4, 0.9, top_n_actual))[::-1]
    x_pos = range(top_n_actual)
    bars = ax3.bar(x_pos, top_ligands[affinity_type], color=top_colors, edgecolor='white', width=0.7)
    ax3.set_xticks(x_pos)
    ax3.set_xticklabels([f"Lig {i}" for i in top_ligands['ligand_idx']], fontsize=9, rotation=45, ha='right')
    ax3.set_ylabel(f'{affinity_type}', fontsize=11)
    ax3.set_title(f'Top {top_n_actual} Highest Affinity Ligands', fontsize=12, fontweight='bold')
    for bar, val in zip(bars, top_ligands[affinity_type]):
        ax3.annotate(f'{val:.2f}', xy=(bar.get_x() + bar.get_width() / 2, bar.get_height()),
                     xytext=(0, 4), textcoords="offset points", ha='center', va='bottom', fontsize=9)
    ax3.spines['top'].set_visible(False)
    ax3.spines['right'].set_visible(False)
    
    plt.tight_layout()
    if save_dir is not None:
        save_dir = Path(save_dir)
        save_dir.mkdir(parents=True, exist_ok=True)
        fig.savefig(save_dir / f'affinity_analysis_{affinity_type}.png', dpi=dpi, bbox_inches='tight')
    if show_plot:
        plt.show()
    return fig


def print_affinity_summary(affinity_df: pd.DataFrame, affinity_type: str = "pic50"):
    """Print summary statistics for the specified affinity type."""
    if affinity_type not in affinity_df.columns:
        return
    values = affinity_df[affinity_type].dropna()
    print(f"\n{'='*50}")
    print(f"  {affinity_type} Summary Statistics")
    print(f"{'='*50}")
    print(f"  Count:  {len(values)}  |  Mean: {values.mean():.3f}  |  Std: {values.std():.3f}")
    print(f"  Min:    {values.min():.3f}  |  Median: {values.median():.3f}  |  Max: {values.max():.3f}")
    print(f"{'='*50}\n")


# Visualize
sdf_path = f"{AFFINITY_SAVE_DIR}/gen_lig_with_aff.sdf"
affinity_df = load_affinity_sdf(sdf_path)
visualize_affinity(affinity_df, affinity_type=AFFINITY_TYPE, save_dir=AFFINITY_SAVE_DIR)
print_affinity_summary(affinity_df, AFFINITY_TYPE)

In [None]:
# =============================================================================
# TUTORIAL COMPLETE
# =============================================================================

print("\n" + "="*60)
print("TUTORIAL COMPLETE")
print("="*60)
print(f"\nAll outputs saved to: {SAVE_DIR}")
print("\nGenerated directories:")
for subdir in ['denovo', 'scaffold', 'elaboration', 'core_growing', 'fragment_growing', 'substructure', 'affinity']:
    subdir_path = SAVE_DIR / subdir
    if subdir_path.exists():
        print(f"  + {subdir_path}")
    else:
        print(f"  - {subdir_path} (run generation to create)")
print("="*60)