# Chai-1 Embedding Analysis

Visualize and analyze the trunk embeddings extracted from Chai-1.

**Prerequisites:** Run `scripts/02_extract_embeddings.py` on a GPU machine first.

In [None]:
import sys
sys.path.insert(0, '..')

import json
import numpy as np
import matplotlib.pyplot as plt
import torch
from pathlib import Path
from collections import defaultdict

from src.embeddings.chai_extractor import ChaiEmbeddings

## Load Embeddings

In [None]:
EMBEDDING_DIR = Path('../data/embeddings/chai_trunk')

# Check if embeddings exist
embedding_files = list(EMBEDDING_DIR.glob('*.pt'))
print(f"Found {len(embedding_files)} embedding files")

if len(embedding_files) == 0:
    print("\nNo embeddings found!")
    print("Run scripts/02_extract_embeddings.py on a GPU machine first.")
else:
    print(f"\nSample files: {[f.name for f in embedding_files[:5]]}")

In [None]:
# Load all embeddings
embeddings = {}
for f in embedding_files:
    emb = ChaiEmbeddings.load(f)
    embeddings[emb.protein_name] = emb

print(f"Loaded {len(embeddings)} protein embeddings")

## Embedding Dimensions

In [None]:
# Check dimensions from first embedding
sample_emb = next(iter(embeddings.values()))

print(f"Sample protein: {sample_emb.protein_name}")
print(f"Sequence length: {len(sample_emb.sequence)}")
print(f"Single shape: {sample_emb.single.shape}")
print(f"Pair shape: {sample_emb.pair.shape}")
print()

D_single = sample_emb.single.shape[1]
D_pair = sample_emb.pair.shape[2]
print(f"D_single = {D_single}")
print(f"D_pair = {D_pair}")

In [None]:
# Collect stats across all proteins
lengths = []
single_means = []
single_stds = []
pair_means = []
pair_stds = []

for name, emb in embeddings.items():
    lengths.append(len(emb.sequence))
    single_means.append(emb.single.mean().item())
    single_stds.append(emb.single.std().item())
    pair_means.append(emb.pair.mean().item())
    pair_stds.append(emb.pair.std().item())

print(f"Sequence lengths: {min(lengths)}-{max(lengths)} (mean {np.mean(lengths):.1f})")
print(f"Single embedding mean: {np.mean(single_means):.4f} ± {np.std(single_means):.4f}")
print(f"Single embedding std: {np.mean(single_stds):.4f} ± {np.std(single_stds):.4f}")
print(f"Pair embedding mean: {np.mean(pair_means):.4f} ± {np.std(pair_means):.4f}")
print(f"Pair embedding std: {np.mean(pair_stds):.4f} ± {np.std(pair_stds):.4f}")

## Single Embedding Analysis

In [None]:
# Visualize single embedding for one protein
sample_name = list(embeddings.keys())[0]
sample = embeddings[sample_name]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Heatmap of single embedding
im = axes[0].imshow(sample.single.numpy(), aspect='auto', cmap='RdBu_r')
axes[0].set_xlabel('Embedding dimension')
axes[0].set_ylabel('Residue position')
axes[0].set_title(f'Single embedding: {sample_name}\n({len(sample.sequence)} aa)')
plt.colorbar(im, ax=axes[0])

# Distribution of values
axes[1].hist(sample.single.numpy().flatten(), bins=100, edgecolor='black', alpha=0.7)
axes[1].set_xlabel('Embedding value')
axes[1].set_ylabel('Count')
axes[1].set_title('Distribution of single embedding values')

plt.tight_layout()
plt.show()

In [None]:
# Per-position embedding magnitude (norm)
norms = torch.norm(sample.single, dim=1).numpy()

fig, ax = plt.subplots(figsize=(12, 4))
ax.bar(range(len(norms)), norms, alpha=0.7)
ax.set_xlabel('Residue position')
ax.set_ylabel('Embedding L2 norm')
ax.set_title(f'Per-residue embedding magnitude: {sample_name}')

# Annotate sequence
for i, aa in enumerate(sample.sequence):
    ax.annotate(aa, (i, norms[i]), ha='center', va='bottom', fontsize=6)

plt.tight_layout()
plt.show()

## Pair Embedding Analysis

In [None]:
# Visualize pair embedding for one protein
# Average across embedding dimension to get contact-like map
pair_mean = sample.pair.mean(dim=-1).numpy()  # [L, L]
pair_norm = torch.norm(sample.pair, dim=-1).numpy()  # [L, L]

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

im1 = axes[0].imshow(pair_mean, cmap='RdBu_r')
axes[0].set_xlabel('Residue j')
axes[0].set_ylabel('Residue i')
axes[0].set_title('Pair embedding mean')
plt.colorbar(im1, ax=axes[0])

im2 = axes[1].imshow(pair_norm, cmap='viridis')
axes[1].set_xlabel('Residue j')
axes[1].set_ylabel('Residue i')
axes[1].set_title('Pair embedding L2 norm')
plt.colorbar(im2, ax=axes[1])

plt.suptitle(f'Pair embeddings: {sample_name} ({len(sample.sequence)} aa)')
plt.tight_layout()
plt.show()

In [None]:
# Look at specific pair embedding channels
n_channels = min(16, D_pair)
fig, axes = plt.subplots(4, 4, figsize=(12, 12))

for idx, ax in enumerate(axes.flat):
    if idx < n_channels:
        im = ax.imshow(sample.pair[:, :, idx].numpy(), cmap='RdBu_r')
        ax.set_title(f'Channel {idx}')
        ax.set_xticks([])
        ax.set_yticks([])
    else:
        ax.axis('off')

plt.suptitle(f'Pair embedding channels: {sample_name}')
plt.tight_layout()
plt.show()

## Sequence Distance vs Pair Embedding

In [None]:
# Does pair embedding correlate with sequence separation?
L = len(sample.sequence)
seq_dist = np.abs(np.arange(L)[:, None] - np.arange(L)[None, :])  # |i - j|

# Flatten and plot
seq_dist_flat = seq_dist.flatten()
pair_norm_flat = pair_norm.flatten()

fig, ax = plt.subplots(figsize=(10, 6))
ax.scatter(seq_dist_flat, pair_norm_flat, alpha=0.3, s=10)
ax.set_xlabel('Sequence separation |i - j|')
ax.set_ylabel('Pair embedding norm')
ax.set_title('Pair embedding magnitude vs sequence distance')

# Add mean line
for d in range(max(seq_dist_flat) + 1):
    mask = seq_dist_flat == d
    if mask.sum() > 0:
        ax.scatter(d, pair_norm_flat[mask].mean(), c='red', s=50, zorder=5)

plt.tight_layout()
plt.show()

## Amino Acid Embedding Comparison

In [None]:
# Collect embeddings by amino acid type across all proteins
AA_VOCAB = 'ACDEFGHIKLMNPQRSTVWY'
aa_embeddings = defaultdict(list)

for name, emb in embeddings.items():
    for i, aa in enumerate(emb.sequence):
        if aa in AA_VOCAB:
            aa_embeddings[aa].append(emb.single[i].numpy())

# Convert to arrays
aa_embeddings = {aa: np.stack(vecs) for aa, vecs in aa_embeddings.items()}

print("Embeddings per amino acid:")
for aa in AA_VOCAB:
    if aa in aa_embeddings:
        print(f"  {aa}: {len(aa_embeddings[aa])} samples")

In [None]:
# Compute mean embedding per amino acid
aa_means = {aa: vecs.mean(axis=0) for aa, vecs in aa_embeddings.items()}

# Compute pairwise cosine similarity between amino acid mean embeddings
aas = sorted(aa_means.keys())
n_aa = len(aas)
similarity_matrix = np.zeros((n_aa, n_aa))

for i, aa1 in enumerate(aas):
    for j, aa2 in enumerate(aas):
        v1, v2 = aa_means[aa1], aa_means[aa2]
        sim = np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
        similarity_matrix[i, j] = sim

fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(similarity_matrix, cmap='RdBu_r', vmin=-1, vmax=1)
ax.set_xticks(range(n_aa))
ax.set_yticks(range(n_aa))
ax.set_xticklabels(aas)
ax.set_yticklabels(aas)
ax.set_title('Cosine similarity between amino acid mean embeddings')
plt.colorbar(im)
plt.tight_layout()
plt.show()

In [None]:
# PCA of amino acid embeddings
from sklearn.decomposition import PCA

# Stack all mean embeddings
aa_matrix = np.stack([aa_means[aa] for aa in aas])

# PCA
pca = PCA(n_components=2)
aa_pca = pca.fit_transform(aa_matrix)

# Color by property
# Hydrophobic: A, V, I, L, M, F, W, P
# Polar: S, T, N, Q, Y, C
# Charged+: K, R, H
# Charged-: D, E
colors = []
for aa in aas:
    if aa in 'AVILMFWP':
        colors.append('orange')  # hydrophobic
    elif aa in 'STNQYC':
        colors.append('green')  # polar
    elif aa in 'KRH':
        colors.append('blue')  # positive
    elif aa in 'DE':
        colors.append('red')  # negative
    else:
        colors.append('gray')

fig, ax = plt.subplots(figsize=(10, 8))
ax.scatter(aa_pca[:, 0], aa_pca[:, 1], c=colors, s=200, alpha=0.7)

for i, aa in enumerate(aas):
    ax.annotate(aa, (aa_pca[i, 0], aa_pca[i, 1]), fontsize=14, ha='center', va='center')

ax.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)')
ax.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)')
ax.set_title('PCA of amino acid mean embeddings\n(orange=hydrophobic, green=polar, blue=positive, red=negative)')
plt.tight_layout()
plt.show()

## Protein-Level Analysis

In [None]:
# PCA/UMAP of protein-level embeddings (mean-pooled)
protein_names = list(embeddings.keys())
protein_single_means = np.stack([embeddings[n].single.mean(dim=0).numpy() for n in protein_names])

print(f"Protein embedding matrix: {protein_single_means.shape}")

# PCA
pca_proteins = PCA(n_components=2)
protein_pca = pca_proteins.fit_transform(protein_single_means)

fig, ax = plt.subplots(figsize=(10, 8))
scatter = ax.scatter(protein_pca[:, 0], protein_pca[:, 1], 
                     c=lengths, cmap='viridis', s=50, alpha=0.7)
ax.set_xlabel(f'PC1 ({pca_proteins.explained_variance_ratio_[0]*100:.1f}%)')
ax.set_ylabel(f'PC2 ({pca_proteins.explained_variance_ratio_[1]*100:.1f}%)')
ax.set_title('PCA of protein mean embeddings (colored by length)')
plt.colorbar(scatter, label='Sequence length')
plt.tight_layout()
plt.show()

In [None]:
# Try UMAP if available
try:
    from umap import UMAP
    
    umap_model = UMAP(n_neighbors=15, min_dist=0.1, random_state=42)
    protein_umap = umap_model.fit_transform(protein_single_means)
    
    fig, ax = plt.subplots(figsize=(10, 8))
    scatter = ax.scatter(protein_umap[:, 0], protein_umap[:, 1], 
                         c=lengths, cmap='viridis', s=50, alpha=0.7)
    ax.set_xlabel('UMAP 1')
    ax.set_ylabel('UMAP 2')
    ax.set_title('UMAP of protein mean embeddings (colored by length)')
    plt.colorbar(scatter, label='Sequence length')
    plt.tight_layout()
    plt.show()
except ImportError:
    print("UMAP not installed. Run: uv add umap-learn")

## Storage Statistics

In [None]:
# Calculate storage requirements
total_single_elements = sum(emb.single.numel() for emb in embeddings.values())
total_pair_elements = sum(emb.pair.numel() for emb in embeddings.values())

# Assuming float32 (4 bytes)
single_bytes = total_single_elements * 4
pair_bytes = total_pair_elements * 4

print(f"Storage requirements:")
print(f"  Single embeddings: {single_bytes / 1e6:.1f} MB")
print(f"  Pair embeddings: {pair_bytes / 1e6:.1f} MB")
print(f"  Total: {(single_bytes + pair_bytes) / 1e6:.1f} MB")
print()
print(f"  Single elements: {total_single_elements:,}")
print(f"  Pair elements: {total_pair_elements:,}")

In [None]:
# Actual file sizes on disk
file_sizes = [f.stat().st_size for f in embedding_files]
total_disk = sum(file_sizes)

print(f"Actual disk usage:")
print(f"  Total: {total_disk / 1e6:.1f} MB")
print(f"  Per protein: {np.mean(file_sizes) / 1e3:.1f} KB (mean)")
print(f"  Range: {min(file_sizes) / 1e3:.1f} - {max(file_sizes) / 1e3:.1f} KB")

## Summary

Key findings from embedding analysis:
- Single embedding dimension: D_single (per-residue features)
- Pair embedding dimension: D_pair (pairwise interactions)
- Pair embeddings capture sequence proximity patterns
- Amino acid embeddings cluster by biochemical properties
- Storage is dominated by pair embeddings (L² scaling)