# Module 1: SE(3) Equivariance and Geometric Deep Learning

**üìç Notebook 4 of 8**

## üíª GPU Requirements
**‚úÖ No GPU needed!** All examples run on CPU.

---

## üéØ Learning Objectives

By the end of this notebook, you will:

1. Understand what symmetries are and why they matter
2. Know the difference between invariance and equivariance
3. Understand the SE(3) group (rotations + translations in 3D)
4. Learn how to build equivariant neural networks
5. Implement simple equivariant operations
6. Understand why SE(3) equivariance is crucial for protein design

## üåü Why Geometric Deep Learning?

**Problem**: Traditional neural networks don't respect 3D geometry!

If we rotate a protein:
- ‚ùå Standard CNN: Different prediction
- ‚úÖ SE(3) Equivariant: Rotated prediction (correct!)

**This is CRITICAL for proteins** - structure is defined up to rotation/translation.

In [None]:
# Import libraries
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.spatial.transform import Rotation

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
np.set_printoptions(precision=3, suppress=True)
np.random.seed(42)

print("‚úÖ Libraries loaded!")

## üîÑ Symmetries: Invariance vs. Equivariance

### Definitions:

**Invariance**: Output doesn't change when input is transformed
```
f(g¬∑x) = f(x)
```
Example: Classifying protein type - rotation shouldn't change the class

**Equivariance**: Output transforms in the same way as input
```
f(g¬∑x) = g¬∑f(x)
```
Example: Predicting atom positions - if input is rotated, output should be rotated the same way

### For Proteins:

- **Invariant**: Energy, stability, binding affinity
- **Equivariant**: Atom positions, backbone coordinates, structural features

**RFDiffusion needs equivariance** - if we rotate input noise, output structure should also be rotated!

In [None]:
# Demonstrate invariance vs equivariance with simple 2D example
def simple_invariant_function(points):
    """Invariant: returns scalar (distance from origin)."""
    return np.mean(np.linalg.norm(points, axis=1))

def simple_equivariant_function(points):
    """Equivariant: returns transformed points (center of mass)."""
    return np.mean(points, axis=0)

# Create simple 2D points
points = np.array([[1, 1], [2, 1], [1.5, 2]])

# Rotation matrix (90 degrees)
theta = np.pi/2
R = np.array([[np.cos(theta), -np.sin(theta)],
              [np.sin(theta), np.cos(theta)]])

# Rotate points
points_rotated = points @ R.T

# Test invariance
inv_original = simple_invariant_function(points)
inv_rotated = simple_invariant_function(points_rotated)

# Test equivariance
eq_original = simple_equivariant_function(points)
eq_rotated = simple_equivariant_function(points_rotated)
eq_original_then_rotated = eq_original @ R.T

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Original
ax1 = axes[0]
ax1.scatter(points[:, 0], points[:, 1], s=100, c='blue', label='Original points', zorder=3)
ax1.scatter(*eq_original, s=200, c='red', marker='*', label='Center (equivariant)', zorder=4)
ax1.set_xlim(-3, 3)
ax1.set_ylim(-3, 3)
ax1.set_aspect('equal')
ax1.grid(True, alpha=0.3)
ax1.legend()
ax1.set_title(f'Original\nInvariant value: {inv_original:.3f}', fontweight='bold')

# Rotated
ax2 = axes[1]
ax2.scatter(points_rotated[:, 0], points_rotated[:, 1], s=100, c='green', label='Rotated points', zorder=3)
ax2.scatter(*eq_rotated, s=200, c='red', marker='*', label='Center (equivariant)', zorder=4)
ax2.set_xlim(-3, 3)
ax2.set_ylim(-3, 3)
ax2.set_aspect('equal')
ax2.grid(True, alpha=0.3)
ax2.legend()
ax2.set_title(f'Rotated 90¬∞\nInvariant value: {inv_rotated:.3f}', fontweight='bold')

# Comparison
ax3 = axes[2]
ax3.text(0.5, 0.8, 'Invariance Check:', ha='center', fontsize=13, fontweight='bold', transform=ax3.transAxes)
ax3.text(0.5, 0.65, f'Original: {inv_original:.3f}', ha='center', fontsize=11, transform=ax3.transAxes)
ax3.text(0.5, 0.55, f'Rotated:  {inv_rotated:.3f}', ha='center', fontsize=11, transform=ax3.transAxes)
ax3.text(0.5, 0.45, f'Same? {np.isclose(inv_original, inv_rotated)} ‚úì', ha='center', fontsize=11, 
         color='green', fontweight='bold', transform=ax3.transAxes)

ax3.text(0.5, 0.25, 'Equivariance Check:', ha='center', fontsize=13, fontweight='bold', transform=ax3.transAxes)
ax3.text(0.5, 0.10, f'f(R(x)) = R(f(x))?', ha='center', fontsize=11, transform=ax3.transAxes)
ax3.text(0.5, 0.0, f'{np.allclose(eq_rotated, eq_original_then_rotated)} ‚úì', ha='center', fontsize=11,
         color='green', fontweight='bold', transform=ax3.transAxes)
ax3.axis('off')

plt.tight_layout()
plt.show()

print("üìå Key Points:")
print("   - Invariant function: Same output after rotation")
print("   - Equivariant function: Output rotates with input")

## üé≤ The SE(3) Group

**SE(3)** = Special Euclidean group in 3D

### What is it?

The group of all **rigid body transformations** in 3D space:
- **Rotations** (SO(3) - Special Orthogonal group)
- **Translations** (‚Ñù¬≥)

### Elements:

Any element g ‚àà SE(3) can be written as:
```
g = (R, t)
```
Where:
- R: 3√ó3 rotation matrix (orthogonal, det(R)=1)
- t: 3√ó1 translation vector

### Group Operations:

**Composition**:
```
g‚ÇÅ ‚àò g‚ÇÇ = (R‚ÇÅR‚ÇÇ, R‚ÇÅt‚ÇÇ + t‚ÇÅ)
```

**Inverse**:
```
g‚Åª¬π = (R^T, -R^T t)
```

**Identity**:
```
e = (I, 0)
```

### Action on Points:

Given point x ‚àà ‚Ñù¬≥:
```
g¬∑x = Rx + t
```

In [None]:
# Implement SE(3) transformations
class SE3Transform:
    """SE(3) transformation: rotation + translation."""
    
    def __init__(self, R, t):
        """
        Args:
            R: 3x3 rotation matrix
            t: 3x1 translation vector
        """
        self.R = R
        self.t = t
    
    def apply(self, x):
        """Apply transformation to point(s)."""
        if x.ndim == 1:
            return self.R @ x + self.t
        else:  # Multiple points
            return (self.R @ x.T).T + self.t
    
    def compose(self, other):
        """Compose two transformations: self ‚àò other."""
        R_new = self.R @ other.R
        t_new = self.R @ other.t + self.t
        return SE3Transform(R_new, t_new)
    
    def inverse(self):
        """Inverse transformation."""
        R_inv = self.R.T
        t_inv = -R_inv @ self.t
        return SE3Transform(R_inv, t_inv)

# Create a simple 3D protein-like structure
def make_helix_3d(n_points=20):
    """Generate helix in 3D."""
    t = np.linspace(0, 4*np.pi, n_points)
    x = np.cos(t)
    y = np.sin(t)
    z = t / 2
    return np.stack([x, y, z], axis=1)

protein = make_helix_3d()

# Create random SE(3) transformation
rot = Rotation.from_euler('xyz', [30, 45, 60], degrees=True).as_matrix()
trans = np.array([2, 1, 3])
g = SE3Transform(rot, trans)

# Transform protein
protein_transformed = g.apply(protein)

# Visualize
fig = plt.figure(figsize=(14, 6))

# Original
ax1 = fig.add_subplot(121, projection='3d')
ax1.plot(protein[:, 0], protein[:, 1], protein[:, 2], 'o-', linewidth=2, markersize=6, label='Original')
ax1.set_xlabel('X')
ax1.set_ylabel('Y')
ax1.set_zlabel('Z')
ax1.set_title('Original Structure', fontsize=13, fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Transformed
ax2 = fig.add_subplot(122, projection='3d')
ax2.plot(protein_transformed[:, 0], protein_transformed[:, 1], protein_transformed[:, 2], 
         'o-', linewidth=2, markersize=6, color='orange', label='Transformed')
ax2.set_xlabel('X')
ax2.set_ylabel('Y')
ax2.set_zlabel('Z')
ax2.set_title('After SE(3) Transform\n(Rotation + Translation)', fontsize=13, fontweight='bold')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Test group properties
print("üî¨ Testing SE(3) Group Properties:")
print("\n1. Composition:")
g2 = SE3Transform(Rotation.from_euler('z', 90, degrees=True).as_matrix(), np.array([0, 0, 1]))
g_composed = g.compose(g2)
print(f"   g1 ‚àò g2 = transformation composed ‚úì")

print("\n2. Inverse:")
g_inv = g.inverse()
identity_test = g.compose(g_inv).apply(protein)
print(f"   g ‚àò g‚Åª¬π ‚âà identity: {np.allclose(identity_test, protein)} ‚úì")

print("\n3. Identity:")
identity = SE3Transform(np.eye(3), np.zeros(3))
identity_test2 = identity.apply(protein)
print(f"   e¬∑x = x: {np.allclose(identity_test2, protein)} ‚úì")

## üß† Building Equivariant Operations

How do we build neural networks that are SE(3) equivariant?

### Key Principles:

1. **Use distances and angles** - these are invariant!
2. **Process in local frames** - transform to/from local coordinates
3. **Equivariant message passing** - aggregate information respecting geometry

### Simple Equivariant Operations:

**Invariant Features** (scalars):
- Distances: ||x_i - x_j||
- Angles: cos Œ∏
- Dihedral angles

**Equivariant Features** (vectors):
- Displacement vectors: (x_i - x_j) / ||x_i - x_j||
- Cross products
- Transformed vectors

### Building Blocks:

```python
# Equivariant operation example
def equivariant_layer(x, edges):
    # x: [N, 3] coordinates
    # edges: [E, 2] edge indices
    
    messages = []
    for i, j in edges:
        # Invariant: distance
        dist = np.linalg.norm(x[i] - x[j])
        
        # Equivariant: direction
        direction = (x[j] - x[i]) / (dist + 1e-8)
        
        # Message: scalar weight * equivariant direction
        message = weight_function(dist) * direction
        messages.append(message)
    
    # Aggregate messages (sum is equivariant!)
    return sum(messages)
```

In [None]:
# Implement a simple SE(3) equivariant layer
def simple_equivariant_layer(coords, k=3):
    """
    Simple SE(3) equivariant layer using k-nearest neighbors.
    
    Args:
        coords: [N, 3] point coordinates
        k: number of nearest neighbors
    
    Returns:
        new_coords: [N, 3] updated coordinates (equivariant)
    """
    N = len(coords)
    new_coords = np.zeros_like(coords)
    
    for i in range(N):
        # Find k nearest neighbors
        distances = np.linalg.norm(coords - coords[i], axis=1)
        neighbors = np.argsort(distances)[1:k+1]  # Exclude self
        
        # Aggregate messages from neighbors
        for j in neighbors:
            # Invariant feature: distance
            dist = distances[j]
            
            # Equivariant feature: normalized direction
            direction = (coords[j] - coords[i]) / (dist + 1e-8)
            
            # Weight based on distance (closer = stronger)
            weight = np.exp(-dist)
            
            # Equivariant update
            new_coords[i] += weight * direction
    
    # Normalize (to prevent explosion)
    new_coords = new_coords * 0.1 + coords  # Residual connection
    
    return new_coords

# Test equivariance
coords_original = make_helix_3d(10)
coords_updated = simple_equivariant_layer(coords_original)

# Transform and update
rot_test = Rotation.from_euler('xyz', [20, 30, 40], degrees=True).as_matrix()
trans_test = np.array([1, 2, 3])
g_test = SE3Transform(rot_test, trans_test)

coords_transformed = g_test.apply(coords_original)
coords_transformed_updated = simple_equivariant_layer(coords_transformed)

# Check: f(g¬∑x) should equal g¬∑f(x)
coords_updated_then_transformed = g_test.apply(coords_updated)

# Visualize
fig = plt.figure(figsize=(15, 5))

ax1 = fig.add_subplot(131, projection='3d')
ax1.plot(coords_original[:, 0], coords_original[:, 1], coords_original[:, 2], 'o-', label='Input')
ax1.plot(coords_updated[:, 0], coords_updated[:, 1], coords_updated[:, 2], 's-', label='f(x)', alpha=0.7)
ax1.legend()
ax1.set_title('Original Path', fontweight='bold')
ax1.grid(True, alpha=0.3)

ax2 = fig.add_subplot(132, projection='3d')
ax2.plot(coords_transformed[:, 0], coords_transformed[:, 1], coords_transformed[:, 2], 'o-', label='g¬∑x')
ax2.plot(coords_transformed_updated[:, 0], coords_transformed_updated[:, 1], coords_transformed_updated[:, 2], 
         's-', label='f(g¬∑x)', alpha=0.7)
ax2.legend()
ax2.set_title('Transform Then Apply f', fontweight='bold')
ax2.grid(True, alpha=0.3)

ax3 = fig.add_subplot(133, projection='3d')
ax3.plot(coords_updated[:, 0], coords_updated[:, 1], coords_updated[:, 2], 'o-', label='f(x)', alpha=0.5)
ax3.plot(coords_updated_then_transformed[:, 0], coords_updated_then_transformed[:, 1], 
         coords_updated_then_transformed[:, 2], 's-', label='g¬∑f(x)', alpha=0.7)
ax3.legend()
ax3.set_title('Apply f Then Transform', fontweight='bold')
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Check equivariance
error = np.linalg.norm(coords_transformed_updated - coords_updated_then_transformed)
print(f"üî¨ Equivariance Test:")
print(f"   ||f(g¬∑x) - g¬∑f(x)|| = {error:.6f}")
print(f"   Equivariant? {error < 1e-10} ‚úì" if error < 1e-10 else f"   Approximately equivariant (error: {error:.6f})")

## üîß RFDiffusion's Approach: Invariant Point Attention

RFDiffusion uses **Invariant Point Attention (IPA)** from AlphaFold2.

### Key Ideas:

1. **Work in local frames** - each residue has its own coordinate system
2. **Compute invariant features** - distances, angles in local frames
3. **Attention mechanism** - weight contributions from other residues
4. **Update frames** - predict rotation/translation updates

### Simplified IPA:

```python
For each residue i:
    1. Transform all positions to local frame i
    2. Compute invariant features (distances in local frame)
    3. Attention weights = softmax(query ¬∑ key / sqrt(d))
    4. Aggregate information: weighted sum of values
    5. Predict frame update (ŒîR, Œît)
    6. Apply update: new_frame = old_frame ‚àò update
```

### Why This Works:

‚úÖ **SE(3) Equivariant**: Local frames transform correctly  
‚úÖ **Efficient**: Attention over frames, not all atoms  
‚úÖ **Powerful**: Can capture long-range interactions  
‚úÖ **Stable**: Working in local coordinates prevents numerical issues

In [None]:
# Visualize the concept of local frames
def visualize_local_frames(coords, n_show=5):
    """Visualize local coordinate frames."""
    fig = plt.figure(figsize=(12, 5))
    
    # Global view
    ax1 = fig.add_subplot(121, projection='3d')
    ax1.plot(coords[:, 0], coords[:, 1], coords[:, 2], 'o-', linewidth=2, markersize=8, 
             color='gray', alpha=0.5, label='Backbone')
    
    # Show local frames for a few points
    for i in range(0, len(coords), max(1, len(coords)//n_show)):
        # Simple frame: z along backbone, x and y perpendicular
        if i < len(coords)-1:
            z_axis = coords[i+1] - coords[i]
            z_axis = z_axis / (np.linalg.norm(z_axis) + 1e-8)
            
            # Arbitrary x axis perpendicular to z
            x_axis = np.cross(z_axis, np.array([0, 0, 1]))
            if np.linalg.norm(x_axis) < 0.1:
                x_axis = np.cross(z_axis, np.array([0, 1, 0]))
            x_axis = x_axis / (np.linalg.norm(x_axis) + 1e-8)
            
            # y axis
            y_axis = np.cross(z_axis, x_axis)
            
            # Draw axes
            origin = coords[i]
            scale = 1.0
            ax1.quiver(*origin, *x_axis, color='red', length=scale, arrow_length_ratio=0.3)
            ax1.quiver(*origin, *y_axis, color='green', length=scale, arrow_length_ratio=0.3)
            ax1.quiver(*origin, *z_axis, color='blue', length=scale, arrow_length_ratio=0.3)
    
    ax1.set_xlabel('X')
    ax1.set_ylabel('Y')
    ax1.set_zlabel('Z')
    ax1.set_title('Local Frames in Global Coordinates', fontweight='bold')
    ax1.legend()
    
    # Local view (as seen from frame 0)
    ax2 = fig.add_subplot(122, projection='3d')
    if len(coords) > 1:
        # Transform all points to frame 0
        z_axis = coords[1] - coords[0]
        z_axis = z_axis / (np.linalg.norm(z_axis) + 1e-8)
        x_axis = np.cross(z_axis, np.array([0, 0, 1]))
        if np.linalg.norm(x_axis) < 0.1:
            x_axis = np.cross(z_axis, np.array([0, 1, 0]))
        x_axis = x_axis / (np.linalg.norm(x_axis) + 1e-8)
        y_axis = np.cross(z_axis, x_axis)
        
        R_frame = np.column_stack([x_axis, y_axis, z_axis])
        t_frame = coords[0]
        
        # Transform to local coordinates
        coords_local = (coords - t_frame) @ R_frame
        
        ax2.plot(coords_local[:, 0], coords_local[:, 1], coords_local[:, 2], 
                'o-', linewidth=2, markersize=8, color='orange', alpha=0.7)
        ax2.quiver(0, 0, 0, 1, 0, 0, color='red', length=1, arrow_length_ratio=0.3, label='Local X')
        ax2.quiver(0, 0, 0, 0, 1, 0, color='green', length=1, arrow_length_ratio=0.3, label='Local Y')
        ax2.quiver(0, 0, 0, 0, 0, 1, color='blue', length=1, arrow_length_ratio=0.3, label='Local Z')
    
    ax2.set_xlabel('Local X')
    ax2.set_ylabel('Local Y')
    ax2.set_zlabel('Local Z')
    ax2.set_title('View from Local Frame 0', fontweight='bold')
    ax2.legend()
    
    plt.tight_layout()
    plt.show()

visualize_local_frames(make_helix_3d(15), n_show=4)

print("üìå Understanding Local Frames:")
print("   - Each residue has its own coordinate system")
print("   - Operations in local frames are inherently equivariant")
print("   - IPA computes attention in these local frames")

---

## üìù Key Takeaways

### Why SE(3) Equivariance Matters for Proteins

1. **Proteins are 3D Geometric Objects**
   - Their function depends on 3D structure, not absolute position/orientation
   - A rotated or translated protein is still the same protein
   - Models should respect this geometric property

2. **SE(3) Equivariance Preserves Geometry**
   - If you transform input coordinates ‚Üí output transforms consistently
   - Property: `f(g¬∑x) = g¬∑f(x)` for all SE(3) transformations g
   - Leads to better generalization and sample efficiency

3. **Local Frames Enable Efficient Computation**
   - Global equivariance is expensive (all-to-all comparisons)
   - Local frames reduce this to invariant scalar operations
   - IPA combines the best of both worlds

4. **Building Blocks are Simple**
   - Invariant: distances, angles, scalar products in local frames
   - Equivariant: weighted sums of vectors, coordinate updates
   - Complex architectures built from these primitives

### Connection to RFDiffusion

- RFDiffusion operates on **rigid body frames** (one per residue)
- Uses **IPA layers** to communicate between residues
- Outputs **frame updates** (rotations + translations) at each diffusion step
- SE(3) equivariance is **built into the architecture**, not learned

### Why This Is Powerful

‚úÖ **Correctness**: Model understands 3D geometry  
‚úÖ **Efficiency**: No need to learn from rotated/translated examples  
‚úÖ **Generalization**: Works for unseen orientations  
‚úÖ **Interpretability**: Updates have clear geometric meaning

---

## ‚úÖ Self-Check Questions

Test your understanding of SE(3) equivariance:

1. **Conceptual Understanding**
   - What is the difference between invariance and equivariance?
   - Why is SE(3) the natural symmetry group for proteins?
   - What does the equation `f(g¬∑x) = g¬∑f(x)` mean intuitively?

2. **Group Properties**
   - What are the four requirements for a mathematical group?
   - How do you compose two SE(3) transformations?
   - What is the inverse of a rotation + translation?

3. **Practical Design**
   - Give 3 examples of SE(3) invariant quantities
   - Give 3 examples of SE(3) equivariant operations
   - Why do we use k-nearest neighbors instead of all pairs?

4. **IPA Understanding**
   - What problem does IPA solve?
   - Why are local frames useful?
   - How does IPA maintain equivariance while doing attention?

5. **Application to RFDiffusion**
   - What does RFDiffusion predict at each diffusion step?
   - Why is equivariance important for protein diffusion models?
   - How would a non-equivariant model perform differently?

<details>
<summary>üí° Click for hints</summary>

- Invariance: output doesn't change with transformation
- Equivariance: output transforms the same way as input
- Local frames make distances/angles invariant by construction
- IPA = Invariant Point Attention, works in local coordinates
- RFDiffusion predicts frame updates (rotations + translations)

</details>

---

## üéØ Practice Exercises

### Exercise 1: Test Invariance (Easy)
Create a function that takes 3D coordinates and returns a scalar that is SE(3) invariant.
```python
def my_invariant_feature(coords):
    # Your code here
    # Should return same value for rotated/translated coords
    pass
```

**Ideas**: Sum of pairwise distances, radius of gyration, max distance from center

### Exercise 2: Implement Distance Matrix (Medium)
Write a function that computes the all-pairs distance matrix. Verify it's invariant.
```python
def distance_matrix(coords):
    # Your code here
    # Return N√óN matrix of pairwise distances
    pass

# Test
coords = make_helix_3d(10)
transform = SE3Transform(rotation=random_rotation(), translation=np.random.randn(3))
coords_transformed = transform.apply(coords)

# These should be equal (up to numerical precision)
D1 = distance_matrix(coords)
D2 = distance_matrix(coords_transformed)
```

### Exercise 3: Frame-Based Features (Medium)
Implement a function that computes features in local frames.
```python
def local_frame_features(coords, i):
    """Compute features for residue i in its local frame."""
    # 1. Build local frame at residue i
    # 2. Transform neighbors to local coordinates
    # 3. Compute distances/angles in local frame
    # Your code here
    pass
```

### Exercise 4: Simple IPA Layer (Hard)
Implement a simplified IPA layer (without all the bells and whistles).
```python
def simple_ipa_layer(frames, features, k=5):
    """
    Args:
        frames: List of SE3Transform objects (one per residue)
        features: (N, d) array of scalar features per residue
        k: Number of nearest neighbors
    
    Returns:
        updated_features: (N, d) updated features
    """
    # 1. For each residue, find k nearest neighbors (in 3D space)
    # 2. Transform neighbor positions to local frame
    # 3. Compute attention weights based on distances
    # 4. Aggregate neighbor features with attention
    # Your code here
    pass
```

### Exercise 5: Verify Equivariance (Hard)
Test that your IPA layer is truly equivariant.
```python
# Original
output1 = simple_ipa_layer(frames, features)

# Transform input
transform = SE3Transform(rotation=random_rotation(), translation=np.random.randn(3))
frames_transformed = [transform.compose(f) for f in frames]

# Apply layer to transformed input
output2 = simple_ipa_layer(frames_transformed, features)

# Check: output1 and output2 should be related by transform
# (This is trickier than it looks! What does it mean for features to transform?)
```

---

## üìö Further Reading

### Papers on Geometric Deep Learning

1. **SE(3)-Transformers**  
   Fuchs et al., 2020 - "SE(3)-Transformers: 3D Rototranslation Equivariant Attention Networks"  
   [arXiv:2006.10503](https://arxiv.org/abs/2006.10503)

2. **AlphaFold 2 (IPA introduced here)**  
   Jumper et al., 2021 - "Highly accurate protein structure prediction with AlphaFold"  
   [Nature Paper](https://www.nature.com/articles/s41586-021-03819-2)  
   [Supplementary Methods](https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf) - See section on IPA

3. **Geometric Deep Learning Book**  
   Bronstein et al., 2021  
   [Free online](https://geometricdeeplearning.com/)  
   Chapter 5: "Geometric graphs and sets"

4. **RFDiffusion Paper**  
   Watson et al., 2022 - "De novo design of protein structure and function with RFdiffusion"  
   [Nature Paper](https://www.nature.com/articles/s41586-023-06415-8)

### Tutorials and Resources

- **3Blue1Brown**: Visualizing quaternions and rotations (YouTube)
- **Geometric Deep Learning Proto-Book**: Chapters on symmetry and equivariance
- **PyTorch Geometric**: Library for geometric deep learning
- **e3nn**: Library for E(3) equivariant neural networks

### Related Notebooks in This Series

- **Notebook 03**: Protein Representation - Where we learned about frames
- **Notebook 05**: Unconditional Generation - Using IPA in practice (coming next!)
- **Notebook 06**: Motif Scaffolding - Conditional generation with equivariance

---

## üöÄ Next Steps

### You've Completed the Foundations! üéâ

You now understand:
- ‚úÖ Diffusion models and how they work (Notebook 02)
- ‚úÖ Protein structure representation (Notebook 03)  
- ‚úÖ SE(3) equivariance and geometric deep learning (Notebook 04)

### Ready for Generation

The next notebooks will put this knowledge into practice:

**Notebook 05: Unconditional Generation** (‚ö†Ô∏è GPU Optional)
- Implement full RFDiffusion sampling loop
- Generate proteins from scratch
- Use real IPA layers
- Visualize generated structures

**Notebook 06: Motif Scaffolding** (‚ö†Ô∏è GPU Optional)
- Conditional generation with fixed motifs
- Design proteins around functional sites
- Control topology and secondary structure

**Notebook 07: Symmetric Design** (‚úÖ GPU Recommended)
- Generate symmetric assemblies
- Understand symmetry constraints
- Design protein complexes

**Notebook 08: Evaluation & Analysis** (‚ùå No GPU)
- Metrics for generated proteins
- Quality assessment (pLDDT, pAE, etc.)
- Comparison with design objectives

### What Changed?

Starting with Notebook 05, we'll need **GPU compute** for:
- Running actual RFDiffusion model inference
- Generating structures (can take 5-15 minutes per protein)
- Testing different sampling strategies

**Reminder**: See [COLAB_SETUP.md](../../docs/COLAB_SETUP.md) for free GPU access via Google Colab!

---

**Continue to Notebook 05** ‚û°Ô∏è [05_unconditional_generation.ipynb](05_unconditional_generation.ipynb)