# Algorithm 1: MSA Block Deletion

MSA Block Deletion is a data augmentation technique used during training. It randomly deletes contiguous blocks of sequences from the MSA to improve model robustness and prevent overfitting.

## Algorithm Pseudocode

![MSA Block Deletion](../imgs/algorithms/MSABlockDeletion.png)

## Source Code Location
- **File**: `AF2-source-code/model/tf/data_transforms.py`
- **Function**: `sample_msa`, `make_msa_mask`
- **Lines**: 214-250

## Overview

The MSA Block Deletion algorithm serves several purposes:

1. **Data Augmentation**: Varies training data by randomly removing MSA sequences
2. **Robustness**: Forces model to work with varying MSA depths
3. **Generalization**: Prevents overfitting to specific MSA patterns
4. **Efficiency**: Reduces computation by working with smaller MSAs during training

### Algorithm Steps

1. Always keep the first sequence (query sequence)
2. Randomly sample which additional sequences to keep
3. Ensure minimum number of sequences is maintained
4. Return the subsampled MSA

In [None]:
import numpy as np

np.random.seed(42)

## NumPy Implementation

In [None]:
def msa_block_deletion(
    msa,
    msa_mask=None,
    max_seq=512,
    deletion_mean=0.1,
    keep_query=True
):
    """
    MSA Block Deletion - Algorithm 1.
    
    Randomly samples sequences from MSA, implementing data augmentation
    for training robustness.
    
    Args:
        msa: MSA array [N_seq, N_res] amino acid indices
        msa_mask: Optional mask [N_seq, N_res] for valid positions
        max_seq: Maximum number of sequences to keep
        deletion_mean: Mean deletion rate (Bernoulli parameter)
        keep_query: Whether to always keep the first sequence
    
    Returns:
        sampled_msa: Subsampled MSA [N_sampled, N_res]
        sampled_mask: Subsampled mask [N_sampled, N_res] (if msa_mask provided)
    """
    N_seq, N_res = msa.shape
    
    if msa_mask is None:
        msa_mask = np.ones((N_seq, N_res), dtype=np.float32)
    
    print(f"MSA Block Deletion")
    print(f"="*50)
    print(f"Input MSA: [{N_seq}, {N_res}]")
    print(f"Max sequences: {max_seq}")
    print(f"Deletion mean: {deletion_mean}")
    
    # Step 1: Always keep query (first sequence)
    if keep_query:
        query_msa = msa[:1]
        query_mask = msa_mask[:1]
        remaining_msa = msa[1:]
        remaining_mask = msa_mask[1:]
        max_extra = max_seq - 1
    else:
        query_msa = np.empty((0, N_res), dtype=msa.dtype)
        query_mask = np.empty((0, N_res), dtype=msa_mask.dtype)
        remaining_msa = msa
        remaining_mask = msa_mask
        max_extra = max_seq
    
    N_remaining = remaining_msa.shape[0]
    
    # Step 2: Randomly sample which sequences to keep
    # Using Bernoulli sampling with given mean
    keep_prob = 1.0 - deletion_mean
    keep_mask = np.random.random(N_remaining) < keep_prob
    
    # Step 3: Apply sampling
    indices = np.where(keep_mask)[0]
    
    # Step 4: Limit to max_extra sequences
    if len(indices) > max_extra:
        # Randomly select max_extra from kept sequences
        selected_indices = np.random.choice(indices, size=max_extra, replace=False)
        selected_indices = np.sort(selected_indices)  # Keep original order
    else:
        selected_indices = indices
    
    # Extract selected sequences
    if len(selected_indices) > 0:
        selected_msa = remaining_msa[selected_indices]
        selected_mask = remaining_mask[selected_indices]
    else:
        selected_msa = np.empty((0, N_res), dtype=msa.dtype)
        selected_mask = np.empty((0, N_res), dtype=msa_mask.dtype)
    
    # Step 5: Concatenate query with selected sequences
    sampled_msa = np.concatenate([query_msa, selected_msa], axis=0)
    sampled_mask = np.concatenate([query_mask, selected_mask], axis=0)
    
    print(f"\nSampling Results:")
    print(f"  Kept {len(selected_indices)} / {N_remaining} extra sequences")
    print(f"  Output MSA: [{sampled_msa.shape[0]}, {sampled_msa.shape[1]}]")
    print(f"  Deletion rate: {1 - len(selected_indices)/max(N_remaining, 1):.2%}")
    
    return sampled_msa, sampled_mask

## Clustered MSA Sampling

AlphaFold2 also uses cluster-based sampling to maintain diversity while reducing MSA size.

In [None]:
def sample_msa_by_cluster(
    msa,
    cluster_ids,
    max_clusters=512,
    max_extra_msa=5120
):
    """
    Sample MSA using cluster-based approach.
    
    Keeps cluster centers and samples additional sequences.
    
    Args:
        msa: Full MSA [N_seq, N_res]
        cluster_ids: Cluster assignment for each sequence [N_seq]
        max_clusters: Maximum number of clusters for main MSA
        max_extra_msa: Maximum extra MSA sequences
    
    Returns:
        clustered_msa: Main MSA with cluster centers
        extra_msa: Additional MSA sequences
    """
    N_seq, N_res = msa.shape
    
    # Get unique clusters and their first occurrence (centers)
    unique_clusters, first_idx = np.unique(cluster_ids, return_index=True)
    
    # Sort by first occurrence to maintain order
    order = np.argsort(first_idx)
    center_indices = first_idx[order]
    
    print(f"Cluster-based MSA Sampling")
    print(f"="*50)
    print(f"Total sequences: {N_seq}")
    print(f"Unique clusters: {len(unique_clusters)}")
    
    # Limit clusters
    if len(center_indices) > max_clusters:
        center_indices = center_indices[:max_clusters]
    
    # Get cluster centers as main MSA
    clustered_msa = msa[center_indices]
    
    # Remaining sequences for extra MSA
    all_indices = set(range(N_seq))
    center_set = set(center_indices)
    extra_indices = np.array(list(all_indices - center_set))
    
    # Sample extra MSA
    if len(extra_indices) > max_extra_msa:
        selected = np.random.choice(extra_indices, size=max_extra_msa, replace=False)
        extra_msa = msa[np.sort(selected)]
    else:
        extra_msa = msa[extra_indices] if len(extra_indices) > 0 else np.empty((0, N_res), dtype=msa.dtype)
    
    print(f"\nOutput:")
    print(f"  Clustered MSA: {clustered_msa.shape}")
    print(f"  Extra MSA: {extra_msa.shape}")
    
    return clustered_msa, extra_msa

## Test Examples

In [None]:
# Test 1: Basic MSA Block Deletion
print("Test 1: Basic MSA Block Deletion")
print("="*60)

# Create synthetic MSA
N_seq, N_res = 1000, 64
msa = np.random.randint(0, 21, size=(N_seq, N_res))

# Test with different deletion rates
for deletion_mean in [0.1, 0.3, 0.5, 0.7]:
    np.random.seed(42)  # Reset for reproducibility
    sampled_msa, _ = msa_block_deletion(
        msa, 
        max_seq=512, 
        deletion_mean=deletion_mean
    )
    print()

In [None]:
# Test 2: Verify query preservation
print("Test 2: Verify Query Preservation")
print("="*60)

np.random.seed(42)
N_seq, N_res = 100, 32
msa = np.random.randint(0, 21, size=(N_seq, N_res))

# Mark query with special pattern for identification
msa[0, :] = np.arange(N_res) % 21

sampled_msa, _ = msa_block_deletion(msa, max_seq=50, deletion_mean=0.5)

# Verify query is preserved
query_match = np.array_equal(sampled_msa[0], msa[0])
print(f"\nQuery preserved: {query_match}")

In [None]:
# Test 3: Cluster-based sampling
print("Test 3: Cluster-Based Sampling")
print("="*60)

np.random.seed(42)
N_seq, N_res = 10000, 64
msa = np.random.randint(0, 21, size=(N_seq, N_res))

# Create synthetic cluster assignments
n_clusters = 800
cluster_ids = np.random.randint(0, n_clusters, size=N_seq)

clustered_msa, extra_msa = sample_msa_by_cluster(
    msa, 
    cluster_ids, 
    max_clusters=512, 
    max_extra_msa=5120
)

In [None]:
# Test 4: Statistical verification
print("\nTest 4: Statistical Verification")
print("="*60)

np.random.seed(42)
N_seq, N_res = 1000, 32
msa = np.random.randint(0, 21, size=(N_seq, N_res))

# Run multiple trials
n_trials = 100
deletion_mean = 0.3
kept_counts = []

for _ in range(n_trials):
    sampled_msa, _ = msa_block_deletion(
        msa, 
        max_seq=1000,  # No limit 
        deletion_mean=deletion_mean
    )
    kept_counts.append(sampled_msa.shape[0])

# Remove query from count
extra_kept = np.array(kept_counts) - 1
expected_kept = (N_seq - 1) * (1 - deletion_mean)

print(f"\nStatistical Summary ({n_trials} trials):")
print(f"  Expected extra sequences: {expected_kept:.1f}")
print(f"  Observed mean: {extra_kept.mean():.1f}")
print(f"  Observed std: {extra_kept.std():.1f}")
print(f"  Error: {abs(extra_kept.mean() - expected_kept):.2f} ({abs(extra_kept.mean() - expected_kept)/expected_kept*100:.1f}%)")

## Verification: Key Properties

In [None]:
print("Verification: Key Properties")
print("="*60)

np.random.seed(42)
N_seq, N_res = 500, 64
msa = np.random.randint(0, 21, size=(N_seq, N_res))

# Property 1: Output never exceeds max_seq
max_seq = 128
all_valid = True
for _ in range(50):
    sampled, _ = msa_block_deletion(msa, max_seq=max_seq, deletion_mean=0.1)
    if sampled.shape[0] > max_seq:
        all_valid = False
        break
print(f"Property 1 - Output ≤ max_seq: {all_valid}")

# Property 2: Query always first
msa[0, :] = 99  # Mark query
sampled, _ = msa_block_deletion(msa, max_seq=128, deletion_mean=0.5)
print(f"Property 2 - Query first: {np.all(sampled[0] == 99)}")

# Property 3: Output is subset of input
sampled, _ = msa_block_deletion(msa, max_seq=128, deletion_mean=0.5)
is_subset = all(
    any(np.array_equal(sampled[i], msa[j]) for j in range(N_seq))
    for i in range(sampled.shape[0])
)
print(f"Property 3 - Output ⊆ Input: {is_subset}")

# Property 4: Mask consistency
msa_mask = np.random.rand(N_seq, N_res) > 0.1
msa_mask = msa_mask.astype(np.float32)
sampled_msa, sampled_mask = msa_block_deletion(msa, msa_mask, max_seq=128)
print(f"Property 4 - Mask shape matches: {sampled_msa.shape == sampled_mask.shape}")

## Source Code Reference

```python
# From AF2-source-code/model/tf/data_transforms.py

def sample_msa(batch, max_seq, keep_extra, seed=None):
  """Sample MSA randomly, keeping max_seq sequences.
  
  Jumper et al. (2021) Suppl. Alg. 1 "MSABlockDeletion"
  
  This implements the MSA subsampling described in the supplementary.
  The first sequence (query) is always kept.
  """
  num_seq = batch['msa'].shape[0]
  
  # Keep query sequence
  shuffled = tf.random.shuffle(tf.range(1, num_seq), seed=seed)
  index_order = tf.concat([[0], shuffled], axis=0)
  
  # Keep at most max_seq
  num_sel = tf.minimum(max_seq, num_seq)
  sel_indices = index_order[:num_sel]
  
  # Apply to all MSA-related features
  for k in ['msa', 'deletion_matrix', 'msa_mask', 'bert_mask']:
    if k in batch:
      batch[k] = tf.gather(batch[k], sel_indices)
  
  return batch
```

## Key Insights

1. **Query Preservation**: The first sequence (query/target) is always kept as it represents the protein we're trying to predict.

2. **Random Sampling**: Remaining sequences are randomly sampled, providing data augmentation during training.

3. **Bounded Output**: The output MSA size is bounded by `max_seq`, ensuring consistent memory usage.

4. **Cluster-Based Sampling**: AlphaFold2 uses clustering to maintain MSA diversity while reducing size, keeping cluster centers for the main MSA and additional sequences for the "extra MSA".

5. **Two MSA Tracks**: The main MSA (~512 sequences) goes through full Evoformer, while extra MSA (~5000 sequences) uses more efficient global attention.