In [1]:
## Attention:


In [1]:
import torch
%set_env TOKENIZERS_PARALLELISM=false
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")


env: TOKENIZERS_PARALLELISM=false
Using device: cuda


In [2]:
import pytest
import torch
from jaxtyping import Float
from torch.testing import assert_close
import torch.nn as nn
from transformer_lens.components import Attention
from transformer_lens.components import LayerNorm
from transformer_lens.components import HookedESM3MLP, swiglu_correction_fn
from transformer_lens.components import HookedEsm3UnifiedTransformerBlock
from esm.layers.attention import MultiHeadAttention
from esm.layers.blocks import swiglu_ln_ffn, UnifiedTransformerBlock
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
import functools
import einops
from esm.utils.constants.esm3 import data_root
import math
from transformer_lens import HookedESM3,SupportedESM3Config
from esm.pretrained import (
    ESM3_sm_open_v0,
)
from esm.models.esm3 import ESM3
import random
import torch.nn.functional as F
from esm.tokenization import get_esm3_model_tokenizers
from esm.utils.structure.protein_chain import ProteinChain


In [47]:
config = SupportedESM3Config(
    use_attn_result=True,
    use_split_qkv_input=True,
    use_hook_mlp_in=True,
    use_attn_in=True,
    esm3_output_type="all",
    esm3_use_torch_layer_norm=True,
    esm3_use_torch_attention_calc=True,
    esm3_use_org_rotary=True
)
esm3_hooked1 = HookedESM3.from_pretrained(esm_cfg=config, device=device)
esm3_original1 = ESM3_sm_open_v0(device).to(device)


If using ESM3 for interpretability research, keep in mind that ESM3 has some significant architectural differences to Language transformers like GPT.


Moving model to device:  cuda
Loaded pretrained model esm3_sm_open_v1 into HookedESM3


In [48]:
def verify_identical_components(real,hooked):
    # Compare parameters
    for (name1, param1), (name2, param2) in zip(real.named_parameters(), hooked.named_parameters()):
        if name1 != name2:
            print(f"Mismatch in parameter names: {name1} != {name2}")
            return False
        assert torch.sum(param1 != param2) == 0

    print(f"verify_identical_components- All parameters match! {type(real)} {type(hooked)} ")
    return True

In [49]:
def compare_transformer_blocks(real_block, hooked_block, cfg):
    if cfg.esm3_use_torch_layer_norm:
        assert torch.sum(real_block.attn.layernorm_qkv[0].weight !=  hooked_block.ln1.weight) == 0
        assert torch.sum(real_block.attn.layernorm_qkv[0].bias !=  hooked_block.ln1.bias) == 0
    else:
        assert torch.sum(real_block.attn.layernorm_qkv[0].weight !=  hooked_block.ln1.w) == 0
        assert torch.sum(real_block.attn.layernorm_qkv[0].bias !=  hooked_block.ln1.b) == 0
        
    qkv_matrix = real_block.attn.layernorm_qkv[1].weight
    query_BLD, key_BLD, value_BLD = torch.chunk(qkv_matrix, 3, dim=-2)
    q = einops.rearrange(hooked_block.attn.W_Q, "n_head d_model d_head ->(n_head d_head) d_model", n_head=hooked_block.attn.W_Q.shape[0])
    v = einops.rearrange(hooked_block.attn.W_V, "n_head d_model d_head ->(n_head d_head) d_model", n_head=hooked_block.attn.W_V.shape[0])
    k = einops.rearrange(hooked_block.attn.W_K, "n_head d_model d_head ->(n_head d_head) d_model", n_head=hooked_block.attn.W_K.shape[0])
    assert torch.sum(query_BLD !=q) == 0
    assert torch.sum(key_BLD !=k) == 0
    assert torch.sum(value_BLD !=v) == 0
    assert(real_block.attn.layernorm_qkv[1].bias is None)
    assert torch.equal(hooked_block.attn.b_Q, torch.zeros_like(hooked_block.attn.b_Q)), "The tensor is not all zeros."
    assert torch.equal(hooked_block.attn.b_K, torch.zeros_like(hooked_block.attn.b_K)), "The tensor is not all zeros."
    assert torch.equal(hooked_block.attn.b_V, torch.zeros_like(hooked_block.attn.b_V)), "The tensor is not all zeros."
    
    if cfg.esm3_use_torch_layer_norm:
        assert torch.sum(real_block.attn.q_ln.weight !=  hooked_block.attn.q_ln.weight) == 0
        assert torch.sum(real_block.attn.k_ln.weight != hooked_block.attn.k_ln.weight) == 0
        assert real_block.attn.q_ln.bias is None
        assert hooked_block.attn.q_ln.bias is None
        assert real_block.attn.k_ln.bias is None
        assert hooked_block.attn.k_ln.bias is None
    else:
        assert torch.sum(real_block.attn.q_ln.weight !=  hooked_block.attn.q_ln.w) == 0
        assert torch.sum(real_block.attn.k_ln.weight != hooked_block.attn.k_ln.w) == 0
        assert real_block.attn.q_ln.bias is None
        assert torch.equal(hooked_block.attn.q_ln.b, torch.zeros_like(hooked_block.attn.q_ln.b)), "The tensor is not all zeros."
        assert real_block.attn.k_ln.bias is None
        assert torch.equal(hooked_block.attn.k_ln.b, torch.zeros_like(hooked_block.attn.k_ln.b)), "The tensor is not all zeros."

    
    out_proj = real_block.attn.out_proj.weight
    W_O= einops.rearrange(hooked_block.attn.W_O, "n_head d_head d_model -> d_model (n_head d_head)", n_head=hooked_block.attn.W_O.shape[0])
    assert torch.sum(W_O !=out_proj) == 0
    assert real_block.attn.out_proj.bias is None
    assert torch.equal(hooked_block.attn.b_O, torch.zeros_like(hooked_block.attn.b_O)), "The tensor is not all zeros."

    assert real_block.use_geom_attn == hooked_block.use_geom_attn
    if real_block.use_geom_attn:
        verify_identical_components(real_block.geom_attn, hooked_block.geom_attn)
    if cfg.esm3_use_torch_layer_norm:
        assert torch.sum(real_block.ffn[0].weight !=  hooked_block.ln2.weight) == 0
        assert torch.sum(real_block.ffn[0].bias !=  hooked_block.ln2.bias) == 0
    else:
        assert torch.sum(real_block.ffn[0].weight !=  hooked_block.ln2.w) == 0
        assert torch.sum(real_block.ffn[0].bias !=  hooked_block.ln2.b) == 0
    assert torch.sum(real_block.ffn[1].weight !=  hooked_block.mlp.l1.weight) == 0
    assert(real_block.ffn[1].bias is None)
    assert(hooked_block.mlp.l1.bias is None)
    assert torch.sum(real_block.ffn[3].weight !=  hooked_block.mlp.l2.weight) == 0
    assert(real_block.ffn[3].bias is None)
    assert(hooked_block.mlp.l2.bias is None)
    print("compare_transformer_blocks- all params match")

In [50]:
def test_loading(esm3_original, esm3_hooked, cfg):
    verify_identical_components(esm3_original.encoder ,esm3_hooked.embed.embed)
    for l in range(len(esm3_original.transformer.blocks)):
        real_block = esm3_original.transformer.blocks[l]
        hooked_block = esm3_hooked.blocks[l]
        compare_transformer_blocks(real_block, hooked_block, cfg)
    if cfg.esm3_use_torch_layer_norm:
        assert torch.sum(esm3_original.transformer.norm.weight !=  esm3_hooked.ln_final.weight) == 0
        assert esm3_hooked.ln_final.bias is None
        assert esm3_original.transformer.norm.bias is None
    else:
        assert torch.sum(esm3_original.transformer.norm.weight !=  esm3_hooked.ln_final.w) == 0
        assert torch.equal(esm3_hooked.ln_final.b, torch.zeros_like(esm3_hooked.ln_final.b)), "The tensor is not all zeros."
    verify_identical_components(esm3_original.output_heads , esm3_hooked.unembed.output_heads)

                    

In [51]:
test_loading(esm3_original1, esm3_hooked1, esm3_hooked1.cfg)

verify_identical_components- All parameters match! <class 'esm.models.esm3.EncodeInputs'> <class 'esm.models.esm3.EncodeInputs'> 
verify_identical_components- All parameters match! <class 'esm.layers.geom_attention.GeometricReasoningOriginalImpl'> <class 'esm.layers.geom_attention.GeometricReasoningOriginalImpl'> 
compare_transformer_blocks- all params match
compare_transformer_blocks- all params match
compare_transformer_blocks- all params match
compare_transformer_blocks- all params match
compare_transformer_blocks- all params match
compare_transformer_blocks- all params match
compare_transformer_blocks- all params match
compare_transformer_blocks- all params match
compare_transformer_blocks- all params match
compare_transformer_blocks- all params match
compare_transformer_blocks- all params match
compare_transformer_blocks- all params match
compare_transformer_blocks- all params match
compare_transformer_blocks- all params match
compare_transformer_blocks- all params match
compare_t

In [52]:
tokenizers = get_esm3_model_tokenizers()
esm3_original1.eval()  # Switch to evaluation mode to save memory
esm3_hooked1.eval()
sequence = "MKSLLLLSILAALAVAALCYESHESLESYEINPFINRRNANSFISPQQRWRAKAQERIRELNKPQYELNREACDDFKLCERYAMVYGYNAAYDRYFRQRRGAK"
tokens = tokenizers.sequence.encode(sequence)
sequence_tokens = torch.tensor(tokens, dtype=torch.int64)
sequence_tokens = sequence_tokens.to(device).unsqueeze(0)

with torch.no_grad():
    output1 = esm3_original1.forward(
        sequence_tokens=sequence_tokens
    )


In [53]:
with torch.no_grad():
    output2 = esm3_hooked1.forward(
        sequence_tokens=sequence_tokens
    )


In [54]:
torch.max(torch.abs(output1.residue_logits-output2.residue_logits))

tensor(3.0518e-05, device='cuda:0')

In [55]:
assert torch.allclose(output1.residue_logits, output2.residue_logits,  rtol=1.3e-6, atol=4e-5)

In [33]:
output1.function_logits

tensor([[[-20.0799, -20.0326, -19.9857,  ..., -20.0251, -20.0564, -19.9789],
         [-17.5317, -17.4406, -17.4318,  ..., -17.4344, -17.5385, -17.4107],
         [-20.0069, -19.9976, -20.0968,  ..., -20.1231, -20.1056, -20.0137],
         ...,
         [-19.6266, -19.5679, -19.5718,  ..., -19.6159, -19.5687, -19.6079],
         [-18.6520, -18.6278, -18.6809,  ..., -18.6770, -18.6054, -18.6648],
         [-18.6440, -18.5629, -18.6397,  ..., -18.6421, -18.5863, -18.6621]]],
       device='cuda:0')

In [11]:
from torch.testing import assert_close
assert_close(output1.sequence_logits, output2.sequence_logits, rtol=1e-6, atol=2e-5)


In [26]:
sequence2 = "MLPGLALLLLAAWTARALEVPTDGNAGLLAEPQIAMFCGRLNMHMNVQNGKWDSDPSGTKTCIDTKEGILQYCQEVYPELQITNVVEANQPVTIQNWCKRGRKQCKTHPHFVIPYRCLVGEFVSDALLVPDKCKFLHQERMDVCETHLHWHTVAKETCSEKSTNLHDYGMLLPCGIDKFRGVEFVCCPLAEESDNVDSADAEEDDSDVWWGGADTDYADGSEDKVVEVAEEEEVAEVEEEEADDDEDDEDGDEVEEEAEEPYEEATERTTSIATTTTTTTESVEEVVREVCSEQAETGPCRAMISRWYFDVTEGKCAPFFYGGCGGNRNNFDTEEYCMAVCGSAMSQSLLKTTQEPLARDPVKLPTTAASTPDAVDKYLETPGDENEHAHFQKAKERLEAKHRERMSQVMREWEEAERQAKNLPKADKKAVIQHFQEKVESLEQEAANERQQLVETHMARVEAMLNDRRRLALENYITALQAVPPRPRHVFNMLKKYVRAEQKDRQHTLKHFEHVRMVDPKKAAQIRSQVMTHLRVIYERMNQSLSLLYNVPAVAEEIQDEVDELLQKEQNYSDDVLANMISEPRISYGNDALMPSLTETKTTVELLPVNGEFSLDDLQPWHSFGADSVPANTENEVEPVDARPAADRGLTTRPGSGLTNIKTEEISEVKMDAEFRHDSGYEVHHQKLVFFAEDVGSNKGAIIGLMVGGVVIATVIVITLVMLKKKQYTSIHHGVVEVDAAVTPEERHLSKMQQNGYENPTYKFFEQMQN"
tokens2 = tokenizers.sequence.encode(sequence2)
sequence_tokens2 = torch.tensor(tokens2, dtype=torch.int64)
sequence_tokens2 = sequence_tokens2.to(device).unsqueeze(0)

In [27]:
encoder_output = None
layer_norm_input=None
layer_norm_output=None
org_attn_out_proj_input=None
org_attn_out = None
org_geo_attn_out = None
org_mlp_in = None
org_mlp_out = None
org_residual_post = None
org_rot = None
org_rot_input=None
org_q_input = None
embed_output = None
hook_attn_in=None
hook_post_layer_norm=None
hook_attn_out=None
hook_geo_attn_out=None
hook_mlp_in = None
hook_mlp_out = None
hook_residual_post = None
q_ln_org=None
k_ln_org=None
hook_q = None
hook_q_ln=None
hook_k_ln=None
hook_z = None
hook_rot_q= None
hook_rot_k=None
# Define the hook function
def hook_fn1(activation, hook):
    global embed_output  # To store the output globally
    print("Hook triggered for embedding")
    embed_output = activation
    
# Define the hook function
def hook_fn2(activation, hook):
    global hook_attn_in  # To store the output globally
    print("Hook triggered for attn in")
    hook_attn_in = activation
    
def hook_fn3(activation, hook):
    global hook_attn_out  # To store the output globally
    print("Hook triggered for attn_out")
    hook_attn_out = activation

def hook_fn4(activation, hook):
    global hook_geo_attn_out  # To store the output globally
    print("Hook triggered for geo attn_out")
    hook_geo_attn_out = activation

def hook_fn5(activation, hook):
    global hook_mlp_in   # To store the output globally
    print("Hook triggered for hook_mlp_in")
    hook_mlp_in  = activation

def hook_fn6(activation, hook):
    global hook_mlp_out  # To store the output globally
    print("Hook triggered for hook_mlp_out")
    hook_mlp_out  = activation

def hook_fn7(activation, hook):
    global hook_residual_post   # To store the output globally
    print("Hook triggered for hook_residual_post ")
    hook_residual_post  = activation
    
def hook_fn8(module, input, output):
    global encoder_output  # Declare as global to modify the global variable
    print("Hook triggered for encoder")
    encoder_output = output 

def hook_fn9(module, input, output):
    global layer_norm_output  # Declare as global to modify the global variable
    print("Hook triggered for layer_norm_output")
    layer_norm_output = output 
def hook_fn10(module, input, output):
    global org_attn_out  # Declare as global to modify the global variable
    print("Hook triggered for org_attn_out")
    org_attn_out = output
def hook_fn11(module, input, output):
    global org_geo_attn_out # Declare as global to modify the global variable
    print("Hook triggered for org_geo_attn_out")
    org_geo_attn_out = output
def hook_fn12(module, input, output):
    global org_mlp_in # Declare as global to modify the global variable
    print("org_mlp_in")
    org_mlp_in = input
def hook_fn13(module, input, output):
    global org_mlp_out # Declare as global to modify the global variable
    print("org_mlp_out")
    org_mlp_out = output
def hook_fn14(module, input, output):
    global org_residual_post # Declare as global to modify the global variable
    print("org_residual_post")
    org_residual_post = output

def hook_fn15(activation, hook):
    global hook_q_ln   # To store the output globally
    print("Hook triggered for q_ln")
    hook_q_ln  = activation
    
def hook_fn16(activation, hook):
    global hook_k_ln   # To store the output globally
    print("Hook triggered for k_ln")
    hook_k_ln  = activation
def hook_fn17(module, input, output):
    global q_ln_org # Declare as global to modify the global variable
    global org_q_input
    print("after q_ln")
    q_ln_org = output
    org_q_input = input

def hook_fn18(module, input, output):
    global k_ln_org # Declare as global to modify the global variable
    print("after k_ln")
    k_ln_org = output

def hook_fn19(activation, hook):
    global hook_post_layer_norm  # To store the output globally
    print("Hook triggered for hook_post_layer_norm")
    hook_post_layer_norm = activation

def hook_fn20(module, input, output):
    global layer_norm_input  # Declare as global to modify the global variable
    print("Hook triggered for layer_norm_input")
    layer_norm_input = input
def hook_fn21(module, input, output):
    global org_attn_out_proj_input  # Declare as global to modify the global variable
    print("Hook triggered for org_attn_out_proj_input")
    org_attn_out_proj_input = input
def hook_fn22(activation, hook):
    global hook_z   # To store the output globally
    print("Hook triggered for hook_z ")
    hook_z  = activation
def rotary_hook(module, input, output):
    global org_rot, org_rot_input
    print("Hook triggered for rotary")
    org_rot = output
    org_rot_input = input

    org_rot = output
def hook_fn24(activation, hook):
    global hook_rot_q   # To store the output globally
    print("Hook triggered hook_rot_q")
    hook_rot_q  = activation
    
def hook_fn25(activation, hook):
    global hook_rot_k  # To store the output globally
    print("Hook triggered for hook_rot_k")
    hook_rot_k  = activation

def hook_q(activation, hook):
    global hook_q
    print("Hook triggered for q")
    hook_q = activation
    
from torch.testing import assert_close
esm3_hooked1.eval()
esm3_original1.eval()
with torch.no_grad():
    esm3_original1.encoder.register_forward_hook(hook_fn8)
    esm3_original1.transformer.blocks[0].attn.layernorm_qkv[0].register_forward_hook(hook_fn9)
    esm3_original1.transformer.blocks[0].attn.out_proj.register_forward_hook(hook_fn21)
    esm3_original1.transformer.blocks[0].attn.rotary.register_forward_hook(rotary_hook)
    esm3_original1.transformer.blocks[0].attn.layernorm_qkv[0].register_forward_hook(hook_fn9)
    esm3_original1.transformer.blocks[0].attn.q_ln.register_forward_hook(hook_fn17)
    esm3_original1.transformer.blocks[0].attn.k_ln.register_forward_hook(hook_fn18)
    esm3_original1.transformer.blocks[0].attn.register_forward_hook(hook_fn10)
    esm3_original1.transformer.blocks[0].geom_attn.register_forward_hook(hook_fn11)
    esm3_original1.transformer.blocks[0].ffn.register_forward_hook(hook_fn12)
    esm3_original1.transformer.blocks[0].ffn.register_forward_hook(hook_fn13)
    esm3_original1.transformer.blocks[0].register_forward_hook(hook_fn14)
    output3 = esm3_original1.forward(
        sequence_tokens=sequence_tokens2
    )

with torch.no_grad():
    esm3_hooked1.add_hook("hook_embed", hook_fn1)
    esm3_hooked1.add_hook("blocks.0.hook_attn_in", hook_fn2)
    esm3_hooked1.add_hook("blocks.0.hook_post_layer_norm", hook_fn19)
    esm3_hooked1.add_hook("blocks.0.hook_attn_out", hook_fn3)
    esm3_hooked1.add_hook("blocks.0.hook_geo_attn_out", hook_fn4)
    esm3_hooked1.add_hook("blocks.0.hook_mlp_in", hook_fn5)
    esm3_hooked1.add_hook("blocks.0.hook_mlp_out", hook_fn6)
    esm3_hooked1.add_hook("blocks.0.hook_resid_post", hook_fn7)
    esm3_hooked1.add_hook("blocks.0.attn.hook_ln_q", hook_fn15)
    esm3_hooked1.add_hook("blocks.0.attn.hook_ln_k", hook_fn16)
    esm3_hooked1.add_hook("blocks.0.attn.hook_z", hook_fn22)
    esm3_hooked1.add_hook("blocks.0.attn.hook_rot_k", hook_fn25)
    esm3_hooked1.add_hook("blocks.0.attn.hook_rot_q", hook_fn24)
    esm3_hooked1.add_hook("blocks.0.attn.hook_q", hook_q)
    output4 = esm3_hooked1.forward(
        sequence_tokens=sequence_tokens2
    )

assert torch.allclose(output3.sequence_logits, output4.sequence_logits, atol=1e-4, rtol=1e-4)

assert torch.allclose(output3.structure_logits, output4.structure_logits, atol=1e-4, rtol=1e-4)

assert torch.allclose(output3.sasa_logits, output4.sasa_logits, atol=1e-4, rtol=1e-4)

assert torch.allclose(output3.secondary_structure_logits, output4.secondary_structure_logits, atol=1e-4, rtol=1e-4)

assert torch.allclose(output3.function_logits, output4.function_logits, atol=1e-4, rtol=1e-4)

assert torch.allclose(output3.residue_logits, output4.residue_logits, atol=1e-4, rtol=1e-4)



Hook triggered for encoder
Hook triggered for layer_norm_output
Hook triggered for layer_norm_output
after q_ln
after k_ln
Hook triggered for rotary
Hook triggered for org_attn_out_proj_input
Hook triggered for org_attn_out
Hook triggered for org_geo_attn_out
org_mlp_in
org_mlp_out
org_residual_post
Hook triggered for embedding
Hook triggered for hook_post_layer_norm
Hook triggered for q
Hook triggered for q_ln
Hook triggered for k_ln
Hook triggered hook_rot_q
Hook triggered for hook_rot_k
Hook triggered for hook_z 
Hook triggered for attn_out
Hook triggered for geo attn_out
Hook triggered for hook_mlp_in
Hook triggered for hook_mlp_out
Hook triggered for hook_residual_post 


In [28]:
assert torch.equal(encoder_output,embed_output)
#assert torch.equal(layer_norm_output, hook_post_layer_norm[...,1,:])
q_flattened = einops.rearrange(hook_q_ln, "batch pos head_index d_head -> batch pos (head_index d_head)")
assert torch.allclose(q_ln_org, q_flattened, rtol=1.3e-6, atol=4e-5)
k_flattened = einops.rearrange(hook_k_ln, "batch pos head_index d_head -> batch pos (head_index d_head)")
assert torch.allclose(k_ln_org, k_flattened, rtol=1.3e-6, atol=4e-5)
assert torch.allclose(hook_rot_k, org_rot[1], rtol=1.3e-6, atol=4e-5)
assert torch.allclose(hook_rot_q, org_rot[0], rtol=1.3e-6, atol=4e-5)
assert torch.allclose(org_attn_out, hook_attn_out, rtol=1.3e-6, atol=4e-5)
assert torch.allclose(org_geo_attn_out,hook_geo_attn_out, rtol=1.3e-6, atol=4e-5)
assert torch.allclose(hook_mlp_in, org_mlp_in[0], rtol=1.3e-6, atol=4e-5)
assert torch.allclose(hook_mlp_out, org_mlp_out, rtol=3e-6, atol=4e-5)
assert torch.allclose(hook_residual_post, org_residual_post, rtol=1.3e-6, atol=4e-5)

In [31]:
esm3_hooked1.cfg.use_attn_result=False
output3 = esm3_original1.forward(
        sequence_tokens=sequence_tokens2
    )
output4 = esm3_hooked1.forward(
        sequence_tokens=sequence_tokens2
    )
torch.max(torch.abs(output3.sequence_logits-output4.sequence_logits))

Hook triggered for encoder
Hook triggered for layer_norm_output
Hook triggered for layer_norm_output
after q_ln
after k_ln
Hook triggered for rotary
Hook triggered for org_attn_out_proj_input
Hook triggered for org_attn_out
Hook triggered for org_geo_attn_out
org_mlp_in
org_mlp_out
org_residual_post
Hook triggered for embedding
Hook triggered for hook_post_layer_norm
Hook triggered for q
Hook triggered for q_ln
Hook triggered for k_ln
Hook triggered hook_rot_q
Hook triggered for hook_rot_k
Hook triggered for hook_z 
Hook triggered for attn_out
Hook triggered for geo attn_out
Hook triggered for hook_mlp_in
Hook triggered for hook_mlp_out
Hook triggered for hook_residual_post 


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 47.44 GiB of which 21.38 MiB is free. Including non-PyTorch memory, this process has 47.41 GiB memory in use. Of the allocated memory 47.05 GiB is allocated by PyTorch, and 48.09 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [31]:
res, cache = esm3_hooked1.run_with_cache(sequence_tokens=sequence_tokens2)

Hook triggered for embedding
Hook triggered for attn in
Hook triggered for hook_post_layer_norm
Hook triggered for q
Hook triggered for q_ln
Hook triggered for k_ln
Hook triggered hook_rot_q
Hook triggered for hook_rot_k
Hook triggered for hook_z 
Hook triggered for attn_out
Hook triggered for geo attn_out
Hook triggered for hook_mlp_in
Hook triggered for hook_mlp_out
Hook triggered for hook_residual_post 


OutOfMemoryError: CUDA out of memory. Tried to allocate 6.79 GiB. GPU 0 has a total capacity of 47.44 GiB of which 6.72 GiB is free. Including non-PyTorch memory, this process has 40.71 GiB memory in use. Of the allocated memory 40.35 GiB is allocated by PyTorch, and 48.63 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [32]:
del output1
del output2
torch.cuda.empty_cache()

In [33]:
print(torch.cuda.memory_summary())


|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 1            |        cudaMalloc retries: 1         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  48172 MiB |  48183 MiB | 554607 MiB | 506435 MiB |
|       from large pool |  47982 MiB |  47993 MiB | 553608 MiB | 505626 MiB |
|       from small pool |    190 MiB |    190 MiB |    999 MiB |    808 MiB |
|---------------------------------------------------------------------------|
| Active memory         |  48172 MiB |  48183 MiB | 554607 MiB | 506435 MiB |
|       from large pool |  47982 MiB |  47993 MiB | 553608 MiB | 505626 MiB |
|       from small pool |    190 MiB |    190 MiB |    999 MiB |    808 MiB |
|---------------------------------------------------------------

In [37]:
import torch

def model_memory_usage(model):
    """
    Calculate the total memory used by a model's parameters and buffers in bytes.

    Args:
        model (torch.nn.Module): The PyTorch model.

    Returns:
        float: Memory usage in MB.
    """
    total_params = sum(p.numel() * p.element_size() for p in model.parameters())
    total_buffers = sum(b.numel() * b.element_size() for b in model.buffers())
    total_memory = total_params + total_buffers  # Total memory in bytes
    return total_memory / (1024 ** 2)  # Convert to MB

# Example usage
memory_in_mb = model_memory_usage(esm3_hooked1)
print(f"The model itself takes approximately {memory_in_mb:.2f} MB of memory.")


The model itself takes approximately 5540.33 MB of memory.


In [53]:
tokenizers = get_esm3_model_tokenizers()
sequence1 = "MPGWFKKAWYGLASLLSFSSFILIIVALVVPHWLSGKILCQTGVDLVNATDRELVKFIGDIYYGLFRGCKVRQCGLGGRQSQFTIFPHLVKELNAGLHVMILLLLFLALALALVSMGFAILNMIQVPYRAVSGPGGICLWNVLAGGVVALAIASFVAAVKFHDLTERIANFQEKLFQFVVVEEQYEESFWICVASASAHAANLVVVAISQIPLPEIKTKIEEATVTAEDILY"
sequence2= "MAAA<mask>"
arr = [sequence1, sequence2]
tokens = tokenizers.sequence(arr, return_tensors="pt", padding=True)


In [59]:
print(tokens.attention_mask[1])

tensor([1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])


In [7]:
import pytest
import torch
from jaxtyping import Float
from torch.testing import assert_close
import torch.nn as nn
from transformer_lens.components import Attention
from transformer_lens.components import LayerNorm
from transformer_lens.components import HookedESM3MLP, swiglu_correction_fn
from transformer_lens.components import HookedEsm3UnifiedTransformerBlock
from esm.layers.attention import MultiHeadAttention
from esm.layers.blocks import swiglu_ln_ffn, UnifiedTransformerBlock
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
import functools
import einops
import math
from esm.pretrained import (
    ESM3_sm_open_v0,
)
from transformer_lens import HookedESM3,SupportedESM3Config
from esm.tokenization import get_esm3_model_tokenizers
import gc
def test_full_model(
    device="cuda",
    esm3_use_torch_attention_calc=True,
    use_attn_result=False,
    use_split_qkv_input=True,
    esm3_use_org_rotary=True,
    esm3_use_torch_layer_norm=True,
):
    esm3_original = ESM3_sm_open_v0(device).to(device)
    esm3_original.eval()
    tokenizer = tokenizers = get_esm3_model_tokenizers()
    sequence = "MDADKEKDLQKFLKNVDEISNLIQEMNSDDPVVQQKAVLETEKRLLLMEEDQEEDECRTTLNKTMISPPQTAMKSAEEINSEAFLASVEKDAKERAKRRRENKVLADALKEKGNEAFAEGNYETAILRYSEGLEKLKDMKVLYTNRAQAYMKLEDYEKALVDCEWALKCDEKCTKAYFHMGKANLALKNYSVSRECYKKILEINPKLQTQVKGYLNQVDLQEKADLQEKEAHELLDSGKNTAVTTKNLLETLSKPDQIPLFYAGGIEILTEMINECTEQTLFRMHNGFSIISDNEVIRRCFSTAGNDAVEEMVCVSVLKLWQAVCSRNEENQRVLVIHHDRARLLAALLSSKVLAIRQQSFALLLHLAQTESGRSLIINHLDLTRLLEALVSFLDFSDKEANTAMGLFTDLALEERFQVWFQANLPGVLPALTGVLKTDPKVSSSSALCQCIAIMGNLSAEPTTRRHMAACEEFGDGCLSLLARCEEDVDLFREVIYTLLGLMMNLCLQAPFVSEVWAVEVSRRCLSLLNSQDGGILTRAAGVLSRTLSSSLKIVEEALRAGVVKKMMKFLKTGGETASRYAIKILAICTNSYHEAREEVIRLDKKLSVMMKLLSSEDEVLVGNAALCLGNCMEVPNVASSLLKTDLLQVLLKLAGSDTQKTAVQVNAGIALGKLCTAEPRFAAQLRKLHGLEILNSTMKYISDS"
    tokens = tokenizers.sequence.encode(sequence)
    sequence_tokens = torch.tensor(tokens, dtype=torch.int64)
    sequence_tokens = sequence_tokens.to(device).unsqueeze(0)
    with torch.no_grad():
        output1 = esm3_original.forward(
            sequence_tokens=sequence_tokens
        )
    del esm3_original
    torch.cuda.empty_cache()
    gc.collect()

    config = SupportedESM3Config(
        use_attn_result=use_attn_result,
        use_split_qkv_input=use_split_qkv_input,
        use_hook_mlp_in=False,
        use_attn_in=False,
        esm3_output_type="all",
        esm3_use_torch_layer_norm=esm3_use_torch_layer_norm,
        esm3_use_torch_attention_calc=esm3_use_torch_attention_calc,
        esm3_use_org_rotary = esm3_use_org_rotary
    )
    esm3_hooked = HookedESM3.from_pretrained(esm_cfg=config, device=device)
    esm3_hooked.eval()
    with torch.no_grad():
        output2 = esm3_hooked.forward(
            sequence_tokens=sequence_tokens
        )
    print(output1.function_logits.shape)
    assert torch.allclose(output1.sequence_logits, output1.sequence_logits, rtol=1e-5, atol=4e-5)
    print(torch.max(torch.abs(output1.sequence_logits-output2.function_logits)))
    assert torch.allclose(output1.structure_logits, output2.structure_logits, rtol=1e-5, atol=4e-5)
    print(torch.max(torch.abs(output1.structure_logits-output2.structure_logits)))
    assert torch.allclose(output1.sasa_logits, output2.sasa_logits, rtol=1e-5, atol=4e-5)
    print(torch.max(torch.abs(output1.sasa_logits-output2.sasa_logits)))
    assert torch.allclose(output1.secondary_structure_logits, output2.secondary_structure_logits,  rtol=1e-5, atol=4e-5)
    print(torch.max(torch.abs(output1.secondary_structure_logits-output2.secondary_structure_logits)))
    
    assert torch.allclose(output1.function_logits, output2.function_logits,  rtol=1e-4, atol=2e-4)
    print(torch.max(torch.abs(output1.function_logits-output2.function_logits)))
    assert torch.allclose(output1.residue_logits, output2.residue_logits,  rtol=1e-5, atol=4e-5)
    print(torch.max(torch.abs(output1.residue_logits-output2.residue_logits)))
    del esm3_hooked
    torch.cuda.empty_cache()
    gc.collect()
    

In [8]:
test_full_model()

If using ESM3 for interpretability research, keep in mind that ESM3 has some significant architectural differences to Language transformers like GPT.


Moving model to device:  cuda
Loaded pretrained model esm3_sm_open_v1 into HookedESM3
torch.Size([1, 707, 8, 260])


RuntimeError: The size of tensor a (64) must match the size of tensor b (260) at non-singleton dimension 3