# Algorithm 18: Atom Cross Attention (AlphaFold3)

Cross-attention between atom representations and token representations.

## Source Code Location
- **File**: `AF3-Ref-src/alphafold3-official/src/alphafold3/model/network/atom_cross_attention.py`

## Overview

### Purpose
- Maps between atom-level and token-level representations
- Enables information flow between different granularities
- Critical for handling ligands and modified residues

### Flow
```
Token → Atom:  Broadcast token info to constituent atoms
Atom → Token:  Aggregate atom info back to tokens
```

In [None]:
import numpy as np
np.random.seed(42)

def layer_norm(x, eps=1e-5):
    mean = np.mean(x, axis=-1, keepdims=True)
    var = np.var(x, axis=-1, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)

def softmax(x, axis=-1):
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

In [None]:
def create_atom_token_mapping(N_tokens, atoms_per_token=5):
    """
    Create mapping between atoms and tokens.
    
    Returns:
        atom_to_token: [N_atoms] mapping each atom to its token
        token_to_atoms: List of atom indices for each token
    """
    N_atoms = N_tokens * atoms_per_token
    atom_to_token = np.repeat(np.arange(N_tokens), atoms_per_token)
    token_to_atoms = [list(range(i * atoms_per_token, (i + 1) * atoms_per_token)) 
                      for i in range(N_tokens)]
    
    return atom_to_token, token_to_atoms

In [None]:
def atom_cross_attention_to_atoms(q, s, atom_to_token, num_heads=4, c=32):
    """
    Cross-attention from tokens to atoms (Token → Atom).
    
    Each atom queries its parent token's representation.
    
    Args:
        q: Atom features [N_atoms, c_atom]
        s: Token features [N_tokens, c_s]
        atom_to_token: Mapping [N_atoms]
        num_heads: Number of attention heads
        c: Head dimension
    
    Returns:
        Updated atom features [N_atoms, c_atom]
    """
    N_atoms, c_atom = q.shape
    N_tokens, c_s = s.shape
    
    print(f"Atom Cross Attention (Token → Atom)")
    print(f"="*50)
    print(f"Atoms: {N_atoms}, Tokens: {N_tokens}")
    
    q_norm = layer_norm(q)
    s_norm = layer_norm(s)
    
    # Queries from atoms
    W_q = np.random.randn(c_atom, num_heads, c) * (c_atom ** -0.5)
    queries = np.einsum('ac,chd->ahd', q_norm, W_q)  # [N_atoms, H, c]
    
    # Keys and values from tokens
    W_k = np.random.randn(c_s, num_heads, c) * (c_s ** -0.5)
    W_v = np.random.randn(c_s, num_heads, c) * (c_s ** -0.5)
    keys = np.einsum('tc,chd->thd', s_norm, W_k)  # [N_tokens, H, c]
    values = np.einsum('tc,chd->thd', s_norm, W_v)
    
    # Each atom attends to its parent token
    parent_keys = keys[atom_to_token]  # [N_atoms, H, c]
    parent_values = values[atom_to_token]
    
    # Simplified: direct attention to parent (not full cross-attention)
    attn = np.einsum('ahd,ahd->ah', queries, parent_keys) / np.sqrt(c)
    attn = softmax(attn, axis=-1)  # [N_atoms, H]
    
    # Weighted values
    output = np.einsum('ah,ahd->ahd', attn, parent_values)
    
    # Output projection
    W_o = np.random.randn(num_heads, c, c_atom) * ((num_heads * c) ** -0.5)
    output = np.einsum('ahd,hdc->ac', output, W_o)
    
    print(f"Output: {output.shape}")
    
    return output

In [None]:
def atom_cross_attention_to_tokens(q, s, atom_to_token, num_heads=4, c=32):
    """
    Cross-attention from atoms to tokens (Atom → Token).
    
    Each token aggregates information from its constituent atoms.
    
    Args:
        q: Atom features [N_atoms, c_atom]
        s: Token features [N_tokens, c_s]
        atom_to_token: Mapping [N_atoms]
    
    Returns:
        Updated token features [N_tokens, c_s]
    """
    N_atoms, c_atom = q.shape
    N_tokens, c_s = s.shape
    
    print(f"Atom Cross Attention (Atom → Token)")
    print(f"="*50)
    
    q_norm = layer_norm(q)
    s_norm = layer_norm(s)
    
    # Queries from tokens
    W_q = np.random.randn(c_s, num_heads, c) * (c_s ** -0.5)
    queries = np.einsum('tc,chd->thd', s_norm, W_q)  # [N_tokens, H, c]
    
    # Keys and values from atoms
    W_k = np.random.randn(c_atom, num_heads, c) * (c_atom ** -0.5)
    W_v = np.random.randn(c_atom, num_heads, c) * (c_atom ** -0.5)
    keys = np.einsum('ac,chd->ahd', q_norm, W_k)  # [N_atoms, H, c]
    values = np.einsum('ac,chd->ahd', q_norm, W_v)
    
    # Each token attends to its atoms
    output = np.zeros((N_tokens, num_heads, c))
    
    for t in range(N_tokens):
        # Find atoms belonging to this token
        atom_mask = (atom_to_token == t)
        n_atoms_t = atom_mask.sum()
        if n_atoms_t == 0:
            continue
        
        token_keys = keys[atom_mask]  # [n_atoms_t, H, c]
        token_values = values[atom_mask]
        
        # Attention
        attn = np.einsum('hd,ahd->ah', queries[t], token_keys) / np.sqrt(c)
        attn = softmax(attn, axis=0)
        
        output[t] = np.einsum('ah,ahd->hd', attn, token_values)
    
    # Output projection
    W_o = np.random.randn(num_heads, c, c_s) * ((num_heads * c) ** -0.5)
    output = np.einsum('thd,hdc->tc', output, W_o)
    
    print(f"Output: {output.shape}")
    
    return output

In [None]:
# Test
print("Test: Atom Cross Attention")
print("="*60)

N_tokens = 20
atoms_per_token = 5
c_atom = 64
c_s = 128

atom_to_token, _ = create_atom_token_mapping(N_tokens, atoms_per_token)
N_atoms = len(atom_to_token)

q = np.random.randn(N_atoms, c_atom)
s = np.random.randn(N_tokens, c_s)

# Token → Atom
atom_update = atom_cross_attention_to_atoms(q, s, atom_to_token)
print(f"Atom update finite: {np.isfinite(atom_update).all()}")

print()

# Atom → Token
token_update = atom_cross_attention_to_tokens(q, s, atom_to_token)
print(f"Token update finite: {np.isfinite(token_update).all()}")

## Key Insights

1. **Bidirectional**: Information flows both ways between atoms and tokens
2. **Sparse Attention**: Each atom only attends to its parent token (efficient)
3. **Aggregation**: Tokens aggregate over variable numbers of atoms
4. **Ligand Support**: Essential for modeling small molecules with many atoms per token