# Algorithm 6: Evoformer Stack

The **Evoformer** is the central neural network component of AlphaFold2, responsible for processing and refining the Multiple Sequence Alignment (MSA) representation and the pair representation. This iterative refinement process enables the model to learn complex evolutionary and structural relationships between residues.

## Overview

The Evoformer stack consists of multiple blocks, each containing several attention and update mechanisms:
- **MSA Row Attention with Pair Bias** (Algorithm 7)
- **MSA Column Attention** (Algorithm 8)  
- **MSA Transition** (Algorithm 9)
- **Outer Product Mean** (Algorithm 10)
- **Triangle Multiplication** (Algorithms 11, 12)
- **Triangle Attention** (Algorithms 13, 14)
- **Pair Transition** (Algorithm 15)

## Algorithm Pseudocode

![Evoformer Stack Algorithm](../imgs/algorithms/EvoformerStack.png)

## Key Concepts

### 1. MSA Representation
The MSA representation `m` has shape `[N_seq, N_res, c_m]` where:
- `N_seq`: Number of sequences in the MSA
- `N_res`: Number of residues (sequence length)
- `c_m`: MSA channel dimension (typically 256)

### 2. Pair Representation
The pair representation `z` has shape `[N_res, N_res, c_z]` where:
- `c_z`: Pair channel dimension (typically 128)

This represents pairwise relationships between all residue pairs.

### 3. Information Flow
The Evoformer enables bidirectional information flow:
- **MSA → Pair**: Via Outer Product Mean, evolutionary information from MSA updates pair representation
- **Pair → MSA**: Via MSA Row Attention with Pair Bias, structural information biases MSA attention

## Source Code Implementation

The Evoformer is implemented in `AF2-source-code/model/modules.py`. Let's examine the key components:

### EvoformerIteration Class

Each iteration of the Evoformer stack is defined in the `EvoformerIteration` class:

```python
class EvoformerIteration(hk.Module):
  """Single iteration (block) of Evoformer stack.

  Jumper et al. (2021) Suppl. Alg. 6 "EvoformerStack" lines 2-10
  """

  def __init__(self, config, global_config, is_extra_msa,
               name='evoformer_iteration'):
    super().__init__(name=name)
    self.config = config
    self.global_config = global_config
    self.is_extra_msa = is_extra_msa
```

The `is_extra_msa` flag determines whether to use regular MSA Column Attention or the more efficient Global Column Attention for processing the extra MSA sequences.

### Main Forward Pass

The `__call__` method implements the forward pass following the algorithm pseudocode:

```python
def __call__(self, activations, masks, is_training=True, safe_key=None):
    c = self.config
    gc = self.global_config

    msa_act, pair_act = activations['msa'], activations['pair']
    msa_mask, pair_mask = masks['msa'], masks['pair']

    dropout_wrapper_fn = functools.partial(
        dropout_wrapper,
        is_training=is_training,
        global_config=gc)

    safe_key, *sub_keys = safe_key.split(10)
    sub_keys = iter(sub_keys)
```

### Step 1: MSA Row Attention with Pair Bias (Line 3)

```python
    # Algorithm 7: MSA row-wise gated self-attention with pair bias
    msa_act = dropout_wrapper_fn(
        MSARowAttentionWithPairBias(
            c.msa_row_attention_with_pair_bias, gc,
            name='msa_row_attention_with_pair_bias'),
        msa_act,
        msa_mask,
        safe_key=next(sub_keys),
        pair_act=pair_act)
```

This step applies self-attention along each row of the MSA, allowing each position to attend to all other positions within the same sequence. The pair representation provides additional bias to guide the attention.

### Step 2: MSA Column Attention (Line 4)

```python
    # Algorithm 8: MSA column-wise gated self-attention
    if not self.is_extra_msa:
      attn_mod = MSAColumnAttention(
          c.msa_column_attention, gc, name='msa_column_attention')
    else:
      # Algorithm 19: MSA column-wise global attention (for extra MSA)
      attn_mod = MSAColumnGlobalAttention(
          c.msa_column_attention, gc, name='msa_column_global_attention')
    msa_act = dropout_wrapper_fn(
        attn_mod,
        msa_act,
        msa_mask,
        safe_key=next(sub_keys))
```

Column attention allows each residue position to attend across all sequences in the MSA, enabling the model to learn evolutionary patterns.

### Step 3: MSA Transition (Line 5)

```python
    # Algorithm 9: MSA Transition
    msa_act = dropout_wrapper_fn(
        Transition(c.msa_transition, gc, name='msa_transition'),
        msa_act,
        msa_mask,
        safe_key=next(sub_keys))
```

A simple feed-forward transition layer to further process the MSA representation.

### Step 4: Outer Product Mean (Line 6)

```python
    # Algorithm 10: Outer Product Mean
    # Updates pair representation using information from MSA
    pair_act = dropout_wrapper_fn(
        OuterProductMean(
            config=c.outer_product_mean,
            global_config=self.global_config,
            num_output_channel=int(pair_act.shape[-1]),
            name='outer_product_mean'),
        msa_act,
        msa_mask,
        safe_key=next(sub_keys),
        output_act=pair_act)
```

This crucial step computes the outer product of MSA features and averages across sequences to update the pair representation. This is how evolutionary covariance information is transferred to the pair representation.

### Step 5: Triangle Multiplications (Lines 7-8)

```python
    # Algorithm 11: Triangle Multiplication (Outgoing)
    pair_act = dropout_wrapper_fn(
        TriangleMultiplication(c.triangle_multiplication_outgoing, gc,
                               name='triangle_multiplication_outgoing'),
        pair_act,
        pair_mask,
        safe_key=next(sub_keys))
    
    # Algorithm 12: Triangle Multiplication (Incoming)
    pair_act = dropout_wrapper_fn(
        TriangleMultiplication(c.triangle_multiplication_incoming, gc,
                               name='triangle_multiplication_incoming'),
        pair_act,
        pair_mask,
        safe_key=next(sub_keys))
```

Triangle multiplication updates edge (i,j) by aggregating information from triangles involving that edge. The "outgoing" variant aggregates via edges (i,k), while "incoming" uses edges (k,j).

### Step 6: Triangle Attention (Lines 9-10)

```python
    # Algorithm 13: Triangle Attention (Starting Node)
    pair_act = dropout_wrapper_fn(
        TriangleAttention(c.triangle_attention_starting_node, gc,
                          name='triangle_attention_starting_node'),
        pair_act,
        pair_mask,
        safe_key=next(sub_keys))
    
    # Algorithm 14: Triangle Attention (Ending Node)
    pair_act = dropout_wrapper_fn(
        TriangleAttention(c.triangle_attention_ending_node, gc,
                          name='triangle_attention_ending_node'),
        pair_act,
        pair_mask,
        safe_key=next(sub_keys))
```

Triangle attention applies self-attention along rows or columns of the pair representation, with additional biases from the pair representation itself.

### Step 7: Pair Transition (Line 11)

```python
    # Algorithm 15: Pair Transition
    pair_act = dropout_wrapper_fn(
        Transition(c.pair_transition, gc, name='pair_transition'),
        pair_act,
        pair_mask,
        safe_key=next(sub_keys))

    return {'msa': msa_act, 'pair': pair_act}
```

A feed-forward transition layer to finalize the pair representation update for this iteration.

## Stacking Evoformer Blocks

The full Evoformer stack is created in `EmbeddingsAndEvoformer` using `layer_stack`:

```python
# Main trunk of the network
# Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 17-18
evoformer_iteration = EvoformerIteration(
    c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration')

def evoformer_fn(x):
  act, safe_key = x
  safe_key, safe_subkey = safe_key.split()
  evoformer_output = evoformer_iteration(
      activations=act,
      masks=evoformer_masks,
      is_training=is_training,
      safe_key=safe_subkey)
  return (evoformer_output, safe_key)

if gc.use_remat:
  evoformer_fn = hk.remat(evoformer_fn)

evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)(
    evoformer_fn)
evoformer_output, safe_key = evoformer_stack(
    (evoformer_input, safe_key))
```

The default configuration uses **48 Evoformer blocks** for the main MSA and **4 blocks** for the extra MSA stack.

## Dropout Wrapper

The `dropout_wrapper` function implements residual connections with dropout:

```python
def dropout_wrapper(module,
                    input_act,
                    mask,
                    safe_key,
                    global_config,
                    output_act=None,
                    is_training=True,
                    **kwargs):
  """Applies module + dropout + residual update."""
  if output_act is None:
    output_act = input_act

  residual = module(input_act, mask, is_training=is_training, **kwargs)
  dropout_rate = 0.0 if gc.deterministic else module.config.dropout_rate

  if module.config.shared_dropout:
    if module.config.orientation == 'per_row':
      broadcast_dim = 0
    else:
      broadcast_dim = 1
  else:
    broadcast_dim = None

  residual = apply_dropout(tensor=residual,
                           safe_key=safe_key,
                           rate=dropout_rate,
                           is_training=is_training,
                           broadcast_dim=broadcast_dim)

  new_act = output_act + residual
  return new_act
```

## Key Design Principles

1. **Bidirectional Information Flow**: MSA and pair representations mutually inform each other

2. **Triangle Updates**: Enforce geometric consistency in the pair representation (if i is close to j and j is close to k, then i should be close to k)

3. **Residual Connections**: Every sub-module uses residual connections for stable training

4. **Shared Dropout**: Row-wise or column-wise dropout is shared across the appropriate dimension

5. **Gradient Checkpointing**: `hk.remat` is used to reduce memory consumption during training

## Summary

The Evoformer stack is the core innovation of AlphaFold2, enabling:

1. **Evolutionary information extraction** from MSA via column attention
2. **Structural constraint propagation** through triangle updates
3. **Coevolution signal extraction** via outer product mean
4. **Deep feature refinement** through 48 stacked blocks

The output representations from the Evoformer are then used by the Structure Module to generate the final 3D coordinates.