In [15]:
import torch
from torch.serialization import safe_globals

from esm.models.esm3 import ESM3
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer

In [16]:
# Register the ESM3 class as safe for unpickling.
torch.serialization.add_safe_globals([ESM3])

In [17]:
backbone_save_path = "/home/jupyter/DATA/model_weights/esm3_backbone/esm3_backbone_model.pt"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
with safe_globals([("esm.models.esm3.ESM3", ESM3)]):
    loaded_backbone_model = torch.load(backbone_save_path, map_location=device, weights_only=False)
    loaded_backbone_model.eval()

cuda


In [18]:
state_dict = loaded_backbone_model.state_dict()
for key, tensor in state_dict.items():
    print(f"{key}: {tensor.shape}")
    
total_params = sum(p.numel() for p in loaded_backbone_model.parameters())
print("Total number of parameters:", total_params)

encoder.sequence_embed.weight: torch.Size([64, 1536])
encoder.plddt_projection.weight: torch.Size([1536, 16])
encoder.plddt_projection.bias: torch.Size([1536])
encoder.structure_per_res_plddt_projection.weight: torch.Size([1536, 16])
encoder.structure_per_res_plddt_projection.bias: torch.Size([1536])
encoder.structure_tokens_embed.weight: torch.Size([4101, 1536])
encoder.ss8_embed.weight: torch.Size([11, 1536])
encoder.sasa_embed.weight: torch.Size([19, 1536])
encoder.function_embed.0.weight: torch.Size([260, 192])
encoder.function_embed.1.weight: torch.Size([260, 192])
encoder.function_embed.2.weight: torch.Size([260, 192])
encoder.function_embed.3.weight: torch.Size([260, 192])
encoder.function_embed.4.weight: torch.Size([260, 192])
encoder.function_embed.5.weight: torch.Size([260, 192])
encoder.function_embed.6.weight: torch.Size([260, 192])
encoder.function_embed.7.weight: torch.Size([260, 192])
encoder.residue_embed.weight: torch.Size([1478, 1536])
transformer.blocks.0.attn.layern

# Explainer: Transformer Block with Geometric Attention in ESM3

In the ESM3 architecture, the **first transformer block** is unique because it integrates a specialized module for geometric reasoning, which is crucial for capturing structural relationships (e.g., in protein data).
```
TransformerStack(
  (blocks): ModuleList(
    (0): UnifiedTransformerBlock(
      (attn): MultiHeadAttention(
        (layernorm_qkv): Sequential(
          (0): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
          (1): Linear(in_features=1536, out_features=4608, bias=False)
        )
        (out_proj): Linear(in_features=1536, out_features=1536, bias=False)
        (q_ln): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
        (k_ln): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
        (rotary): RotaryEmbedding()
      )
      (geom_attn): GeometricReasoningOriginalImpl(
        (s_norm): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
        (proj): Linear(in_features=1536, out_features=3840, bias=False)
        (out_proj): Linear(in_features=768, out_features=1536, bias=False)
      )
      (ffn): Sequential(
        (0): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
        (1): Linear(in_features=1536, out_features=8192, bias=False)
        (2): SwiGLU()
        (3): Linear(in_features=4096, out_features=1536, bias=False)
      )
    )
    (1-47): 47 x UnifiedTransformerBlock(
      (attn): MultiHeadAttention(
        (layernorm_qkv): Sequential(
          (0): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
          (1): Linear(in_features=1536, out_features=4608, bias=False)
        )
        (out_proj): Linear(in_features=1536, out_features=1536, bias=False)
        (q_ln): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
        (k_ln): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
        (rotary): RotaryEmbedding()
      )
      (ffn): Sequential(
        (0): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
        (1): Linear(in_features=1536, out_features=8192, bias=False)
        (2): SwiGLU()
        (3): Linear(in_features=4096, out_features=1536, bias=False)
      )
    )
  )
  (norm): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
```

Below is a breakdown of its components:

---

## 1. Multi-Head Self-Attention (`attn`)

- **Purpose:**  
  Processes the input sequence by allowing each token to attend to every other token.

- **Key Components:**  
  - **Layer Normalization & Linear Projections:**  
    Prepares the queries, keys, and values for attention computation.
  - **Rotary Embeddings:**  
    Enhances the attention mechanism by incorporating relative positional information.
  - **Output Projection:**  
    Consolidates the attention outputs back into the original embedding space.

---

## 2. Geometric Attention (`geom_attn`)

- **Purpose:**  
  Enhances the transformer block by integrating geometric or structural information. This is particularly important for applications like protein structure prediction where spatial relationships are key.

- **Key Components:**
  - **Layer Normalization (`s_norm`):**  
    Normalizes the input prior to geometric processing.
  - **Projection (`proj`):**  
    Maps the 1536-dimensional input into a higher-dimensional space (3840 dimensions) to capture more complex geometric features.
  - **Output Projection (`out_proj`):**  
    Reduces the dimensionality from 768 (after geometric processing) back to 1536, ensuring compatibility with the rest of the network.

---

## 3. Feed-Forward Network (`ffn`)

- **Purpose:**  
  Applies further non-linear transformations to the output from the attention mechanisms.

- **Key Components:**
  - **Layer Normalization:**  
    Prepares the data for the feed-forward operations.
  - **Intermediate Linear Layer (Expansion to 8192 dimensions):**  
    Expands the feature space to allow for complex representations.
  - **SwiGLU Activation:**  
    A non-linear activation that enhances the model’s expressivity.
  - **Final Linear Projection:**  
    Compresses the expanded representation back to the original 1536 dimensions.

---

## Summary

The **first transformer block** in ESM3 is a **UnifiedTransformerBlock** that stands out due to the integration of **geometric attention**. This extra module (`geom_attn`) augments the standard self-attention mechanism by incorporating geometric reasoning, making the block adept at handling structural data alongside sequence information.

# esm encoder+transformer block unit testing

In [5]:
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer


In [6]:
from esm.sdk.api import ESMProtein, ProteinComplex
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.types import FunctionAnnotation

In [7]:
fpath = '/home/jupyter/1BEY.pdb'
protein_chain = ProteinChain.from_pdb(fpath, chain_id='H')
monomer_protein = ESMProtein.from_protein_chain(protein_chain)

In [8]:
complex = ProteinComplex.from_pdb(fpath)
multimer_protein = ESMProtein.from_protein_complex(complex)
protein = multimer_protein # for convenience

In [38]:
# We must pass zero vectors for input modalities with null

# Dummy structure tokens
dummy_structure_tokens = torch.zeros_like(seq_tokens_tensor, dtype=torch.int64, device=device)

# Dummy tokens for ss8 (secondary structure, 11 classes) and SASA (solvent accessibility, 19 classes)
dummy_ss8_tokens = torch.zeros_like(seq_tokens_tensor, dtype=torch.int64, device=device)
dummy_sasa_tokens = torch.zeros_like(seq_tokens_tensor, dtype=torch.int64, device=device)

# Dummy pLDDT values (typically float values per residue)
dummy_average_plddt = torch.ones_like(seq_tokens_tensor, dtype=torch.float32, device=device)
dummy_per_res_plddt = torch.ones_like(seq_tokens_tensor, dtype=torch.float32, device=device)

batch_size, seq_len = seq_tokens_tensor.shape
dummy_rbf = torch.ones(batch_size, seq_len, 16,
                       dtype=loaded_backbone_model.encoder.plddt_projection.weight.dtype,
                       device=device)


In [70]:
# ---------------------------------------------------------------------------
# TEST SEQUENCE EMBEDDINGS
with torch.no_grad():
    embeddings = loaded_backbone_model.encoder.sequence_embed(seq_tokens_tensor)
print("Sequence embeddings shape:", embeddings.shape)
print(embeddings)

Sequence embeddings shape: torch.Size([1, 427, 1536])
tensor([[[ 0.0776, -0.0176,  0.0410,  ..., -0.0295, -0.0197, -0.0006],
         [ 0.0859, -0.1328, -0.0781,  ..., -0.0317,  0.0208, -0.0364],
         [-0.0033, -0.0562,  0.0261,  ..., -0.0188, -0.0027, -0.0928],
         ...,
         [ 0.1367, -0.0083,  0.0413,  ...,  0.0200,  0.1426, -0.0144],
         [ 0.0132, -0.0060,  0.1436,  ..., -0.0309, -0.1348, -0.0116],
         [ 0.0598,  0.0299,  0.0173,  ..., -0.0312, -0.0217, -0.0162]]],
       device='cuda:0', dtype=torch.bfloat16)


In [40]:
# ---------------------------------------------------------------------------
# TEST STRUCTURE TOKENS EMBEDDINGS
# Create dummy structure tokens (all zeros, matching the sequence shape)
dummy_structure_tokens = torch.zeros_like(seq_tokens_tensor, dtype=torch.int64, device=device)
with torch.no_grad():
    structure_embedding = loaded_backbone_model.encoder.structure_tokens_embed(dummy_structure_tokens)
print("Structure tokens embedding shape:", structure_embedding.shape)
print(structure_embedding)

Structure tokens embedding shape: torch.Size([1, 17, 1536])
tensor([[[ 0.0454,  0.1699,  0.1035,  ..., -0.1133, -0.0625,  0.0530],
         [ 0.0454,  0.1699,  0.1035,  ..., -0.1133, -0.0625,  0.0530],
         [ 0.0454,  0.1699,  0.1035,  ..., -0.1133, -0.0625,  0.0530],
         ...,
         [ 0.0454,  0.1699,  0.1035,  ..., -0.1133, -0.0625,  0.0530],
         [ 0.0454,  0.1699,  0.1035,  ..., -0.1133, -0.0625,  0.0530],
         [ 0.0454,  0.1699,  0.1035,  ..., -0.1133, -0.0625,  0.0530]]],
       device='cuda:0', dtype=torch.bfloat16)


In [41]:
# ---------------------------------------------------------------------------
# TEST PLDDT PROJECTION
# Create a dummy input for the pLDDT projection module (shape: batch_size x seq_len x 16)
batch_size, seq_len = seq_tokens_tensor.shape
dummy_rbf = torch.ones(
    batch_size,
    seq_len,
    16,
    dtype=loaded_backbone_model.encoder.plddt_projection.weight.dtype,
    device=device
)
with torch.no_grad():
    plddt_embedding = loaded_backbone_model.encoder.plddt_projection(dummy_rbf)
print("pLDDT projection output shape:", plddt_embedding.shape)
print(plddt_embedding)

pLDDT projection output shape: torch.Size([1, 17, 1536])
tensor([[[ 1.0391,  0.9102, -0.0137,  ..., -0.2275, -1.1875,  1.2891],
         [ 1.0391,  0.9102, -0.0137,  ..., -0.2275, -1.1875,  1.2891],
         [ 1.0391,  0.9102, -0.0137,  ..., -0.2275, -1.1875,  1.2891],
         ...,
         [ 1.0391,  0.9102, -0.0137,  ..., -0.2275, -1.1875,  1.2891],
         [ 1.0391,  0.9102, -0.0137,  ..., -0.2275, -1.1875,  1.2891],
         [ 1.0391,  0.9102, -0.0137,  ..., -0.2275, -1.1875,  1.2891]]],
       device='cuda:0', dtype=torch.bfloat16)


In [42]:
# ---------------------------------------------------------------------------
# TEST SS8 EMBEDDING
# Create dummy tokens for SS8 (secondary structure; 11 classes)
dummy_ss8_tokens = torch.zeros_like(seq_tokens_tensor, dtype=torch.int64, device=device)
with torch.no_grad():
    ss8_embedding = loaded_backbone_model.encoder.ss8_embed(dummy_ss8_tokens)
print("SS8 embedding shape:", ss8_embedding.shape)
print(ss8_embedding)

SS8 embedding shape: torch.Size([1, 17, 1536])
tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0',
       dtype=torch.bfloat16)


In [43]:
# ---------------------------------------------------------------------------
# TEST SASA EMBEDDING
# Create dummy tokens for SASA (solvent accessibility; 19 classes)
dummy_sasa_tokens = torch.zeros_like(seq_tokens_tensor, dtype=torch.int64, device=device)
with torch.no_grad():
    sasa_embedding = loaded_backbone_model.encoder.sasa_embed(dummy_sasa_tokens)
print("SASA embedding shape:", sasa_embedding.shape)
print(sasa_embedding)

SASA embedding shape: torch.Size([1, 17, 1536])
tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0',
       dtype=torch.bfloat16)


In [44]:
# ---------------------------------------------------------------------------
# TEST FUNCTION EMBEDDINGS
# The function_embed module is a ModuleList; test each embedding layer.
for i, func_embed in enumerate(loaded_backbone_model.encoder.function_embed):
    with torch.no_grad():
        func_embedding = func_embed(seq_tokens_tensor)
    print(f"Function embedding {i} shape:", func_embedding.shape)
    print(func_embedding)

Function embedding 0 shape: torch.Size([1, 17, 192])
tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.2812,  0.1147, -0.4746,  ...,  0.1230, -0.6133,  0.3711],
         [-0.9922, -0.3828,  0.4082,  ...,  0.4238, -0.2676,  0.6367],
         ...,
         [-0.3105,  0.6406, -0.7344,  ..., -0.4609,  0.2061,  0.7539],
         [-0.1162, -0.0476,  0.5820,  ..., -0.3125, -0.2734, -0.3555],
         [-0.2480,  0.2949, -0.6211,  ...,  0.3691,  0.1768, -0.2119]]],
       device='cuda:0', dtype=torch.bfloat16)
Function embedding 1 shape: torch.Size([1, 17, 192])
tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.3535, -0.1108,  0.1235,  ...,  0.5859,  0.1719, -0.1631],
         [ 0.0452,  0.4883, -0.4570,  ...,  0.4160, -0.3770, -0.2305],
         ...,
         [ 0.9102,  0.6133,  0.3770,  ...,  0.2148,  0.0776, -1.0000],
         [-0.6211,  0.0542, -0.7031,  ..., -0.0571, -0.3672,  0.0087],
         [-0.1553, -0.3516,  0.2656, 

In [45]:
# ---------------------------------------------------------------------------
# TEST RESIDUE EMBEDDING
# The residue embedding is an EmbeddingBag, so we flatten the input and provide offsets.
dummy_residue_tokens = torch.zeros(seq_tokens_tensor.shape, dtype=torch.int64, device=device)
dummy_residue_tokens_flat = dummy_residue_tokens.view(-1)
offsets = torch.tensor([0], dtype=torch.int64, device=device)
with torch.no_grad():
    residue_embedding = loaded_backbone_model.encoder.residue_embed(dummy_residue_tokens_flat, offsets)
print("Residue embedding shape:", residue_embedding.shape)
print(residue_embedding)

Residue embedding shape: torch.Size([1, 1536])
tensor([[0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0', dtype=torch.bfloat16)


structure_embedding = loaded_backbone_model.encoder.structure_tokens_embed(structure_tokens)
print("Structure tokens embedding shape:", structure_embedding.shape)
    

# objectives:
# 1. validate other embeddings work (may need antibody pdbs as input)
# 2. validate embedding dimensions / shape
# 3. Run multiple sequences in one pass?
# 4. Speed test 10, 100, 1000 sequences runtime
# 5. Training...OPT 3 guide a fine-tuning of this model (on 5 random antibody sequences)
# 6. LORA for training econo.y








