# Algorithm 26: Rename Symmetric Ground Truth Atoms

Some amino acids have chemically equivalent atoms that can be swapped without changing the chemistry (e.g., the two oxygens of aspartate). This algorithm finds the optimal naming to minimize the loss between predictions and ground truth.

## Algorithm Pseudocode

![renameSymmetricGroundTruthAtoms](../imgs/algorithms/renameSymmetricGroundTruthAtoms.png)

## Source Code Location
- **File**: `AF2-source-code/model/folding.py`
- **Function**: `compute_renamed_ground_truth`
- **Lines**: 428-480

## Overview

### Symmetric Amino Acids

| Amino Acid | Symmetric Atoms | Description |
|------------|-----------------|-------------|
| Asp (D) | OD1 ↔ OD2 | Carboxyl oxygens |
| Glu (E) | OE1 ↔ OE2 | Carboxyl oxygens |
| Phe (F) | CD1 ↔ CD2, CE1 ↔ CE2 | Ring carbons |
| Tyr (Y) | CD1 ↔ CD2, CE1 ↔ CE2 | Ring carbons |
| Arg (R) | NH1 ↔ NH2 | Guanidinium nitrogens |

### Algorithm

For each symmetric residue:
1. Compute loss with original atom naming
2. Compute loss with swapped atom naming
3. Use the naming that gives lower loss

In [None]:
import numpy as np

np.random.seed(42)

## NumPy Implementation

In [None]:
# Amino acid indices
AA_NAMES = ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY',
            'HIS', 'ILE', 'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER',
            'THR', 'TRP', 'TYR', 'VAL']

# Symmetric atom swap pairs (atom indices in atom14 format)
# Format: {aa_index: [(atom1_idx, atom2_idx), ...]}
SYMMETRIC_ATOMS = {
    1: [(5, 6)],              # ARG: NH1(5) ↔ NH2(6)
    3: [(5, 6)],              # ASP: OD1(5) ↔ OD2(6)
    6: [(5, 6)],              # GLU: OE1(5) ↔ OE2(6)
    13: [(5, 6), (7, 8)],     # PHE: CD1(5) ↔ CD2(6), CE1(7) ↔ CE2(8)
    18: [(5, 6), (7, 8)],     # TYR: CD1(5) ↔ CD2(6), CE1(7) ↔ CE2(8)
}

In [None]:
def compute_squared_distance(pos1, pos2):
    """Compute squared distance between two positions."""
    diff = pos1 - pos2
    return np.sum(diff ** 2, axis=-1)


def rename_symmetric_ground_truth_atoms(pred_pos, gt_pos, aatype, atom_mask=None):
    """
    Rename Symmetric Ground Truth Atoms - Algorithm 26.
    
    Finds optimal atom naming for symmetric residues to minimize loss.
    
    Args:
        pred_pos: Predicted atom positions [N_res, N_atoms, 3]
        gt_pos: Ground truth atom positions [N_res, N_atoms, 3]
        aatype: Amino acid types [N_res]
        atom_mask: Valid atom mask [N_res, N_atoms] (optional)
    
    Returns:
        gt_pos_renamed: Renamed ground truth positions
        swap_mask: Boolean mask indicating which residues were swapped
    """
    N_res, N_atoms, _ = gt_pos.shape
    
    print(f"Rename Symmetric Ground Truth Atoms")
    print(f"="*50)
    print(f"Residues: {N_res}")
    print(f"Atoms per residue: {N_atoms}")
    
    if atom_mask is None:
        atom_mask = np.ones((N_res, N_atoms), dtype=np.float32)
    
    gt_renamed = gt_pos.copy()
    swap_mask = np.zeros(N_res, dtype=bool)
    
    n_symmetric = 0
    n_swapped = 0
    
    for res_idx in range(N_res):
        aa = int(aatype[res_idx])
        
        # Check if this amino acid has symmetric atoms
        if aa not in SYMMETRIC_ATOMS:
            continue
        
        n_symmetric += 1
        swap_pairs = SYMMETRIC_ATOMS[aa]
        
        # Compute loss for original naming
        loss_original = 0.0
        for a1, a2 in swap_pairs:
            if atom_mask[res_idx, a1] > 0 and atom_mask[res_idx, a2] > 0:
                loss_original += compute_squared_distance(
                    pred_pos[res_idx, a1], gt_pos[res_idx, a1])
                loss_original += compute_squared_distance(
                    pred_pos[res_idx, a2], gt_pos[res_idx, a2])
        
        # Compute loss for swapped naming
        loss_swapped = 0.0
        for a1, a2 in swap_pairs:
            if atom_mask[res_idx, a1] > 0 and atom_mask[res_idx, a2] > 0:
                loss_swapped += compute_squared_distance(
                    pred_pos[res_idx, a1], gt_pos[res_idx, a2])  # Swapped!
                loss_swapped += compute_squared_distance(
                    pred_pos[res_idx, a2], gt_pos[res_idx, a1])  # Swapped!
        
        # Swap if it reduces loss
        if loss_swapped < loss_original:
            for a1, a2 in swap_pairs:
                gt_renamed[res_idx, a1] = gt_pos[res_idx, a2].copy()
                gt_renamed[res_idx, a2] = gt_pos[res_idx, a1].copy()
            swap_mask[res_idx] = True
            n_swapped += 1
    
    print(f"\nResults:")
    print(f"  Symmetric residues: {n_symmetric}")
    print(f"  Swapped: {n_swapped}")
    print(f"  Kept original: {n_symmetric - n_swapped}")
    
    return gt_renamed, swap_mask

## Test Examples

In [None]:
# Test 1: Basic functionality
print("Test 1: Basic Functionality")
print("="*60)

N_res, N_atoms = 10, 14

# Predicted positions
pred_pos = np.random.randn(N_res, N_atoms, 3)

# Ground truth - copy of predictions with some noise
gt_pos = pred_pos.copy() + np.random.randn(N_res, N_atoms, 3) * 0.5

# Amino acid types (include some symmetric ones)
aatype = np.array([0, 3, 1, 6, 2, 13, 3, 18, 0, 1])  # ALA, ASP, ARG, GLU, ASN, PHE, ASP, TYR, ALA, ARG

gt_renamed, swap_mask = rename_symmetric_ground_truth_atoms(pred_pos, gt_pos, aatype)

print(f"\nSwap mask: {swap_mask}")

In [None]:
# Test 2: Force swapping
print("\nTest 2: Force Swapping (GT atoms are swapped relative to prediction)")
print("="*60)

N_res = 5
N_atoms = 14

# Predicted positions
pred_pos = np.random.randn(N_res, N_atoms, 3)

# Ground truth - same as predictions
gt_pos = pred_pos.copy()

# Now swap the ground truth atoms for ASP residues
# This should trigger the algorithm to swap them back
aatype = np.array([3, 3, 3, 3, 3])  # All ASP

# Swap atoms 5 and 6 in ground truth
for i in range(N_res):
    gt_pos[i, 5], gt_pos[i, 6] = gt_pos[i, 6].copy(), gt_pos[i, 5].copy()

# Compute loss before renaming
loss_before = np.sum((pred_pos - gt_pos) ** 2)

gt_renamed, swap_mask = rename_symmetric_ground_truth_atoms(pred_pos, gt_pos, aatype)

# Compute loss after renaming
loss_after = np.sum((pred_pos - gt_renamed) ** 2)

print(f"\nLoss before renaming: {loss_before:.4f}")
print(f"Loss after renaming: {loss_after:.4f}")
print(f"All swapped: {swap_mask.all()}")

In [None]:
# Test 3: Mixed residues
print("\nTest 3: Mixed Residues (some symmetric, some not)")
print("="*60)

N_res = 20
N_atoms = 14

pred_pos = np.random.randn(N_res, N_atoms, 3)
gt_pos = pred_pos.copy() + np.random.randn(N_res, N_atoms, 3) * 0.1

# Random amino acid types
aatype = np.random.randint(0, 20, size=N_res)

# Count symmetric residues
n_symmetric_total = sum(1 for aa in aatype if aa in SYMMETRIC_ATOMS)
print(f"Amino acids: {[AA_NAMES[aa] for aa in aatype]}")
print(f"Symmetric residues: {n_symmetric_total}")

gt_renamed, swap_mask = rename_symmetric_ground_truth_atoms(pred_pos, gt_pos, aatype)

In [None]:
# Test 4: PHE/TYR with ring symmetry
print("\nTest 4: Phenylalanine Ring Symmetry")
print("="*60)

N_res = 4
N_atoms = 14

pred_pos = np.random.randn(N_res, N_atoms, 3)
gt_pos = pred_pos.copy()

# All PHE residues
aatype = np.array([13, 13, 13, 13])

# Swap ring atoms in ground truth for half the residues
for i in [0, 2]:  # Swap for residues 0 and 2
    gt_pos[i, 5], gt_pos[i, 6] = gt_pos[i, 6].copy(), gt_pos[i, 5].copy()  # CD1 ↔ CD2
    gt_pos[i, 7], gt_pos[i, 8] = gt_pos[i, 8].copy(), gt_pos[i, 7].copy()  # CE1 ↔ CE2

gt_renamed, swap_mask = rename_symmetric_ground_truth_atoms(pred_pos, gt_pos, aatype)

print(f"\nExpected swaps at residues 0, 2: {swap_mask}")

## Verification: Key Properties

In [None]:
print("Verification: Key Properties")
print("="*60)

np.random.seed(42)
N_res, N_atoms = 50, 14

pred_pos = np.random.randn(N_res, N_atoms, 3)
gt_pos = pred_pos.copy() + np.random.randn(N_res, N_atoms, 3) * 0.3
aatype = np.random.randint(0, 20, size=N_res)

gt_renamed, swap_mask = rename_symmetric_ground_truth_atoms(pred_pos, gt_pos, aatype)

# Property 1: Shape preserved
shape_preserved = gt_renamed.shape == gt_pos.shape
print(f"Property 1 - Shape preserved: {shape_preserved}")

# Property 2: Loss is reduced or unchanged
loss_original = np.sum((pred_pos - gt_pos) ** 2)
loss_renamed = np.sum((pred_pos - gt_renamed) ** 2)
loss_improved = loss_renamed <= loss_original + 1e-6  # Small tolerance
print(f"Property 2 - Loss improved: {loss_improved} ({loss_original:.4f} -> {loss_renamed:.4f})")

# Property 3: Non-symmetric residues unchanged
non_symmetric_mask = np.array([aa not in SYMMETRIC_ATOMS for aa in aatype])
non_sym_unchanged = np.allclose(gt_renamed[non_symmetric_mask], gt_pos[non_symmetric_mask])
print(f"Property 3 - Non-symmetric unchanged: {non_sym_unchanged}")

# Property 4: Only swap pairs are modified
for res_idx in range(N_res):
    if not swap_mask[res_idx]:
        continue
    aa = int(aatype[res_idx])
    swap_indices = set()
    for a1, a2 in SYMMETRIC_ATOMS[aa]:
        swap_indices.add(a1)
        swap_indices.add(a2)
    for atom_idx in range(N_atoms):
        if atom_idx not in swap_indices:
            assert np.allclose(gt_renamed[res_idx, atom_idx], gt_pos[res_idx, atom_idx])

print(f"Property 4 - Only swap pairs modified: True")

## Source Code Reference

```python
# From AF2-source-code/model/folding.py

def compute_renamed_ground_truth(batch, atom14_pred_positions):
  """Find optimal renaming for symmetric atoms.

  Jumper et al. (2021) Suppl. Alg. 26
  "renameSymmetricGroundTruthAtoms"
  """
  # Get ground truth chi angles
  true_chi_angles = batch['chi_angles']
  
  # Compute predicted chi angles from atom positions
  pred_chi_angles = compute_chi_angles(atom14_pred_positions, batch['aatype'])
  
  # Find residues where swapping reduces chi angle error
  # (Chi angles change sign when symmetric atoms are swapped)
  
  chi_error_orig = angular_difference(pred_chi_angles, true_chi_angles)
  chi_error_swap = angular_difference(pred_chi_angles, -true_chi_angles)
  
  alt_naming_is_better = chi_error_swap < chi_error_orig
  
  # Apply renaming
  renamed_atom_positions = jnp.where(
      alt_naming_is_better[..., None, None],
      swap_atom_positions(batch['atom14_gt_positions']),
      batch['atom14_gt_positions'])
  
  return renamed_atom_positions
```

## Key Insights

1. **Chemical Equivalence**: Some atoms are chemically identical and can be labeled interchangeably in crystal structures.

2. **Training Supervision**: Without renaming, the model might be penalized for correct predictions with "wrong" atom naming.

3. **Chi Angle Based**: The actual implementation uses chi angles rather than positions, as chi angles change sign when atoms are swapped.

4. **Per-Residue Decision**: Each symmetric residue is handled independently.

5. **Consistent Swapping**: For residues like PHE/TYR with multiple swap pairs, all pairs are swapped together to maintain ring geometry.

6. **Loss Reduction**: By construction, renaming can only reduce or maintain the loss, never increase it.