# Protein Structure Diffusion Process Visualization

This notebook demonstrates how to visualize the diffusion process for protein structure prediction. We'll see how the model progressively denoises a random initialization into a protein structure.

In [2]:
import os
import sys
import torch
import numpy as np
from IPython.display import display
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

# Add the project root directory to the path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)

from src.data_utils import one_hot_encode_sequence, save_protein_structure
from src.model import ProteinDenoiser
from src.diffusion import ProteinDiffusion
from scripts.visualize_diffusion import visualize_denoising_process, create_denoising_animation

## 1. Set Up Model

First, we'll load a pre-trained diffusion model.

In [5]:
# Model parameters
n_timesteps = 1000
hidden_dim = 256
n_layers = 8
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Path to checkpoint
checkpoint_path = '../models/proteindiff_epoch_50.pt'  # Change this to your checkpoint path

# Create model
hidden_dims = [hidden_dim] * n_layers
model = ProteinDenoiser(
    hidden_dims=hidden_dims,
    diffusion_time_embedding_dim=hidden_dim,
    n_times=n_timesteps
).to(device)

protein_diffusion = ProteinDiffusion(
    model, 
    n_times=n_timesteps,
    beta_minmax=[1e-4, 2e-2],
    device=device
).to(device)

# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device)
protein_diffusion.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded checkpoint from epoch {checkpoint['epoch']+1}")

Loaded checkpoint from epoch 50


## 2. Define a Protein Sequence

Let's define a protein sequence to visualize the diffusion process.

In [8]:
# Example sequence - Crambin (a small protein)
sequence = "TTCCPSIVARSNFNVCRLPGTPEAICATYTGCIIIPGATCPGDYAN"

# For visualization purposes, using a shorter sequence might be clearer
# Uncomment this for a shorter sequence
# sequence = sequence[:25]  # Just use the first 25 residues

print(f"Sequence length: {len(sequence)}")
print(f"Sequence: {sequence}")

Sequence length: 46
Sequence: TTCCPSIVARSNFNVCRLPGTPEAICATYTGCIIIPGATCPGDYAN


## 3. Visualize the Diffusion Process

Now we'll visualize the denoising process at various timesteps.

In [11]:
# Choose specific timesteps to visualize
timesteps_to_visualize = [1000, 900, 800, 700, 600, 500, 400, 300, 200, 100, 50, 25, 0]

# Visualize the denoising process
output_dir = "diffusion_visualizations"
structures, pdb_files, visualizations = visualize_denoising_process(
    protein_diffusion,
    sequence,
    device,
    timesteps_to_save=timesteps_to_visualize,
    output_dir=output_dir
)

Will visualize these timesteps: [1000, 900, 800, 700, 600, 500, 400, 300, 200, 100, 50, 25, 0]


Sampling: 100%|███████████████████████████████████████████████████████████████████| 1000/1000 [00:07<00:00, 137.34it/s]



Showing denoising process visualizations:

Timestep: 999 (Progress: 0.0%)


None


Timestep: 900 (Progress: 9.9%)


None


Timestep: 800 (Progress: 19.9%)


None


Timestep: 700 (Progress: 29.9%)


None


Timestep: 600 (Progress: 39.9%)


None


Timestep: 500 (Progress: 49.9%)


None


Timestep: 400 (Progress: 60.0%)


None


Timestep: 300 (Progress: 70.0%)


None


Timestep: 200 (Progress: 80.0%)


None


Timestep: 100 (Progress: 90.0%)


None


Timestep: 50 (Progress: 95.0%)


None


Timestep: 25 (Progress: 97.5%)


None


Timestep: 0 (Progress: 100.0%)


None

## 4. Create Animation

Let's create an animation of the diffusion process.

In [13]:
# Create and display the animation
animation_file = os.path.join(output_dir, "diffusion_animation.gif")
create_denoising_animation(structures, sequence, animation_file, fps=2)

# Display the animation
from IPython.display import Image
display(Image(filename=animation_file))

IndexError: list index out of range

## 5. Visualize Individual Structures

Let's take a closer look at some specific timesteps.

In [None]:
import py3Dmol

# Let's compare the initial noise (t=1000), middle (t=500) and final structure (t=0)
for timestep in [1000, 500, 0]:
    pdb_file = os.path.join(output_dir, f"structure_timestep_{timestep}.pdb")
    
    with open(pdb_file) as f:
        pdb_data = f.read()
    
    view = py3Dmol.view(width=600, height=400)
    view.addModel(pdb_data, 'pdb')
    view.setStyle({'cartoon': {'color': 'spectrum'}})
    view.zoomTo()
    
    progress_percent = (n_timesteps - 1 - timestep) / (n_timesteps - 1) * 100
    print(f"\nTimestep: {timestep} (Progress: {progress_percent:.1f}%)")
    display(view.show())

## 6. Compare the Full Denoising Sequence

We can also arrange all the structures side by side to see the full progression.

In [None]:
# Create a grid of visualizations
from ipywidgets import HBox, VBox, Label

rows = []
num_cols = 3  # Number of visualizations per row

# Sort timesteps in descending order
sorted_timesteps = sorted([t for t, _ in visualizations], reverse=True)

# Create rows of visualizations
current_row = []
for timestep in sorted_timesteps:
    # Find the visualization for this timestep
    for t, vis in visualizations:
        if t == timestep:
            progress_percent = (n_timesteps - 1 - t) / (n_timesteps - 1) * 100
            label = Label(f"t={t} ({progress_percent:.1f}%)")
            current_row.append(VBox([label, vis.show()]))
            break
    
    if len(current_row) == num_cols:
        rows.append(HBox(current_row))
        current_row = []

# Add the last row if not empty
if current_row:
    rows.append(HBox(current_row))

# Display all rows
display(VBox(rows))

## 7. Plot RMSD Changes During Diffusion

Let's analyze how much the structure changes at each step of the diffusion process by calculating the RMSD between consecutive frames.

In [None]:
from scipy.spatial.distance import cdist

def calculate_rmsd(coords1, coords2):
    """Calculate RMSD between two coordinate sets"""
    # Get only CA atoms (index 1 in our representation)
    ca1 = coords1[:, 1, :].numpy()
    ca2 = coords2[:, 1, :].numpy()
    
    # Calculate RMSD
    squared_diff = np.sum((ca1 - ca2) ** 2, axis=1)
    rmsd = np.sqrt(np.mean(squared_diff))
    return rmsd

# Calculate RMSD between consecutive frames
rmsds = []
timesteps = []

sorted_timesteps = sorted(structures.keys())
for i in range(1, len(sorted_timesteps)):
    t_prev = sorted_timesteps[i-1]
    t_curr = sorted_timesteps[i]
    rmsd = calculate_rmsd(structures[t_prev], structures[t_curr])
    rmsds.append(rmsd)
    timesteps.append(t_curr)

# Plot RMSD changes
plt.figure(figsize=(12, 6))
plt.plot(timesteps, rmsds, 'o-')
plt.title('RMSD Changes During Diffusion Process')
plt.xlabel('Timestep')
plt.ylabel('RMSD (Å) between consecutive frames')
plt.grid(True)
plt.show()

# Plot RMSD changes with log scale for better visualization of early steps
plt.figure(figsize=(12, 6))
plt.semilogx(timesteps, rmsds, 'o-')
plt.title('RMSD Changes During Diffusion Process (Log Scale)')
plt.xlabel('Timestep (log scale)')
plt.ylabel('RMSD (Å) between consecutive frames')
plt.grid(True)
plt.show()