In [16]:
from transformers import AutoTokenizer, EsmForMaskedLM
import torch

In [2]:
model_name = "facebook/esm2_t6_8M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = EsmForMaskedLM.from_pretrained(model_name)

In [3]:
protein_sequence = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"

In [55]:
input_ids = tokenizer.encode(protein_sequence, return_tensors="pt") # it will add a cls and eos tokens, so the lenght is less 
tokenizer.decode(input_ids[0])
sequence_length = input_ids.shape[1] - 2 
# List of amino acids
amino_acids = list("ACDEFGHIKLMNPQRSTVWY")
masked_input_ids = input_ids.clone()

In [25]:
len(protein_sequence)

108

In [61]:
masked_input_ids[0,56] = tokenizer.mask_token_id
with torch.no_grad():
    output = model(masked_input_ids).logits
    newoutput = model(input_ids).logits

In [67]:
newoutput[0,56]

tensor([-12.1878, -22.0185, -11.5768, -22.0190,   2.1428,   0.3948,  -1.7579,
          1.1307,  -1.7544,   2.1297,  -0.8656,  -1.4050,   1.5409,  -2.1582,
         -2.3874,  -0.6448,  -1.7216,  -2.0923,   0.6902,   0.4454,  -0.1989,
         -1.9539,  -0.7098,  -2.4093,  -7.8731, -12.2290, -12.4519, -12.8954,
        -16.0544, -16.4845, -16.5221, -16.6085, -22.0090])

In [66]:
output[0, 56]

tensor([-13.1540, -23.7290, -12.1718, -23.7257,   2.1987,   1.1187,  -0.3541,
          1.6377,  -0.7187,  -0.2680,  -0.0645,  -0.5286,   1.7129,  -1.2313,
         -1.3465,   0.1527,  -0.5898,  -0.9066,   1.1585,   1.1271,   0.0304,
         -0.8827,  -0.3168,  -1.5503,  -7.9917, -11.9826, -12.2444, -12.7819,
        -16.1148, -16.4314, -16.4897, -16.5185, -23.7213])

In [57]:
probabilities = torch.nn.functional.softmax(output[0, 56], dim=0)
log_probabilities = torch.log(probabilities)

In [58]:
wt_residue = input_ids[0, 56].item()
log_prob_wt = log_probabilities[wt_residue].item()

In [59]:
for i, amino_acid in enumerate(amino_acids):
    log_prob_mt = log_probabilities[tokenizer.convert_tokens_to_ids(amino_acid)].item()
    u = log_prob_mt - log_prob_wt
    print(f"{amino_acid}: {u:.2f}")

A: 1.39
C: -1.28
D: -0.96
E: 0.00
F: 1.43
G: -0.09
H: -0.61
I: 1.98
K: 0.42
L: 2.47
M: 0.30
N: -0.64
P: -1.08
Q: -0.32
R: 0.20
S: -0.45
T: -0.26
V: 1.91
W: -0.05
Y: 1.40
