# ColabDesign STL Extension: Design Proteins from STL Shapes

This notebook designs protein sequences whose predicted structures match tube-like STL shapes using centerline extraction and per-index path loss.

## 1) Config (edit here)

In [None]:
from pathlib import Path
import json

# Choose a preset or set to None to fully manual-tune.
PRESET = "stl_centerline_cylinder"
OVERRIDES = {}  # e.g., {"TARGET_EXTENT": 40.0}

# Fallback/manual values (used if PRESET is None or not overriding)
STL_PATH = "examples/stl/cylinder.stl"
OUT_DIR = "results/cylinder"  # Use results/ for submission-ready outputs
PROTEIN_LENGTH = 80
TARGET_EXTENT = 30.0                   # Bbox max-dimension scaling (Å)
TARGET_ARCLENGTH = None                # Optional: only for centerline extraction (diagnostic)
CENTERLINE_SURFACE_SAMPLES = 10000
CENTERLINE_BINS = None
CENTERLINE_SMOOTH_WINDOW = 5
SAMPLE_SEED = 0
RUN_SEED = 0
SOFT_ITERS = 300
TEMP_ITERS = 150
HARD_ITERS = 20
PATH_WEIGHT = 0.02
PLDDT_WEIGHT = 2.0
PAE_WEIGHT = 0.2
CON_WEIGHT = 0.5
NORMALIZE_TARGET = True               # If False, skip bbox scaling (center only)
DATA_DIR = "/content/data_dir"
AUTO_DOWNLOAD_PARAMS = True
FORCE_RECLONE = True

## 2) Presets

In [None]:
PRESETS = {
    "stl_centerline_cylinder": {
        "STL_PATH": "examples/stl/cylinder.stl",
        "OUT_DIR": "results/cylinder",
        "PROTEIN_LENGTH": 80,
        "TARGET_EXTENT": 100.0,
        "CENTERLINE_SURFACE_SAMPLES": 10000,
        "PATH_WEIGHT": 0.02,
        "CON_WEIGHT": 0.2,
        "PLDDT_WEIGHT": 2.0,
    },
    "stl_centerline_sine": {
        "STL_PATH": "examples/stl/sine_tube.stl",
        "OUT_DIR": "results/sine_tube",
        "PROTEIN_LENGTH": 80,
        "TARGET_EXTENT": 120.0,
        "CENTERLINE_SURFACE_SAMPLES": 12000,
        "PATH_WEIGHT": 0.02,
        "CON_WEIGHT": 0.2,
        "PLDDT_WEIGHT": 2.0,
    },
    "stl_centerline_helix1turn": {
        "STL_PATH": "examples/stl/helix_tube_1turn.stl",
        "OUT_DIR": "results/helix_tube_1turn",
        "PROTEIN_LENGTH": 80,
        "TARGET_EXTENT": 100.0,
        "CENTERLINE_SURFACE_SAMPLES": 12000,
        "PATH_WEIGHT": 0.02,
        "CON_WEIGHT": 0.2,
        "PLDDT_WEIGHT": 2.0,
    },
}

# Apply preset
if PRESET is not None and PRESET in PRESETS:
    locals().update(PRESETS[PRESET])
if OVERRIDES:
    locals().update(OVERRIDES)

## 3) Setup (clone, deps, data_dir)

In [None]:
import sys
import subprocess
import shutil
import os
import numpy as np

def pip_install(*packages):
    cmd = [sys.executable, "-m", "pip", "install", "--quiet", *packages]
    print("Running:", " ".join(cmd))
    subprocess.check_call(cmd)

# Core deps
pip_install("git+https://github.com/sokrypton/ColabDesign.git")
pip_install("trimesh", "py3Dmol", "matplotlib")

# Resolve ROOT (clone if missing)
REPO_URL = "https://github.com/ib565/colabdesign-stl"
try:
    ROOT = Path(__file__).resolve().parents[1]
except NameError:
    ROOT = Path.cwd()
if not (ROOT / "src").exists():
    clone_dir = Path("/content/colabdesign-stl")
    if FORCE_RECLONE and clone_dir.exists():
        print("Forcing reclone. Deleting existing repo dir")
        shutil.rmtree(clone_dir)
    if not clone_dir.exists():
        print(f"src/ not found; cloning {REPO_URL} into {clone_dir} ...")
        subprocess.check_call(["git", "clone", REPO_URL, str(clone_dir)])
    ROOT = clone_dir
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from src import (
    STLProteinDesigner,
    normalize_points,
    plot_point_cloud,
    stl_to_centerline_points,
)
from examples.stl.generators.resolve_stl import resolve_or_generate_stl

## 4) Extract STL centerline

In [None]:
# Extract centerline from STL
stl_resolved = resolve_or_generate_stl(STL_PATH)
print(f"Using STL centerline from: {stl_resolved}")

pts = stl_to_centerline_points(
    str(stl_resolved),
    num_points=PROTEIN_LENGTH,
    surface_samples=CENTERLINE_SURFACE_SAMPLES,
    bins=CENTERLINE_BINS,
    smooth_window=CENTERLINE_SMOOTH_WINDOW,
    seed=SAMPLE_SEED if SAMPLE_SEED >= 0 else None,
    target_arclength=TARGET_ARCLENGTH,
)

if NORMALIZE_TARGET:
    target_points = normalize_points(pts, target_extent=TARGET_EXTENT, center=True)
else:
    target_points = (pts - pts.mean(axis=0)).astype(np.float32)

# Diagnostics: report arclength / avg step
def _polyline_arclength(poly):
    if len(poly) < 2:
        return 0.0
    seg = np.diff(poly, axis=0)
    return float(np.sum(np.linalg.norm(seg, axis=1)))

arc = _polyline_arclength(target_points)
avg_step = arc / max(len(target_points) - 1, 1)
print(f"Target arclength (post-scaling): {arc:.2f} Å")
print(f"Average step: {avg_step:.2f} Å")

# Quick visualization
plot_point_cloud(
    target_points,
    title=f"Target centerline ({Path(STL_PATH).stem})",
    show=True,
    save_path=None,
    connected=True,
)

## 5) Download AlphaFold params

In [None]:
def ensure_af_params(data_dir: Path, auto_download: bool = False):
    AF_TAR_URL = "https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar"
    data_dir.mkdir(parents=True, exist_ok=True)
    marker = data_dir / "params"
    if marker.exists() and any(marker.iterdir()):
        print(f"AlphaFold params found at: {marker}")
        return marker
    if not auto_download:
        raise FileNotFoundError(f"AlphaFold params not found at {marker}.")
    tar_path = data_dir / "alphafold_params_2022-12-06.tar"
    if not tar_path.exists():
        cmd = ["curl", "-L", "-o", str(tar_path), AF_TAR_URL]
        print("Downloading AF params (several minutes)...")
        subprocess.check_call(cmd)
    print("Extracting params...")
    marker.mkdir(exist_ok=True)
    subprocess.check_call(["tar", "-xf", str(tar_path), "-C", str(marker)])
    print("Params ready at:", marker)
    return marker

resolved_data_dir = None
if DATA_DIR:
    resolved_data_dir = Path(DATA_DIR)
elif os.environ.get("AF_DATA_DIR"):
    resolved_data_dir = Path(os.environ["AF_DATA_DIR"])
else:
    candidate = ROOT.parent / "ColabDesign"
    resolved_data_dir = candidate if candidate.exists() else None
if resolved_data_dir is None:
    print("No AlphaFold params directory found. Set DATA_DIR or AF_DATA_DIR, or enable AUTO_DOWNLOAD_PARAMS.")
else:
    print("Using data_dir:", resolved_data_dir)
    if AUTO_DOWNLOAD_PARAMS:
        resolved_data_dir = ensure_af_params(resolved_data_dir, auto_download=True)

## 6) Run design

In [None]:
out_dir = Path(OUT_DIR)
out_dir.mkdir(parents=True, exist_ok=True)

print("Initializing designer...")
designer = STLProteinDesigner(
    stl_path=STL_PATH,
    protein_length=PROTEIN_LENGTH,
    target_extent=TARGET_EXTENT,
    sample_seed=None if SAMPLE_SEED in (-1, None) else SAMPLE_SEED,
    path_weight=PATH_WEIGHT,
    con_weight=CON_WEIGHT,
    plddt_weight=PLDDT_WEIGHT,
    pae_weight=PAE_WEIGHT,
    data_dir=str(resolved_data_dir) if resolved_data_dir else None,
    verbose=max(1, SOFT_ITERS // 20),
    stl_target_mode="centerline",
    target_arclength=TARGET_ARCLENGTH,
    centerline_surface_samples=CENTERLINE_SURFACE_SAMPLES,
    centerline_bins=CENTERLINE_BINS,
    centerline_smooth_window=CENTERLINE_SMOOTH_WINDOW,
    normalize_target_points=NORMALIZE_TARGET,
)

print("Running design... (first JIT can take 30–90s on Colab GPU)")
seq = designer.design(
    soft_iters=SOFT_ITERS,
    temp_iters=TEMP_ITERS,
    hard_iters=HARD_ITERS,
    run_seed=RUN_SEED,
    save_best=True,
)

# Save outputs
(out_dir / "sequence.txt").write_text(seq)
pdb_path = out_dir / "structure.pdb"
designer.get_structure(save_path=str(pdb_path), get_best=True)
metrics = designer.get_metrics()

# Save metrics as JSON
config = {
    "stl_path": STL_PATH,
    "protein_length": PROTEIN_LENGTH,
    "target_extent": TARGET_EXTENT,
    "centerline_surface_samples": CENTERLINE_SURFACE_SAMPLES,
    "sample_seed": SAMPLE_SEED,
    "run_seed": RUN_SEED,
    "soft_iters": SOFT_ITERS,
    "temp_iters": TEMP_ITERS,
    "hard_iters": HARD_ITERS,
    "path_weight": PATH_WEIGHT,
    "con_weight": CON_WEIGHT,
    "plddt_weight": PLDDT_WEIGHT,
    "pae_weight": PAE_WEIGHT,
}
(out_dir / "config.json").write_text(json.dumps(config, indent=2))
(out_dir / "metrics.json").write_text(json.dumps(metrics, indent=2))

# Print results
print("\n" + "="*60)
print("Design Complete")
print("="*60)
print(f"Sequence length: {len(seq)}")
print(f"Path loss:       {metrics['path']:.3f} (squared Å)")
print(f"Path aligned:    {metrics.get('path_aligned', float('nan')):.3f} (squared Å)")
print(f"pLDDT:           {metrics['plddt']:.3f}")
print(f"PAE:             {metrics['pae']:.3f}")
print(f"\nOutputs saved to: {out_dir.resolve()}")
print(f"  - sequence.txt")
print(f"  - structure.pdb")
print(f"  - config.json")
print(f"  - metrics.json")

## 7) Overlay plot

In [None]:
try:
    from IPython.display import Image, display  # type: ignore
    plot_path = out_dir / "overlay.png"
    designer.plot_overlay(save_path=str(plot_path), show=False)
    print("Overlay saved to", plot_path)
    display(Image(filename=str(plot_path)))
except Exception:
    plot_path = out_dir / "overlay.png"
    designer.plot_overlay(save_path=str(plot_path), show=False)
    print("Overlay saved to", plot_path)
    print("IPython display not available in this environment.")

## 8) 3D visualization (py3Dmol)

In [None]:
# 8a) Protein structure (cartoon view)
try:
    import py3Dmol  # type: ignore
    
    pdb_str = Path(pdb_path).read_text()
    
    view = py3Dmol.view(width=720, height=720)
    view.addModel(pdb_str, "pdb")
    view.setStyle({"cartoon": {"color": "skyblue", "opacity": 0.55}})
    view.addStyle({"atom": "CA"}, {"sphere": {"color": "deepskyblue", "radius": 0.7}})
    
    view.zoomTo()
    view.show()
except Exception as e:
    print("py3Dmol not available:", e)

In [None]:
# 8b) Aligned target vs predicted Cα (overlay)
try:
    import py3Dmol  # type: ignore
    
    pdb_str = Path(pdb_path).read_text()
    tgt = np.asarray(designer.target_points, dtype=float)
    tgt = tgt - tgt.mean(axis=0)  # Center target
    
    ca_aligned = designer.get_ca_coords(get_best=True, aligned=True)
    
    view = py3Dmol.view(width=720, height=720)
    view.addModel(pdb_str, "pdb")
    view.setStyle({"cartoon": {"color": "skyblue", "opacity": 0.10}})  # Dim cartoon for clarity
    
    # Aligned Cα as spheres
    xyz_ca = "\n".join(f"C {x:.3f} {y:.3f} {z:.3f}" for x, y, z in ca_aligned)
    view.addModel(f"{len(ca_aligned)}\nca\n{xyz_ca}\n", "xyz")
    view.setStyle({"model": 1}, {"sphere": {"color": "deepskyblue", "radius": 0.8}})
    
    # Target centerline as spheres
    xyz_tgt = "\n".join(f"C {x:.3f} {y:.3f} {z:.3f}" for x, y, z in tgt)
    view.addModel(f"{len(tgt)}\npoints\n{xyz_tgt}\n", "xyz")
    view.setStyle({"model": 2}, {"sphere": {"color": "red", "radius": 0.8}})
    
    view.zoomTo()
    view.show()
except Exception as e:
    print("py3Dmol not available:", e)


## 9) Inspect outputs

In [None]:
print("First 80 aa:", seq[:80])
print("Outputs in:", out_dir.resolve())

## 10) Download results (Colab only)


In [None]:
# Download all outputs from Colab
try:
    from google.colab import files
    import zipfile
    
    # Create zip archive of results
    zip_path = out_dir.parent / f"{out_dir.name}_results.zip"
    with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
        for file_path in out_dir.rglob("*"):
            if file_path.is_file():
                zf.write(file_path, file_path.relative_to(out_dir.parent))
    
    print(f"Created archive: {zip_path}")
    print("Downloading...")
    files.download(str(zip_path))
    print("Download complete!")
except ImportError:
    print("Not running in Colab. Files are saved locally at:", out_dir.resolve())
except Exception as e:
    print(f"Download failed: {e}")
    print("Files are saved locally at:", out_dir.resolve())
