# Generate Protein Structures

Load a trained diffusion model and generate protein backbone structures.

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

from src.model import DiffusionTransformer
from src.diffusion import DiffusionSchedule
from src.sampler import DiffusionSampler
from src.geom import ca_bond_lengths, radius_of_gyration

## Configuration

In [None]:
# Path to the run directory
RUN_DIR = Path("../runs/20251231_121024")

# Generation parameters
NUM_SAMPLES = 10
SEQ_LENGTH = 64  # Length of generated proteins
SCALE_FACTOR = 10.0  # Same as training

# Device
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")
print(f"Using device: {DEVICE}")

## Load Model

In [None]:
# Load checkpoint
checkpoint_path = RUN_DIR / "model.pt"
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)

# Extract model args
args = checkpoint["args"]
print("Model hyperparameters:")
for k, v in args.items():
    print(f"  {k}: {v}")

In [None]:
# Create model with same architecture
model = DiffusionTransformer(
    d_model=args.get("d_model", 128),
    num_layers=args.get("num_layers", 4),
    num_heads=args.get("num_heads", 4),
)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(DEVICE)
model.eval()

num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params:,}")

In [None]:
# Create diffusion schedule and sampler
schedule = DiffusionSchedule(T=1000).to(DEVICE)
sampler = DiffusionSampler(model, schedule)

## Generate Structures

In [None]:
# Generate samples
print(f"Generating {NUM_SAMPLES} structures of length {SEQ_LENGTH}...")

with torch.no_grad():
    samples = sampler.sample(
        shape=(NUM_SAMPLES, SEQ_LENGTH, 3),
        device=DEVICE,
        verbose=True,
        use_self_cond=True,
    )

# Convert to numpy and unscale
samples_np = samples.cpu().numpy() * SCALE_FACTOR
print(f"Generated samples shape: {samples_np.shape}")

## Analyze Generated Structures

In [None]:
# Compute metrics for each sample
print("\nStructure metrics:")
print(f"{'Sample':>6} {'Bond Mean':>10} {'Bond Std':>10} {'Valid %':>10} {'Rg':>10}")
print("-" * 50)

all_bond_means = []
all_bond_stds = []
all_valid_pcts = []
all_rgs = []

for i, coords in enumerate(samples_np):
    bonds = ca_bond_lengths(coords)
    bond_mean = bonds.mean()
    bond_std = bonds.std()
    valid_pct = ((bonds > 3.6) & (bonds < 4.0)).mean() * 100
    rg = radius_of_gyration(coords)
    
    all_bond_means.append(bond_mean)
    all_bond_stds.append(bond_std)
    all_valid_pcts.append(valid_pct)
    all_rgs.append(rg)
    
    print(f"{i:>6} {bond_mean:>10.2f} {bond_std:>10.2f} {valid_pct:>10.1f} {rg:>10.2f}")

print("-" * 50)
print(f"{'Mean':>6} {np.mean(all_bond_means):>10.2f} {np.mean(all_bond_stds):>10.2f} {np.mean(all_valid_pcts):>10.1f} {np.mean(all_rgs):>10.2f}")

## Visualize Structures

In [None]:
def plot_structure_3d(coords, ax, title="", color="blue"):
    """Plot a protein backbone in 3D."""
    ax.plot(coords[:, 0], coords[:, 1], coords[:, 2], '-', color=color, linewidth=1.5, alpha=0.8)
    ax.scatter(coords[0, 0], coords[0, 1], coords[0, 2], c='green', s=50, label='N-term')
    ax.scatter(coords[-1, 0], coords[-1, 1], coords[-1, 2], c='red', s=50, label='C-term')
    ax.set_title(title)
    ax.set_xlabel('X (Å)')
    ax.set_ylabel('Y (Å)')
    ax.set_zlabel('Z (Å)')

In [None]:
# Plot all generated structures
n_cols = 5
n_rows = (NUM_SAMPLES + n_cols - 1) // n_cols

fig = plt.figure(figsize=(4 * n_cols, 4 * n_rows))

for i, coords in enumerate(samples_np):
    ax = fig.add_subplot(n_rows, n_cols, i + 1, projection='3d')
    plot_structure_3d(coords, ax, title=f"Sample {i}")

plt.tight_layout()
plt.savefig(RUN_DIR / "generated_structures.png", dpi=150)
plt.show()

## Save Structures as PDB

In [None]:
def save_ca_pdb(coords: np.ndarray, path: Path, chain_id: str = "A"):
    """Save CA-only coordinates as a PDB file."""
    with open(path, "w") as f:
        for i, (x, y, z) in enumerate(coords):
            res_num = i + 1
            # PDB ATOM format
            f.write(
                f"ATOM  {i+1:5d}  CA  ALA {chain_id}{res_num:4d}    "
                f"{x:8.3f}{y:8.3f}{z:8.3f}  1.00  0.00           C\n"
            )
        f.write("END\n")

In [None]:
# Save all generated structures
output_dir = RUN_DIR / "generated_pdbs"
output_dir.mkdir(exist_ok=True)

for i, coords in enumerate(samples_np):
    pdb_path = output_dir / f"sample_{i:02d}.pdb"
    save_ca_pdb(coords, pdb_path)
    print(f"Saved: {pdb_path}")

print(f"\nAll structures saved to: {output_dir}")

## Bond Length Distribution

In [None]:
# Collect all bond lengths
all_bonds = []
for coords in samples_np:
    bonds = ca_bond_lengths(coords)
    all_bonds.extend(bonds)

all_bonds = np.array(all_bonds)

# Plot distribution
fig, ax = plt.subplots(figsize=(10, 5))
ax.hist(all_bonds, bins=50, density=True, alpha=0.7, color='blue', label='Generated')
ax.axvline(3.8, color='red', linestyle='--', linewidth=2, label='Ideal (3.8 Å)')
ax.axvspan(3.6, 4.0, alpha=0.2, color='green', label='Valid range (3.6-4.0 Å)')
ax.set_xlabel('CA-CA Bond Length (Å)')
ax.set_ylabel('Density')
ax.set_title('Bond Length Distribution of Generated Structures')
ax.legend()
plt.tight_layout()
plt.savefig(RUN_DIR / "bond_distribution.png", dpi=150)
plt.show()

print(f"Bond length stats:")
print(f"  Mean: {all_bonds.mean():.2f} Å")
print(f"  Std:  {all_bonds.std():.2f} Å")
print(f"  Valid %: {((all_bonds > 3.6) & (all_bonds < 4.0)).mean() * 100:.1f}%")

## Length Sweep

Test how the model (trained on 128-residue windows) generalizes to different chain lengths.

In [None]:
# Length sweep configuration
SWEEP_LENGTHS = [32, 64, 96, 128, 192, 256]
SAMPLES_PER_LENGTH = 10

# Get training max_len for reference
train_max_len = args.get("max_len", 128)
print(f"Model was trained with max_len={train_max_len}")
print(f"Sweeping lengths: {SWEEP_LENGTHS}")

In [None]:
# Generate samples at each length and collect metrics
sweep_results = []

for length in SWEEP_LENGTHS:
    print(f"\nGenerating {SAMPLES_PER_LENGTH} samples at length {length}...")
    
    with torch.no_grad():
        samples = sampler.sample(
            shape=(SAMPLES_PER_LENGTH, length, 3),
            device=DEVICE,
            verbose=True,
            use_self_cond=True,
        )
    
    samples_np = samples.cpu().numpy() * SCALE_FACTOR
    
    # Compute metrics for each sample
    for i, coords in enumerate(samples_np):
        bonds = ca_bond_lengths(coords)
        sweep_results.append({
            "length": length,
            "sample": i,
            "bond_mean": bonds.mean(),
            "bond_std": bonds.std(),
            "valid_pct": ((bonds > 3.6) & (bonds < 4.0)).mean() * 100,
            "rg": radius_of_gyration(coords),
        })

print(f"\nGenerated {len(sweep_results)} total samples across {len(SWEEP_LENGTHS)} lengths")

In [None]:
# Aggregate metrics by length
print("\nLength Sweep Summary:")
print(f"{'Length':>8} {'Bond Mean':>12} {'Bond Std':>12} {'Valid %':>12} {'Rg':>12}")
print("-" * 60)

length_stats = {}
for length in SWEEP_LENGTHS:
    length_results = [r for r in sweep_results if r["length"] == length]
    stats = {
        "bond_mean": np.mean([r["bond_mean"] for r in length_results]),
        "bond_std": np.mean([r["bond_std"] for r in length_results]),
        "valid_pct": np.mean([r["valid_pct"] for r in length_results]),
        "rg": np.mean([r["rg"] for r in length_results]),
        # Also store std across samples for error bars
        "bond_mean_err": np.std([r["bond_mean"] for r in length_results]),
        "valid_pct_err": np.std([r["valid_pct"] for r in length_results]),
        "rg_err": np.std([r["rg"] for r in length_results]),
    }
    length_stats[length] = stats
    
    marker = " <-- trained" if length == train_max_len else ""
    print(f"{length:>8} {stats['bond_mean']:>12.2f} {stats['bond_std']:>12.2f} {stats['valid_pct']:>12.1f} {stats['rg']:>12.2f}{marker}")

In [None]:
# Visualize sweep results
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

lengths = list(length_stats.keys())
bond_means = [length_stats[l]["bond_mean"] for l in lengths]
bond_errs = [length_stats[l]["bond_mean_err"] for l in lengths]
valid_pcts = [length_stats[l]["valid_pct"] for l in lengths]
valid_errs = [length_stats[l]["valid_pct_err"] for l in lengths]
rgs = [length_stats[l]["rg"] for l in lengths]
rg_errs = [length_stats[l]["rg_err"] for l in lengths]

# Bond length vs chain length
ax = axes[0]
ax.errorbar(lengths, bond_means, yerr=bond_errs, marker='o', capsize=5, linewidth=2, markersize=8)
ax.axhline(3.8, color='red', linestyle='--', label='Ideal (3.8 Å)')
ax.axhspan(3.6, 4.0, alpha=0.2, color='green', label='Valid range')
ax.axvline(train_max_len, color='gray', linestyle=':', alpha=0.7, label=f'Trained ({train_max_len})')
ax.set_xlabel('Chain Length (residues)')
ax.set_ylabel('Mean Bond Length (Å)')
ax.set_title('Bond Length vs Chain Length')
ax.legend()
ax.grid(True, alpha=0.3)

# Valid bond % vs chain length
ax = axes[1]
ax.errorbar(lengths, valid_pcts, yerr=valid_errs, marker='s', capsize=5, linewidth=2, markersize=8, color='green')
ax.axvline(train_max_len, color='gray', linestyle=':', alpha=0.7, label=f'Trained ({train_max_len})')
ax.set_xlabel('Chain Length (residues)')
ax.set_ylabel('Valid Bonds (%)')
ax.set_title('Bond Quality vs Chain Length')
ax.set_ylim(0, 105)
ax.legend()
ax.grid(True, alpha=0.3)

# Radius of gyration vs chain length
ax = axes[2]
ax.errorbar(lengths, rgs, yerr=rg_errs, marker='^', capsize=5, linewidth=2, markersize=8, color='purple')
ax.axvline(train_max_len, color='gray', linestyle=':', alpha=0.7, label=f'Trained ({train_max_len})')
ax.set_xlabel('Chain Length (residues)')
ax.set_ylabel('Radius of Gyration (Å)')
ax.set_title('Compactness vs Chain Length')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(RUN_DIR / "length_sweep.png", dpi=150)
plt.show()

print(f"\nSweep results saved to: {RUN_DIR / 'length_sweep.png'}")

### Interpreting the Sweep

**What to look for:**

1. **Bond Length**: Should stay close to 3.8 Å across all lengths. If it drifts at longer/shorter lengths, the model struggles to generalize.

2. **Valid Bond %**: Higher is better (target: >90%). A sharp drop at certain lengths indicates the model breaks down.

3. **Radius of Gyration**: Should increase with chain length (longer chains = larger structures). If Rg stays flat or decreases at longer lengths, the model may be "collapsing" structures.

**Typical failure modes:**

- **Short chains (< trained)**: Usually works well - model just uses part of its capacity
- **Long chains (> trained)**: Often breaks down - positional encodings extrapolate poorly, structures may collapse or have poor geometry