# Algorithm 14: Triangle Attention (Ending Node)

Triangle attention with ending node orientation. Attends along columns where each position considers edges that share the same ending node.

## Algorithm Pseudocode

![TriangleAttentionEndingNode](../imgs/algorithms/TriangleAttentionEndingNode.png)

## Source Code Location
- **File**: `AF2-source-code/model/modules.py`
- **Class**: `TriangleAttention`
- **Lines**: 965-1045

## Comparison with Starting Node (Algorithm 13)

| Aspect | Starting Node (Alg 13) | Ending Node (Alg 14) |
|--------|------------------------|----------------------|
| Orientation | Row-wise | Column-wise |
| Attention pattern | z[i,j] attends to z[i,k] | z[i,j] attends to z[k,j] |
| Triangle edges | Edges starting from i | Edges ending at j |

In [None]:
import numpy as np

np.random.seed(42)

In [None]:
def triangle_attention_ending_node(z, c=32, num_heads=4):
    """
    Triangle Attention (Ending Node) - Algorithm 14.
    
    Attention along columns: z[i,j] attends to z[k,j] for all k.
    
    Args:
        z: Pair representation [N_res, N_res, c_z]
        c: Head dimension
        num_heads: Number of attention heads
    
    Returns:
        Updated z [N_res, N_res, c_z]
    """
    N_res, _, c_z = z.shape
    
    # Layer normalization
    z_norm = (z - z.mean(axis=-1, keepdims=True)) / (z.std(axis=-1, keepdims=True) + 1e-5)
    
    # Linear projections (random weights for demo)
    W_q = np.random.randn(c_z, num_heads, c) * 0.02
    W_k = np.random.randn(c_z, num_heads, c) * 0.02
    W_v = np.random.randn(c_z, num_heads, c) * 0.02
    W_b = np.random.randn(c_z, num_heads) * 0.02
    W_g = np.random.randn(c_z, num_heads, c) * 0.02
    W_o = np.random.randn(num_heads, c, c_z) * 0.02
    
    # Compute Q, K, V
    q = np.einsum('ijc,chd->ijhd', z_norm, W_q)  # [N, N, H, c]
    k = np.einsum('ijc,chd->ijhd', z_norm, W_k)
    v = np.einsum('ijc,chd->ijhd', z_norm, W_v)
    
    # Bias from pair representation
    b = np.einsum('ijc,ch->ijh', z_norm, W_b)  # [N, N, H]
    
    # Gating
    g = 1 / (1 + np.exp(-np.einsum('ijc,chd->ijhd', z_norm, W_g)))  # Sigmoid
    
    # ENDING NODE: Transpose for column-wise attention
    # Swap first two dimensions to attend along columns
    q_t = q.transpose(1, 0, 2, 3)  # [N, N, H, c] -> [j, i, H, c]
    k_t = k.transpose(1, 0, 2, 3)  # [j, k, H, c]
    v_t = v.transpose(1, 0, 2, 3)
    b_t = b.transpose(1, 0, 2)  # [j, k, H]
    
    # Attention weights along columns
    # For each column j: attn[i,k] = softmax(q[i] @ k[k].T + b[k])
    # q_t[j,i], k_t[j,k] -> attn[j,i,h,k]
    attn_logits = np.einsum('jihd,jkhd->jihk', q_t, k_t) / np.sqrt(c)
    # b_t: [j, k, H] -> expand to [j, 1, H, k]
    attn_logits += b_t.transpose(0, 2, 1)[:, None, :, :]  # [j, 1, H, k]
    
    # Softmax over k dimension
    attn_logits_max = attn_logits.max(axis=-1, keepdims=True)
    attn_weights = np.exp(attn_logits - attn_logits_max)
    attn_weights /= attn_weights.sum(axis=-1, keepdims=True)
    
    # Apply attention (column-wise)
    attended = np.einsum('jihk,jkhd->jihd', attn_weights, v_t)
    
    # Transpose back to [i, j, H, c]
    attended = attended.transpose(1, 0, 2, 3)
    
    # Apply gating
    gated = g * attended
    
    # Output projection
    output = np.einsum('ijhd,hdc->ijc', gated, W_o)
    
    return output

In [None]:
# Test
N_res, c_z = 16, 64
z = np.random.randn(N_res, N_res, c_z)

print("Test Triangle Attention (Ending Node)")
print("="*50)
print(f"Input shape: {z.shape}")

output = triangle_attention_ending_node(z, c=32, num_heads=4)
print(f"Output shape: {output.shape}")
print(f"Shape preserved: {output.shape == z.shape}")

## Source Code Reference

```python
# From AF2-source-code/model/modules.py

class TriangleAttention(hk.Module):
  """Triangle Attention.

  Jumper et al. (2021) Suppl. Alg. 13 "TriangleAttentionStartingNode"
  Jumper et al. (2021) Suppl. Alg. 14 "TriangleAttentionEndingNode"
  """

  def __call__(self, pair_act, pair_mask, is_training=False):
    c = self.config
    
    # For ending node: transpose first
    if c.orientation == 'per_column':
      pair_act = jnp.swapaxes(pair_act, -2, -3)
      pair_mask = jnp.swapaxes(pair_mask, -1, -2)
    
    # ... attention computation ...
    
    # Transpose back
    if c.orientation == 'per_column':
      pair_act = jnp.swapaxes(pair_act, -2, -3)
```