## ðŸ§¬ Generating novel sequences and structures with ProFam-1

This notebook presents an end-to-end workflow for **protein sequence generation and structural validation** using **ProFam-1**, a family-aware generative model trained on protein domain and family information. The goal of the workshop is to guide you from an input protein sequence, through controlled generative modelling, to quantitative and visual structural comparison against a reference structure.

The notebook is designed to be fully reproducible in a Google Colab environment with GPU support and does not require Flash Attention. All steps are executed interactively, allowing participants to explore how generative parameters affect sequence quality and downstream structural fidelity.

We begin by setting up the computational environment and installing the required dependencies. The ProFam repository is cloned, and a pretrained ProFam-1 checkpoint is located and prepared for inference. Care is taken to ensure compatibility with standard Colab GPU instances, avoiding optional components that are not universally available.

Next, an input protein sequence is provided. By default, the notebook uses a known PETase sequence from UniProt as a single-sequence prompt, but participants may alternatively paste their own sequence directly into the notebook. Any extraneous whitespace or line breaks are removed automatically, and the sequence is written to FASTA format for downstream use.

Using ProFam-1, one or more novel protein sequences are then generated, conditioned on the input sequence. The generation step exposes a small number of adjustable parameters, such as the number of sequences to generate and the sampling temperature, while keeping most model settings fixed to sensible defaults. Additional quality filters are applied to enforce reasonable sequence lengths and minimum similarity to the prompt, helping to avoid pathological or truncated outputs.

Structure prediction is performed outside the notebook using a tool such as AlphaFold or ESMFold. Once predicted, the structures of both the generated sequence and the original input sequence are uploaded back into the notebook in either PDB or mmCIF format. These structures form the basis for structural validation.

Structural comparison is carried out using **TM-align**, which provides quantitative measures of similarity including TM-score, RMSD, and aligned length. The TM-score is particularly informative: values above 0.5 generally indicate a shared fold, while values above 0.8 suggest strong structural agreement. These metrics allow an objective assessment of whether the generated sequence preserves the overall fold of the reference protein.

Finally, the two structures are superposed and visualised together using **NGLView** in cartoon (ribbon/strand) representation. The generated structure is rigidly transformed onto the reference structure using the rotation matrix produced by TM-align, allowing direct visual inspection of conserved cores, flexible regions, and any structural deviations. This final step helps connect quantitative scores to intuitive, three-dimensional understanding.

By the end of this notebook, you will have walked through a complete generative protein design loop, from sequence generation to structure-based validation, and gained practical insight into how modern protein language models can be evaluated using structural biology tools.

In [None]:
#@title 1) GPU sanity check

import torch, subprocess

print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))
    print("CUDA version (torch):", torch.version.cuda)

# Optional: show nvidia-smi if available
try:
    out = subprocess.check_output(["nvidia-smi"], text=True)
    print("\n" + out[:1500])
except Exception:
    print("nvidia-smi not available")

In [None]:
#@title 2) Clone ProFam

import os, pathlib

BASE_DIR = pathlib.Path("/content").resolve()
WORK_DIR = BASE_DIR / "profam_workshop"
WORK_DIR.mkdir(exist_ok=True)

os.chdir(WORK_DIR)
print("Working directory:", WORK_DIR)

if not (WORK_DIR / "profam").exists():
    !git clone https://github.com/alex-hh/profam.git
else:
    print("ProFam repository already present")

%cd {WORK_DIR / "profam"}
!git rev-parse --short HEAD

In [None]:
#@title 3) Install ProFam (GPU runtime, no flash-attn, Colab-friendly)

import re, pathlib, os

repo = pathlib.Path("/content/profam_workshop/profam").resolve()
os.chdir(repo)
print("Repository:", repo)

# Ensure up-to-date build tools
!pip -q install -U pip setuptools wheel

req_in = repo / "requirements.txt"
if not req_in.exists():
    raise FileNotFoundError(f"Missing {req_in}")

# Packages we deliberately do not override in Colab
SKIP_PREFIXES = (
    "flash-attn", "flash_attn",
    "numpy", "pandas", "requests",
    "torch", "torchvision", "torchaudio",
    "jax", "jaxlib",
)

filtered = []
removed = []

for ln in req_in.read_text().splitlines():
    s = ln.strip()
    if not s or s.startswith("#"):
        filtered.append(ln)
        continue

    head = re.split(r"[<>=!\s\[]", s, maxsplit=1)[0].lower()

    if any(head.startswith(p) for p in SKIP_PREFIXES):
        removed.append(ln)
        continue

    filtered.append(ln)

req_out = repo / "requirements.colab_noflash.txt"
req_out.write_text("\n".join(filtered) + "\n")

print(f"Wrote {req_out.name}")

if removed:
    print("Skipped packages:")
    for r in removed:
        print(" ", r)

# Install filtered dependencies
!pip -q install -r requirements.colab_noflash.txt

# Patch common Colab/runtime dependencies
!pip -q install \
    "pandas==2.2.2" \
    "requests==2.32.4" \
    "packaging>=24.2" \
    "typing-extensions>=4.12.0" \
    "xxhash>=3.5.0" \
    "jedi>=0.16"

# Install ProFam in editable mode
!pip -q install -e .

print("ProFam installation complete.")

In [None]:
#@title 4) Download ProFam model checkpoint from Hugging Face

import os, pathlib

repo = pathlib.Path("/content/profam_workshop/profam").resolve()
os.chdir(repo)

os.environ["HF_HOME"] = "/content/hf_cache"
pathlib.Path(os.environ["HF_HOME"]).mkdir(exist_ok=True, parents=True)

!python scripts/hf_download_checkpoint.py

print("Checkpoint download complete.")

In [None]:
#@title 5) Fix Colab torch/torchvision mismatch (remove torchvision)

# Remove torchvision to prevent torchmetrics/lightning import issues
!pip -q uninstall -y torchvision || true

print("torchvision removed.")

In [None]:
#@title 6) Show ProFam CLI help (score + generate)

import pathlib, os

repo = pathlib.Path("/content/profam_workshop/profam").resolve()
os.chdir(repo)

print("=== score_sequences.py ===")
!python scripts/score_sequences.py -h | head -n 120

print("\n=== generate_sequences.py ===")
!python scripts/generate_sequences.py -h | head -n 120

In [None]:
#@title 7) Create prompt FASTA (default PETase or user-provided sequence)

import pathlib, textwrap, re, os

repo = pathlib.Path("/content/profam_workshop/profam").resolve()
os.chdir(repo)

# -----------------------------------------------------------------------------
# DISCLAIMER
# -----------------------------------------------------------------------------
# ProFam can also take a multi-sequence FASTA or an aligned MSA (e.g. A3M).
#
# - A single sequence is sufficient and recommended for sequence generation.
# - Multiple FASTA entries are treated as independent prompts.
# - Providing an MSA can be used to condition generation on a protein family,
#   but alignment is NOT required for generation.
# - MSAs are primarily important for variant scoring, not for generation.
#
# For this workshop, we use a single-sequence prompt by default.
# -----------------------------------------------------------------------------

# =========================
# USER CHOICES
# =========================
USE_DEFAULT = True  #@param {type:"boolean"}

# If USE_DEFAULT = False, paste either:
#  - FASTA (one or more records), OR
#  - a raw amino-acid sequence
USER_SEQUENCE_TEXT = "MNKFLALALAVSLSASAAPVPSQAFGDLGKDTVAV GDSGVPVSPQTDPSATVGRRLTAAALDALDAGADVV VPGSAGTFSVTLGATNATVVGVDLQLAGADATVTLA AGATGNSGGYVVWGGHGTQATQVVAGLPQLAVAGAD VVIVDNNRAGADVVAVSGGTTSTTTW"  #@param {type:"string"}

USER_SEQ_NAME = "TEST"  #@param {type:"string"}

out_fa = repo / "prompts.fasta"

# =========================
# DEFAULT: PETase (Ideonella sakaiensis)
# UniProt: QRG82925
# =========================
DEFAULT_FASTA = textwrap.dedent("""\
>PETase_QRG82925
MNFPRASRLMQAAVLGGLMAVSAAATAQTNPYARGPNPTAASLEASAGPFTVRSFTVSRP
SGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQP
SSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAA
PQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCA
NSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCS
""")

AA_RE = re.compile(r"^[ACDEFGHIKLMNPQRSTVWYBXZJUO\-\.\s]+$", re.IGNORECASE)

def normalize_raw_sequence(s: str) -> str:
    """
    Accept raw amino-acid sequences that may contain spaces or newlines.
    Strip everything except letters Aâ€“Z and uppercase.
    """
    s = s.strip().upper()
    s = re.sub(r"\s+", "", s)
    s = re.sub(r"[^A-Z]", "", s)
    return s

def to_fasta_from_raw(seq: str, name: str) -> str:
    seq = normalize_raw_sequence(seq)
    if not seq:
        raise ValueError("Sequence is empty after cleaning.")
    wrapped = "\n".join(seq[i:i+60] for i in range(0, len(seq), 60))
    return f">{name}\n{wrapped}\n"

def is_fasta(txt: str) -> bool:
    return txt.lstrip().startswith(">")

# =========================
# WRITE FASTA
# =========================
if USE_DEFAULT:
    out_fa.write_text(DEFAULT_FASTA)
    print("Using default PETase sequence (UniProt QRG82925)")
else:
    txt = (USER_SEQUENCE_TEXT or "").strip()
    if not txt:
        raise ValueError("USE_DEFAULT is False but USER_SEQUENCE_TEXT is empty.")

    if is_fasta(txt):
        out_fa.write_text(txt.strip() + "\n")
        print("Using user-provided FASTA (single or multiple sequences)")
    else:
        if not AA_RE.match(txt.replace("\n", "")):
            raise ValueError("Input is neither FASTA nor a valid amino-acid sequence.")
        out_fa.write_text(to_fasta_from_raw(txt, USER_SEQ_NAME))
        print("Using user-provided raw sequence")

print("\nNote:")
print("This FASTA can contain a single sequence, multiple sequences, or an MSA.")
print("For generation, alignment is optional; for scoring, MSAs are recommended.")

print("\n--- prompts.fasta ---")
print(out_fa.read_text())

In [None]:
#@title 8) Locate ProFam checkpoint_dir

import pathlib, os

repo = pathlib.Path("/content/profam_workshop/profam").resolve()
os.chdir(repo)

candidates = []

# Look inside the repository
for p in pathlib.Path(".").rglob(".hydra"):
    candidates.append(p.parent)

# Look in Hugging Face cache
hf_home = pathlib.Path(os.environ.get("HF_HOME", "/content/hf_cache"))
if hf_home.exists():
    for p in hf_home.rglob(".hydra"):
        candidates.append(p.parent)

# De-duplicate
uniq = []
seen = set()
for c in candidates:
    c = c.resolve()
    if c not in seen:
        seen.add(c)
        uniq.append(c)

if not uniq:
    raise RuntimeError(
        "Could not find a checkpoint directory containing '.hydra'. "
        "Run the checkpoint download cell first."
    )

# Select the most recent checkpoint
checkpoint_dir = sorted(uniq, key=lambda x: x.stat().st_mtime, reverse=True)[0]
print("Using checkpoint_dir:", checkpoint_dir)

# Generate Sequences using ProFam-1

In this section, we use **ProFam-1**, a protein familyâ€“aware generative model, to generate new protein sequences from a single input sequence.

The model conditions on the input sequence and samples novel sequences that are consistent with the learned protein family constraints.

**Notes for the workshop:**
- The generation runs on **GPU** if available.
- **Flash Attention is disabled** for compatibility with Google Colab.
- Output sequences will be written as FASTA files in the `generated/` directory.



> **Runtime & Resource Notice**
>
> - Sequence generation may take **1â€“3 minutes** on a Google Colab GPU.
> - The first run may be slightly slower due to **model weight loading**.
> - If you see CUDA or cuDNN warnings, these are expected and can be ignored.
>
> If the notebook appears idle, please wait â€” the model is still running.

## Generation Parameters

The sequence generation step uses probabilistic sampling. The key parameters are:

| Parameter | Meaning |
|---------|--------|
| `temperature` | Controls randomness. Lower = more conservative, higher = more diverse |
| `top_p` | Nucleus sampling threshold (keeps top probability mass) |
| `num_samples` | Number of sequences generated per prompt |
| `max_tokens` | Maximum total tokens generated |
| `sampler` | `ensemble` uses multiple internal prompts for robustness |
| `device` | `cuda` uses GPU if available |
| `dtype` | `float16` reduces memory usage on GPU |
| `attn_implementation` | `sdpa` is used instead of Flash Attention for Colab compatibility |

**Tip:**  
For more diversity, increase `temperature` (e.g. `1.2`).  
For safer, more conservative sequences, decrease it (e.g. `0.8`).

In [None]:
#@title 9) Install MMseqs2 (required for --minimum_sequence_identity filtering)

import os, pathlib

WORK_DIR = pathlib.Path("/content/profam_workshop").resolve()
mmseqs_dir = WORK_DIR / "mmseqs"
mmseqs_bin = mmseqs_dir / "bin" / "mmseqs"

# Clean and install
!rm -rf "{mmseqs_dir}"
!mkdir -p "{mmseqs_dir}"
!wget -q https://mmseqs.com/latest/mmseqs-linux-avx2.tar.gz -O /tmp/mmseqs.tar.gz
!tar -xzf /tmp/mmseqs.tar.gz -C "{mmseqs_dir}" --strip-components=1
!rm /tmp/mmseqs.tar.gz

# Add to PATH
os.environ["PATH"] = f"{mmseqs_dir}/bin:" + os.environ["PATH"]

print("mmseqs binary:", mmseqs_bin)
!{mmseqs_bin} version

In [None]:
#@title 10) Device & dtype selection (GPU with safe fallback)

import torch

# Preferred defaults (workshop intent)
PREFERRED_DEVICE = "cuda"
PREFERRED_DTYPE  = "float16"

if torch.cuda.is_available():
    device = PREFERRED_DEVICE
    dtype  = PREFERRED_DTYPE
    gpu_name = torch.cuda.get_device_name(0)
    print(f"Using GPU: {gpu_name}")
else:
    device = "cpu"
    dtype  = "float32"
    print(
        "No NVIDIA GPU detected.\n"
        "Falling back to CPU (this will be slower, but functional)."
    )

print(f"device = {device}")
print(f"dtype  = {dtype}")

In [None]:
#@title 11) Generate sequences with ProFam-1

import pathlib, os, shutil
import torch

# ---------------- User parameters ----------------
num_sequences = 10  #@param {type:"integer", min:1, max:50, step:1}
temperature   = 0.8  #@param {type:"number", min:0.5, max:1.5, step:0.1}

max_tokens = 2048  #@param {type:"integer", min:256, max:8192, step:256}
max_generated_length = 512  #@param {type:"integer", min:64, max:4096, step:64}

min_seq_len_prop = 0.8   #@param {type:"number", min:0.5, max:1.0, step:0.05}
max_len_mult     = 1.2   #@param {type:"number", min:1.0, max:2.0, step:0.1}
min_seq_identity = 0.8   #@param {type:"number", min:0.0, max:0.9, step:0.05}

# ---------------- Fixed workshop defaults ----------------
top_p      = 0.95
sampler    = "ensemble"
attn_impl  = "sdpa"   # no flash-attn
seed       = 1

# ---------------- Auto device / dtype ----------------
if torch.cuda.is_available():
    device = "cuda"
    dtype  = "float16"
    print("Using GPU:", torch.cuda.get_device_name(0))
else:
    device = "cpu"
    dtype  = "float32"
    print("No GPU detected; using CPU (slower).")

# ---------------- Paths ----------------
repo = pathlib.Path("/content/profam_workshop/profam").resolve()
os.chdir(repo)

prompt_fa = repo / "prompts.fasta"
save_dir  = repo / "generated"

if not prompt_fa.exists():
    raise FileNotFoundError("prompts.fasta not found. Run the prompt FASTA cell first.")

if "checkpoint_dir" not in globals():
    raise RuntimeError("checkpoint_dir not set. Run the checkpoint discovery cell first.")

# Overwrite previous results
if save_dir.exists():
    shutil.rmtree(save_dir)
save_dir.mkdir(exist_ok=True)

# ---------------- Run generation ----------------
cmd = (
    f"python scripts/generate_sequences.py "
    f"--checkpoint_dir '{checkpoint_dir}' "
    f"--file_path '{prompt_fa}' "
    f"--save_dir '{save_dir}' "
    f"--sampler {sampler} "
    f"--num_prompts_in_ensemble 1 "
    f"--num_samples {num_sequences} "
    f"--max_tokens {max_tokens} "
    f"--max_generated_length {max_generated_length} "
    f"--temperature {temperature} "
    f"--top_p {top_p} "
    f"--device {device} "
    f"--dtype {dtype} "
    f"--attn_implementation {attn_impl} "
    f"--minimum_sequence_length_proportion {min_seq_len_prop} "
    f"--max_sequence_length_multiplier {max_len_mult} "
    f"--seed {seed}"
)

# Only enable identity filter when mmseqs exists
mmseqs_ok = shutil.which("mmseqs") is not None
if mmseqs_ok:
    cmd += f" --minimum_sequence_identity {min_seq_identity}"
else:
    print("mmseqs not found in PATH; skipping --minimum_sequence_identity filtering.")

print("\nCommand:\n", cmd, "\n")
!bash -lc "{cmd}"

print("\nGenerated files:")
!ls -lh "{save_dir}" | head -n 50

In [None]:
#@title 12) Preview generated FASTA (first file)

import pathlib

repo = pathlib.Path("/content/profam_workshop/profam").resolve()
save_dir = repo / "generated"

fasta_files = sorted(save_dir.glob("*.fa*"))
if not fasta_files:
    raise RuntimeError("No FASTA files found in generated/. Run the generation cell first.")

fp = fasta_files[0]
print("Showing:", fp, "\n")
print("\n".join(fp.read_text().splitlines()[:120]))

In [None]:
#@title 13) Rank generated sequences and pick the best candidate

import pathlib, re
import pandas as pd

repo = pathlib.Path("/content/profam_workshop/profam").resolve()
prompt_fa = repo / "prompts.fasta"
gen_dir = repo / "generated"
out_best = repo / "best_generated.fasta"

if not prompt_fa.exists():
    raise FileNotFoundError("prompts.fasta not found.")
if not gen_dir.exists():
    raise FileNotFoundError("generated/ directory not found. Run generation first.")

AA = set("ACDEFGHIKLMNPQRSTVWYBXZJUO")

def read_fasta(path: pathlib.Path):
    records = []
    name, seq = None, []
    for line in path.read_text().splitlines():
        line = line.strip()
        if not line:
            continue
        if line.startswith(">"):
            if name is not None:
                records.append((name, "".join(seq)))
            name = line[1:].strip()
            seq = []
        else:
            seq.append(line)
    if name is not None:
        records.append((name, "".join(seq)))
    return records

def clean_seq(s: str) -> str:
    s = re.sub(r"\s+", "", s).upper()
    return "".join(c for c in s if c.isalpha())

def longest_run_fraction(s: str) -> float:
    if not s:
        return 0.0
    best = 1
    cur = 1
    for i in range(1, len(s)):
        if s[i] == s[i - 1]:
            cur += 1
            if cur > best:
                best = cur
        else:
            cur = 1
    return best / len(s)

def kmer_low_complexity(s: str, k: int = 3) -> float:
    if len(s) < k + 1:
        return 0.0
    kmers = [s[i:i + k] for i in range(len(s) - k + 1)]
    return 1.0 - (len(set(kmers)) / len(kmers))

def identity_to_prompt(g: str, p: str) -> float:
    n = min(len(g), len(p))
    if n == 0:
        return 0.0
    matches = sum(1 for i in range(n) if g[i] == p[i])
    return matches / n

prompt_records = read_fasta(prompt_fa)
if not prompt_records:
    raise RuntimeError("No FASTA records found in prompts.fasta.")
prompt_name, prompt_seq = prompt_records[0][0], clean_seq(prompt_records[0][1])

rows = []
for fp in sorted(gen_dir.glob("*.fa*")):
    for name, seq in read_fasta(fp):
        seq = clean_seq(seq)
        if not seq:
            continue
        if any(c not in AA for c in seq):
            continue

        rows.append(
            {
                "file": fp.name,
                "id": name,
                "length": len(seq),
                "len_ratio_to_prompt": len(seq) / max(1, len(prompt_seq)),
                "prefix_identity_to_prompt": identity_to_prompt(seq, prompt_seq),
                "longest_run_frac": longest_run_fraction(seq),
                "kmer_low_complexity": kmer_low_complexity(seq, k=3),
                "seq": seq,
            }
        )

if not rows:
    raise RuntimeError("No generated sequences found/parsed in generated/.")

df = pd.DataFrame(rows)

target_min_id = 0.20
target_max_id = 0.60

def rank_score(r):
    len_pen = abs(r["len_ratio_to_prompt"] - 1.0)

    idv = r["prefix_identity_to_prompt"]
    if idv < target_min_id:
        id_pen = (target_min_id - idv) * 2.0
    elif idv > target_max_id:
        id_pen = (idv - target_max_id) * 2.0
    else:
        id_pen = 0.0

    rep_pen = r["kmer_low_complexity"] * 1.5 + r["longest_run_frac"] * 2.0

    return len_pen + id_pen + rep_pen

df["rank_score"] = df.apply(rank_score, axis=1)
df_sorted = df.sort_values("rank_score", ascending=True).reset_index(drop=True)

display_cols = [
    "file",
    "id",
    "length",
    "len_ratio_to_prompt",
    "prefix_identity_to_prompt",
    "kmer_low_complexity",
    "longest_run_frac",
    "rank_score",
]
display(df_sorted[display_cols].head(15))

best = df_sorted.iloc[0]
best_id = best["id"]
best_seq = best["seq"]

wrapped = "\n".join(best_seq[i:i + 60] for i in range(0, len(best_seq), 60))
out_best.write_text(f">{best_id}\n{wrapped}\n")

print("Recommended sequence:")
print("  ID:", best_id)
print("  Length:", len(best_seq))
print("  Score:", float(best["rank_score"]))
print("  Saved:", out_best)

In [None]:
#@title 14) Upload two structures for TM-score and visualisation

from google.colab import files
from pathlib import Path

UPLOAD_DIR = Path("/content/tmalign_uploads")
UPLOAD_DIR.mkdir(exist_ok=True)

def upload_one(label):
    print(f"\nUpload {label} (PDB .pdb/.ent or mmCIF .cif/.mmcif)")
    uploaded = files.upload()

    if len(uploaded) != 1:
        raise RuntimeError(
            f"{label}: Please upload exactly one file (got {len(uploaded)}). "
            "Re-run the cell."
        )

    name, data = next(iter(uploaded.items()))
    dst = UPLOAD_DIR / name

    with open(dst, "wb") as f:
        f.write(data)

    print(f"{label} saved to: {dst}")
    return str(dst)

# Upload structure A
struct_a = upload_one("Structure A (e.g. ProFam-generated model)")

# Upload structure B
struct_b = upload_one("Structure B (e.g. AFDB or reference structure)")

print("\nStructures ready for TM-score:")
print("  Structure A:", struct_a)
print("  Structure B:", struct_b)
print("\nRun the TM-score alignment cell next.")

In [None]:
#@title TM-align: align A vs B, report TM-score, and visualise superposition (single output)

import re, subprocess
from pathlib import Path

SHOW_3DMOL = False  #@param {type:"boolean"}  # If True, shows 3Dmol output before NGLView

# --- Inputs from upload cells ------------------------------------------------
if "struct_a" not in globals() or "struct_b" not in globals():
    raise RuntimeError("struct_a / struct_b not set. Run the upload cell first.")

struct_a = str(struct_a)
struct_b = str(struct_b)

chain_a = ""  #@param {type:"string"}  # leave blank = first chain
chain_b = ""  #@param {type:"string"}  # leave blank = first chain

WORK_DIR = Path("/content/tmalign_work").resolve()
WORK_DIR.mkdir(exist_ok=True)

# Install deps (quiet)
!pip -q install gemmi nglview
if SHOW_3DMOL:
    !pip -q install py3Dmol

# Enable widget manager for NGLView in Colab
from google.colab import output
output.enable_custom_widget_manager()

import gemmi
import nglview as nv

def ensure_tmalign(bin_path: Path) -> Path:
    if bin_path.exists():
        return bin_path
    bin_path.parent.mkdir(parents=True, exist_ok=True)
    url = "https://zhanggroup.org/TM-align/TMalign"
    subprocess.run(["bash", "-lc", f"wget -q '{url}' -O '{bin_path}'"], check=True)
    subprocess.run(["bash", "-lc", f"chmod +x '{bin_path}'"], check=True)
    return bin_path

TMALIGN = ensure_tmalign(WORK_DIR / "bin" / "TMalign")

def first_chain_id(path: str) -> str:
    st = gemmi.read_structure(path)
    for model in st:
        for ch in model:
            return ch.name
    return ""

def to_pdb(in_path: str, out_path: Path, chain_id: str = "") -> Path:
    st = gemmi.read_structure(str(in_path))
    if chain_id.strip():
        keep = chain_id.strip()
        for model in st:
            for ch in list(model):
                if ch.name != keep:
                    model.remove_chain(ch.name)
        st.remove_empty_chains()
    st.write_pdb(str(out_path))
    return out_path

# Choose default chains if blank
if chain_a.strip() == "":
    chain_a = first_chain_id(struct_a)
if chain_b.strip() == "":
    chain_b = first_chain_id(struct_b)

# Prepare inputs
pdb_a = WORK_DIR / "A.pdb"
pdb_b = WORK_DIR / "B.pdb"
to_pdb(struct_a, pdb_a, chain_a)
to_pdb(struct_b, pdb_b, chain_b)

# --- Run TM-align with matrix output -----------------------------------------
matrix_file = WORK_DIR / "tmalign_matrix.txt"
if matrix_file.exists():
    matrix_file.unlink()

res = subprocess.run(
    [str(TMALIGN), str(pdb_a), str(pdb_b), "-m", str(matrix_file)],
    capture_output=True, text=True
)
out = (res.stdout or "") + "\n" + (res.stderr or "")
if res.returncode != 0:
    print(out)
    raise RuntimeError(f"TM-align failed with exit code {res.returncode}")

# --- Parse summary ------------------------------------------------------------
tm_scores = [float(x) for x in re.findall(r"TM-score=\s*([0-9]*\.[0-9]+)", out)]
rmsd_m = re.search(r"RMSD=\s*([0-9]*\.[0-9]+)", out)
aln_m  = re.search(r"Aligned length=\s*(\d+)", out)

print("TM-align summary")
print("  A:", Path(struct_a).name, f"(chain {chain_a})")
print("  B:", Path(struct_b).name, f"(chain {chain_b})")
if aln_m:
    print(f"  Aligned length: {int(aln_m.group(1))}")
if rmsd_m:
    print(f"  RMSD: {float(rmsd_m.group(1)):.3f}")
if tm_scores:
    if len(tm_scores) >= 2:
        print(f"  TM-score (norm by A length): {tm_scores[0]:.5f}")
        print(f"  TM-score (norm by B length): {tm_scores[1]:.5f}")
        print(f"  TM-score (max): {max(tm_scores[0], tm_scores[1]):.5f}")
    else:
        print(f"  TM-score: {tm_scores[0]:.5f}")
print()

# --- Robust parse of 3x4 matrix from -m file ---------------------------------
if not matrix_file.exists() or matrix_file.stat().st_size == 0:
    raise RuntimeError("TM-align did not create a matrix file. Something went wrong with -m.")

lines = [ln.strip() for ln in matrix_file.read_text().splitlines() if ln.strip()]

rows = {}
for ln in lines:
    parts = ln.split()
    if len(parts) >= 5 and parts[0] in ("1", "2", "3"):
        m = int(parts[0])
        rows[m] = [float(parts[1]), float(parts[2]), float(parts[3]), float(parts[4])]

if not all(k in rows for k in (1, 2, 3)):
    raise RuntimeError(
        "Could not parse the 3 rotation-matrix rows from TM-align -m output.\n"
        "First 30 non-empty lines:\n" + "\n".join(lines[:30])
    )

t = [rows[1][0], rows[2][0], rows[3][0]]
U = [
    [rows[1][1], rows[1][2], rows[1][3]],
    [rows[2][1], rows[2][2], rows[2][3]],
    [rows[3][1], rows[3][2], rows[3][3]],
]

# TM-align: X2 = t + U * x1  (Chain_1 -> Chain_2)
# We want to move B onto A => apply inverse transform to B:
UT = [
    [U[0][0], U[1][0], U[2][0]],
    [U[0][1], U[1][1], U[2][1]],
    [U[0][2], U[1][2], U[2][2]],
]
t_inv = [
    -(UT[0][0]*t[0] + UT[0][1]*t[1] + UT[0][2]*t[2]),
    -(UT[1][0]*t[0] + UT[1][1]*t[1] + UT[1][2]*t[2]),
    -(UT[2][0]*t[0] + UT[2][1]*t[1] + UT[2][2]*t[2]),
]

M_inv = [
    UT[0][0], UT[0][1], UT[0][2], t_inv[0],
    UT[1][0], UT[1][1], UT[1][2], t_inv[1],
    UT[2][0], UT[2][1], UT[2][2], t_inv[2],
    0,        0,        0,        1
]

# --- Write B superposed onto A (for NGLView) ---------------------------------
B_SUP = WORK_DIR / "B_superposed.pdb"

def apply_transform_to_structure(in_pdb: Path, out_pdb: Path, M):
    st = gemmi.read_structure(str(in_pdb))
    r11,r12,r13,t1, r21,r22,r23,t2, r31,r32,r33,t3, _,_,_,_ = M
    for model in st:
        for chain in model:
            for res in chain:
                for atom in res:
                    x,y,z = atom.pos.x, atom.pos.y, atom.pos.z
                    X = t1 + r11*x + r12*y + r13*z
                    Y = t2 + r21*x + r22*y + r23*z
                    Z = t3 + r31*x + r32*y + r33*z
                    atom.pos = gemmi.Position(X, Y, Z)
    st.write_pdb(str(out_pdb))

apply_transform_to_structure(pdb_b, B_SUP, M_inv)

# --- Optional: 3Dmol (only if requested) -------------------------------------
if SHOW_3DMOL:
    import py3Dmol
    v3 = py3Dmol.view(width=900, height=520)
    v3.addModel(pdb_a.read_text(), "pdb")
    v3.addModel(pdb_b.read_text(), "pdb")
    v3.setTransform({"model": 1}, M_inv)
    v3.setStyle({"model": 0}, {"cartoon": {"color": "dodgerblue"}})
    v3.setStyle({"model": 1}, {"cartoon": {"color": "tomato", "opacity": 0.70}})
    v3.zoomTo()
    v3  # render as output (no .show())

# Superposition viewer (NGLView, force uniform colors; defeat rainbow defaults)

from google.colab import output
output.enable_custom_widget_manager()

!pip -q install nglview gemmi

import nglview as nv
from pathlib import Path

WORK_DIR = Path("/content/tmalign_work").resolve()
pdb_a = WORK_DIR / "A.pdb"
b_sup = WORK_DIR / "B_superposed.pdb"

if not pdb_a.exists():
    raise FileNotFoundError(pdb_a)
if not b_sup.exists():
    raise FileNotFoundError(b_sup)

view = nv.NGLWidget()

# Add components
compA = view.add_component(str(pdb_a))
compB = view.add_component(str(b_sup))

compA.clear_representations()
compB.clear_representations()


compA.add_cartoon(colorScheme="uniform", color="red")
compB.add_cartoon(colorScheme="uniform", color="blue", opacity=0.7)

view.center()
view