# Algorithm 15: Atom Attention Decoder (Boltz)

Decodes token-level predictions to atom coordinates.

## Source Code Location
- **File**: `Boltz-Ref-src/boltz-official/src/boltz/model/modules/encoders.py`
- **Class**: `AtomAttentionDecoder`

In [None]:
import numpy as np
np.random.seed(42)

def layer_norm(x, eps=1e-5):
    mean = np.mean(x, axis=-1, keepdims=True)
    var = np.var(x, axis=-1, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)

def softmax(x, axis=-1):
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

In [None]:
def atom_attention_decoder(token_features, atom_features, atom_to_token, num_heads=4, c=32):
    """
    Atom Attention Decoder.
    
    Broadcasts token-level features to atoms using cross-attention.
    
    Args:
        token_features: Token features [N_tokens, c_token]
        atom_features: Current atom features [N_atoms, c_atom]
        atom_to_token: Atom-to-token mapping [N_atoms]
        num_heads: Number of attention heads
        c: Head dimension
    
    Returns:
        Updated atom features [N_atoms, c_atom]
    """
    N_atoms, c_atom = atom_features.shape
    N_tokens, c_token = token_features.shape
    
    print(f"Atom Attention Decoder")
    print(f"="*50)
    print(f"Tokens: {N_tokens}, Atoms: {N_atoms}")
    
    atom_norm = layer_norm(atom_features)
    token_norm = layer_norm(token_features)
    
    # Queries from atoms
    W_q = np.random.randn(c_atom, num_heads, c) * (c_atom ** -0.5)
    q = np.einsum('ac,chd->ahd', atom_norm, W_q)
    
    # Keys and values from parent tokens
    W_k = np.random.randn(c_token, num_heads, c) * (c_token ** -0.5)
    W_v = np.random.randn(c_token, num_heads, c) * (c_token ** -0.5)
    
    k = np.einsum('tc,chd->thd', token_norm, W_k)
    v = np.einsum('tc,chd->thd', token_norm, W_v)
    
    # Each atom attends to its parent token
    parent_k = k[atom_to_token]  # [N_atoms, H, c]
    parent_v = v[atom_to_token]
    
    # Attention (simplified: direct attention to parent)
    attn = np.einsum('ahd,ahd->ah', q, parent_k) / np.sqrt(c)
    attn = softmax(attn, axis=-1)
    
    output = np.einsum('ah,ahd->ahd', attn, parent_v)
    output = output.reshape(N_atoms, -1)
    
    # Project back to atom dimension
    W_o = np.random.randn(num_heads * c, c_atom) * ((num_heads * c) ** -0.5)
    output = output @ W_o
    
    print(f"Output: {output.shape}")
    return output

In [None]:
def predict_atom_positions(atom_features, ref_positions):
    """
    Predict atom position updates from features.
    """
    c_atom = atom_features.shape[-1]
    
    W_pos = np.random.randn(c_atom, 3) * (c_atom ** -0.5)
    delta_pos = layer_norm(atom_features) @ W_pos
    
    return ref_positions + delta_pos

In [None]:
# Test
print("Test: Atom Attention Decoder")
print("="*60)

N_tokens = 20
atoms_per_token = 5
N_atoms = N_tokens * atoms_per_token
c_token = 128
c_atom = 64

token_features = np.random.randn(N_tokens, c_token)
atom_features = np.random.randn(N_atoms, c_atom)
atom_to_token = np.repeat(np.arange(N_tokens), atoms_per_token)

output = atom_attention_decoder(token_features, atom_features, atom_to_token)

print(f"\nOutput shape: {output.shape}")
print(f"Output finite: {np.isfinite(output).all()}")

# Predict positions
ref_pos = np.random.randn(N_atoms, 3)
new_pos = predict_atom_positions(output, ref_pos)
print(f"Position update norm: {np.linalg.norm(new_pos - ref_pos):.4f}")

## Key Insights

1. **Tokenâ†’Atom**: Broadcasts token features to atoms
2. **Cross-Attention**: Atoms query their parent tokens
3. **Position Prediction**: Final layer predicts 3D coordinates
4. **Windowed**: Uses windows for efficiency