In [1]:
# 1. Set a custom cache directory for torch.hub.
# This tells torch.hub where to store (and look for) downloaded model files.
import torch
import os
import copy

In [2]:
# 2. Import and load ESM3 from the native codebase.
# Here, we use the native ESM3 API from the EvolutionaryScale/esm repository.
# Make sure you have cloned/installed the ESM repository (see README in the repo).
from esm.models.esm3 import ESM3
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer

print(f"making a call to huggingface to pull model weights and storing in Torch Hub cache directory: {torch.hub._get_torch_home()}")
# The model identifier should match one available in the repository.
model_id = "esm3_sm_open_v1"
model = ESM3.from_pretrained(model_id,device=torch.device("cuda"))

making a call to huggingface to pull model weights and storing in Torch Hub cache directory: /home/jupyter/.cache/torch


Fetching 22 files:   0%|          | 0/22 [00:00<?, ?it/s]

In [3]:
# Move the model to GPU if available; otherwise, use CPU.
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

# check the number of parameters
total_params = sum(p.numel() for p in model.parameters())
print("Total number of parameters:", total_params)

# check state dict
state_dict = model.state_dict()
for key in state_dict:
    print(f"{key}: {state_dict[key].shape}")

# set to eval mode
model.eval()  # Set model to evaluation mode

Total number of parameters: 1401735748
encoder.sequence_embed.weight: torch.Size([64, 1536])
encoder.plddt_projection.weight: torch.Size([1536, 16])
encoder.plddt_projection.bias: torch.Size([1536])
encoder.structure_per_res_plddt_projection.weight: torch.Size([1536, 16])
encoder.structure_per_res_plddt_projection.bias: torch.Size([1536])
encoder.structure_tokens_embed.weight: torch.Size([4101, 1536])
encoder.ss8_embed.weight: torch.Size([11, 1536])
encoder.sasa_embed.weight: torch.Size([19, 1536])
encoder.function_embed.0.weight: torch.Size([260, 192])
encoder.function_embed.1.weight: torch.Size([260, 192])
encoder.function_embed.2.weight: torch.Size([260, 192])
encoder.function_embed.3.weight: torch.Size([260, 192])
encoder.function_embed.4.weight: torch.Size([260, 192])
encoder.function_embed.5.weight: torch.Size([260, 192])
encoder.function_embed.6.weight: torch.Size([260, 192])
encoder.function_embed.7.weight: torch.Size([260, 192])
encoder.residue_embed.weight: torch.Size([1478, 

ESM3(
  (encoder): EncodeInputs(
    (sequence_embed): Embedding(64, 1536)
    (plddt_projection): Linear(in_features=16, out_features=1536, bias=True)
    (structure_per_res_plddt_projection): Linear(in_features=16, out_features=1536, bias=True)
    (structure_tokens_embed): Embedding(4101, 1536)
    (ss8_embed): Embedding(11, 1536)
    (sasa_embed): Embedding(19, 1536)
    (function_embed): ModuleList(
      (0-7): 8 x Embedding(260, 192, padding_idx=0)
    )
    (residue_embed): EmbeddingBag(1478, 1536, mode='sum', padding_idx=0)
  )
  (transformer): TransformerStack(
    (blocks): ModuleList(
      (0): UnifiedTransformerBlock(
        (attn): MultiHeadAttention(
          (layernorm_qkv): Sequential(
            (0): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
            (1): Linear(in_features=1536, out_features=4608, bias=False)
          )
          (out_proj): Linear(in_features=1536, out_features=1536, bias=False)
          (q_ln): LayerNorm((1536,), eps=1e-05, el

In [4]:
backbone_model = copy.deepcopy(model)
# Remove the output heads.
# For example, if your model stores its output heads in an attribute named 'output_heads',
# you can remove them by setting that attribute to None (or an empty dict).
if hasattr(backbone_model, "output_heads"):
    backbone_model.output_heads = None

In [5]:
backbone_model

ESM3(
  (encoder): EncodeInputs(
    (sequence_embed): Embedding(64, 1536)
    (plddt_projection): Linear(in_features=16, out_features=1536, bias=True)
    (structure_per_res_plddt_projection): Linear(in_features=16, out_features=1536, bias=True)
    (structure_tokens_embed): Embedding(4101, 1536)
    (ss8_embed): Embedding(11, 1536)
    (sasa_embed): Embedding(19, 1536)
    (function_embed): ModuleList(
      (0-7): 8 x Embedding(260, 192, padding_idx=0)
    )
    (residue_embed): EmbeddingBag(1478, 1536, mode='sum', padding_idx=0)
  )
  (transformer): TransformerStack(
    (blocks): ModuleList(
      (0): UnifiedTransformerBlock(
        (attn): MultiHeadAttention(
          (layernorm_qkv): Sequential(
            (0): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
            (1): Linear(in_features=1536, out_features=4608, bias=False)
          )
          (out_proj): Linear(in_features=1536, out_features=1536, bias=False)
          (q_ln): LayerNorm((1536,), eps=1e-05, el

In [31]:
# Create a backbone-only state dict by removing all keys associated with output heads.
# In your model, the output heads are under keys starting with "output_heads."
backbone_state_dict = {
    key: value for key, value in model.state_dict().items()
    if not key.startswith("output_heads")
}

# Optionally, you can verify which keys remain:
print("Backbone keys:")
for key in sorted(backbone_state_dict.keys()):
    print(key)

# Save the backbone-only state dict to a custom directory.
backbone_save_dir = "/home/jupyter/DATA/evqlv-dev/model-weights/esm3_backbone"
os.makedirs(backbone_save_dir, exist_ok=True)
backbone_save_path = os.path.join(backbone_save_dir, "esm3_backbone_model.pt")
torch.save(backbone_model, backbone_save_path)
print("Saved backbone-only model to:", backbone_save_path)

Backbone keys:
encoder.function_embed.0.weight
encoder.function_embed.1.weight
encoder.function_embed.2.weight
encoder.function_embed.3.weight
encoder.function_embed.4.weight
encoder.function_embed.5.weight
encoder.function_embed.6.weight
encoder.function_embed.7.weight
encoder.plddt_projection.bias
encoder.plddt_projection.weight
encoder.residue_embed.weight
encoder.sasa_embed.weight
encoder.sequence_embed.weight
encoder.ss8_embed.weight
encoder.structure_per_res_plddt_projection.bias
encoder.structure_per_res_plddt_projection.weight
encoder.structure_tokens_embed.weight
transformer.blocks.0.attn.k_ln.weight
transformer.blocks.0.attn.layernorm_qkv.0.bias
transformer.blocks.0.attn.layernorm_qkv.0.weight
transformer.blocks.0.attn.layernorm_qkv.1.weight
transformer.blocks.0.attn.out_proj.weight
transformer.blocks.0.attn.q_ln.weight
transformer.blocks.0.ffn.0.bias
transformer.blocks.0.ffn.0.weight
transformer.blocks.0.ffn.1.weight
transformer.blocks.0.ffn.3.weight
transformer.blocks.0.geo