<a href="https://colab.research.google.com/github/istender15/bme5900-p1/blob/main/binding_affinity.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [1]:
import numpy as np
from scipy.optimize import linear_sum_assignment
from transformers import AutoTokenizer, EsmForMaskedLM
import torch


In [2]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
# This is used to convert protein sequences into a format suitable for the model.
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")
# This is the ESM-2 model, specifically built for protein sequences.


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# this will set the code to use GPUs if available
model = model.to(device)
# casting the model to the GPU


# Compute Binder Affinity

In [4]:
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch

# Load the base model and tokenizer
base_model_path = "facebook/esm2_t12_35M_UR50D"
model = AutoModelForMaskedLM.from_pretrained(base_model_path)
tokenizer = AutoTokenizer.from_pretrained(base_model_path)

# Ensure the model is in evaluation mode
model.eval()

# Define the protein of interest and its potential binders
protein_of_interest = "EKKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYVQRNYDLSFLKTIQEVAGYVLIALNTVERIPLENLQIIRGNMYYENSYALAVLSNYDANKTGLKELPMRNLQEILHGAVRFSNNPALCNVESIQWRDIVSSDFLSNMSMDFQNHLGSCQKCDPSCPNGSCWGAGEENCQKLTKIICAQQCSGRCRGKSPSDCCHNQCAAGCTGPRESDCLVCRKFRDEATCKDTCPPLMLYNPTTYQMDVNPEGKYSFGATCVKKCPRNYVVTDHGSCVRACGADSYEMEEDGVRKCKKCEGPCRKVCNGIGIGEFKDSLSINATNIKHFKNCTSISGDLHILPVAFRGDSFTHTPPLDPQELDILKTVKEITGFLLIQAWPENRTDLHAFENLEIIRGRTKQHGQFSLAVVSLNITSLGLRSLKEISDGDVIISGNKNLCYANTINWKKLFGTSGQKTKIISNRGENSCKATGQVCHALCSPEGCWGPEPRDCVSCRNVSRGRECVDKCNLLEGEPREFVENSECIQCHPECLPQAMNITCTGRGPDNCIQCAHYIDGPHCVKT"
potential_binders = [
    "KPQRKTYYGNMKGREDYEPEQSKEVYAKKFASKTEEELEEVIKEEKAEIEKKKKQLEEDIKAGKVTEYNPKVKVITPVYPSEYKEVEE", #my protein
    "SEEEEERERALKEIIEETRRELKAAKAKHGKVVVVLIMASSTLEPEFILELSKALIKEMKSLFPNVVLIIVVVGLAPASLLARIRDVSLELAKYAKSLGIKVIVIVGNENEAVFVPAFEALGVEVIVDRTIIEIAAEELGLSEEEVLARFAAAAELLDELFAADPSLRERYARLDVAGATELLLERLRELFGAKVERHERLITVEVERVLTPDERRRVTAILLTPEAAREVVERLVDLVVDLILEKIAEGHNVLVLVFTPTIALAREVAALFEERRPLLEEAGAAVIIRLVARDPDTFLI", #ben protein
    "LEEKKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYVQRNYDLSFLKTIQEVAGYVLIALNTVERIPLENLQIIRGNMYYENSYALAVLSNYDANKTGLKELPMRNLQEILHGAVRFSNNPALCNVESIQWRDIVSSDFLSNMSMDFQNHLGSCQKCDPSCPNGSCW", #noah protein
    "GRPPKNVKVSGVDGNSATISFDLADENDKYILIMIGPANDPNSWTSWWLPHETSYLSISNLPPGAEYQVTFMMVRPGKMSPPITQDFKC" # jackson swalley protein
]  # Add potential binding sequences here

def compute_mlm_loss(protein, binder, iterations=3):
    total_loss = 0.0

    for _ in range(iterations):
        # Concatenate protein sequences with a separator
        concatenated_sequence = protein + ":" + binder

        # Mask a subset of amino acids in the concatenated sequence (excluding the separator)
        tokens = list(concatenated_sequence)
        mask_rate = 0.15  # For instance, masking 15% of the sequence
        num_mask = int(len(tokens) * mask_rate)

        # Exclude the separator from potential mask indices
        available_indices = [i for i, token in enumerate(tokens) if token != ":"]
        probs = torch.ones(len(available_indices))
        mask_indices = torch.multinomial(probs, num_mask, replacement=False)

        for idx in mask_indices:
            tokens[available_indices[idx]] = tokenizer.mask_token

        masked_sequence = "".join(tokens)
        inputs = tokenizer(masked_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')

        # Compute the MLM loss
        with torch.no_grad():
            outputs = model(**inputs, labels=inputs["input_ids"])
            loss = outputs.loss

        total_loss += loss.item()

    # Return the average loss
    return total_loss / iterations

# Compute MLM loss for each potential binder
mlm_losses = {}
for binder in potential_binders:
    loss = compute_mlm_loss(protein_of_interest, binder)
    mlm_losses[binder] = loss

# Rank binders based on MLM loss
ranked_binders = sorted(mlm_losses, key=mlm_losses.get)

print("Ranking of Potential Binders:")
for idx, binder in enumerate(ranked_binders, 1):
    print(f"{idx}. {binder} - MLM Loss: {mlm_losses[binder]}")


Ranking of Potential Binders:
1. SEEEEERERALKEIIEETRRELKAAKAKHGKVVVVLIMASSTLEPEFILELSKALIKEMKSLFPNVVLIIVVVGLAPASLLARIRDVSLELAKYAKSLGIKVIVIVGNENEAVFVPAFEALGVEVIVDRTIIEIAAEELGLSEEEVLARFAAAAELLDELFAADPSLRERYARLDVAGATELLLERLRELFGAKVERHERLITVEVERVLTPDERRRVTAILLTPEAAREVVERLVDLVVDLILEKIAEGHNVLVLVFTPTIALAREVAALFEERRPLLEEAGAAVIIRLVARDPDTFLI - MLM Loss: 6.843159993489583
2. LEEKKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYVQRNYDLSFLKTIQEVAGYVLIALNTVERIPLENLQIIRGNMYYENSYALAVLSNYDANKTGLKELPMRNLQEILHGAVRFSNNPALCNVESIQWRDIVSSDFLSNMSMDFQNHLGSCQKCDPSCPNGSCW - MLM Loss: 8.65021832784017
3. GRPPKNVKVSGVDGNSATISFDLADENDKYILIMIGPANDPNSWTSWWLPHETSYLSISNLPPGAEYQVTFMMVRPGKMSPPITQDFKC - MLM Loss: 10.40660031636556
4. KPQRKTYYGNMKGREDYEPEQSKEVYAKKFASKTEEELEEVIKEEKAEIEKKKKQLEEDIKAGKVTEYNPKVKVITPVYPSEYKEVEE - MLM Loss: 10.430347124735514
