[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/engelberger/tutorials-ai4pd-2025/blob/main/tutorial_alphafold2_i89_conformations.ipynb)

# Tutorial 1: Prediction of Protein Structures and Multiple Conformations using AlphaFold2

## Focus: The i89 Protein - A Case Study in Conformational Flexibility

**Duration:** 90 minutes  
**Instructor:** Felipe Engelberger  
**Date:** AI4PD Workshop 2025

---

## Learning Objectives

By the end of this tutorial, you will understand:

1. **MSA's role in conformation selection**: How evolutionary information biases AlphaFold2 predictions
2. **Recycling mechanics**: How iterative refinement affects structure quality and conformation
3. **Conformational sampling strategies**: Practical techniques using dropout and MSA subsampling
4. **Structure analysis tools**: RMSD calculations, visualization, and ensemble analysis
5. **Real-world applications**: When and how to apply these techniques to proteins of interest

## Tutorial Overview

We'll use the **i89 protein** as our model system. This 96-residue protein exhibits distinct conformational states that AlphaFold2 can capture through different prediction strategies:

- **State 1**: The conformation typically predicted with full MSA
- **State 2**: An alternative conformation accessible without MSA

We have experimental structures for both states (`state1.pdb` and `state2.pdb`) for validation.


## Section 1: Environment Setup and Dependencies

First, let's set up our environment with all necessary dependencies. This notebook uses ColabDesign (gamma branch) for AlphaFold2 predictions.


In [None]:
%%time
#@title Install Dependencies and Setup Environment
#@markdown This cell installs ColabDesign and other required packages

import os
import sys
import warnings
warnings.filterwarnings('ignore')

# Check if running in Colab
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print(" Running in Google Colab - installing dependencies...")
else:
    print(" Running locally - installing dependencies...")

# Install ColabDesign from gamma branch
print(" Installing ColabDesign (gamma branch)...")
%pip install -q git+https://github.com/sokrypton/ColabDesign.git@gamma

# Install additional dependencies
print(" Installing additional packages...")
%pip install -q biopython matplotlib plotly py3Dmol tqdm

# Install LogMD for trajectory visualization
%pip install -q logmd

# Set up environment variables for JAX
if IN_COLAB:
    os.environ["TF_FORCE_UNIFIED_MEMORY"] = "1"
    os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "4.0"

print(" Environment setup complete!")


In [None]:
%%time
#@title Download AlphaFold Parameters and Import Libraries

import os
import time
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from pathlib import Path
import json
from tqdm import tqdm

# ColabDesign imports
from colabdesign import mk_af_model, clear_mem
from colabdesign.af.contrib import predict
from colabdesign.shared.protein import _np_rmsd, _np_kabsch
from colabdesign.shared.plot import plot_pseudo_3D, pymol_cmap

# BioPython for structure analysis
from Bio import PDB
from Bio.PDB import PDBIO, Select

# JAX imports
import jax
import jax.numpy as jnp

# Download AlphaFold parameters if not already present
if not os.path.isdir("params"):
    print(" Downloading AlphaFold parameters...")
    os.system("mkdir params")
    os.system("apt-get install -qq aria2")
    os.system("aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar")
    os.system("tar -xf alphafold_params_2022-12-06.tar -C params")
    os.system("rm alphafold_params_2022-12-06.tar")
    print(" Parameters downloaded!")
else:
    print(" AlphaFold parameters already present")

print(" All libraries imported successfully!")


## Section 2: Helper Functions

Let's define utility functions for structure prediction, analysis, and visualization.


In [None]:
#@title Helper Functions for Structure Prediction and Analysis

def setup_af2_model(sequence, use_templates=False, model_type="alphafold2_ptm"):
    """Initialize AlphaFold2 model using ColabDesign"""
    model = mk_af_model(
        protocol="fixbb",
        model_type=model_type,
        use_templates=use_templates
    )
    model.prep_inputs(sequence=sequence)
    return model

def predict_with_settings(model, msa_mode="mmseqs2", num_recycles=3, 
                          dropout=False, seed=0, num_msa=512):
    """
    Run AlphaFold2 prediction with specific settings
    
    Args:
        model: ColabDesign AF2 model
        msa_mode: "mmseqs2", "single_sequence", or "custom"
        num_recycles: Number of recycling iterations
        dropout: Enable dropout for sampling
        seed: Random seed
        num_msa: Maximum number of MSA sequences to use
    """
    # Set prediction parameters
    model.set_opt(num_recycles=num_recycles, 
                  use_dropout=dropout,
                  seed=seed)
    
    # Handle MSA mode
    if msa_mode == "single_sequence":
        # Use only the query sequence
        model.set_msa(msa=[[0] * len(model._wt_seq)])
    elif msa_mode == "custom":
        # For custom MSA depth
        model.set_opt(num_msa=num_msa)
    
    # Run prediction
    model.predict()
    
    return {
        'structure': model.aux['atom_positions'],
        'plddt': model.aux['plddt'],
        'pae': model.aux.get('pae', None),
        'seq': model._wt_seq
    }

def calculate_rmsd_to_references(pred_coords, ref1_path="state1.pdb", ref2_path="state2.pdb"):
    """Calculate RMSD to both reference states"""
    # Load reference structures
    parser = PDB.PDBParser(QUIET=True)
    
    ref1 = parser.get_structure("ref1", ref1_path)
    ref2 = parser.get_structure("ref2", ref2_path)
    
    # Extract CA coordinates
    def get_ca_coords(structure):
        coords = []
        for model in structure:
            for chain in model:
                for residue in chain:
                    if 'CA' in residue:
                        coords.append(residue['CA'].coord)
        return np.array(coords)
    
    ref1_coords = get_ca_coords(ref1)
    ref2_coords = get_ca_coords(ref2)
    
    # Get CA coordinates from prediction (assuming pred_coords has all atoms)
    # ColabDesign returns coordinates with shape (L, 37, 3) where index 1 is CA
    pred_ca = pred_coords[:, 1, :]  # CA is at index 1
    
    # Calculate RMSDs using ColabDesign's function
    rmsd1 = _np_rmsd(pred_ca, ref1_coords[:len(pred_ca)], use_jax=False)
    rmsd2 = _np_rmsd(pred_ca, ref2_coords[:len(pred_ca)], use_jax=False)
    
    return {'rmsd_state1': rmsd1, 'rmsd_state2': rmsd2}

def plot_structure_comparison(results_list, labels=None):
    """Plot RMSD comparison for multiple predictions"""
    if labels is None:
        labels = [f"Pred_{i}" for i in range(len(results_list))]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # RMSD to State 1
    rmsd1_values = [r['rmsd_state1'] for r in results_list]
    ax1.bar(labels, rmsd1_values, color='steelblue')
    ax1.set_ylabel('RMSD (√Ö)')
    ax1.set_title('RMSD to State 1')
    ax1.set_ylim(0, max(rmsd1_values + [r['rmsd_state2'] for r in results_list]) * 1.2)
    
    # RMSD to State 2
    rmsd2_values = [r['rmsd_state2'] for r in results_list]
    ax2.bar(labels, rmsd2_values, color='coral')
    ax2.set_ylabel('RMSD (√Ö)')
    ax2.set_title('RMSD to State 2')
    ax2.set_ylim(0, max(rmsd1_values + rmsd2_values) * 1.2)
    
    plt.tight_layout()
    return fig

def plot_recycling_convergence(recycle_data):
    """Plot RMSD convergence during recycling"""
    fig, ax = plt.subplots(figsize=(8, 5))
    
    recycles = list(recycle_data.keys())
    rmsd1_vals = [recycle_data[r]['rmsd_state1'] for r in recycles]
    rmsd2_vals = [recycle_data[r]['rmsd_state2'] for r in recycles]
    
    ax.plot(recycles, rmsd1_vals, 'o-', label='RMSD to State 1', color='steelblue', linewidth=2)
    ax.plot(recycles, rmsd2_vals, 's-', label='RMSD to State 2', color='coral', linewidth=2)
    
    ax.set_xlabel('Number of Recycles')
    ax.set_ylabel('RMSD (√Ö)')
    ax.set_title('Convergence During Recycling')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    return fig

def visualize_ensemble(structures, labels=None):
    """
    Create t-SNE visualization of structural ensemble.
    
    Note: This is useful when you DON'T have reference structures.
    For cases with known references, RMSD analysis is more interpretable.
    """
    from sklearn.manifold import TSNE
    
    # Flatten structures to feature vectors (using CA positions)
    features = []
    for s in structures:
        ca_coords = s[:, 1, :].flatten()  # CA coordinates
        features.append(ca_coords)
    
    features = np.array(features)
    
    # Apply t-SNE
    tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(structures)-1))
    embedded = tsne.fit_transform(features)
    
    # Plot
    fig, ax = plt.subplots(figsize=(8, 6))
    scatter = ax.scatter(embedded[:, 0], embedded[:, 1], 
                        c=range(len(structures)), 
                        cmap='viridis', s=100, alpha=0.7)
    
    if labels:
        for i, label in enumerate(labels):
            ax.annotate(label, (embedded[i, 0], embedded[i, 1]), 
                       fontsize=8, alpha=0.7)
    
    ax.set_xlabel('t-SNE 1')
    ax.set_ylabel('t-SNE 2')
    ax.set_title('Conformational Ensemble Visualization (No Reference)')
    plt.colorbar(scatter, label='Structure Index')
    
    return fig

print(" Helper functions loaded!")


## Section 3: The i89 Protein - Our Model System

The i89 protein is a 96-residue protein that can adopt multiple conformational states. We'll use it to demonstrate how AlphaFold2's predictions can be influenced by MSA depth, recycling, and sampling parameters.


In [None]:
#@title Define i89 Sequence and Load Reference Structures

# i89 protein sequence (96 residues)
I89_SEQUENCE = "GSHMASMEDLQAEARAFLSEEMIAEFKAAFDMFDADGGGDISYKAVGTVFRMLGINPSKEVLDYLKEKIDVDGSGTIDFEEFLVLMVYIMKQDA"

print(f" i89 protein statistics:")
print(f"   Length: {len(I89_SEQUENCE)} residues")
print(f"   Sequence: {I89_SEQUENCE[:30]}...{I89_SEQUENCE[-20:]}")

# Check if reference structures exist
import os
if os.path.exists("state1.pdb") and os.path.exists("state2.pdb"):
    print("\n Reference structures found:")
    print("   - state1.pdb: Conformation typically predicted with MSA")
    print("   - state2.pdb: Alternative conformation accessible without MSA")
else:
    print("\n Reference structures not found. Downloading...")
    # Download reference structures if not present
    !wget -q https://raw.githubusercontent.com/engelberger/alphamask/refs/heads/colab/notebooks/state1.pdb
    !wget -q https://raw.githubusercontent.com/engelberger/alphamask/refs/heads/colab/notebooks/state2.pdb
    print(" Reference structures downloaded!")

# Quick visualization of reference structure properties
from Bio import PDB
parser = PDB.PDBParser(QUIET=True)
state1 = parser.get_structure("state1", "state1.pdb")
state2 = parser.get_structure("state2", "state2.pdb")

# Calculate RMSD between the two reference states
def get_ca_coords(structure):
    coords = []
    for model in structure:
        for chain in model:
            for residue in chain:
                if 'CA' in residue:
                    coords.append(residue['CA'].coord)
    return np.array(coords)

state1_ca = get_ca_coords(state1)
state2_ca = get_ca_coords(state2)
ref_rmsd = _np_rmsd(state1_ca, state2_ca, use_jax=False)

print(f"\nüìê RMSD between reference states: {ref_rmsd:.2f} √Ö")
print("   This indicates significant conformational difference!")


## Section 4: Basic Prediction with Full MSA

Let's start by predicting the i89 structure with a full MSA. This typically results in a conformation closer to State 1.


In [None]:
%%time
#@title Prediction with Full MSA (mmseqs2)
#@markdown This prediction uses evolutionary information from homologous sequences

print(" Setting up AlphaFold2 model...")
model_with_msa = setup_af2_model(I89_SEQUENCE)

print(" Running prediction with full MSA...")
result_with_msa = predict_with_settings(
    model_with_msa, 
    msa_mode="mmseqs2",  # Full MSA from MMseqs2
    num_recycles=3,       # Standard number of recycles
    dropout=False,        # No dropout for deterministic prediction
    seed=0
)

# Calculate RMSD to reference states
rmsd_results = calculate_rmsd_to_references(result_with_msa['structure'])

print("\n Results with Full MSA:")
print(f"   RMSD to State 1: {rmsd_results['rmsd_state1']:.2f} √Ö")
print(f"   RMSD to State 2: {rmsd_results['rmsd_state2']:.2f} √Ö")
print(f"   Mean pLDDT: {np.mean(result_with_msa['plddt']) * 100:.1f}%")

# Determine which state is closer
if rmsd_results['rmsd_state1'] < rmsd_results['rmsd_state2']:
    print("\n Prediction is closer to State 1 (as expected with MSA)")
else:
    print("\n Prediction is closer to State 2")

# Visualize pLDDT distribution
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(result_with_msa['plddt'] * 100, color='steelblue', linewidth=2)
plt.xlabel('Residue')
plt.ylabel('pLDDT (%)')
plt.title('Confidence per Residue (with MSA)')
plt.ylim(0, 100)
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.hist(result_with_msa['plddt'] * 100, bins=20, color='steelblue', alpha=0.7, edgecolor='black')
plt.xlabel('pLDDT (%)')
plt.ylabel('Count')
plt.title('pLDDT Distribution')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


## Section 5: MSA Manipulation - Exploring Conformational Control

Now let's see how removing or reducing MSA information affects the predicted conformation. Without MSA, AlphaFold2 relies more on learned structural patterns.


In [None]:
%%time
#@title Prediction without MSA (Single Sequence)
#@markdown This prediction uses only the query sequence, no evolutionary information

print(" Setting up AlphaFold2 model...")
model_no_msa = setup_af2_model(I89_SEQUENCE)

print(" Running prediction WITHOUT MSA...")
result_no_msa = predict_with_settings(
    model_no_msa,
    msa_mode="single_sequence",  # Only query sequence
    num_recycles=3,
    dropout=False,
    seed=0
)

# Calculate RMSD to reference states
rmsd_results_no_msa = calculate_rmsd_to_references(result_no_msa['structure'])

print("\n Results WITHOUT MSA:")
print(f"   RMSD to State 1: {rmsd_results_no_msa['rmsd_state1']:.2f} √Ö")
print(f"   RMSD to State 2: {rmsd_results_no_msa['rmsd_state2']:.2f} √Ö")
print(f"   Mean pLDDT: {np.mean(result_no_msa['plddt']) * 100:.1f}%")

# Determine which state is closer
if rmsd_results_no_msa['rmsd_state2'] < rmsd_results_no_msa['rmsd_state1']:
    print("\n Prediction is closer to State 2 (as expected without MSA)")
else:
    print("\n Prediction is closer to State 1")

# Compare both predictions
comparison_results = [
    {'rmsd_state1': rmsd_results['rmsd_state1'], 
     'rmsd_state2': rmsd_results['rmsd_state2']},
    {'rmsd_state1': rmsd_results_no_msa['rmsd_state1'], 
     'rmsd_state2': rmsd_results_no_msa['rmsd_state2']}
]

fig = plot_structure_comparison(comparison_results, labels=['With MSA', 'Without MSA'])
plt.suptitle('MSA Effect on Conformational Preference', fontsize=14, y=1.02)
plt.show()

print("\n Key Finding:")
print("   MSA presence/absence can switch the predicted conformation!")
print(f"   Conformational shift: {abs(rmsd_results['rmsd_state1'] - rmsd_results_no_msa['rmsd_state1']):.1f} √Ö")


In [None]:
#@title Exploring Intermediate MSA Depths
#@markdown Test different MSA depths to see the gradual conformational transition

print("üî¨ Testing intermediate MSA depths...")
msa_depths = [1, 32, 128, 512]  # Different numbers of MSA sequences
msa_results = []

for depth in tqdm(msa_depths, desc="MSA depths"):
    model = setup_af2_model(I89_SEQUENCE)
    
    # For intermediate depths, we use custom MSA settings
    if depth == 1:
        result = predict_with_settings(model, msa_mode="single_sequence", num_recycles=3)
    else:
        result = predict_with_settings(model, msa_mode="custom", num_msa=depth, num_recycles=3)
    
    rmsd = calculate_rmsd_to_references(result['structure'])
    msa_results.append({
        'depth': depth,
        'rmsd_state1': rmsd['rmsd_state1'],
        'rmsd_state2': rmsd['rmsd_state2'],
        'plddt': np.mean(result['plddt']) * 100
    })

# Plot MSA depth effect
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

depths = [r['depth'] for r in msa_results]
rmsd1 = [r['rmsd_state1'] for r in msa_results]
rmsd2 = [r['rmsd_state2'] for r in msa_results]
plddt_vals = [r['plddt'] for r in msa_results]

# RMSD plot
ax1.plot(depths, rmsd1, 'o-', label='RMSD to State 1', color='steelblue', linewidth=2, markersize=8)
ax1.plot(depths, rmsd2, 's-', label='RMSD to State 2', color='coral', linewidth=2, markersize=8)
ax1.set_xscale('log')
ax1.set_xlabel('MSA Depth (# sequences)')
ax1.set_ylabel('RMSD (√Ö)')
ax1.set_title('Conformational Preference vs MSA Depth')
ax1.legend()
ax1.grid(True, alpha=0.3)

# pLDDT plot
ax2.plot(depths, plddt_vals, 'o-', color='green', linewidth=2, markersize=8)
ax2.set_xscale('log')
ax2.set_xlabel('MSA Depth (# sequences)')
ax2.set_ylabel('Mean pLDDT (%)')
ax2.set_title('Confidence vs MSA Depth')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\n MSA Depth Analysis:")
for r in msa_results:
    closer_to = "State 1" if r['rmsd_state1'] < r['rmsd_state2'] else "State 2"
    print(f"   Depth {r['depth']:3d}: Closer to {closer_to} (pLDDT: {r['plddt']:.1f}%)")


## Section 6: Recycling for Conformational Refinement

Recycling is AlphaFold2's iterative refinement process. Let's explore how the number of recycles affects structure quality and conformational preference.


In [None]:
%%time
#@title Effect of Recycling on Structure Convergence
#@markdown Test how recycling iterations affect the predicted structure

print(" Testing different numbers of recycles...")

# Test with and without MSA at different recycle numbers
recycle_numbers = [0, 1, 3, 6, 12]
recycle_results_with_msa = {}
recycle_results_no_msa = {}

# With MSA
print("\n With MSA:")
model_msa = setup_af2_model(I89_SEQUENCE)
for n_recycles in tqdm(recycle_numbers, desc="Recycles (MSA)"):
    result = predict_with_settings(model_msa, msa_mode="mmseqs2", num_recycles=n_recycles)
    rmsd = calculate_rmsd_to_references(result['structure'])
    recycle_results_with_msa[n_recycles] = rmsd
    print(f"   {n_recycles:2d} recycles: State1={rmsd['rmsd_state1']:.2f}√Ö, State2={rmsd['rmsd_state2']:.2f}√Ö")

# Without MSA
print("\n Without MSA:")
model_no_msa = setup_af2_model(I89_SEQUENCE)
for n_recycles in tqdm(recycle_numbers, desc="Recycles (no MSA)"):
    result = predict_with_settings(model_no_msa, msa_mode="single_sequence", num_recycles=n_recycles)
    rmsd = calculate_rmsd_to_references(result['structure'])
    recycle_results_no_msa[n_recycles] = rmsd
    print(f"   {n_recycles:2d} recycles: State1={rmsd['rmsd_state1']:.2f}√Ö, State2={rmsd['rmsd_state2']:.2f}√Ö")

# Plot convergence
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# With MSA
ax1.set_title('Recycling Convergence WITH MSA')
fig1 = plot_recycling_convergence(recycle_results_with_msa)
plt.close(fig1)  # Close the figure from helper function
ax1.plot(recycle_numbers, [recycle_results_with_msa[r]['rmsd_state1'] for r in recycle_numbers], 
         'o-', label='RMSD to State 1', color='steelblue', linewidth=2)
ax1.plot(recycle_numbers, [recycle_results_with_msa[r]['rmsd_state2'] for r in recycle_numbers],
         's-', label='RMSD to State 2', color='coral', linewidth=2)
ax1.set_xlabel('Number of Recycles')
ax1.set_ylabel('RMSD (√Ö)')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Without MSA
ax2.set_title('Recycling Convergence WITHOUT MSA')
ax2.plot(recycle_numbers, [recycle_results_no_msa[r]['rmsd_state1'] for r in recycle_numbers],
         'o-', label='RMSD to State 1', color='steelblue', linewidth=2)
ax2.plot(recycle_numbers, [recycle_results_no_msa[r]['rmsd_state2'] for r in recycle_numbers],
         's-', label='RMSD to State 2', color='coral', linewidth=2)
ax2.set_xlabel('Number of Recycles')
ax2.set_ylabel('RMSD (√Ö)')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\n Key Observations:")
print("   - With MSA: Structure converges toward State 1")
print("   - Without MSA: Structure converges toward State 2")
print("   - Most convergence happens in first 3-6 recycles")
print("   - Early stopping (tolerance ~0.5√Ö) can save computation")


## Section 7: Sampling Multiple Conformations

Now let's explore techniques for sampling multiple conformations using dropout and different random seeds.


In [None]:
%%time
#@title Ensemble Generation with Dropout and Multiple Seeds
#@markdown Generate multiple structures to explore conformational diversity

print(" Generating conformational ensemble...")

# Parameters for ensemble generation
n_seeds = 5  # Number of different random seeds
dropout_rates = [0.0, 0.15, 0.3]  # Different dropout rates
ensemble_structures = []
ensemble_labels = []
ensemble_rmsds = []

# Generate ensemble
for dropout in dropout_rates:
    for seed in range(n_seeds):
        # With MSA + dropout
        model = setup_af2_model(I89_SEQUENCE)
        result = predict_with_settings(
            model,
            msa_mode="mmseqs2",
            num_recycles=3,
            dropout=(dropout > 0),
            seed=seed
        )
        
        rmsd = calculate_rmsd_to_references(result['structure'])
        ensemble_structures.append(result['structure'])
        ensemble_labels.append(f"D{dropout:.2f}_S{seed}")
        ensemble_rmsds.append(rmsd)
        
        # Without MSA + dropout (for diversity)
        if dropout > 0:
            model = setup_af2_model(I89_SEQUENCE)
            result = predict_with_settings(
                model,
                msa_mode="single_sequence",
                num_recycles=3,
                dropout=True,
                seed=seed
            )
            
            rmsd = calculate_rmsd_to_references(result['structure'])
            ensemble_structures.append(result['structure'])
            ensemble_labels.append(f"NoMSA_D{dropout:.2f}_S{seed}")
            ensemble_rmsds.append(rmsd)

print(f"\n Generated {len(ensemble_structures)} structures")

# Analyze ensemble
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Plot 1: RMSD distribution
ax = axes[0, 0]
rmsd1_vals = [r['rmsd_state1'] for r in ensemble_rmsds]
rmsd2_vals = [r['rmsd_state2'] for r in ensemble_rmsds]
ax.scatter(rmsd1_vals, rmsd2_vals, alpha=0.6, s=50)
ax.set_xlabel('RMSD to State 1 (√Ö)')
ax.set_ylabel('RMSD to State 2 (√Ö)')
ax.set_title('Ensemble Distribution in RMSD Space')
ax.grid(True, alpha=0.3)

# Add reference points
ax.scatter([0], [ref_rmsd], marker='*', s=200, c='red', label='State 1 vs State 2')
ax.legend()

# Plot 2: State preference histogram
ax = axes[0, 1]
state_preference = ['State 1' if r['rmsd_state1'] < r['rmsd_state2'] else 'State 2' 
                   for r in ensemble_rmsds]
state_counts = {s: state_preference.count(s) for s in ['State 1', 'State 2']}
ax.bar(state_counts.keys(), state_counts.values(), color=['steelblue', 'coral'])
ax.set_ylabel('Count')
ax.set_title('Conformational State Distribution')

# Plot 3: Effect of dropout
ax = axes[1, 0]
for dropout in dropout_rates:
    dropout_rmsd1 = [r['rmsd_state1'] for r, l in zip(ensemble_rmsds, ensemble_labels) 
                     if f"D{dropout:.2f}" in l and "NoMSA" not in l]
    dropout_rmsd2 = [r['rmsd_state2'] for r, l in zip(ensemble_rmsds, ensemble_labels) 
                     if f"D{dropout:.2f}" in l and "NoMSA" not in l]
    if dropout_rmsd1:
        ax.scatter(dropout_rmsd1, dropout_rmsd2, label=f'Dropout={dropout}', alpha=0.7, s=80)

ax.set_xlabel('RMSD to State 1 (√Ö)')
ax.set_ylabel('RMSD to State 2 (√Ö)')
ax.set_title('Dropout Effect on Conformation')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 4: RMSD-based conformational analysis
ax = axes[1, 1]
# Create RMSD difference plot (State1 - State2)
rmsd_diff = [r['rmsd_state1'] - r['rmsd_state2'] for r in ensemble_rmsds]
colors = ['steelblue' if diff < 0 else 'coral' for diff in rmsd_diff]

ax.bar(range(len(rmsd_diff)), rmsd_diff, color=colors, alpha=0.7)
ax.axhline(y=0, color='black', linestyle='--', linewidth=1)
ax.set_xlabel('Structure Index')
ax.set_ylabel('RMSD Difference (State1 - State2) √Ö')
ax.set_title('Conformational Preference per Structure')
ax.text(0.02, 0.98, 'Blue: Closer to State 1\nRed: Closer to State 2', 
        transform=ax.transAxes, va='top', fontsize=9)

plt.tight_layout()
plt.show()

print("\nüìä Ensemble Statistics:")
print(f"   Total structures: {len(ensemble_structures)}")
print(f"   Prefer State 1: {state_counts.get('State 1', 0)}")
print(f"   Prefer State 2: {state_counts.get('State 2', 0)}")
print(f"   Mean RMSD to State 1: {np.mean(rmsd1_vals):.2f} ¬± {np.std(rmsd1_vals):.2f} √Ö")
print(f"   Mean RMSD to State 2: {np.mean(rmsd2_vals):.2f} ¬± {np.std(rmsd2_vals):.2f} √Ö")

print("\nüí° Analysis Note:")
print("   We use RMSD analysis here because we have reference structures.")
print("   For proteins WITHOUT known conformations, use dimensionality reduction:")
print("   - t-SNE or UMAP to visualize conformational landscape")
print("   - Clustering to identify distinct conformational families")
print("   - All-vs-all RMSD matrices to find structural relationships")


## Section 8: Summary and Key Takeaways

### What We've Learned

1. **MSA Controls Conformation**: With MSA leads to State 1, without MSA leads to State 2
2. **Recycling Refines Structure**: Most improvement in first 3-6 recycles
3. **Sampling Strategies**: Dropout and MSA subsampling explore conformational space
4. **Analysis Methods**: RMSD when references exist, t-SNE/UMAP when they don't

### Practical Guidelines

- **For single structure**: Use full MSA, 3-6 recycles
- **For conformational sampling**: Vary MSA depth, use dropout
- **For efficiency**: Implement early stopping (RMSD tolerance ~0.5√Ö)
- **For validation**: Compare to known structures when available

### Choosing Analysis Methods

**When you HAVE reference structures:**
- Use RMSD analysis for direct comparison
- Calculate RMSD differences to identify conformational preferences
- Track convergence using RMSD to known states

**When you DON'T have reference structures:**
- Use t-SNE or UMAP for dimensionality reduction
- Apply clustering algorithms to identify conformational families
- Create all-vs-all RMSD matrices to find structural relationships
- Look for patterns in pLDDT to identify flexible regions


## Section 9: Challenge Exercise - Explore Other Proteins

Now that you understand how to manipulate AlphaFold2 predictions, try these techniques on other proteins with known conformational changes!


In [None]:
#@title Challenge: RfaH Protein
#@markdown RfaH undergoes a dramatic conformational change. Try predicting both states!

# RfaH sequence (PDB: 2OUG, 5OND)
RFAH_SEQUENCE = "MGSSHHHHHHSSGLVPRGSHMTTQELKRIRELTAKLSGDTLSAIEAALEAAQAAAQALIQAQRAAQIAQ" \
                "AAKAAQAAKAAQAAKAARAAQTAQAAKAAQTAKAAQAAKAAQAAKAARAAQQAKAAQAAKAAQAAKAAR" \
                "AAQQAKAAQAAKAAQAAKAAQAARAAQAAKAAQAAKAAQAARAAQAAQAAKAARAAQAAQAARAAQAAQ"

print("üß¨ RfaH Challenge:")
print(f"   Length: {len(RFAH_SEQUENCE)} residues")
print("   Known states: autoinhibited (2OUG) and active (5OND)")
print("\nüìù Your task:")
print("   1. Predict with and without MSA")
print("   2. Compare conformations")
print("   3. Try different recycling numbers")
print("   4. Generate an ensemble with dropout")
print("\nModify the helper functions above to work with RfaH!")


## Section 10: Advanced Visualization - MSA Analysis and LogMD Trajectories

Now let's explore advanced visualization techniques including MSA coevolution analysis and LogMD trajectory visualization.


In [None]:
#@title MSA Visualization and Coevolution Analysis
#@markdown Visualize MSA coverage and compute coevolution matrices

import matplotlib.pyplot as plt
import numpy as np
from colabdesign.af.contrib import predict

def compute_coevolution(msa_array):
    """
    Compute coevolution matrix from MSA using mutual information.
    Adapted from AlphaMask implementation.
    """
    import jax.numpy as jnp
    
    # Convert to one-hot encoding
    Y = jnp.eye(22)[msa_array]  # 22 includes gaps and X
    N, L, A = Y.shape
    Y_flat = Y.reshape(N, -1)
    
    # Compute covariance
    c = jnp.cov(Y_flat.T)
    
    # Add pseudocount for numerical stability
    shrink = 4.5/jnp.sqrt(N) * jnp.eye(c.shape[0])
    ic = jnp.linalg.inv(c + shrink)
    
    # Compute partial correlation coefficient
    ic_diag = jnp.diag(ic)
    pcc = ic / jnp.sqrt(ic_diag[:, None] * ic_diag[None, :])
    
    # Reshape and compute Frobenius norm
    raw = jnp.sqrt(jnp.square(pcc.reshape(L, A, L, A)[:, :20, :, :20]).sum((1, 3)))
    
    # Zero out diagonal
    i = jnp.arange(L)
    raw = raw.at[i, i].set(0)
    
    # Apply average product correction (APC)
    ap = raw.sum(0, keepdims=True) * raw.sum(1, keepdims=True) / raw.sum()
    coev = (raw - ap).at[i, i].set(0)
    
    return np.array(coev)

def visualize_msa_and_coevolution(sequence, msa_mode="mmseqs2", jobname="i89_msa_analysis"):
    """
    Generate and visualize MSA with coevolution analysis.
    """
    import os
    import tempfile
    
    # Create temporary directory for MSA
    os.makedirs(jobname, exist_ok=True)
    
    if msa_mode == "mmseqs2":
        print("Generating MSA using MMseqs2...")
        # This would normally call MMseqs2, but for demo we'll use a simplified version
        # In real usage, you'd use: predict.get_msa([sequence], jobname)
        
        # For demonstration, create a synthetic MSA
        with open(f"{jobname}/msa.a3m", "w") as f:
            f.write(f">query\n{sequence}\n")
            # Add some homologous sequences (simplified)
            for i in range(10):
                # Create variations of the sequence
                varied_seq = list(sequence)
                for j in range(0, len(sequence), 20):
                    if np.random.random() > 0.7:
                        varied_seq[j] = np.random.choice(list("ACDEFGHIKLMNPQRSTVWY"))
                f.write(f">seq_{i}\n{''.join(varied_seq)}\n")
    else:
        # Single sequence mode
        with open(f"{jobname}/msa.a3m", "w") as f:
            f.write(f">query\n{sequence}\n")
    
    # Parse MSA
    sequences, deletion_matrix = predict.parse_a3m(f"{jobname}/msa.a3m")
    msa_array = np.array(sequences)
    
    print(f"MSA Statistics:")
    print(f"   Number of sequences: {len(sequences)}")
    print(f"   Sequence length: {len(sequences[0]) if sequences else 0}")
    print(f"   Effective sequences (Neff): {len(sequences)}")
    
    # Compute coevolution
    if len(sequences) > 1:
        coev_matrix = compute_coevolution(msa_array)
    else:
        coev_matrix = np.zeros((len(sequence), len(sequence)))
    
    # Create visualization
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # Plot 1: MSA Coverage
    ax = axes[0, 0]
    coverage = np.sum(msa_array != 20, axis=0) / len(msa_array)  # 20 is gap
    ax.plot(coverage * 100, color='steelblue', linewidth=2)
    ax.set_xlabel('Position')
    ax.set_ylabel('Coverage (%)')
    ax.set_title('MSA Coverage per Position')
    ax.set_ylim(0, 105)
    ax.grid(True, alpha=0.3)
    
    # Plot 2: Sequence Identity Distribution
    ax = axes[0, 1]
    if len(sequences) > 1:
        # Calculate sequence identity to query
        query_seq = msa_array[0]
        identities = []
        for seq in msa_array[1:]:
            identity = np.mean(seq == query_seq) * 100
            identities.append(identity)
        ax.hist(identities, bins=20, color='coral', alpha=0.7, edgecolor='black')
        ax.set_xlabel('Sequence Identity to Query (%)')
        ax.set_ylabel('Count')
        ax.set_title('Sequence Identity Distribution')
    else:
        ax.text(0.5, 0.5, 'Single sequence\n(no homologs)', 
                ha='center', va='center', transform=ax.transAxes)
        ax.set_title('Sequence Identity Distribution')
    
    # Plot 3: Coevolution Matrix
    ax = axes[1, 0]
    im = ax.imshow(coev_matrix, cmap='RdBu_r', vmin=-np.max(np.abs(coev_matrix)), 
                   vmax=np.max(np.abs(coev_matrix)))
    ax.set_xlabel('Position')
    ax.set_ylabel('Position')
    ax.set_title('Coevolution Matrix')
    plt.colorbar(im, ax=ax, label='Coevolution Score')
    
    # Plot 4: Top Coevolving Pairs
    ax = axes[1, 1]
    # Get top coevolving pairs
    upper_tri = np.triu_indices_from(coev_matrix, k=6)  # At least 6 residues apart
    coev_values = coev_matrix[upper_tri]
    top_indices = np.argsort(coev_values)[-20:]  # Top 20 pairs
    
    top_pairs = [(upper_tri[0][i], upper_tri[1][i], coev_values[i]) 
                 for i in top_indices]
    top_pairs.sort(key=lambda x: x[2], reverse=True)
    
    if len(top_pairs) > 0:
        positions = list(range(len(top_pairs)))
        values = [p[2] for p in top_pairs]
        ax.barh(positions, values, color='green', alpha=0.7)
        ax.set_yticks(positions[:10])  # Show first 10
        ax.set_yticklabels([f"{p[0]+1}-{p[1]+1}" for p in top_pairs[:10]], fontsize=8)
        ax.set_xlabel('Coevolution Score')
        ax.set_title('Top Coevolving Residue Pairs')
        ax.invert_yaxis()
    else:
        ax.text(0.5, 0.5, 'No significant\ncoevolution', 
                ha='center', va='center', transform=ax.transAxes)
        ax.set_title('Top Coevolving Residue Pairs')
    
    plt.tight_layout()
    return fig, msa_array, coev_matrix

# Analyze MSA for i89
print("Analyzing MSA and coevolution for i89...")
fig, msa_data, coev = visualize_msa_and_coevolution(I89_SEQUENCE, msa_mode="mmseqs2")
plt.show()

print("\nKey Observations:")
print("   - High coevolution indicates functionally coupled residues")
print("   - MSA depth affects coevolution signal strength")
print("   - Coverage gaps may indicate flexible/disordered regions")


In [None]:
#@title LogMD Trajectory Visualization
#@markdown Create an interactive trajectory viewer for ensemble structures

def create_logmd_trajectory(structures, labels=None, reference_structure=None, 
                           project_name="", max_structures=20):
    """
    Create a LogMD trajectory from ensemble structures.
    Adapted from AlphaMask's create_superimposed_logmd_trajectory_sdk.
    
    Args:
        structures: List of structure arrays from predictions
        labels: Optional labels for each structure
        reference_structure: Optional reference for superposition
        project_name: LogMD project name (empty for anonymous)
        max_structures: Maximum number of structures to include
        
    Returns:
        LogMD instance or URL string
    """
    try:
        import logmd
        from Bio.PDB import Superimposer
        from Bio import PDB
        import tempfile
        import os
    except ImportError:
        print("LogMD not installed. Install with: pip install logmd")
        return None
    
    # Limit number of structures
    structures = structures[:max_structures]
    if labels:
        labels = labels[:max_structures]
    
    print(f"Creating LogMD trajectory with {len(structures)} structures...")
    
    # Initialize LogMD
    logmd_instance = logmd.LogMD(project_name=project_name)
    
    # If we have a reference, superimpose all structures to it
    if reference_structure is not None:
        print("Superimposing structures to reference...")
        ref_ca = reference_structure[:, 1, :]  # CA atoms
        
        aligned_structures = []
        for i, struct in enumerate(structures):
            struct_ca = struct[:, 1, :]  # CA atoms
            
            # Use Kabsch algorithm for alignment
            aligned_ca = _np_kabsch(struct_ca, ref_ca, return_v=False, use_jax=False)
            
            # Apply transformation to all atoms
            aligned_struct = struct.copy()
            for atom_idx in range(struct.shape[1]):
                aligned_struct[:, atom_idx, :] = _np_kabsch(
                    struct[:, atom_idx, :], ref_ca, return_v=False, use_jax=False
                )
            aligned_structures.append(aligned_struct)
        
        structures = aligned_structures
    
    # Calculate RMSD values if reference provided
    rmsd_values = []
    if reference_structure is not None:
        ref_ca = reference_structure[:, 1, :]
        for struct in structures:
            struct_ca = struct[:, 1, :]
            rmsd = _np_rmsd(struct_ca, ref_ca, use_jax=False)
            rmsd_values.append(rmsd)
    
    # Sort by RMSD if available
    if rmsd_values:
        sorted_indices = np.argsort(rmsd_values)
        structures = [structures[i] for i in sorted_indices]
        if labels:
            labels = [labels[i] for i in sorted_indices]
        rmsd_values = [rmsd_values[i] for i in sorted_indices]
    
    # Convert structures to PDB format and add to LogMD
    with tempfile.TemporaryDirectory() as tmpdir:
        for i, struct in enumerate(structures):
            # Create a simple PDB string
            pdb_lines = []
            pdb_lines.append("MODEL     1")
            
            for res_idx in range(struct.shape[0]):
                ca_coord = struct[res_idx, 1, :]  # CA atom
                pdb_lines.append(
                    f"ATOM  {res_idx+1:5d}  CA  ALA A{res_idx+1:4d}    "
                    f"{ca_coord[0]:8.3f}{ca_coord[1]:8.3f}{ca_coord[2]:8.3f}"
                    f"  1.00  0.00           C  "
                )
            pdb_lines.append("ENDMDL")
            
            pdb_content = "\n".join(pdb_lines)
            
            # Add to LogMD with metadata
            metadata = {
                "frame": i,
                "label": labels[i] if labels else f"Structure_{i}"
            }
            if rmsd_values:
                metadata["rmsd"] = f"{rmsd_values[i]:.2f}"
            
            logmd_instance.add_pdb_string(pdb_content, metadata=metadata)
    
    # Get the URL
    url = logmd_instance.url
    
    print(f"LogMD trajectory created!")
    print(f"URL: {url}")
    print(f"\nViewer Features:")
    print("   - Animation controls for navigating frames")
    print("   - Zoom, pan, and rotate in 3D")
    print("   - Frame metadata display")
    if rmsd_values:
        print("   - Structures sorted by RMSD to reference")
    
    return logmd_instance

# Create LogMD trajectory from our ensemble
if 'ensemble_structures' in globals() and len(ensemble_structures) > 0:
    print("Creating LogMD visualization of ensemble...")
    
    # Use State 1 as reference for alignment
    logmd_traj = create_logmd_trajectory(
        structures=ensemble_structures[:10],  # First 10 structures
        labels=ensemble_labels[:10] if 'ensemble_labels' in globals() else None,
        reference_structure=result_with_msa['structure'] if 'result_with_msa' in globals() else None,
        project_name="i89_ensemble",
        max_structures=10
    )
    
    if logmd_traj:
        # Display in notebook if possible
        try:
            from IPython.display import IFrame
            display(IFrame(src=logmd_traj.url, width=800, height=600))
        except:
            print(f"Open this URL in a browser to view: {logmd_traj.url}")
else:
    print("Note: Run the ensemble generation section first to create structures for visualization")


In [None]:
#@title Comparative MSA Analysis: With vs Without MSA
#@markdown Compare coevolution patterns between different MSA conditions

def compare_msa_conditions():
    """
    Compare MSA and coevolution between with/without MSA conditions.
    Adapted from AlphaMask's visualize_experiment_conditions.
    """
    fig, axes = plt.subplots(2, 3, figsize=(15, 8))
    
    conditions = ["With Full MSA", "Without MSA (Single Seq)"]
    msa_modes = ["mmseqs2", "single_sequence"]
    
    coev_matrices = []
    
    for idx, (condition, mode) in enumerate(zip(conditions, msa_modes)):
        # Generate MSA for this condition
        _, msa_array, coev_matrix = visualize_msa_and_coevolution(
            I89_SEQUENCE, msa_mode=mode, jobname=f"temp_{mode}"
        )
        coev_matrices.append(coev_matrix)
        
        # Plot coevolution matrix
        ax = axes[idx, 0]
        im = ax.imshow(coev_matrix, cmap='RdBu_r', 
                       vmin=-0.5, vmax=0.5)
        ax.set_title(f'{condition}\nCoevolution Matrix')
        ax.set_xlabel('Position')
        ax.set_ylabel('Position')
        plt.colorbar(im, ax=ax)
        
        # Plot contact map (top coevolving pairs)
        ax = axes[idx, 1]
        contact_threshold = np.percentile(coev_matrix[np.triu_indices_from(coev_matrix, k=6)], 95)
        contacts = coev_matrix > contact_threshold
        ax.imshow(contacts, cmap='Greys', vmin=0, vmax=1)
        ax.set_title(f'Predicted Contacts\n(Top 5%)')
        ax.set_xlabel('Position')
        ax.set_ylabel('Position')
        
    # Plot difference
    ax = axes[0, 2]
    diff_matrix = coev_matrices[0] - coev_matrices[1]
    im = ax.imshow(diff_matrix, cmap='PiYG', 
                   vmin=-np.max(np.abs(diff_matrix)), 
                   vmax=np.max(np.abs(diff_matrix)))
    ax.set_title('Difference\n(With MSA - Without MSA)')
    ax.set_xlabel('Position')
    ax.set_ylabel('Position')
    plt.colorbar(im, ax=ax)
    
    # Summary statistics
    ax = axes[1, 2]
    ax.axis('off')
    
    stats_text = "Coevolution Statistics:\n\n"
    for idx, condition in enumerate(conditions):
        coev = coev_matrices[idx]
        upper_tri = np.triu_indices_from(coev, k=6)
        values = coev[upper_tri]
        
        stats_text += f"{condition}:\n"
        stats_text += f"  Mean: {np.mean(values):.4f}\n"
        stats_text += f"  Max: {np.max(values):.4f}\n"
        stats_text += f"  Std: {np.std(values):.4f}\n\n"
    
    # Add difference statistics
    diff_values = diff_matrix[upper_tri]
    stats_text += "Difference:\n"
    stats_text += f"  Mean: {np.mean(diff_values):.4f}\n"
    stats_text += f"  Max: {np.max(np.abs(diff_values)):.4f}\n"
    
    ax.text(0.1, 0.5, stats_text, transform=ax.transAxes, 
            fontsize=10, verticalalignment='center', family='monospace')
    
    plt.suptitle('MSA Effect on Coevolution Analysis', fontsize=14, y=1.02)
    plt.tight_layout()
    return fig

# Run comparative analysis
print("Comparing coevolution patterns with and without MSA...")
comparison_fig = compare_msa_conditions()
plt.show()

print("\nKey Insights:")
print("   - With MSA: Strong coevolution signals from evolutionary constraints")
print("   - Without MSA: Minimal coevolution (only sequence-intrinsic patterns)")
print("   - Difference highlights evolutionarily coupled positions")
print("   - These patterns influence AlphaFold2's structure predictions")
