# Algorithm 3: Atom Transformer (AlphaFold3)

The Atom Transformer processes per-atom features, a key innovation in AF3. Unlike AF2 which primarily works at residue level, AF3 explicitly models individual atoms.

## Source Code Location
- **File**: `AF3-Ref-src/alphafold3-official/src/alphafold3/model/network/atom_cross_attention.py`

## Overview

### Per-Atom Representations

| Representation | Shape | Description |
|---------------|-------|-------------|
| `p` | [N_atoms, c_atom_pair] | Pairwise atom features |
| `q` | [N_atoms, c_atom] | Per-atom features |

### Key Innovation
AF3 maintains explicit atom-level representations that are updated through the Atom Transformer, enabling better modeling of ligands and non-standard residues.

In [None]:
import numpy as np
np.random.seed(42)

def layer_norm(x, eps=1e-5):
    mean = np.mean(x, axis=-1, keepdims=True)
    var = np.var(x, axis=-1, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)

def softmax(x, axis=-1):
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

def swish(x):
    return x / (1 + np.exp(-x))

In [None]:
def atom_transformer_block(q, p, num_heads=4, c=32):
    """
    Single Atom Transformer block.
    
    Args:
        q: Per-atom features [N_atoms, c_atom]
        p: Atom pair features [N_atoms, N_atoms, c_pair]
        num_heads: Number of attention heads
        c: Head dimension
    
    Returns:
        Updated q
    """
    N_atoms, c_atom = q.shape
    c_pair = p.shape[-1]
    
    # Self-attention with pair bias
    q_norm = layer_norm(q)
    p_norm = layer_norm(p)
    
    # QKV projections
    W_q = np.random.randn(c_atom, num_heads, c) * (c_atom ** -0.5)
    W_k = np.random.randn(c_atom, num_heads, c) * (c_atom ** -0.5)
    W_v = np.random.randn(c_atom, num_heads, c) * (c_atom ** -0.5)
    
    queries = np.einsum('ic,chd->ihd', q_norm, W_q)
    keys = np.einsum('jc,chd->jhd', q_norm, W_k)
    values = np.einsum('jc,chd->jhd', q_norm, W_v)
    
    # Pair bias
    W_b = np.random.randn(c_pair, num_heads) * (c_pair ** -0.5)
    bias = np.einsum('ijc,ch->ijh', p_norm, W_b)
    
    # Attention
    attn_logits = np.einsum('ihd,jhd->ijh', queries, keys) / np.sqrt(c)
    attn_logits = attn_logits + bias
    attn_weights = softmax(attn_logits, axis=1)
    
    # Apply attention
    attended = np.einsum('ijh,jhd->ihd', attn_weights, values)
    
    # Output projection
    W_o = np.random.randn(num_heads, c, c_atom) * ((num_heads * c) ** -0.5)
    output = np.einsum('ihd,hdc->ic', attended, W_o)
    
    # Residual
    q = q + output
    
    # FFN
    q_norm = layer_norm(q)
    W_up = np.random.randn(c_atom, c_atom * 4) * (c_atom ** -0.5)
    W_down = np.random.randn(c_atom * 4, c_atom) * ((c_atom * 4) ** -0.5)
    
    hidden = swish(q_norm @ W_up)
    q = q + hidden @ W_down
    
    return q

In [None]:
def atom_transformer(q, p, num_blocks=3):
    """
    Full Atom Transformer stack.
    
    Args:
        q: Per-atom features [N_atoms, c_atom]
        p: Atom pair features [N_atoms, N_atoms, c_pair]
        num_blocks: Number of transformer blocks
    """
    print(f"Atom Transformer")
    print(f"="*50)
    print(f"Atoms: {q.shape[0]}, Atom dim: {q.shape[1]}")
    print(f"Pair dim: {p.shape[-1]}")
    
    for i in range(num_blocks):
        q = atom_transformer_block(q, p)
        print(f"  Block {i+1}: q_norm = {np.linalg.norm(q):.2f}")
    
    return q

In [None]:
# Test
print("Test: Atom Transformer")
print("="*60)

N_atoms = 150  # ~5 residues worth of atoms
c_atom = 128
c_pair = 16

q = np.random.randn(N_atoms, c_atom) * 0.1
p = np.random.randn(N_atoms, N_atoms, c_pair) * 0.1

q_out = atom_transformer(q, p, num_blocks=3)

print(f"\nInput norm: {np.linalg.norm(q):.2f}")
print(f"Output norm: {np.linalg.norm(q_out):.2f}")
print(f"Output finite: {np.isfinite(q_out).all()}")

## Key Insights

1. **Atom-Level**: Explicitly models individual atoms, not just residues
2. **Pair Bias**: Uses pairwise atom features as attention bias
3. **3 Blocks**: Default configuration uses 3 transformer blocks
4. **Ligand Support**: Enables modeling of ligands, ions, and non-standard residues