In [42]:
import torch, re, numpy as np, pandas as pd
from transformers import AutoTokenizer, AutoModel, EsmModel
from sklearn.metrics import roc_auc_score
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score
from typing import List, Dict

In [43]:
# Not totally sure what each part of this model name means
model_name = 'facebook/esm2_t33_650M_UR50D'

# Tokenizer converts a string into a form the model (ESM) can handle.
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load the model with the pretrained model weights (as opposed to just the architecture)
model = EsmModel.from_pretrained(model_name, output_hidden_states=True).to(device).eval() #, output_hidden_states=True)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [44]:
print(model)

EsmModel(
  (embeddings): EsmEmbeddings(
    (word_embeddings): Embedding(33, 1280, padding_idx=1)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): EsmEncoder(
    (layer): ModuleList(
      (0-32): 33 x EsmLayer(
        (attention): EsmAttention(
          (self): EsmSelfAttention(
            (query): Linear(in_features=1280, out_features=1280, bias=True)
            (key): Linear(in_features=1280, out_features=1280, bias=True)
            (value): Linear(in_features=1280, out_features=1280, bias=True)
            (rotary_embeddings): RotaryEmbedding()
          )
          (output): EsmSelfOutput(
            (dense): Linear(in_features=1280, out_features=1280, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (LayerNorm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        )
        (intermediate): EsmIntermediate(
          (dense): Linear(in_features=1280, out_features=5120, bias=True)
        )
        (output): EsmOut

In [45]:
plm_model = "esm2-650m"
plm_layer = 24          
sae = load_sae_from_hf(plm_model=plm_model, plm_layer=plm_layer).to(device).eval()


In [46]:
print(sae)

AutoEncoder(
  (encoder): Linear(in_features=1280, out_features=10240, bias=True)
  (decoder): Linear(in_features=10240, out_features=1280, bias=False)
)


In [47]:
seqs = {
    "Ab_H": "EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYAMHWVRQAPGKGLEWVSYISSGSSSYIYYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCARGLGGFGDYWGQGTLVTVSS",
    "Ab_L": "DIQMTQSPSSLSASVGDRVTITCRASQGISNYLAWYQQKPGKAPKLLIYDASTRATGIPDRFSGSGSGTDFTLTISSVQAEDLAVYYCQQYNTYPFTFGQGTKVEIK",
    # glycine-repeat-ish & histidines:
    "Collagen_like": "MGPPGPPGPPGPPGPPGPPGPP",
    "His_rich": "MKKRHHHHHHGSGSGSGHHHHEE",
    # has an N-glyc motif (N[^P][ST]):
    "NGlyc": "MATRNATSNEKSTNVTQLLNNST",
    # Cys-pair toy:
    "CysPair": "MAGRCCGGTTCCGGAAACCXXC"
}


## Test with cls vs without cls token

In [65]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

device = "cuda" if torch.cuda.is_available() else "cpu"

esm_name = "facebook/esm2_t33_650M_UR50D"  # any ESM2 is fine
tokenizer = AutoTokenizer.from_pretrained(esm_name, do_lower_case=False)
model = AutoModel.from_pretrained(esm_name, output_hidden_states=True).to(device).eval()


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [80]:

seq = "ASDF"
layers_to_check = [6, 12, 18, 24, 30, 33]  # pick any


In [83]:
@torch.no_grad()
def get_residue_reps(sequence: str, layer: int, add_special_tokens: bool):
    """
    Return (l, H) embeddings for residues only at a given layer.
    """
    t = tokenizer(sequence, return_tensors="pt", add_special_tokens=add_special_tokens).to(device)
    out = model(**t) #hidden_states is a tuple; 0=embeddings, 1..N = layer_outputs
    # print(len(out.hidden_states)) #prints 34, for 34 layers
    # print(out.hidden_states[layer].shape)# (B, L, H)
    h = out.hidden_states[layer][0] #(L', H); L' = L+2 if specials else L
    if add_special_tokens:
        h = h[1:-1] #drop CLS/BOS and EOS
    return h #shape(L, H)

def compare_one_layer(layer: int):
    h_with = get_residue_reps(seq, layer, add_special_tokens=True)
    h_without=get_residue_reps(seq, layer, add_special_tokens=False)
    assert h_with.shape == h_without.shape, f"shape mismatch at layer {layer}: {h_with.shape} vs {h_without.shape}"

    #cosine per residue
    cos = F.cosine_similarity(h_with, h_without, dim=1)
    return cos.cpu()


In [84]:
results = {}
for L in layers_to_check:
    cos = compare_one_layer(L)
    results[L] = {
        "mean": float(cos.mean()),
        "std": float(cos.std()),
        "min": float(cos.min()),
        "max": float(cos.max()),
        "per_residue": cos.tolist()
    }

for L in layers_to_check:
    r = results[L]
    print(f"Layer {L:>2} | cos mean={r['mean']:.4f}  std={r['std']:.4f}  min={r['min']:.4f}  max={r['max']:.4f}")

Layer  6 | cos mean=0.6528  std=0.1122  min=0.4870  max=0.7265
Layer 12 | cos mean=0.7652  std=0.0449  min=0.7046  max=0.8081
Layer 18 | cos mean=0.6700  std=0.0213  min=0.6545  max=0.7003
Layer 24 | cos mean=0.5940  std=0.0245  min=0.5751  max=0.6300
Layer 30 | cos mean=0.6050  std=0.0318  min=0.5712  max=0.6364
Layer 33 | cos mean=0.4669  std=0.0449  min=0.4047  max=0.5040


In [61]:
print(type(out))

<class 'transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions'>


In [62]:
print(out.keys())

odict_keys(['last_hidden_state', 'pooler_output', 'hidden_states'])


In [64]:
t = tokenizer("MK", return_tensors="pt", add_special_tokens=False)
print(t["input_ids"].shape)  # (1, L+2)
print(t["input_ids"])

torch.Size([1, 2])
tensor([[20, 15]])


In [58]:
print(tokenizer.cls_token, tokenizer.eos_token)


<cls> <eos>


In [59]:
seq = "MKWVTFISLL..."
toks = tokenizer(seq, return_tensors="pt", add_special_tokens=True)
print("len(seq):", len(seq))
print("tokenized length:", toks["input_ids"].shape[1])  # should be L+2


len(seq): 13
tokenized length: 15


In [55]:
out.pooler_output.shape

torch.Size([1, 1280])

In [51]:
out.hidden_states

(tensor([[[ 0.0472, -0.0513, -0.1144,  ..., -0.1953,  0.1039, -0.0581],
          [ 0.0111, -0.0479, -0.0466,  ...,  0.0748, -0.0232, -0.0177],
          [-0.0382, -0.0023, -0.0250,  ..., -0.0425,  0.1069,  0.0341],
          ...,
          [-0.0424,  0.0859,  0.1297,  ...,  0.0050, -0.0570, -0.0317],
          [-0.0424,  0.0859,  0.1297,  ...,  0.0050, -0.0570, -0.0317],
          [-0.0788, -0.0394, -0.0506,  ..., -0.0515,  0.0573, -0.0959]]],
        device='cuda:0'),
 tensor([[[ 1.1615, -1.0889,  1.0283,  ..., -0.6624, -0.0925, -0.0342],
          [ 1.0284, -1.2134,  1.4535,  ..., -0.4915, -0.0523, -0.0882],
          [ 1.1642, -1.2191,  1.3524,  ..., -0.5972, -0.4859,  0.2135],
          ...,
          [ 1.6823, -0.4731,  1.2003,  ..., -0.8387, -0.1662,  0.0967],
          [ 1.6832, -0.4767,  1.2098,  ..., -0.8331, -0.1589,  0.0829],
          [ 1.0229, -1.1698,  0.9816,  ..., -0.9636, -0.1165,  0.1218]]],
        device='cuda:0'),
 tensor([[[ 2.2887, -0.4810,  2.3785,  ..., -1.378

In [49]:
out.hidden_states[0].shape

torch.Size([1, 120, 1280])

In [33]:
out.hidden_states[1].shape

torch.Size([1, 120, 1280])

In [35]:
len(out.hidden_states)

34

In [39]:
out[0]

tensor([[[ 13.4608,  -5.5221,   1.1816,  ..., -14.0422, -14.1963,  -5.5962],
         [ -7.9331, -13.8691,  -9.0875,  ..., -15.4542, -15.0812, -13.7574],
         [-11.5409, -15.5602, -10.1995,  ..., -14.9057, -15.2386, -15.5302],
         ...,
         [ -9.9504, -20.1673, -10.8031,  ..., -14.8612, -14.2043, -20.0579],
         [-11.2605, -16.6821,  -8.8577,  ..., -15.3363, -14.2295, -16.6268],
         [  1.9429,  -5.1557,  21.6996,  ..., -13.5158, -13.5677,  -5.2176]]],
       device='cuda:0')

In [40]:
out[1]

(tensor([[[ 0.0472, -0.0513, -0.1144,  ..., -0.1953,  0.1039, -0.0581],
          [ 0.0111, -0.0479, -0.0466,  ...,  0.0748, -0.0232, -0.0177],
          [-0.0382, -0.0023, -0.0250,  ..., -0.0425,  0.1069,  0.0341],
          ...,
          [-0.0424,  0.0859,  0.1297,  ...,  0.0050, -0.0570, -0.0317],
          [-0.0424,  0.0859,  0.1297,  ...,  0.0050, -0.0570, -0.0317],
          [-0.0788, -0.0394, -0.0506,  ..., -0.0515,  0.0573, -0.0959]]],
        device='cuda:0'),
 tensor([[[ 1.1615, -1.0889,  1.0283,  ..., -0.6624, -0.0925, -0.0342],
          [ 1.0284, -1.2134,  1.4535,  ..., -0.4915, -0.0523, -0.0882],
          [ 1.1642, -1.2191,  1.3524,  ..., -0.5972, -0.4859,  0.2135],
          ...,
          [ 1.6823, -0.4731,  1.2003,  ..., -0.8387, -0.1662,  0.0967],
          [ 1.6832, -0.4767,  1.2098,  ..., -0.8331, -0.1589,  0.0829],
          [ 1.0229, -1.1698,  0.9816,  ..., -0.9636, -0.1165,  0.1218]]],
        device='cuda:0'),
 tensor([[[ 2.2887, -0.4810,  2.3785,  ..., -1.378