<a href="https://colab.research.google.com/github/catastropiyush/quetzal/blob/main/colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Scalable Autoregressive 3D Molecule Generation

This notebook samples molecules from Quetzal models pretrained on the QM9 and GEOM datasets.

[Paper](https://arxiv.org/abs/2505.13791) | [GitHub](https://github.com/aspuru-guzik-group/quetzal)

## Setup

In [1]:
!git clone https://github.com/aspuru-guzik-group/quetzal.git
import os
os.chdir('quetzal')
os.environ['PYTHONPATH'] = '/env/python:/content/quetzal'
!pip uninstall datasets -y # name collision
!pip install -q rdkit py3Dmol lightning==2.5.0.post0

Cloning into 'quetzal'...
remote: Enumerating objects: 106, done.[K
remote: Counting objects: 100% (106/106), done.[K
remote: Compressing objects: 100% (91/91), done.[K
remote: Total 106 (delta 15), reused 101 (delta 13), pack-reused 0 (from 0)[K
Receiving objects: 100% (106/106), 27.07 MiB | 13.70 MiB/s, done.
Resolving deltas: 100% (15/15), done.
Found existing installation: datasets 4.0.0
Uninstalling datasets-4.0.0:
  Successfully uninstalled datasets-4.0.0
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.4/40.4 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m815.2/815.2 kB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m36.6/36.6 MB[0m [31m37.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m65.5/65.5 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
# Download pretrained checkpoints
!wget https://huggingface.co/auhcheng/quetzal/resolve/main/original.ckpt # best qm9 model
# !wget https://huggingface.co/auhcheng/quetzal/resolve/main/geom.ckpt # best geom model

--2026-02-02 06:21:34--  https://huggingface.co/auhcheng/quetzal/resolve/main/original.ckpt
Resolving huggingface.co (huggingface.co)... 3.166.152.110, 3.166.152.44, 3.166.152.105, ...
Connecting to huggingface.co (huggingface.co)|3.166.152.110|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://us.gcp.cdn.hf.co/xet-bridge-us/68bae7d142a11efbc6d5c83e/ba408df2edf4e3ff5713075c8be0bac397487bc1a5aedf0601adba183f744a65?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27original.ckpt%3B+filename%3D%22original.ckpt%22%3B&Expires=1770016894&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiRXBvY2hUaW1lIjoxNzcwMDE2ODk0fX0sIlJlc291cmNlIjoiaHR0cHM6Ly91cy5nY3AuY2RuLmhmLmNvL3hldC1icmlkZ2UtdXMvNjhiYWU3ZDE0MmExMWVmYmM2ZDVjODNlL2JhNDA4ZGYyZWRmNGUzZmY1NzEzMDc1YzhiZTBiYWMzOTc0ODdiYzFhNWFlZGYwNjAxYWRiYTE4M2Y3NDRhNjVcXD9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSoifV19&Signature=Ir2ME2XC-vYRYlhiuseQoh-WEO9Xv7Y7K3KJwCVmvfoaBZ4Ek%7EXZSaUJZXNcQQLJ2sC

## Generate molecules

In [3]:
import torch
DIFF_STEPS = 60
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
import sys; sys.path.append(".")
from train import Config, LitQuetzal

kwargs = {"bsz": 25, "device": DEVICE, "num_steps": DIFF_STEPS, "pbar": True, "max_len": 32}

### qm9
ckpt_name = "original"
kwargs["max_len"] = 32

### geom
# ckpt_name = "geom"
# kwargs["max_len"] = 192

ckpt = f"{ckpt_name}.ckpt"
lit = LitQuetzal.load_from_checkpoint(ckpt, map_location=DEVICE)
model = lit.ema.module
model.eval();

out = model.generate(**kwargs)
samples, all_traj = out

  _C._set_float32_matmul_precision(precision)
 78%|███████▊  | 25/32 [07:25<02:04, 17.83s/it]


In [4]:
from draw import show_grid
show_grid(samples, 5, 5)

<py3Dmol.view at 0x7c0ba13b3f20>

In [6]:
kwargs["bsz"] = 5 # Generate 50 molecules
# You can also try changing other parameters, for example:
# kwargs["num_steps"] = 100 # Increase diffusion steps for potentially better quality
kwargs["max_len"] = 64 # Increase max length if generating larger molecules

print(f"Generating {kwargs['bsz']} molecules with parameters: {kwargs}")

out = model.generate(**kwargs)
samples_more, all_traj_more = out


Generating 5 molecules with parameters: {'bsz': 5, 'device': 'cpu', 'num_steps': 60, 'pbar': True, 'max_len': 64}


 27%|██▋       | 17/64 [02:27<06:48,  8.68s/it]


In [11]:
from draw import show_grid
# Adjust nrows and ncols to exactly match the number of samples (5 in this case)
show_grid(samples_more, 1, 5)

<py3Dmol.view at 0x7c0b4d5c64e0>

In [13]:
import os

# Helper function to convert a Molecule object to XYZ format
def molecule_to_xyz_string(mol):
    # Define a mapping from atomic number to element symbol
    # This mapping covers common atoms in QM9 and GEOM datasets
    atom_map = {
        1: 'H', 6: 'C', 7: 'N', 8: 'O', 9: 'F',
        15: 'P', 16: 'S', 17: 'Cl', 35: 'Br', 53: 'I' # Added for completeness if GEOM is used
    }

    num_atoms = mol.atoms.shape[0]
    xyz_string = f"{num_atoms}\n"
    xyz_string += "Generated by Quetzal\n"
    for i in range(num_atoms):
        atom_type_idx = mol.atoms[i].item() # Get atomic number
        atom_symbol = atom_map.get(atom_type_idx, 'X') # 'X' for unknown atoms
        x, y, z = mol.coords[i].tolist()
        xyz_string += f"{atom_symbol} {x:.6f} {y:.6f} {z:.6f}\n"
    return xyz_string

# Create a directory to save the XYZ files if it doesn't exist
output_dir = "generated_molecules_xyz"
os.makedirs(output_dir, exist_ok=True)

xyz_files = []
for i, mol in enumerate(samples_more):
    file_path = os.path.join(output_dir, f"molecule_{i+1}.xyz")
    xyz_content = molecule_to_xyz_string(mol)
    with open(file_path, 'w') as f:
        f.write(xyz_content)
    xyz_files.append(file_path)

print(f"Generated {len(xyz_files)} XYZ files in the '{output_dir}' directory.")

# Display the content of the first XYZ file as an example
if xyz_files:
    print(f"\nContent of {xyz_files[0]}:")
    with open(xyz_files[0], 'r') as f:
        print(f.read())


Generated 5 XYZ files in the 'generated_molecules_xyz' directory.

Content of generated_molecules_xyz/molecule_1.xyz:
17
Generated by Quetzal
C -1.475023 -0.210866 1.093985
C -1.044911 -0.523002 -0.270959
C -0.676465 -0.783046 -1.386853
C -0.259353 -1.083754 -2.738675
O 1.152002 -0.951793 -3.013773
C 0.617299 -2.280130 -3.016505
C 1.084367 -3.153178 -1.956076
N 1.447656 -3.864648 -1.123330
H -2.535427 0.029155 1.146129
H -1.282333 -1.051384 1.765803
H -0.922798 0.656004 1.473987
H -0.893006 -0.718977 -3.533175
H 0.513713 -2.750571 -4.000368
X -0.000127 0.000217 -0.001378
C -0.323769 0.137752 -0.643849
N -0.846788 1.155487 -1.583995
H 0.237001 1.472288 -1.343668



To generate molecules using the GEOM dataset model (if downloaded), you would uncomment the relevant lines below and run the cell. This model is designed for larger molecules, so `max_len` is typically set higher.

In [None]:
# To switch to the GEOM model (assuming 'geom.ckpt' is downloaded)
# ckpt_name_geom = "geom"
# kwargs_geom = {"bsz": 25, "device": DEVICE, "num_steps": DIFF_STEPS, "pbar": True, "max_len": 192} # GEOM typically has larger max_len

# ckpt_geom = f"{ckpt_name_geom}.ckpt"
# lit_geom = LitQuetzal.load_from_checkpoint(ckpt_geom, map_location=DEVICE)
# model_geom = lit_geom.ema.module
# model_geom.eval()

# print(f"Generating {kwargs_geom['bsz']} molecules with GEOM model and parameters: {kwargs_geom}")
# out_geom = model_geom.generate(**kwargs_geom)
# samples_geom, all_traj_geom = out_geom


In [None]:
from draw import show_traj

b_idx = 0 # visualize the first molecule
show_traj(out, b_idx=b_idx, interval=10)

<py3Dmol.view at 0x7ade0075b740>