# Algorithm 14: Triangle Attention (Ending Node)

Triangle attention with ending node orientation attends along columns. Each position z[i,j] considers edges that share the same ending node j, gathering information from z[k,j] for all k.

## Algorithm Pseudocode

![TriangleAttentionEndingNode](../imgs/algorithms/TriangleAttentionEndingNode.png)

## Source Code Location
- **File**: `AF2-source-code/model/modules.py`
- **Class**: `TriangleAttention`
- **Lines**: 965-1045

## Overview

### Comparison with Starting Node (Algorithm 13)

| Aspect | Starting Node (Alg 13) | Ending Node (Alg 14) |
|--------|------------------------|----------------------|
| Orientation | Row-wise (per_row) | Column-wise (per_column) |
| Attention pattern | z[i,j] attends to z[i,k] | z[i,j] attends to z[k,j] |
| Fixed node | Starting node i | Ending node j |
| Triangle edges | Edges sharing source i | Edges sharing target j |

### Algorithm Steps

1. Layer normalize input
2. Compute Q, K, V projections
3. **Transpose** for column-wise attention
4. Compute attention with pair bias
5. Apply gating mechanism
6. **Transpose back** to original orientation

In [None]:
import numpy as np

np.random.seed(42)

## NumPy Implementation

In [None]:
def layer_norm(x, axis=-1, eps=1e-5):
    """Layer normalization."""
    mean = np.mean(x, axis=axis, keepdims=True)
    var = np.var(x, axis=axis, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)


def softmax(x, axis=-1):
    """Numerically stable softmax."""
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)


def sigmoid(x):
    """Sigmoid activation."""
    return 1.0 / (1.0 + np.exp(-x))


def triangle_attention_ending_node(
    z,
    z_mask=None,
    c=32,
    num_heads=4
):
    """
    Triangle Attention (Ending Node) - Algorithm 14.
    
    Attention along columns: z[i,j] attends to z[k,j] for all k.
    This gathers information from all edges ending at node j.
    
    Args:
        z: Pair representation [N_res, N_res, c_z]
        z_mask: Pair mask [N_res, N_res] (optional)
        c: Head dimension
        num_heads: Number of attention heads
    
    Returns:
        Updated z [N_res, N_res, c_z]
    """
    N_res, _, c_z = z.shape
    
    print(f"Triangle Attention (Ending Node)")
    print(f"="*50)
    print(f"Input: [{N_res}, {N_res}, {c_z}]")
    print(f"Heads: {num_heads}, Head dim: {c}")
    
    if z_mask is None:
        z_mask = np.ones((N_res, N_res), dtype=np.float32)
    
    # Step 1: Layer normalization (Line 1)
    z_norm = layer_norm(z, axis=-1)
    print(f"\nStep 1 - Layer norm: {z_norm.shape}")
    
    # Step 2: Initialize projection weights
    W_q = np.random.randn(c_z, num_heads, c) * (c_z ** -0.5)
    W_k = np.random.randn(c_z, num_heads, c) * (c_z ** -0.5)
    W_v = np.random.randn(c_z, num_heads, c) * (c_z ** -0.5)
    W_b = np.random.randn(c_z, num_heads) * (c_z ** -0.5)
    W_g = np.random.randn(c_z, num_heads, c) * (c_z ** -0.5)
    W_o = np.random.randn(num_heads, c, c_z) * ((num_heads * c) ** -0.5)
    b_g = np.zeros((num_heads, c))
    
    # Step 3: Compute Q, K, V (Lines 2-4)
    q = np.einsum('ijc,chd->ijhd', z_norm, W_q)  # [N, N, H, c]
    k = np.einsum('ijc,chd->ijhd', z_norm, W_k)
    v = np.einsum('ijc,chd->ijhd', z_norm, W_v)
    
    print(f"Step 3 - Q,K,V: {q.shape}")
    
    # Step 4: Compute bias from pair representation (Line 5)
    b = np.einsum('ijc,ch->ijh', z_norm, W_b)  # [N, N, H]
    
    # Step 5: Compute gating values (Line 6)
    g_logits = np.einsum('ijc,chd->ijhd', z_norm, W_g) + b_g
    g = sigmoid(g_logits)  # [N, N, H, c]
    
    print(f"Step 5 - Gating: {g.shape}")
    
    # Step 6: ENDING NODE - Transpose for column-wise attention (Line 7)
    # Swap first two dimensions to attend along columns
    # Original: z[i,j] -> After transpose: z'[j,i]
    q_t = q.transpose(1, 0, 2, 3)  # [j, i, H, c]
    k_t = k.transpose(1, 0, 2, 3)  # [j, k, H, c]
    v_t = v.transpose(1, 0, 2, 3)  # [j, k, H, c]
    b_t = b.transpose(1, 0, 2)     # [j, k, H]
    mask_t = z_mask.T              # [j, k]
    
    print(f"Step 6 - Transposed for column attention")
    
    # Step 7: Compute attention (Lines 8-10)
    # For each column j: attention over all rows k
    # attn[j,i,h,k] = softmax_k(q[j,i,h] @ k[j,k,h]^T / sqrt(c) + b[j,k,h])
    
    # Attention logits: [j, i, H, k]
    attn_logits = np.einsum('jihd,jkhd->jihk', q_t, k_t) / np.sqrt(c)
    
    # Add bias: b_t is [j, k, H], need to expand to [j, 1, H, k]
    attn_logits = attn_logits + b_t[:, None, :, :].transpose(0, 1, 3, 2)
    
    # Apply mask (masked positions get -inf)
    mask_t_expanded = mask_t[:, None, None, :]  # [j, 1, 1, k]
    attn_logits = np.where(mask_t_expanded > 0, attn_logits, -1e9)
    
    # Softmax over k (the column dimension, now last)
    attn_weights = softmax(attn_logits, axis=-1)
    
    print(f"Step 7 - Attention weights: {attn_weights.shape}")
    
    # Step 8: Apply attention to values (Line 11)
    # attended[j,i,h,d] = sum_k attn[j,i,h,k] * v[j,k,h,d]
    attended = np.einsum('jihk,jkhd->jihd', attn_weights, v_t)
    
    # Step 9: Transpose back to original orientation (Line 12)
    attended = attended.transpose(1, 0, 2, 3)  # [i, j, H, c]
    
    print(f"Step 9 - Attended (transposed back): {attended.shape}")
    
    # Step 10: Apply gating (Line 13)
    gated = g * attended
    
    # Step 11: Output projection (Line 14)
    output = np.einsum('ijhd,hdc->ijc', gated, W_o)
    
    print(f"Step 11 - Output: {output.shape}")
    
    return output

## Test Examples

In [None]:
# Test 1: Basic functionality
print("Test 1: Basic Functionality")
print("="*60)

N_res, c_z = 16, 64
z = np.random.randn(N_res, N_res, c_z).astype(np.float32)

output = triangle_attention_ending_node(z, c=32, num_heads=4)

print(f"\nInput shape: {z.shape}")
print(f"Output shape: {output.shape}")
print(f"Shape preserved: {output.shape == z.shape}")

In [None]:
# Test 2: With masking
print("\nTest 2: With Masking")
print("="*60)

N_res, c_z = 20, 64
z = np.random.randn(N_res, N_res, c_z).astype(np.float32)

# Mask last 4 residues
z_mask = np.ones((N_res, N_res), dtype=np.float32)
z_mask[:, -4:] = 0  # Mask columns 16-19
z_mask[-4:, :] = 0  # Mask rows 16-19

output = triangle_attention_ending_node(z, z_mask=z_mask, c=32, num_heads=4)

print(f"\nMasked positions: {(z_mask == 0).sum()}")
print(f"Output shape: {output.shape}")

In [None]:
# Test 3: Compare with starting node attention pattern
print("\nTest 3: Attention Pattern Analysis")
print("="*60)

def triangle_attention_starting_node(z, c=32, num_heads=4):
    """Starting node version (Algorithm 13) for comparison."""
    N_res, _, c_z = z.shape
    
    z_norm = layer_norm(z, axis=-1)
    
    W_q = np.random.randn(c_z, num_heads, c) * (c_z ** -0.5)
    W_k = np.random.randn(c_z, num_heads, c) * (c_z ** -0.5)
    W_v = np.random.randn(c_z, num_heads, c) * (c_z ** -0.5)
    W_b = np.random.randn(c_z, num_heads) * (c_z ** -0.5)
    W_g = np.random.randn(c_z, num_heads, c) * (c_z ** -0.5)
    W_o = np.random.randn(num_heads, c, c_z) * ((num_heads * c) ** -0.5)
    
    q = np.einsum('ijc,chd->ijhd', z_norm, W_q)
    k = np.einsum('ijc,chd->ijhd', z_norm, W_k)
    v = np.einsum('ijc,chd->ijhd', z_norm, W_v)
    b = np.einsum('ijc,ch->ijh', z_norm, W_b)
    g = sigmoid(np.einsum('ijc,chd->ijhd', z_norm, W_g))
    
    # Starting node: row-wise attention (NO transpose)
    # attn[i,j,h,k] = softmax_k(q[i,j] @ k[i,k]^T)
    attn_logits = np.einsum('ijhd,ikhd->ijhk', q, k) / np.sqrt(c)
    attn_logits = attn_logits + b[:, :, :, None].transpose(0, 3, 2, 1)
    attn_weights = softmax(attn_logits, axis=-1)
    
    attended = np.einsum('ijhk,ikhd->ijhd', attn_weights, v)
    gated = g * attended
    output = np.einsum('ijhd,hdc->ijc', gated, W_o)
    
    return output

# Same input for both
np.random.seed(42)
N_res, c_z = 8, 32
z = np.random.randn(N_res, N_res, c_z).astype(np.float32)

np.random.seed(42)
out_starting = triangle_attention_starting_node(z, c=16, num_heads=2)

np.random.seed(42)
out_ending = triangle_attention_ending_node(z, c=16, num_heads=2)

print(f"Starting node output shape: {out_starting.shape}")
print(f"Ending node output shape: {out_ending.shape}")
print(f"\nOutputs are different (as expected): {not np.allclose(out_starting, out_ending)}")

In [None]:
# Test 4: Verify attention direction
print("\nTest 4: Verify Attention Direction")
print("="*60)

# Create a structured input where we can verify attention direction
N_res = 4
c_z = 8

# Make z[k, j] have distinct values for column j
z = np.zeros((N_res, N_res, c_z))
for j in range(N_res):
    for k in range(N_res):
        z[k, j, :] = j * 10 + k  # Column j, row k has value j*10 + k

print(f"Input structure (first channel):")
print(z[:, :, 0])
print(f"\nEach column has values 0-3 added to column_index * 10")
print(f"Column 0: [0, 1, 2, 3], Column 1: [10, 11, 12, 13], etc.")
print(f"\nEnding node attention at position [i,j] attends to all [k,j] (same column)")

## Verification: Key Properties

In [None]:
print("Verification: Key Properties")
print("="*60)

np.random.seed(42)
N_res, c_z = 16, 64
z = np.random.randn(N_res, N_res, c_z).astype(np.float32)

output = triangle_attention_ending_node(z, c=32, num_heads=4)

# Property 1: Shape preserved
shape_preserved = output.shape == z.shape
print(f"Property 1 - Shape preserved: {shape_preserved}")

# Property 2: Finite output
output_finite = np.isfinite(output).all()
print(f"Property 2 - Output finite: {output_finite}")

# Property 3: Not identity (actual computation happened)
not_identity = not np.allclose(output, z)
print(f"Property 3 - Not identity: {not_identity}")

# Property 4: Reasonable magnitude
magnitude_ok = np.abs(output).mean() < 100
print(f"Property 4 - Reasonable magnitude: {magnitude_ok} (mean={np.abs(output).mean():.4f})")

# Property 5: Gradient check (output depends on all inputs)
# Perturb one position and check output changes
z_perturbed = z.copy()
z_perturbed[5, 10, :] += 0.1
output_perturbed = triangle_attention_ending_node(z_perturbed, c=32, num_heads=4)
gradient_flows = not np.allclose(output, output_perturbed)
print(f"Property 5 - Gradient flows: {gradient_flows}")

## Source Code Reference

```python
# From AF2-source-code/model/modules.py

class TriangleAttention(hk.Module):
  """Triangle Attention.

  Jumper et al. (2021) Suppl. Alg. 13 "TriangleAttentionStartingNode"
  Jumper et al. (2021) Suppl. Alg. 14 "TriangleAttentionEndingNode"
  """

  def __call__(self, pair_act, pair_mask, is_training=False):
    c = self.config
    
    # For ending node: transpose first
    if c.orientation == 'per_column':
      pair_act = jnp.swapaxes(pair_act, -2, -3)
      pair_mask = jnp.swapaxes(pair_mask, -1, -2)
    
    # Standard attention computation
    pair_act = hk.LayerNorm(...)(pair_act)
    q = Linear(num_head * c, name='query_projection')(pair_act)
    k = Linear(num_head * c, name='key_projection')(pair_act)
    v = Linear(num_head * c, name='value_projection')(pair_act)
    
    # Attention with gating
    logits = jnp.einsum('bqhc,bkhc->bhqk', q, k) / jnp.sqrt(c)
    logits += nonbatched_bias
    weights = jax.nn.softmax(logits)
    weighted_avg = jnp.einsum('bhqk,bkhc->bqhc', weights, v)
    
    # Gating
    gate = jax.nn.sigmoid(Linear(num_head * c)(pair_act))
    output = gate * weighted_avg
    output = Linear(c.pair_channel)(output)
    
    # Transpose back for ending node
    if c.orientation == 'per_column':
      pair_act = jnp.swapaxes(pair_act, -2, -3)
    
    return output
```

## Key Insights

1. **Column-wise Attention**: By transposing before attention, we change from row-wise (starting node) to column-wise (ending node) attention patterns.

2. **Triangle Update Rule**: This implements one part of the triangle update equation for geometric consistency in pair representations.

3. **Complementary Algorithms**: Algorithms 13 (starting) and 14 (ending) are used together in each Evoformer block to capture both directions of edge relationships.

4. **Gating Mechanism**: The sigmoid gating allows the network to control information flow, similar to LSTM/GRU gates.

5. **Shared Implementation**: Both starting and ending node attention use the same underlying `TriangleAttention` class, distinguished only by the `orientation` config parameter.

6. **Pair Bias**: The attention includes a bias computed from the pair representation itself, allowing position-dependent attention patterns.