# Testing out of box ESM sdk with antibodies using our antibody multimerization

In [1]:
import torch
from esm.pretrained import ESM3_sm_open_v0  # Loads ESM3 with appropriate parameters and weights
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
from esm.sdk.api import ESM3InferenceClient, GenerationConfig  # The SDK provides inference helpers

# Setup device and load model from local weights.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Here we use the helper function to load the small open model (ESM3_sm_open_v1)
model = ESM3_sm_open_v0(device)

# Load the state dictionary from the file
fpath = '/home/jupyter/DATA/model_weights/esm3_complete/esm3_sm_open_v1_state_dict.pt'
state_dict = torch.load(fpath, map_location=device)
model.load_state_dict(state_dict)

# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()



Fetching 22 files:   0%|          | 0/22 [00:00<?, ?it/s]

ESM3(
  (encoder): EncodeInputs(
    (sequence_embed): Embedding(64, 1536)
    (plddt_projection): Linear(in_features=16, out_features=1536, bias=True)
    (structure_per_res_plddt_projection): Linear(in_features=16, out_features=1536, bias=True)
    (structure_tokens_embed): Embedding(4101, 1536)
    (ss8_embed): Embedding(11, 1536)
    (sasa_embed): Embedding(19, 1536)
    (function_embed): ModuleList(
      (0-7): 8 x Embedding(260, 192, padding_idx=0)
    )
    (residue_embed): EmbeddingBag(1478, 1536, mode='sum', padding_idx=0)
  )
  (transformer): TransformerStack(
    (blocks): ModuleList(
      (0): UnifiedTransformerBlock(
        (attn): MultiHeadAttention(
          (layernorm_qkv): Sequential(
            (0): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
            (1): Linear(in_features=1536, out_features=4608, bias=False)
          )
          (out_proj): Linear(in_features=1536, out_features=1536, bias=False)
          (q_ln): LayerNorm((1536,), eps=1e-05, el

In [2]:
# Wrap the model in the SDK's inference client.
# The SDK (Software Development Kit) provides helper functions to run inference
# (such as encoding sequences and generating predictions) without calling the model's
# forward() method directly.
client = ESM3InferenceClient(model=model, tokenizers=model.tokenizers, device=device)

TypeError: ESM3InferenceClient() takes no arguments

In [None]:
# -------------------------------
# 1. Sequence to Structure
# -------------------------------
seq_tokenizer = EsmSequenceTokenizer()
test_sequence = "EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYAMSWVRQAPGKGLEWVSAISWNSGNTLYLQMNSLRAEDTAVYYCAR"
# Use the SDK's encode method, which wraps the tokenizer and prepares the input for the model.
encoded_sequence = client.encode(test_sequence)
print("Input sequence:", test_sequence)
print("Encoded sequence tensor shape:", encoded_sequence.shape)

# Generate structure tokens from the sequence.
# Here we specify the generation configuration. Adjust num_steps and temperature as needed.
gen_config = GenerationConfig(track="structure", num_steps=10, temperature=0.1)
with torch.no_grad():
    output_seq2struct = client.generate(encoded_sequence, config=gen_config)

print("\n--- Sequence to Structure Output ---")
if hasattr(output_seq2struct, "structure") and output_seq2struct.structure is not None:
    print("Structure tokens shape:", output_seq2struct.structure.shape)
else:
    print("Full output from sequence-to-structure pass:")
    print(output_seq2struct)

In [2]:
# -------------------------------
# 1. Sequence to Structure
# -------------------------------
seq_tokenizer = EsmSequenceTokenizer()
test_sequence = "EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYAMSWVRQAPGKGLEWVSAISWNSGNTLYLQMNSLRAEDTAVYYCAR"
seq_tokens = seq_tokenizer.encode(test_sequence)
seq_tokens_tensor = torch.tensor(seq_tokens, dtype=torch.int64).unsqueeze(0).to(device)
print("Input sequence:", test_sequence)
print("Sequence tokens shape:", seq_tokens_tensor.shape)
print("Sequence tokens are:", seq_tokens)

Input sequence: EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYAMSWVRQAPGKGLEWVSAISWNSGNTLYLQMNSLRAEDTAVYYCAR
Sequence tokens shape: torch.Size([1, 80])
Sequence tokens are: [0, 9, 7, 16, 4, 7, 9, 8, 6, 6, 6, 4, 7, 16, 14, 6, 6, 8, 4, 10, 4, 8, 23, 5, 5, 8, 6, 18, 11, 18, 8, 8, 19, 5, 20, 8, 22, 7, 10, 16, 5, 14, 6, 15, 6, 4, 9, 22, 7, 8, 5, 12, 8, 22, 17, 8, 6, 17, 11, 4, 19, 4, 16, 20, 17, 8, 4, 10, 5, 9, 13, 11, 5, 7, 19, 19, 23, 5, 10, 2]


In [19]:
with torch.no_grad():
    
    # Pass the dictionary to the model
    results = model.forward(sequence_tokens=seq_tokens_tensor)


# print("\n--- Sequence to Structure Output ---")
# if "structure_tokens" in result:
#     structure_tokens = result["structure_tokens"]
#     print("Structure tokens shape:", structure_tokens.shape)
# else:
#     print("Full output from sequence-to-structure pass:")
#     print(result())

ESMOutput(sequence_logits=tensor([[[-19.6439, -19.6183, -19.5486,  ..., -19.6119, -19.6217, -19.5716],
         [-18.8586, -18.7849, -18.7275,  ..., -18.7384, -18.8308, -18.8172],
         [-18.8029, -18.8535, -18.7761,  ..., -18.8522, -18.8633, -18.9367],
         ...,
         [-23.5875, -23.5623, -23.5294,  ..., -23.6264, -23.2895, -23.5783],
         [-21.1202, -21.2024, -21.2147,  ..., -21.2279, -21.1434, -21.3248],
         [-20.3476, -20.3085, -20.2701,  ..., -20.3249, -20.2758, -20.3743]]],
       device='cuda:0'), structure_logits=tensor([[[23.2422, 20.4796, 26.0145,  ..., 20.5472, 18.1796, 21.0636],
         [19.5295, 22.4368, 20.5198,  ..., 21.4520, 13.8315, 15.8104],
         [27.2014, 21.8847, 24.6682,  ..., 21.7292, 17.7618, 21.2891],
         ...,
         [25.4692, 13.9029, 28.3393,  ..., 13.9382, 13.2019, 18.0187],
         [24.2239, 19.6468, 27.0249,  ..., 18.8914, 15.2234, 20.6710],
         [23.8610, 22.1027, 24.8920,  ..., 21.5795, 17.7413, 19.7615]]],
       devic

In [None]:
# -------------------------------
# 2. Partially Masked Sequence to Infilled Sequence
# -------------------------------
# Test: mask a segment of the input sequence and let the model fill it in.
mask_token_id = seq_tokenizer.mask_token_id  # Mask token as defined by the tokenizer
masked_seq_tokens = seq_tokens.copy()
# For demonstration, mask tokens at positions 10 through 15.
for i in range(10, 16):
    masked_seq_tokens[i] = mask_token_id
masked_seq_tokens_tensor = torch.tensor(masked_seq_tokens, dtype=torch.int64).unsqueeze(0).to(device)
print("\n--- Partially Masked Sequence ---")
print("Masked sequence tokens:", masked_seq_tokens_tensor)

with torch.no_grad():
    # We assume the model can perform masked token prediction; this may be part of the forward pass
    # or a separate generation/inpainting routine.
    output_mask = model(masked_seq_tokens_tensor)
    # We assume that the output contains a key "sequence" with the predicted full sequence tokens.
    if "sequence" in output_mask:
        predicted_seq_tokens = output_mask["sequence"]
        infilled_sequence = seq_tokenizer.decode(predicted_seq_tokens.squeeze(0).tolist())
        print("Infilled sequence:")
        print(infilled_sequence)
    else:
        print("Masked sequence infilling output:")
        print(output_mask)

# -------------------------------
# 3. Inverse – Structure to Sequence
# -------------------------------
# Test: Given structure output, attempt to recover a sequence.
# This inverse operation is not standard in every model.
# Here we assume that our model has a method called invert_structure().
print("\n--- Inverse: Structure to Sequence ---")
if "structure_tokens" in output_seq2struct:
    with torch.no_grad():
        # Check if the model has an 'invert_structure' method.
        if hasattr(model, "invert_structure"):
            # Pass the structure tokens to the inverse method.
            output_struct2seq = model.invert_structure(structure_tokens)
            if "sequence" in output_struct2seq:
                recovered_seq_tokens = output_struct2seq["sequence"]
                recovered_sequence = seq_tokenizer.decode(recovered_seq_tokens.squeeze(0).tolist())
                print("Recovered sequence from structure:")
                print(recovered_sequence)
            else:
                print("Inverse structure-to-sequence output:")
                print(output_struct2seq)
        else:
            print("Model does not implement an invert_structure method. Inverse operation not supported.")
else:
    print("No structure tokens available from sequence-to-structure pass; cannot run inverse operation.")

In [2]:
# Setup device and load the model with local weights.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
local_weights_path = "/home/jupyter/DATA/model_weights/esm3_complete/esm3_sm_open_v1_state_dict.pt"
state_dict = torch.load(local_weights_path, map_location=device)
model.load_state_dict(state_dict)
model.to(device)
model.eval()  # Set model to evaluation mode

NameError: name 'ESM3' is not defined

In [None]:
state_dict = loaded_backbone_model.state_dict()
for key, tensor in state_dict.items():
    print(f"{key}: {tensor.shape}")
    
total_params = sum(p.numel() for p in loaded_backbone_model.parameters())
print("Total number of parameters:", total_params)

encoder.sequence_embed.weight: torch.Size([64, 1536])
encoder.plddt_projection.weight: torch.Size([1536, 16])
encoder.plddt_projection.bias: torch.Size([1536])
encoder.structure_per_res_plddt_projection.weight: torch.Size([1536, 16])
encoder.structure_per_res_plddt_projection.bias: torch.Size([1536])
encoder.structure_tokens_embed.weight: torch.Size([4101, 1536])
encoder.ss8_embed.weight: torch.Size([11, 1536])
encoder.sasa_embed.weight: torch.Size([19, 1536])
encoder.function_embed.0.weight: torch.Size([260, 192])
encoder.function_embed.1.weight: torch.Size([260, 192])
encoder.function_embed.2.weight: torch.Size([260, 192])
encoder.function_embed.3.weight: torch.Size([260, 192])
encoder.function_embed.4.weight: torch.Size([260, 192])
encoder.function_embed.5.weight: torch.Size([260, 192])
encoder.function_embed.6.weight: torch.Size([260, 192])
encoder.function_embed.7.weight: torch.Size([260, 192])
encoder.residue_embed.weight: torch.Size([1478, 1536])
transformer.blocks.0.attn.layern

# Explainer: Transformer Block with Geometric Attention in ESM3

In the ESM3 architecture, the **first transformer block** is unique because it integrates a specialized module for geometric reasoning, which is crucial for capturing structural relationships (e.g., in protein data).
```
TransformerStack(
  (blocks): ModuleList(
    (0): UnifiedTransformerBlock(
      (attn): MultiHeadAttention(
        (layernorm_qkv): Sequential(
          (0): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
          (1): Linear(in_features=1536, out_features=4608, bias=False)
        )
        (out_proj): Linear(in_features=1536, out_features=1536, bias=False)
        (q_ln): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
        (k_ln): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
        (rotary): RotaryEmbedding()
      )
      (geom_attn): GeometricReasoningOriginalImpl(
        (s_norm): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
        (proj): Linear(in_features=1536, out_features=3840, bias=False)
        (out_proj): Linear(in_features=768, out_features=1536, bias=False)
      )
      (ffn): Sequential(
        (0): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
        (1): Linear(in_features=1536, out_features=8192, bias=False)
        (2): SwiGLU()
        (3): Linear(in_features=4096, out_features=1536, bias=False)
      )
    )
    (1-47): 47 x UnifiedTransformerBlock(
      (attn): MultiHeadAttention(
        (layernorm_qkv): Sequential(
          (0): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
          (1): Linear(in_features=1536, out_features=4608, bias=False)
        )
        (out_proj): Linear(in_features=1536, out_features=1536, bias=False)
        (q_ln): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
        (k_ln): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
        (rotary): RotaryEmbedding()
      )
      (ffn): Sequential(
        (0): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
        (1): Linear(in_features=1536, out_features=8192, bias=False)
        (2): SwiGLU()
        (3): Linear(in_features=4096, out_features=1536, bias=False)
      )
    )
  )
  (norm): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
```

Below is a breakdown of its components:

---

## 1. Multi-Head Self-Attention (`attn`)

- **Purpose:**  
  Processes the input sequence by allowing each token to attend to every other token.

- **Key Components:**  
  - **Layer Normalization & Linear Projections:**  
    Prepares the queries, keys, and values for attention computation.
  - **Rotary Embeddings:**  
    Enhances the attention mechanism by incorporating relative positional information.
  - **Output Projection:**  
    Consolidates the attention outputs back into the original embedding space.

---

## 2. Geometric Attention (`geom_attn`)

- **Purpose:**  
  Enhances the transformer block by integrating geometric or structural information. This is particularly important for applications like protein structure prediction where spatial relationships are key.

- **Key Components:**
  - **Layer Normalization (`s_norm`):**  
    Normalizes the input prior to geometric processing.
  - **Projection (`proj`):**  
    Maps the 1536-dimensional input into a higher-dimensional space (3840 dimensions) to capture more complex geometric features.
  - **Output Projection (`out_proj`):**  
    Reduces the dimensionality from 768 (after geometric processing) back to 1536, ensuring compatibility with the rest of the network.

---

## 3. Feed-Forward Network (`ffn`)

- **Purpose:**  
  Applies further non-linear transformations to the output from the attention mechanisms.

- **Key Components:**
  - **Layer Normalization:**  
    Prepares the data for the feed-forward operations.
  - **Intermediate Linear Layer (Expansion to 8192 dimensions):**  
    Expands the feature space to allow for complex representations.
  - **SwiGLU Activation:**  
    A non-linear activation that enhances the model’s expressivity.
  - **Final Linear Projection:**  
    Compresses the expanded representation back to the original 1536 dimensions.

---

## Summary

The **first transformer block** in ESM3 is a **UnifiedTransformerBlock** that stands out due to the integration of **geometric attention**. This extra module (`geom_attn`) augments the standard self-attention mechanism by incorporating geometric reasoning, making the block adept at handling structural data alongside sequence information.

# esm encoder+transformer block unit testing

In [34]:
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer


In [35]:
# for ESM3
# Initialize sequence 
seq_tokenizer = EsmSequenceTokenizer()
test_sequence = "CSSDGSYGFGAMDYW"
seq_tokens = seq_tokenizer.encode(test_sequence)
seq_tokens_tensor = torch.tensor(seq_tokens, dtype=torch.int64).unsqueeze(0).to(device)
print("Sequence tokens shape:", seq_tokens_tensor.shape)

# Create required tensors
# dummy_average_plddt = torch.ones(seq_tokens_tensor.shape, dtype=torch.float32, device=device)
# dummy_per_res_plddt = torch.ones(seq_tokens_tensor.shape, dtype=torch.float32, device=device)
# dummy_structure_tokens = torch.zeros(seq_tokens_tensor.shape, dtype=torch.int64, device=device)

Sequence tokens shape: torch.Size([1, 17])


In [36]:
# TEST SEQUENCE EMBEDDINGS
with torch.no_grad():
    embeddings = loaded_backbone_model.encoder.sequence_embed(seq_tokens_tensor)
print("Embeddings shape:", embeddings.shape)
print(embeddings)

Embeddings shape: torch.Size([1, 17, 1536])
tensor([[[ 7.7637e-02, -1.7578e-02,  4.1016e-02,  ..., -2.9541e-02,
          -1.9653e-02, -6.3324e-04],
         [ 6.5308e-03,  2.5879e-02, -8.4473e-02,  ..., -3.1250e-02,
          -8.1543e-02,  1.2500e-01],
         [ 1.1035e-01, -8.0566e-02, -1.2061e-01,  ...,  6.2256e-02,
          -2.2339e-02, -5.4688e-02],
         ...,
         [ 2.0599e-04,  1.5625e-01,  9.5703e-02,  ...,  1.1914e-01,
          -4.3701e-02,  5.5664e-02],
         [-3.1006e-02, -1.2451e-02,  2.6367e-01,  ..., -1.8164e-01,
          -7.3242e-02,  2.2168e-01],
         [ 5.9814e-02,  2.9907e-02,  1.7334e-02,  ..., -3.1250e-02,
          -2.1729e-02, -1.6235e-02]]], device='cuda:0', dtype=torch.bfloat16)


## load inputs (sequence or pdb structure)

In [46]:
# load amino acid sequence string input
seq_tokenizer = EsmSequenceTokenizer()
test_sequence = "CSSDGSYGFGAMDYW"
seq_tokens = seq_tokenizer.encode(test_sequence)
seq_tokens_tensor = torch.tensor(seq_tokens, dtype=torch.int64).unsqueeze(0).to(device)

In [55]:
from esm.utils.structure.protein_chain import ProteinChain
fpath = '/home/jupyter/1BEY.pdb'
protein_chain = ProteinChain.from_pdb(path=fpath, chain_id='H')
atom37 = protein_chain.atom37_positions

## ** TODO: how do we pass a multimer structure into ESM3?**

In [57]:
protein_chain

ProteinChain(id='1BEY', sequence='QVQLQESGPGLVRPSQTLSLTCTVSGFTFTDFYMNWVRQPPGRGLEWIGFIRDKAKGYTTEYNPSVKGRVTMLVDTSKNQFSLRLSSVTAADTAVYYCAREGHTAAPFDYWGQGSLVTVSSASTKGPSVFPLAPAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTQTYICNVNHKPSNTKVDKKV', chain_id='H', entity_id=1, residue_index=array([  1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
        14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,
        27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,
        40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,
        53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,  65,
        66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,  78,
        79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,
        92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103, 104,
       105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117,
       118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 12

In [38]:
# We must pass zero vectors for input modalities with null

# Dummy structure tokens
dummy_structure_tokens = torch.zeros_like(seq_tokens_tensor, dtype=torch.int64, device=device)

# Dummy tokens for ss8 (secondary structure, 11 classes) and SASA (solvent accessibility, 19 classes)
dummy_ss8_tokens = torch.zeros_like(seq_tokens_tensor, dtype=torch.int64, device=device)
dummy_sasa_tokens = torch.zeros_like(seq_tokens_tensor, dtype=torch.int64, device=device)

# Dummy pLDDT values (typically float values per residue)
dummy_average_plddt = torch.ones_like(seq_tokens_tensor, dtype=torch.float32, device=device)
dummy_per_res_plddt = torch.ones_like(seq_tokens_tensor, dtype=torch.float32, device=device)

batch_size, seq_len = seq_tokens_tensor.shape
dummy_rbf = torch.ones(batch_size, seq_len, 16,
                       dtype=loaded_backbone_model.encoder.plddt_projection.weight.dtype,
                       device=device)


In [39]:
# ---------------------------------------------------------------------------
# TEST SEQUENCE EMBEDDINGS
with torch.no_grad():
    embeddings = loaded_backbone_model.encoder.sequence_embed(seq_tokens_tensor)
print("Sequence embeddings shape:", embeddings.shape)
print(embeddings)

Sequence embeddings shape: torch.Size([1, 17, 1536])
tensor([[[ 7.7637e-02, -1.7578e-02,  4.1016e-02,  ..., -2.9541e-02,
          -1.9653e-02, -6.3324e-04],
         [ 6.5308e-03,  2.5879e-02, -8.4473e-02,  ..., -3.1250e-02,
          -8.1543e-02,  1.2500e-01],
         [ 1.1035e-01, -8.0566e-02, -1.2061e-01,  ...,  6.2256e-02,
          -2.2339e-02, -5.4688e-02],
         ...,
         [ 2.0599e-04,  1.5625e-01,  9.5703e-02,  ...,  1.1914e-01,
          -4.3701e-02,  5.5664e-02],
         [-3.1006e-02, -1.2451e-02,  2.6367e-01,  ..., -1.8164e-01,
          -7.3242e-02,  2.2168e-01],
         [ 5.9814e-02,  2.9907e-02,  1.7334e-02,  ..., -3.1250e-02,
          -2.1729e-02, -1.6235e-02]]], device='cuda:0', dtype=torch.bfloat16)


In [40]:
# ---------------------------------------------------------------------------
# TEST STRUCTURE TOKENS EMBEDDINGS
# Create dummy structure tokens (all zeros, matching the sequence shape)
dummy_structure_tokens = torch.zeros_like(seq_tokens_tensor, dtype=torch.int64, device=device)
with torch.no_grad():
    structure_embedding = loaded_backbone_model.encoder.structure_tokens_embed(dummy_structure_tokens)
print("Structure tokens embedding shape:", structure_embedding.shape)
print(structure_embedding)

Structure tokens embedding shape: torch.Size([1, 17, 1536])
tensor([[[ 0.0454,  0.1699,  0.1035,  ..., -0.1133, -0.0625,  0.0530],
         [ 0.0454,  0.1699,  0.1035,  ..., -0.1133, -0.0625,  0.0530],
         [ 0.0454,  0.1699,  0.1035,  ..., -0.1133, -0.0625,  0.0530],
         ...,
         [ 0.0454,  0.1699,  0.1035,  ..., -0.1133, -0.0625,  0.0530],
         [ 0.0454,  0.1699,  0.1035,  ..., -0.1133, -0.0625,  0.0530],
         [ 0.0454,  0.1699,  0.1035,  ..., -0.1133, -0.0625,  0.0530]]],
       device='cuda:0', dtype=torch.bfloat16)


In [41]:
# ---------------------------------------------------------------------------
# TEST PLDDT PROJECTION
# Create a dummy input for the pLDDT projection module (shape: batch_size x seq_len x 16)
batch_size, seq_len = seq_tokens_tensor.shape
dummy_rbf = torch.ones(
    batch_size,
    seq_len,
    16,
    dtype=loaded_backbone_model.encoder.plddt_projection.weight.dtype,
    device=device
)
with torch.no_grad():
    plddt_embedding = loaded_backbone_model.encoder.plddt_projection(dummy_rbf)
print("pLDDT projection output shape:", plddt_embedding.shape)
print(plddt_embedding)

pLDDT projection output shape: torch.Size([1, 17, 1536])
tensor([[[ 1.0391,  0.9102, -0.0137,  ..., -0.2275, -1.1875,  1.2891],
         [ 1.0391,  0.9102, -0.0137,  ..., -0.2275, -1.1875,  1.2891],
         [ 1.0391,  0.9102, -0.0137,  ..., -0.2275, -1.1875,  1.2891],
         ...,
         [ 1.0391,  0.9102, -0.0137,  ..., -0.2275, -1.1875,  1.2891],
         [ 1.0391,  0.9102, -0.0137,  ..., -0.2275, -1.1875,  1.2891],
         [ 1.0391,  0.9102, -0.0137,  ..., -0.2275, -1.1875,  1.2891]]],
       device='cuda:0', dtype=torch.bfloat16)


In [42]:
# ---------------------------------------------------------------------------
# TEST SS8 EMBEDDING
# Create dummy tokens for SS8 (secondary structure; 11 classes)
dummy_ss8_tokens = torch.zeros_like(seq_tokens_tensor, dtype=torch.int64, device=device)
with torch.no_grad():
    ss8_embedding = loaded_backbone_model.encoder.ss8_embed(dummy_ss8_tokens)
print("SS8 embedding shape:", ss8_embedding.shape)
print(ss8_embedding)

SS8 embedding shape: torch.Size([1, 17, 1536])
tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0',
       dtype=torch.bfloat16)


In [43]:
# ---------------------------------------------------------------------------
# TEST SASA EMBEDDING
# Create dummy tokens for SASA (solvent accessibility; 19 classes)
dummy_sasa_tokens = torch.zeros_like(seq_tokens_tensor, dtype=torch.int64, device=device)
with torch.no_grad():
    sasa_embedding = loaded_backbone_model.encoder.sasa_embed(dummy_sasa_tokens)
print("SASA embedding shape:", sasa_embedding.shape)
print(sasa_embedding)

SASA embedding shape: torch.Size([1, 17, 1536])
tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0',
       dtype=torch.bfloat16)


In [44]:
# ---------------------------------------------------------------------------
# TEST FUNCTION EMBEDDINGS
# The function_embed module is a ModuleList; test each embedding layer.
for i, func_embed in enumerate(loaded_backbone_model.encoder.function_embed):
    with torch.no_grad():
        func_embedding = func_embed(seq_tokens_tensor)
    print(f"Function embedding {i} shape:", func_embedding.shape)
    print(func_embedding)

Function embedding 0 shape: torch.Size([1, 17, 192])
tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.2812,  0.1147, -0.4746,  ...,  0.1230, -0.6133,  0.3711],
         [-0.9922, -0.3828,  0.4082,  ...,  0.4238, -0.2676,  0.6367],
         ...,
         [-0.3105,  0.6406, -0.7344,  ..., -0.4609,  0.2061,  0.7539],
         [-0.1162, -0.0476,  0.5820,  ..., -0.3125, -0.2734, -0.3555],
         [-0.2480,  0.2949, -0.6211,  ...,  0.3691,  0.1768, -0.2119]]],
       device='cuda:0', dtype=torch.bfloat16)
Function embedding 1 shape: torch.Size([1, 17, 192])
tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.3535, -0.1108,  0.1235,  ...,  0.5859,  0.1719, -0.1631],
         [ 0.0452,  0.4883, -0.4570,  ...,  0.4160, -0.3770, -0.2305],
         ...,
         [ 0.9102,  0.6133,  0.3770,  ...,  0.2148,  0.0776, -1.0000],
         [-0.6211,  0.0542, -0.7031,  ..., -0.0571, -0.3672,  0.0087],
         [-0.1553, -0.3516,  0.2656, 

In [45]:
# ---------------------------------------------------------------------------
# TEST RESIDUE EMBEDDING
# The residue embedding is an EmbeddingBag, so we flatten the input and provide offsets.
dummy_residue_tokens = torch.zeros(seq_tokens_tensor.shape, dtype=torch.int64, device=device)
dummy_residue_tokens_flat = dummy_residue_tokens.view(-1)
offsets = torch.tensor([0], dtype=torch.int64, device=device)
with torch.no_grad():
    residue_embedding = loaded_backbone_model.encoder.residue_embed(dummy_residue_tokens_flat, offsets)
print("Residue embedding shape:", residue_embedding.shape)
print(residue_embedding)

Residue embedding shape: torch.Size([1, 1536])
tensor([[0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0', dtype=torch.bfloat16)


structure_embedding = loaded_backbone_model.encoder.structure_tokens_embed(structure_tokens)
print("Structure tokens embedding shape:", structure_embedding.shape)
    

# objectives:
# 1. validate other embeddings work (may need antibody pdbs as input)
# 2. validate embedding dimensions / shape
# 3. Run multiple sequences in one pass?
# 4. Speed test 10, 100, 1000 sequences runtime
# 5. Training...OPT 3 guide a fine-tuning of this model (on 5 random antibody sequences)
# 6. LORA for training econo.y








