# from: https://huggingface.co/blog/AmelieSchreiber/protein-binding-partners-with-esm2

In [2]:
import numpy as np
from scipy.optimize import linear_sum_assignment
from transformers import AutoTokenizer, EsmForMaskedLM
import torch


In [3]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/775 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/31.4M [00:00<?, ?B/s]

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)


In [5]:
all_proteins = [
    "MLEELLEELEEEEEKEELLELLEEMTKRFKEGIEEIEKELEELEKYLEESGVDEEYAEEIKEEIEEIREELKRLEEKLEEAKKLIEEGELEKALKILKEMTKEAKEAYKEAKERYKEAKKKYKEAAAELRAAAAALGDEKERKAKLAEIDAELEKRTKELEEKEKEMKELEKELEKKIEEVKGLDPLKQAILQFDAYFKMYKAATLLYILAKKYTAYMQAKLKELEAANAKAAAEAAAAAA",
    "SEEEEERERALKEIIEETRRELKAAKAKHGKVVVVLIMASSTLEPEFILELSKALIKEMKSLFPNVVLIIVVVGLAPASLLARIRDVSLELAKYAKSLGIKVIVIVGNENEAVFVPAFEALGVEVIVDRTIIEIAAEELGLSEEEVLARFAAAAELLDELFAADPSLRERYARLDVAGATELLLERLRELFGAKVERHERLITVEVERVLTPDERRRVTAILLTPEAAREVVERLVDLVVDLILEKIAEGHNVLVLVFTPTIALAREVAALFEERRPLLEEAGAAVIIRLVARDPDTFLI",
    "PTLEELGAELRVEEIDGVTVVEIRDRKHGNKLVRVVAQGVSPERVEALAKELFLKYVIEKGYDGVVAAISRSGAVFVLRRNLKTGEVEFEVLEIPDEDDILSIELETYPDGHALAIVKTKEETIVLSILHNNGIIVANAVTVRTERELEVVLLLLEAVRVASKVADIKEFLVVRSPSGLWHLRLRLDFPDGETRFEYEVTVPEEEFVAAARAVVAGLLALLAEAAKEDPAVAPLAEEMAAILARLEALAA",
    "CPPLENIDISGVDGDSATISFEPCREPVDYVVLHYGRAGDPGDWKTYFLPPGDTSFTLTGLEPGGWYRVELWCWRPGRCCEPQTEYFEV"
]  # A list of protein sequences


In [6]:
BATCH_SIZE = 2
NUM_MASKS = 10
P_MASK = 0.15

# Function to compute MLM loss for a batch of protein pairs
def compute_mlm_loss_batch(pairs):
    avg_losses = []
    for _ in range(NUM_MASKS):
        # Tokenize the concatenated protein pairs
        inputs = tokenizer(pairs, return_tensors="pt", truncation=True, padding=True, max_length=1022)

        # Move input tensors to GPU if available
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Get the mask token ID
        mask_token_id = tokenizer.mask_token_id

        # Clone input IDs for labels
        labels = inputs["input_ids"].clone()

        # Randomly mask 15% of the residues for each sequence in the batch
        for idx in range(inputs["input_ids"].shape[0]):
            mask_indices = np.random.choice(inputs["input_ids"].shape[1], size=int(P_MASK * inputs["input_ids"].shape[1]), replace=False)
            inputs["input_ids"][idx, mask_indices] = mask_token_id
            labels[idx, [i for i in range(inputs["input_ids"].shape[1]) if i not in mask_indices]] = -100

        # Compute the MLM loss
        outputs = model(**inputs, labels=labels)
        avg_losses.append(outputs.loss.item())

    # Return the average loss for the batch
    return sum(avg_losses) / NUM_MASKS


In [7]:
# Compute loss matrix
loss_matrix = np.zeros((len(all_proteins), len(all_proteins)))

for i in range(len(all_proteins)):
    for j in range(i+1, len(all_proteins), BATCH_SIZE):  # to avoid self-pairing and use batches
        pairs = [all_proteins[i] + all_proteins[k] for k in range(j, min(j+BATCH_SIZE, len(all_proteins)))]
        batch_loss = compute_mlm_loss_batch(pairs)
        for k in range(len(pairs)):
            loss_matrix[i, j+k] = batch_loss
            loss_matrix[j+k, i] = batch_loss  # the matrix is symmetric

# Set the diagonal of the loss matrix to a large value to prevent self-pairings
np.fill_diagonal(loss_matrix, np.inf)


In [8]:
# Use the linear assignment problem to find the optimal pairing based on MLM loss
rows, cols = linear_sum_assignment(loss_matrix)
optimal_pairs = list(zip(rows, cols))

print(optimal_pairs)


[(0, 1), (1, 0), (2, 3), (3, 2)]


In [1]:
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch

# Load the base model and tokenizer
base_model_path = "facebook/esm2_t12_35M_UR50D"
model = AutoModelForMaskedLM.from_pretrained(base_model_path)
tokenizer = AutoTokenizer.from_pretrained(base_model_path)

# Ensure the model is in evaluation mode
model.eval()

# Define the protein of interest and its potential binders
protein_of_interest = "LEEKKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYVQRNYDLSFLKTIQEVAGYVLIALNTVERIPLENLQIIRGNMYYENSYALAVLSNYDANKTGLKELPMRNLQEILHGAVRFSNNPALCNVESIQWRDIVSSDFLSNMSMDFQNHLGSCQKCDPSCPNGSCWGAGEENCQKLTKIICAQQCSGRCRGKSPSDCCHNQCAAGCTGPRESDCLVCRKFRDEATCKDTCPPLMLYNPTTYQMDVNPEGKYSFGATCVKKCPRNYVVTDHGSCVRACGADSYEMEEDGVRKCKKCEGPCRKVCNGIGIGEFKDSLSINATNIKHFKNCTSISGDLHILPVAFRGDSFTHTPPLDPQELDILKTVKEITGFLLIQAWPENRTDLHAFENLEIIRGRTKQHGQFSLAVVSLNITSLGLRSLKEISDGDVIISGNKNLCYANTINWKKLFGTSGQKTKIISNRGENSCKATGQVCHALCSPEGCWGPEPRDCVSCRNVSRGRECVDKCNLLEGEPREFVENSECIQCHPECLPQAMNITCTGRGPDNCIQCAHYIDGPHCVKTCPAGVMGENNTLVWKYADAGHVCHLCHPNCTYGCTGPGLEGCPTNGPKIPS"
potential_binders = [
    "AAAARLAELEALLAEAEALAKKIRKGEDLKVLEKLRATVEAIAALLAASGVDGPVAELVARLKEIAALLAEAIAKRKEAEKIKKEVEEKEELLKTYKEYNELLKKMKELKKKIAELKKKLEEKRKELKKAIEKEKLPAEVVAKIEKFLEKMEELLKRTEEEAKAWREELDKKEKKFIEDFVKDLSKEMKESKEEEKKEILKKKKEEVKKFVEEWEKENKPKLEEFRKKVEERREELKKLAEEAAKLPNAEVAKLLQEALKLASELIEAIAEYTELYYRRSALGTLVLTEKMFRAVEEGKDPEEKAKEFLELEKKGEVDEELLEKLEKEAEKTKEKRAKAEKLEKEAEKLLKEALALLEEALAAVRALLAAWAAAA",
    "SEEEEERERALKEIIEETRRELKAAKAKHGKVVVVLIMASSTLEPEFILELSKALIKEMKSLFPNVVLIIVVVGLAPASLLARIRDVSLELAKYAKSLGIKVIVIVGNENEAVFVPAFEALGVEVIVDRTIIEIAAEELGLSEEEVLARFAAAAELLDELFAADPSLRERYARLDVAGATELLLERLRELFGAKVERHERLITVEVERVLTPDERRRVTAILLTPEAAREVVERLVDLVVDLILEKIAEGHNVLVLVFTPTIALAREVAALFEERRPLLEEAGAAVIIRLVARDPDTFLI",
    "PTLEELGAELRVEEIDGVTVVEIRDRKHGNKLVRVVAQGVSPERVEALAKELFLKYVIEKGYDGVVAAISRSGAVFVLRRNLKTGEVEFEVLEIPDEDDILSIELETYPDGHALAIVKTKEETIVLSILHNNGIIVANAVTVRTERELEVVLLLLEAVRVASKVADIKEFLVVRSPSGLWHLRLRLDFPDGETRFEYEVTVPEEEFVAAARAVVAGLLALLAEAAKEDPAVAPLAEEMAAILARLEALAA",
    "CPPLENIDISGVDGDSATISFEPCREPVDYVVLHYGRAGDPGDWKTYFLPPGDTSFTLTGLEPGGWYRVELWCWRPGRCCEPQTEYFEV",
    'AATAAALEHLEAAAAALKELAALVATEAADAAALKAKAEELAAKVREHLRAARAATGDTSLTDEDIDAFIQRILDAVDDAEAVKALYEELEAAIAAFRAAQEAAA','SAAAAAQARLDAALAALREWLAARAREAIERYRDAKERVVEEEAITRDFHGVLTLEAVRIEVTPTTVAISARLRHASGQTVYLSILAPHDPAALEAALAIAELATRLALEAGYDLFVAVAFEPPGPVTPERWEEFAAFLELVAEDLRALLADAAAKGRPLLVVIVIVVNDDLAAHLPLESHTDPEAAAAAVATYVAEVEAKTGRKLTLPAEIAAALAAGASVVLVVARREDIAGVPARVEAALRAALAAA',
    'LEEKKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYVQRNYDLSFLKTIQEVAGYVLIALNTVERIPLENLQIIRGNMYYENSYALAVLSNYDANKTGLKELPMRNLQEILHGAVRFSNNPALCNVESIQWRDIVSSDFLSNMSMDFQNHLGSCQKCDPSCPNGSCW',
    'SNKTQLGSSG/ELEELLAKKEELLKKLYKELLKKGNVLVDTEYLKTLTEEELKEISKAYISEEEGMIILEFKGTYNGYLVIKHKDVETSEEVREEQKKLAEELKKKLEALGAEVREIEVKVKEEVKTEKEGNITKTTLTLEVEIDGEKVTLKLTEVEVEL',
    'AATAAALEHLEAAAAALKELAALVATEAADAAALKAKAEELAAKVREHLRAARAATGDTSLTDEDIDAFIQRILDAVDDAEAVKALYEELEAAIAAFRAAQEAAA',
    'GKLNIKVTFLSSGKEEKLAALKAHVDALVASIDTKASGAPPLKVEVKESESKETREIDGKTYEYGFTTVTYSFEGTNDILNQLANDIVTHISNTLKDLLIEIDIAATSDGDLNLTINITVNGVDTVILLNVSLTAGTNVNLTINITVTGATVTVHIIVSLTTTSAGSATVTINATAGAGATLNITLMGVFTNTAVKDVTVNVTTTATSGTVTVTLGPVTQASAAEMAAGVAAAREAAREEALREVARLTE',
    'GRPPKNVKVSGVDGNSATISFDLADENDKYILIMIGPANDPNSWTSWWLPHETSYLSISNLPPGAEYQVTFMMVRPGKMSPPITQDFKC',
    'KKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYVQRNYDLSFLKTIQEVAGYVLIALNTVERIPLENLQIIRGNMYYENSYALAVLSNYDANKTGLKELPMRNLQEILHGAVRFSNNPALCNVESIQWRDIVSSDFLSNMSMDFQNHLGSCQKCDPSCPNGSCWGAGEENCQKLTKIICAQQCSGRCRGKSPSDCCHNQCAAGCTGPRESDCLVCRKFRDEATCKDTCPPLMLYNPTTYQM',
    'SAAAAAQARLDAALAALREWLAARAREAIERYRDAKERVVEEEAITRDFHGVLTLEAVRIEVTPTTVAISARLRHASGQTVYLSILAPHDPAALEAALAIAELATRLALEAGYDLFVAVAFEPPGPVTPERWEEFAAFLELVAEDLRALLADAAAKGRPLLVVIVIVVNDDLAAHLPLESHTDPEAAAAAVATYVAEVEAKTGRKLTLPAEIAAALAAGASVVLVVARREDIAGVPARVEAALRAALAAA',
    'LSLKLLEKALSKELADKIITFHLLGLVLEVSKDHPEKPIFDELRERLEELEEELEEHLDLPEEEFNELVDKRLSEFIAEAFSHPAVVDAFLDLLVTLKAMADIRSARAEKADEERRANDPSGKEEEVPDSEELVTLKKASAALDRALDTLLKDPRVREMVERYLRARGVRIPEEALSLPRAEQLRLAFQRLVAREAARMTPAGKTAAEVTDEELAASFAASPNPFARAFAHRFPELAKELAEMQDLHDKL'
]  # Add potential binding sequences here

def compute_mlm_loss(protein, binder, iterations=3):
    total_loss = 0.0

    for _ in range(iterations):
        # Concatenate protein sequences with a separator
        concatenated_sequence = protein + ":" + binder

        # Mask a subset of amino acids in the concatenated sequence (excluding the separator)
        tokens = list(concatenated_sequence)
        mask_rate = 0.15  # For instance, masking 15% of the sequence
        num_mask = int(len(tokens) * mask_rate)

        # Exclude the separator from potential mask indices
        available_indices = [i for i, token in enumerate(tokens) if token != ":"]
        probs = torch.ones(len(available_indices))
        mask_indices = torch.multinomial(probs, num_mask, replacement=False)

        for idx in mask_indices:
            tokens[available_indices[idx]] = tokenizer.mask_token

        masked_sequence = "".join(tokens)
        inputs = tokenizer(masked_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')

        # Compute the MLM loss
        with torch.no_grad():
            outputs = model(**inputs, labels=inputs["input_ids"])
            loss = outputs.loss

        total_loss += loss.item()

    # Return the average loss
    return total_loss / iterations

# Compute MLM loss for each potential binder
mlm_losses = {}
for binder in potential_binders:
    loss = compute_mlm_loss(protein_of_interest, binder)
    mlm_losses[binder] = loss

# Rank binders based on MLM loss
ranked_binders = sorted(mlm_losses, key=mlm_losses.get)

print("Ranking of Potential Binders:")
for idx, binder in enumerate(ranked_binders, 1):
    print(f"{idx}. {binder} - MLM Loss: {mlm_losses[binder]}")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/778 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/136M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]



Ranking of Potential Binders:
1. AAAARLAELEALLAEAEALAKKIRKGEDLKVLEKLRATVEAIAALLAASGVDGPVAELVARLKEIAALLAEAIAKRKEAEKIKKEVEEKEELLKTYKEYNELLKKMKELKKKIAELKKKLEEKRKELKKAIEKEKLPAEVVAKIEKFLEKMEELLKRTEEEAKAWREELDKKEKKFIEDFVKDLSKEMKESKEEEKKEILKKKKEEVKKFVEEWEKENKPKLEEFRKKVEERREELKKLAEEAAKLPNAEVAKLLQEALKLASELIEAIAEYTELYYRRSALGTLVLTEKMFRAVEEGKDPEEKAKEFLELEKKGEVDEELLEKLEKEAEKTKEKRAKAEKLEKEAEKLLKEALALLEEALAAVRALLAAWAAAA - MLM Loss: 4.531987031300862
2. SEEEEERERALKEIIEETRRELKAAKAKHGKVVVVLIMASSTLEPEFILELSKALIKEMKSLFPNVVLIIVVVGLAPASLLARIRDVSLELAKYAKSLGIKVIVIVGNENEAVFVPAFEALGVEVIVDRTIIEIAAEELGLSEEEVLARFAAAAELLDELFAADPSLRERYARLDVAGATELLLERLRELFGAKVERHERLITVEVERVLTPDERRRVTAILLTPEAAREVVERLVDLVVDLILEKIAEGHNVLVLVFTPTIALAREVAALFEERRPLLEEAGAAVIIRLVARDPDTFLI - MLM Loss: 5.972407341003418
3. KKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYVQRNYDLSFLKTIQEVAGYVLIALNTVERIPLENLQIIRGNMYYENSYALAVLSNYDANKTGLKELPMRNLQEILHGAVRFSNNPALCNVESIQWRDIVSSDFLSNMSMDFQNHLGSCQKCDPSCPNGSCWGAGEENCQKLTKIICAQQCSGRCRGKSPSDCCHNQCAAGCTGPRESDCLVC