# Algorithm 13: Triangle Attention (Starting Node)

Triangle Attention applies self-attention along rows or columns of the pair representation, with additional biases from the pair representation itself. The "Starting Node" variant applies attention along rows.

## Algorithm Pseudocode

![Triangle Attention Starting Node](../imgs/algorithms/TriangleAttentionStartingNode.png)

## Source Code Location
- **File**: `AF2-source-code/model/modules.py`
- **Class**: `TriangleAttention`
- **Lines**: 892-951

## Key Concepts

### Orientation
- **Starting Node (Alg 13)**: `orientation='per_row'` - attention along axis 1 (columns as keys)
- **Ending Node (Alg 14)**: `orientation='per_column'` - attention along axis 0 (rows as keys)

### Pair Bias
Unlike standard attention, triangle attention uses the pair representation itself as an additional bias term, allowing pairwise information to influence the attention pattern.

In [None]:
import numpy as np

np.random.seed(42)

## NumPy Implementation

In [None]:
def layer_norm(x, axis=-1, eps=1e-5):
    mean = np.mean(x, axis=axis, keepdims=True)
    var = np.var(x, axis=axis, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)


def softmax(x, axis=-1):
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)


def sigmoid(x):
    return 1 / (1 + np.exp(-np.clip(x, -500, 500)))


def triangle_attention(pair_act, pair_mask, orientation='per_row', num_head=4):
    """
    Triangle Attention.
    
    Algorithm 13 (Starting Node) and Algorithm 14 (Ending Node).
    
    Args:
        pair_act: Pair activations [N_res, N_res, c_z]
        pair_mask: Pair mask [N_res, N_res]
        orientation: 'per_row' (Alg 13) or 'per_column' (Alg 14)
        num_head: Number of attention heads
    
    Returns:
        Updated pair activations [N_res, N_res, c_z]
    """
    N_res, _, c_z = pair_act.shape
    head_dim = c_z // num_head
    
    print(f"Triangle Attention ({orientation})")
    print(f"Input: {pair_act.shape}")
    
    # Step 1: Transpose for column attention if needed
    if orientation == 'per_column':
        pair_act = np.swapaxes(pair_act, 0, 1)  # [N_res, N_res, c_z]
        pair_mask = np.swapaxes(pair_mask, 0, 1)
    
    # Step 2: Create attention bias from mask
    # [N_res, 1, 1, N_res]
    bias = (1e9 * (pair_mask - 1.))[:, None, None, :]
    
    # Step 3: Layer normalization
    pair_act_norm = layer_norm(pair_act, axis=-1)
    
    # Step 4: Compute pair bias for attention
    # Project pair representation to attention heads
    bias_w = np.random.randn(c_z, num_head) / np.sqrt(c_z)
    # [N_res, N_res, num_head] -> [num_head, N_res, N_res]
    nonbatched_bias = np.einsum('qkc,ch->hqk', pair_act_norm, bias_w)
    
    print(f"Pair bias: {nonbatched_bias.shape}")
    
    # Step 5: Q, K, V projections
    q_w = np.random.randn(c_z, num_head, head_dim) * 0.01
    k_w = np.random.randn(c_z, num_head, head_dim) * 0.01
    v_w = np.random.randn(c_z, num_head, head_dim) * 0.01
    
    # [N_res, N_res, c_z] -> [N_res, N_res, num_head, head_dim]
    q = np.einsum('bqc,chd->bqhd', pair_act_norm, q_w) * (head_dim ** -0.5)
    k = np.einsum('bkc,chd->bkhd', pair_act_norm, k_w)
    v = np.einsum('bkc,chd->bkhd', pair_act_norm, v_w)
    
    # Step 6: Compute attention logits
    # [N_res, N_res, num_head, head_dim] x [N_res, N_res, num_head, head_dim]
    # -> [N_res, num_head, N_res, N_res]
    logits = np.einsum('bqhd,bkhd->bhqk', q, k)
    
    # Add biases
    logits = logits + bias  # Mask bias
    logits = logits + nonbatched_bias[None, :, :, :]  # Pair bias (broadcast over batch)
    
    # Step 7: Softmax
    weights = softmax(logits, axis=-1)
    
    # Step 8: Weighted sum
    # [N_res, num_head, N_res, N_res] x [N_res, N_res, num_head, head_dim]
    # -> [N_res, N_res, num_head, head_dim]
    output = np.einsum('bhqk,bkhd->bqhd', weights, v)
    
    # Step 9: Gating
    gate_w = np.random.randn(c_z, num_head, head_dim) * 0.01
    gate_b = np.ones((num_head, head_dim))
    gate = sigmoid(np.einsum('bqc,chd->bqhd', pair_act_norm, gate_w) + gate_b)
    output = output * gate
    
    # Step 10: Output projection
    o_w = np.random.randn(num_head, head_dim, c_z) * 0.01
    output = np.einsum('bqhd,hdc->bqc', output, o_w)
    
    # Step 11: Transpose back if needed
    if orientation == 'per_column':
        output = np.swapaxes(output, 0, 1)
    
    print(f"Output: {output.shape}")
    
    return output

## Test Example

In [None]:
# Test parameters
N_res = 32
c_z = 64
num_head = 4

# Create test inputs
pair_act = np.random.randn(N_res, N_res, c_z).astype(np.float32)
pair_mask = np.ones((N_res, N_res), dtype=np.float32)

print(f"Input: {pair_act.shape}")
print()

In [None]:
# Algorithm 13: Starting Node (per_row)
print("="*50)
output_starting = triangle_attention(
    pair_act, pair_mask,
    orientation='per_row',
    num_head=num_head
)
print(f"Stats: mean={output_starting.mean():.6f}, std={output_starting.std():.6f}")

In [None]:
# Algorithm 14: Ending Node (per_column)
print("\n" + "="*50)
output_ending = triangle_attention(
    pair_act, pair_mask,
    orientation='per_column',
    num_head=num_head
)
print(f"Stats: mean={output_ending.mean():.6f}, std={output_ending.std():.6f}")

## Interpretation

### Starting Node (per_row)
For each row `i` (fixed starting node):
- Attention is over columns `k`
- Each edge (i,j) attends to all edges (i,k) where k varies
- Captures: "What other residues does i connect to?"

### Ending Node (per_column)
For each column `j` (fixed ending node):
- Attention is over rows `k`  
- Each edge (i,j) attends to all edges (k,j) where k varies
- Captures: "What other residues connect to j?"

## Source Code Reference

```python
# From AF2-source-code/model/modules.py

class TriangleAttention(hk.Module):
  """Triangle Attention.

  Jumper et al. (2021) Suppl. Alg. 13 "TriangleAttentionStartingNode"
  Jumper et al. (2021) Suppl. Alg. 14 "TriangleAttentionEndingNode"
  """

  def __call__(self, pair_act, pair_mask, is_training=False):
    c = self.config

    if c.orientation == 'per_column':
      pair_act = jnp.swapaxes(pair_act, -2, -3)
      pair_mask = jnp.swapaxes(pair_mask, -1, -2)

    bias = (1e9 * (pair_mask - 1.))[:, None, None, :]

    pair_act = hk.LayerNorm(axis=[-1], ...)(pair_act)

    # Compute pair bias
    weights = hk.get_parameter('feat_2d_weights', ...)
    nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights)

    attn_mod = Attention(c, self.global_config, pair_act.shape[-1])
    pair_act = attn_mod(pair_act, pair_act, bias, nonbatched_bias)

    if c.orientation == 'per_column':
      pair_act = jnp.swapaxes(pair_act, -2, -3)

    return pair_act
```

## Key Insights

1. **Pair Bias**: Uses the pair representation as attention bias, allowing learned pairwise preferences.

2. **Two Orientations**: Starting/Ending node capture complementary information.

3. **Same Module**: Both algorithms use the same `TriangleAttention` class with different `orientation` config.

4. **Complexity**: O(N² × N) = O(N³) attention computations.