In [1]:
from omtra.load.quick import datamodule_from_config
import omtra.load.quick as quick_load
from omtra.utils import omtra_root
from rdkit import Chem
from pathlib import Path
from omtra.constants import num_condensed_atom_types
from omtra.data.condensed_atom_typing import CondensedAtomTyper
import torch

from omtra.tasks.register import task_name_to_class
import rdkit
from rdkit import Chem
import py3Dmol
rdkit.Chem.Draw.IPythonConsole.ipython_3d = True  # enable py3Dmol inline visualization

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_task_capabilities(task_name):
  """
  Determine what modalities a task supports based on its groups.

  Returns:
      dict: Dictionary with boolean flags for each modality
  """
  task_class = task_name_to_class(task_name)
  groups = task_class.groups_present

  capabilities = {
      'has_protein': any(g in groups for g in ['protein_identity', 'protein_structure']),
      'has_pharmacophore': 'pharmacophore' in groups,
      'has_ligand': any(g in groups for g in ['ligand_identity', 'ligand_identity_condensed', 'ligand_identity_extra', 'ligand_structure']),
  }

  return capabilities

In [3]:
import numpy as np
np.random.seed(0)

In [4]:
# Load model and data (your existing setup)
ckpt = '/net/galaxy/home/koes/icd3/moldiff/OMTRA/local/mlsb_runs_backup/mt_plinder/prot_protpharm_cond_2025-09-11_18-31-586141/checkpoints/last.ckpt'
ckpt = Path(ckpt)

cfg_file = ckpt.parent.parent / '.hydra/config.yaml'
cfg = quick_load.load_trained_model_cfg(cfg_file)
pharmit_path = "/net/galaxy/home/koes/icd3/moldiff/OMTRA/data/pharmit"
plinder_path = "/net/galaxy/home/koes/icd3/moldiff/OMTRA/data/plinder"
cfg.pharmit_path = str(pharmit_path)
cfg.plinder_path = str(plinder_path)

model = quick_load.omtra_from_checkpoint(str(ckpt)).cuda().eval()
dm = quick_load.datamodule_from_config(cfg)
multiset = dm.load_dataset('train')
dataset = multiset.datasets['plinder']['no_links']

âš› Instantiating datamodule <omtra.dataset.data_module.MultiTaskDataModule>


In [5]:

## conformer generation (unconditional and pharm conditioned)
task_name = 'ligand_conformer_from_pharmacophore_condensed'
#task_name = 'ligand_conformer_condensed'

## denovo task (unconditional and pharm conditioned)
#task_name = "denovo_ligand_condensed"
task_name = "denovo_ligand_from_pharmacophore_condensed"


## ====>> tasks that have proteins: <<=====

#====> denovo
#task_name = "fixed_protein_ligand_denovo_condensed"
#task_name = 'fixed_protein_pharmacophore_ligand_denovo_condensed'

#====> docking
#task_name = "rigid_docking_condensed"
task_name = "rigid_docking_pharmacophore_condensed"



chosen_indices = [443] #,702,152]

g_list = [ dataset[(task_name, i)] for i in chosen_indices ]
for i in range(len(g_list)):
    g_list[i] = g_list[i].to('cuda')

sampled_systems = model.sample(
    task_name=task_name,
    g_list=g_list,
    n_replicates=3,
    # unconditional_n_atoms_dist='plinder',
    n_timesteps=250,
)

In [6]:
def view_pharm_and_ligand(sys, task_capabilities, prefix="", ground_truth_lig=False,):
    conformer_file = f"./figures/{prefix}_temp.sdf"
    sys.write_ligand(conformer_file, ground_truth=ground_truth_lig)
    
    if task_capabilities['has_pharmacophore']:
        pharmacophore_file = f"./figures/{prefix}temp.xyz"
        sys.write_pharmacophore(pharmacophore_file, ground_truth=True)
        
        # get pharmacophore in memory
        pharm = sys.get_pharmacophore_from_graph(kind='gt', xyz=False)
        pharm_types = pharm['types']

    if task_capabilities['has_protein']:
        protein_file = f"./figures/{prefix}.pdb"
        sys.write_protein_pdb("./", protein_file, ground_truth=True)



    pharm_type_to_color = {
        'HydrogenDonor': 'red',
        'HydrogenAcceptor': 'blue',
        'Hydrophobic': 'green',
        'Aromatic': 'purple',
    }
    # Read the molecule from SDF file
    mol = Chem.SDMolSupplier(conformer_file)[0]

    # Create py3Dmol viewer
    viewer = py3Dmol.view(width=400, height=300)

    # Add the molecule
    viewer.addModel(Chem.MolToMolBlock(mol), 'sdf')
    viewer.setStyle({'stick': {'radius': 0.1}, 'sphere': {'radius': 0.3}})

    
    if task_capabilities['has_pharmacophore']:
        # Read pharmacophore coordinates from XYZ file
        with open(pharmacophore_file, 'r') as f:
            lines = f.readlines()
    
        # Parse XYZ format (skip first lines - atom count)
        for pharm_idx, line in enumerate(lines[1:]):
            parts = line.strip().split()
            if len(parts) >= 4:
                x, y, z = float(parts[1]), float(parts[2]), float(parts[3])
                
                # Add pharmacophore points as spheres with different colors based on type
                color = pharm_type_to_color.get(pharm_types[pharm_idx], 'yellow')
                viewer.addSphere({'center': {'x': x, 'y': y, 'z': z}, 'radius': 1.0, 'color': color, 'alpha': 0.5})


    ## add protein to py3dmol
    if task_capabilities['has_protein']:
        with open(protein_file + "_0.pdb", 'r') as f:
          pdb_data = f.read()
        
        viewer.addModel(pdb_data, 'pdb')
        viewer.setStyle({'model': 1},
                     {'cartoon': {'color': 'lightblue', 'opacity': 0.8}})

    viewer.zoomTo()
    viewer.show()

print("Ground truth")
task_capabilities = get_task_capabilities(task_name)

view_pharm_and_ligand(sampled_systems[0], task_capabilities, ground_truth_lig=True)

for i, sys in enumerate(sampled_systems):
    print(f"System {i}")
    prefix = f"{task_name}_{i}"
    view_pharm_and_ligand(sys, task_capabilities, prefix)

Ground truth


System 0


System 1


System 2
