# Algorithm 2: Atom Attention Encoder (Boltz)

Encodes atom-level features using local attention within tokens.

## Source Code Location
- **File**: `Boltz-Ref-src/boltz-official/src/boltz/model/modules/encoders.py`
- **Class**: `AtomAttentionEncoder`

In [None]:
import numpy as np
np.random.seed(42)

def layer_norm(x, eps=1e-5):
    mean = np.mean(x, axis=-1, keepdims=True)
    var = np.var(x, axis=-1, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)

def softmax(x, axis=-1):
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

In [None]:
def atom_attention_encoder(atom_features, atom_to_token, num_heads=4, c=32):
    """
    Atom Attention Encoder - aggregates atom features to token level.
    
    Uses windowed attention for efficiency.
    
    Args:
        atom_features: Per-atom features [N_atoms, c_atom]
        atom_to_token: Mapping atoms to tokens [N_atoms]
        num_heads: Number of attention heads
        c: Head dimension
    
    Returns:
        Token features [N_tokens, c_token]
    """
    N_atoms, c_atom = atom_features.shape
    N_tokens = atom_to_token.max() + 1
    c_token = num_heads * c
    
    print(f"Atom Attention Encoder")
    print(f"="*50)
    print(f"Atoms: {N_atoms}, Tokens: {N_tokens}")
    
    # Normalize
    atom_norm = layer_norm(atom_features)
    
    # QKV projections
    W_q = np.random.randn(c_atom, num_heads, c) * (c_atom ** -0.5)
    W_k = np.random.randn(c_atom, num_heads, c) * (c_atom ** -0.5)
    W_v = np.random.randn(c_atom, num_heads, c) * (c_atom ** -0.5)
    
    q = np.einsum('ac,chd->ahd', atom_norm, W_q)
    k = np.einsum('ac,chd->ahd', atom_norm, W_k)
    v = np.einsum('ac,chd->ahd', atom_norm, W_v)
    
    # Aggregate per token using attention
    token_features = np.zeros((N_tokens, num_heads, c))
    
    for t in range(N_tokens):
        mask = (atom_to_token == t)
        n_atoms_t = mask.sum()
        if n_atoms_t == 0:
            continue
        
        q_t = q[mask]  # [n_atoms_t, H, c]
        k_t = k[mask]
        v_t = v[mask]
        
        # Self-attention within token
        attn = np.einsum('ahd,bhd->abh', q_t, k_t) / np.sqrt(c)
        attn = softmax(attn, axis=1)
        
        attended = np.einsum('abh,bhd->ahd', attn, v_t)
        
        # Mean pool to get token representation
        token_features[t] = attended.mean(axis=0)
    
    # Reshape
    token_features = token_features.reshape(N_tokens, -1)
    
    print(f"Output: {token_features.shape}")
    
    return token_features

In [None]:
# Test
print("Test: Atom Attention Encoder")
print("="*60)

N_tokens = 20
atoms_per_token = 5
N_atoms = N_tokens * atoms_per_token
c_atom = 64

atom_features = np.random.randn(N_atoms, c_atom)
atom_to_token = np.repeat(np.arange(N_tokens), atoms_per_token)

token_features = atom_attention_encoder(atom_features, atom_to_token, num_heads=4, c=32)

print(f"\nOutput shape: {token_features.shape}")
print(f"Output finite: {np.isfinite(token_features).all()}")

## Key Insights

1. **Atom-to-Token Aggregation**: Aggregates atom features to token level
2. **Local Attention**: Self-attention within each token's atoms
3. **Mean Pooling**: Final aggregation via mean pooling
4. **Windowed**: Uses windows for computational efficiency