## Protein generator diffusion trajectories

In [8]:
from __future__ import annotations

import glob
import os, sys
from pathlib import Path
from typing import Union, List, Tuple, Optional
from tqdm import tqdm
import warnings
import subprocess
from concurrent.futures import ProcessPoolExecutor, as_completed

import imageio.v2 as imageio  # type: ignore
import matplotlib.pyplot as plt
from matplotlib.patches import ConnectionPatch
import matplotlib.image as mpimg
from matplotlib import gridspec
import numpy as np
import pandas as pd
import torch
from matplotlib.colors import ListedColormap
from PIL import Image, ImageDraw, ImageOps

# Domain-specific imports – leave the heavy lifting to the model code-base
from model import util
from model.chemical import aa2long, num2aa
from model.utils import parsers_inference as parser  # noqa: F401 (kept for parity)
import logomaker
import pymol
from pymol import cmd, finish_launching

In [9]:
CONVERSION = "ARNDCQEGHILKMFPSTWYVX-"
NUM_ALIGN = 3  # residues used for Kabsch alignment to the 1st frame

In [10]:
GREYS = [f"#{i:02x}{i:02x}{i:02x}" for i in range(0x1e, 0x88, 0x0b)]
PURPLES = [
    "#6F26AF",
    "#7D33B0",
    "#8A3FB1",
    "#974CB3",
    "#A458B4",
    "#B165B5",
    "#BE71B6",
    "#CC7DB7",
    "#D98AB8",
    "#E697BA",
    "#F3A4BB",
    "#FF7FCC",
]
CMAP = ListedColormap(GREYS + PURPLES)

In [11]:
def render_pdb(pdb_path, png_path, ref_path=None):
    if ref_path:
        subprocess.run(
            [sys.executable, 'utils\\render_pdb.py', str(pdb_path), str(png_path), str(ref_path)],
            check=True
        )
    else:
        subprocess.run(
            [sys.executable, 'utils\\render_pdb.py', str(pdb_path), str(png_path)],
            check=True
        )

def build_video_ffmpeg(
    frames_pattern: str,    # e.g. "output/graphs/traj_NAME/frame_%03d.png"
    video_path: Path,
    fps: int = 5
):
    cmd = [
        "ffmpeg", "-y",
        "-framerate", str(fps),
        "-i", frames_pattern,
        "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2",  # Ensure even width and height
        "-c:v", "libx264",
        "-pix_fmt", "yuv420p",
        str(video_path)
    ]
    with subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) as proc:
        for line in proc.stdout:
            print(line, end='')

In [12]:
def trim_transparency(img: Image.Image) -> Image.Image:
    """Strip any fully‑transparent border from an RGBA/LA image."""
    if img.mode in ("RGBA", "LA"):
        b = img.getchannel("A").getbbox()
        if b:
            return img.crop((b[0], 0, b[2], img.height))
    return img


def make_frame(
    logo_png: Union[str, Path],
    heat_png: Union[str, Path],
    out_png: Union[str, Path],
    extra_png: Optional[Union[str, Path]] = None,
    zoom: float = 2,
    total_cols: int = 20,
    col_start: int = 10,
    col_end: int = 20,
    vpad: int = 150,
    line_width: int = 10,
) -> None:
    # 1 ─ Load & trim images
    logo_img = trim_transparency(Image.open(logo_png).convert("RGBA"))
    heat_img = trim_transparency(Image.open(heat_png).convert("RGBA"))
    extra_img = None
    if extra_png is not None:
        extra_img = Image.open(extra_png).convert("RGBA")

    heat_w, heat_h = heat_img.size

    # 2 ─ Column slice bounds (pixels)
    col_width = heat_w / total_cols
    x_left_px = round(col_start * col_width)
    x_right_px = round(col_end * col_width)
    subset_w = x_right_px - x_left_px

    # 3 ─ Resize logo according to zoom
    new_logo_w = max(1, round(subset_w * zoom))
    aspect = logo_img.height / logo_img.width
    new_logo_h = max(1, round(new_logo_w * aspect))
    logo_resized = logo_img.resize((new_logo_w, new_logo_h), Image.LANCZOS)

    # 4 ─ Resize extra image to heatmap width (if provided)
    if extra_img is not None:
        extra_aspect = extra_img.height / extra_img.width
        extra_new_h = max(1, round(heat_w * extra_aspect))
        extra_resized = extra_img.resize((heat_w, extra_new_h), Image.LANCZOS)
    else:
        extra_resized = None
        extra_new_h = 0

    # 5 ─ Prepare canvas & place images
    x_left_logo = x_left_px + (subset_w - new_logo_w) // 2  # center over slice
    x_right_logo = x_left_logo + new_logo_w

    canvas_h = new_logo_h + vpad + heat_h + (vpad + extra_new_h if extra_resized else 0)
    canvas = Image.new("RGBA", (heat_w, canvas_h), (255, 255, 255, 0))

    # Paste logo
    canvas.paste(logo_resized, (x_left_logo, 0), mask=logo_resized)

    # Paste heatmap
    heat_y = new_logo_h + vpad
    canvas.paste(heat_img, (0, heat_y), mask=heat_img)

    # Paste extra image (if any)
    if extra_resized:
        extra_y = heat_y + heat_h + vpad
        canvas.paste(extra_resized, (0, extra_y), mask=extra_resized)

    # 6 ─ Draw funnel call‑out (diagonal then vertical within heatmap only)
    draw = ImageDraw.Draw(canvas)
    line_kwargs = dict(fill=(0, 0, 0, 255), width=line_width)
    logo_bottom_y = new_logo_h - 1  # inside logo edge
    heatmap_top_y = heat_y
    heatmap_bottom_y = heat_y + heat_h  # stop here

    # Left edge
    draw.line([(x_left_logo, logo_bottom_y), (x_left_px, heatmap_top_y)], **line_kwargs)
    draw.line([(x_left_px, heatmap_top_y), (x_left_px, heatmap_bottom_y)], **line_kwargs)

    # Right edge
    draw.line([(x_right_logo - 1, logo_bottom_y), (x_right_px - 1, heatmap_top_y)], **line_kwargs)
    draw.line([(x_right_px - 1, heatmap_top_y), (x_right_px - 1, heatmap_bottom_y)], **line_kwargs)

    # 7 ─ Save
    out_path = Path(out_png)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    canvas.save(out_path)

In [13]:
def process_trajectory(pt_file: os.PathLike | str, *, num_align: int = NUM_ALIGN) -> None:
    """Generate visualisations & animations for a single *.pt* trajectory file."""
    name = Path(pt_file).stem  # strip extension
    out_dir = Path("output/graphs") / f"traj_{name}"
    out_dir.mkdir(parents=True, exist_ok=True)

    print(f"[process] {name}")
    traj = torch.load(pt_file)

    frame_paths: List[Path] = []  # frames for GIF

    # set the initial reference PDB path (previous protein)
    prev_ref_path = pt_file.replace("_trajectory.pt", ".pdb")

    traj = {k: v for k, v in traj.items() if k.startswith("step")}
    
    for step_idx, (step, (xyz, logits, diffs)) in enumerate(traj.items()):
        print(f"  step {step_idx:03d} → {step}")

        # --------------------- PNG 1: logits ---------------------------------
        fig, ax = plt.subplots(dpi=400)
        ax.imshow(torch.permute(logits.float()[:, :20], (1, 0)), cmap=CMAP)
        for sp in ("top", "right", "bottom", "left"):
            ax.spines[sp].set_visible(False)
        ax.set_xticks([])
        ax.set_yticks([])
        logits_png = out_dir / f"{name}_{step[:4]}{step_idx:03d}_predXo.png"
        fig.savefig(logits_png, bbox_inches="tight", transparent=True)
        plt.close(fig)

        # --------------------- PNG 2: diffs ----------------------------------
        fig, ax = plt.subplots(dpi=400)
        ax.imshow(torch.permute(diffs, (1, 0)), cmap=CMAP)
        plt.axis("off")
        diff_png = out_dir / f"{name}_{step[:4]}{step_idx:03d}_Xt-1.png"
        fig.savefig(diff_png, bbox_inches="tight", transparent=True)
        plt.close(fig)

        # --------------------- PNG 3: seq logo -------------------------------
        fig, ax = plt.subplots(dpi=400)
        seq_window = torch.softmax(diffs.float()[10:20, :20], dim=-1)
        columns = list(logomaker.get_example_matrix("ww_information_matrix", print_description=False).columns)
        aa_prob = {
            aa: [float(seq_window[i, CONVERSION.index(aa)]) for i in range(seq_window.size(0))]
            for aa in columns[:20]
        }
        df_logo = pd.DataFrame(aa_prob)

        # tell Logo to draw into your ax
        logo = logomaker.Logo(df_logo, ax=ax, color_scheme="grays", vpad=.1, width=0.8)
        for sp in ("top", "right", "bottom", "left"):
            ax.spines[sp].set_visible(False)
        ax.set_xticks([])
        ax.set_yticks([])

        logo_png = out_dir / f"{name}_{step[:4]}{step_idx:03d}_aaLOGO.png"
        fig.savefig(logo_png, bbox_inches="tight", transparent=True)
        plt.close(fig)

        # --------------------- PDB + render ----------------------------------
        pdb_path = out_dir / f"trajectory_{name}_{step[:4]}{step_idx:03d}.pdb"
        util.writepdb(pdb_path, xyz, torch.argmax(logits, dim=-1))
        prot_png = pdb_path.with_suffix('.png')
        render_pdb(pdb_path, prot_png, ref_path=prev_ref_path)
        # update the reference for the next iteration
        prev_ref_path = None

        # --------------------- Combine ----------------------------------
        frame_png = out_dir / f"{name}_{step[:4]}{step_idx:03d}_frame.png"
        make_frame(logo_png, diff_png, frame_png, prot_png,
                col_start=seq_window.size(0), col_end=seq_window.size(1),
                total_cols=diffs.size(0))
        frame_paths.append(frame_png)

    # assemble GIF once all frames are ready
    video_path = out_dir / f"{name}.mp4"
    build_video_ffmpeg(str(out_dir / "design_000000_trajectory_step%03d_frame.png"), video_path, fps=20)
    print(f"[done] Video written to {video_path}\n")

In [14]:
# suppress all warnings
warnings.filterwarnings("ignore")

# gather all .pt files and iterate with a progress bar
for pt in tqdm(glob.glob("output/designs/*.pt"), desc="Processing trajectories"):
    process_trajectory(pt)

Processing trajectories:   0%|          | 0/1 [00:00<?, ?it/s]

[process] design_000000_trajectory
  step 000 → step99
  step 001 → step98
  step 002 → step97
  step 003 → step96
  step 004 → step95
  step 005 → step94
  step 006 → step93
  step 007 → step92
  step 008 → step91
  step 009 → step90
  step 010 → step89
  step 011 → step88
  step 012 → step87
  step 013 → step86
  step 014 → step85
  step 015 → step84
  step 016 → step83
  step 017 → step82
  step 018 → step81
  step 019 → step80
  step 020 → step79
  step 021 → step78
  step 022 → step77
  step 023 → step76
  step 024 → step75
  step 025 → step74
  step 026 → step73
  step 027 → step72
  step 028 → step71
  step 029 → step70
  step 030 → step69
  step 031 → step68
  step 032 → step67
  step 033 → step66
  step 034 → step65
  step 035 → step64
  step 036 → step63
  step 037 → step62
  step 038 → step61
  step 039 → step60
  step 040 → step59
  step 041 → step58
  step 042 → step57
  step 043 → step56
  step 044 → step55
  step 045 → step54
  step 046 → step53
  step 047 → step52
  ste

Processing trajectories: 100%|██████████| 1/1 [05:02<00:00, 302.48s/it]

[done] Video written to output\graphs\traj_design_000000_trajectory\design_000000_trajectory.mp4




