# **Mini Project 2 - Comparative Assessment of Designed Binder**



### Novel Protein Design
Methods:

Used RF diffusion notebook from https://github.com/sokrypton/ColabDesign

Adjust backbone parameters to:

    250 contigs
    5WB7 pdb

PDB file of my designed protein can be found here: https://github.com/hannahcarrolll/BME5590_MiniProject2.git




### Comparative Assessment
This code was modified from Amelie Schreiber. Modifications include changing protein and binding sequences to data from UVM's BME 5990 Applied Deep Learning in BME 2024 course.

Link to blog post with notebook: https://huggingface.co/blog/AmelieSchreiber/protein-binding-partners-with-esm2



In [1]:
import numpy as np
from scipy.optimize import linear_sum_assignment
from transformers import AutoTokenizer, EsmForMaskedLM
import torch

### Initializing Model & Tokenizer
Tokenizer: This is used to convert protein sequences into a format suitable for the model.

Model: This is the ESM-2 model, specifically built for protein sequences.

In [2]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/775 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/31.4M [00:00<?, ?B/s]

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

### Defining Protein Sequences
The following protein sequences were designed in various ways by my peers. The goal was to create a protein that is most likely to bind to epidermal growth factor (EGF).

In [15]:
all_proteins = [
    "SAAAAAQARLDAALAALREWLAARAREAIERYRDAKERVVEEEAITRDFHGVLTLEAVRIEVTPTTVAISARLRHASGQTVYLSILAPHDPAALEAALAIAELATRLALEAGYDLFVAVAFEPPGPVTPERWEEFAAFLELVAEDLRALLADAAAKGRPLLVVIVIVVNDDLAAHLPLESHTDPEAAAAAVATYVAEVEAKTGRKLTLPAEIAAALAAGASVVLVVARREDIAGVPARVEAALRAALAAA",
    "CPPLENIDISGVDGDSATISFEPCREPVDYVVLHYGRAGDPGDWKTYFLPPGDTSFTLTGLEPGGWYRVELWCWRPGRCCEPQTEYFEV",
    "SEEEEERERALKEIIEETRRELKAAKAKHGKVVVVLIMASSTLEPEFILELSKALIKEMKSLFPNVVLIIVVVGLAPASLLARIRDVSLELAKYAKSLGIKVIVIVGNENEAVFVPAFEALGVEVIVDRTIIEIAAEELGLSEEEVLARFAAAAELLDELFAADPSLRERYARLDVAGATELLLERLRELFGAKVERHERLITVEVERVLTPDERRRVTAILLTPEAAREVVERLVDLVVDLILEKIAEGHNVLVLVFTPTIALAREVAALFEERRPLLEEAGAAVIIRLVARDPDTFLI",
    "LEEKKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYVQRNYDLSFLKTIQEVAGYVLIALNTVERIPLENLQIIRGNMYYENSYALAVLSNYDANKTGLKELPMRNLQEILHGAVRFSNNPALCNVESIQWRDIVSSDFLSNMSMDFQNHLGSCQKCDPSCPNGSCW",
    "SNKTQLGSSG/ELEELLAKKEELLKKLYKELLKKGNVLVDTEYLKTLTEEELKEISKAYISEEEGMIILEFKGTYNGYLVIKHKDVETSEEVREEQKKLAEELKKKLEALGAEVREIEVKVKEEVKTEKEGNITKTTLTLEVEIDGEKVTLKLTEVEVEL",
    "AATAAALEHLEAAAAALKELAALVATEAADAAALKAKAEELAAKVREHLRAARAATGDTSLTDEDIDAFIQRILDAVDDAEAVKALYEELEAAIAAFRAAQEAAA",
    "GKLNIKVTFLSSGKEEKLAALKAHVDALVASIDTKASGAPPLKVEVKESESKETREIDGKTYEYGFTTVTYSFEGTNDILNQLANDIVTHISNTLKDLLIEIDIAATSDGDLNLTINITVNGVDTVILLNVSLTAGTNVNLTINITVTGATVTVHIIVSLTTTSAGSATVTINATAGAGATLNITLMGVFTNTAVKDVTVNVTTTATSGTVTVTLGPVTQASAAEMAAGVAAAREAAREEALREVARLTE",
    "PTLEELGAELRVEEIDGVTVVEIRDRKHGNKLVRVVAQGVSPERVEALAKELFLKYVIEKGYDGVVAAISRSGAVFVLRRNLKTGEVEFEVLEIPDEDDILSIELETYPDGHALAIVKTKEETIVLSILHNNGIIVANAVTVRTERELEVVLLLLEAVRVASKVADIKEFLVVRSPSGLWHLRLRLDFPDGETRFEYEVTVPEEEFVAAARAVVAGLLALLAEAAKEDPAVAPLAEEMAAILARLEALAA",
    "GRPPKNVKVSGVDGNSATISFDLADENDKYILIMIGPANDPNSWTSWWLPHETSYLSISNLPPGAEYQVTFMMVRPGKMSPPITQDFKC",
    "LSLKLLEKALSKELADKIITFHLLGLVLEVSKDHPEKPIFDELRERLEELEEELEEHLDLPEEEFNELVDKRLSEFIAEAFSHPAVVDAFLDLLVTLKAMADIRSARAEKADEERRANDPSGKEEEVPDSEELVTLKKASAALDRALDTLLKDPRVREMVERYLRARGVRIPEEALSLPRAEQLRLAFQRLVAREAARMTPAGKTAAEVTDEELAASFAASPNPFARAFAHRFPELAKELAEMQDLHDKL"
]  # A list of protein sequences

### Defining MLM Loss Function
(From blog post) Inside the function:

We tokenize each protein pair.

15% of the residues in each sequence are randomly masked. This can be adjusted per your preferences, and different percentages should be tested.

We compute the MLM loss using the ESM-2 model.

This process is repeated NUM_MASKS times for each pair, and the average loss is computed.



In [16]:
BATCH_SIZE = 2
NUM_MASKS = 10
P_MASK = 0.15

# Function to compute MLM loss for a batch of protein pairs
def compute_mlm_loss_batch(pairs):
    avg_losses = []
    for _ in range(NUM_MASKS):
        # Tokenize the concatenated protein pairs
        inputs = tokenizer(pairs, return_tensors="pt", truncation=True, padding=True, max_length=1022)

        # Move input tensors to GPU if available
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Get the mask token ID
        mask_token_id = tokenizer.mask_token_id

        # Clone input IDs for labels
        labels = inputs["input_ids"].clone()

        # Randomly mask 15% of the residues for each sequence in the batch
        for idx in range(inputs["input_ids"].shape[0]):
            mask_indices = np.random.choice(inputs["input_ids"].shape[1], size=int(P_MASK * inputs["input_ids"].shape[1]), replace=False)
            inputs["input_ids"][idx, mask_indices] = mask_token_id
            labels[idx, [i for i in range(inputs["input_ids"].shape[1]) if i not in mask_indices]] = -100

        # Compute the MLM loss
        outputs = model(**inputs, labels=labels)
        avg_losses.append(outputs.loss.item())

    # Return the average loss for the batch
    return sum(avg_losses) / NUM_MASKS

### Constructing Loss Matrix

In [17]:
# Compute loss matrix
loss_matrix = np.zeros((len(all_proteins), len(all_proteins)))

for i in range(len(all_proteins)):
    for j in range(i+1, len(all_proteins), BATCH_SIZE):  # to avoid self-pairing and use batches
        pairs = [all_proteins[i] + all_proteins[k] for k in range(j, min(j+BATCH_SIZE, len(all_proteins)))]
        batch_loss = compute_mlm_loss_batch(pairs)
        for k in range(len(pairs)):
            loss_matrix[i, j+k] = batch_loss
            loss_matrix[j+k, i] = batch_loss  # the matrix is symmetric

# Set the diagonal of the loss matrix to a large value to prevent self-pairings
np.fill_diagonal(loss_matrix, np.inf)

### Finding Optimal Pairs
This function tries to find the optimal assignment that minimizes the total MLM loss.


In [18]:
# Use the linear assignment problem to find the optimal pairing based on MLM loss
rows, cols = linear_sum_assignment(loss_matrix)
optimal_pairs = list(zip(rows, cols))

print(optimal_pairs)

[(0, 3), (1, 6), (2, 4), (3, 0), (4, 2), (5, 7), (6, 5), (7, 1), (8, 9), (9, 8)]


### Ranking of Potential Binders

In [20]:
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch

# Load the base model and tokenizer
base_model_path = "facebook/esm2_t12_35M_UR50D"
model = AutoModelForMaskedLM.from_pretrained(base_model_path)
tokenizer = AutoTokenizer.from_pretrained(base_model_path)

# Ensure the model is in evaluation mode
model.eval()

# Define the protein of interest and its potential binders
protein_of_interest = "LEEKKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYVQRNYDLSFLKTIQEVAGYVLIALNTVERIPLENLQIIRGNMYYENSYALAVLSNYDANKTGLKELPMRNLQEILHGAVRFSNNPALCNVESIQWRDIVSSDFLSNMSMDFQNHLGSCQKCDPSCPNGSCWGAGEENCQKLTKIICAQQCSGRCRGKSPSDCCHNQCAAGCTGPRESDCLVCRKFRDEATCKDTCPPLMLYNPTTYQMDVNPEGKYSFGATCVKKCPRNYVVTDHGSCVRACGADSYEMEEDGVRKCKKCEGPCRKVCNGIGIGEFKDSLSINATNIKHFKNCTSISGDLHILPVAFRGDSFTHTPPLDPQELDILKTVKEITGFLLIQAWPENRTDLHAFENLEIIRGRTKQHGQFSLAVVSLNITSLGLRSLKEISDGDVIISGNKNLCYANTINWKKLFGTSGQKTKIISNRGENSCKATGQVCHALCSPEGCWGPEPRDCVSHHHHHH"
potential_binders = [
    "SAAAAAQARLDAALAALREWLAARAREAIERYRDAKERVVEEEAITRDFHGVLTLEAVRIEVTPTTVAISARLRHASGQTVYLSILAPHDPAALEAALAIAELATRLALEAGYDLFVAVAFEPPGPVTPERWEEFAAFLELVAEDLRALLADAAAKGRPLLVVIVIVVNDDLAAHLPLESHTDPEAAAAAVATYVAEVEAKTGRKLTLPAEIAAALAAGASVVLVVARREDIAGVPARVEAALRAALAAA",
    "CPPLENIDISGVDGDSATISFEPCREPVDYVVLHYGRAGDPGDWKTYFLPPGDTSFTLTGLEPGGWYRVELWCWRPGRCCEPQTEYFEV",
    "SEEEEERERALKEIIEETRRELKAAKAKHGKVVVVLIMASSTLEPEFILELSKALIKEMKSLFPNVVLIIVVVGLAPASLLARIRDVSLELAKYAKSLGIKVIVIVGNENEAVFVPAFEALGVEVIVDRTIIEIAAEELGLSEEEVLARFAAAAELLDELFAADPSLRERYARLDVAGATELLLERLRELFGAKVERHERLITVEVERVLTPDERRRVTAILLTPEAAREVVERLVDLVVDLILEKIAEGHNVLVLVFTPTIALAREVAALFEERRPLLEEAGAAVIIRLVARDPDTFLI",
    "LEEKKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYVQRNYDLSFLKTIQEVAGYVLIALNTVERIPLENLQIIRGNMYYENSYALAVLSNYDANKTGLKELPMRNLQEILHGAVRFSNNPALCNVESIQWRDIVSSDFLSNMSMDFQNHLGSCQKCDPSCPNGSCW",
    "SNKTQLGSSG/ELEELLAKKEELLKKLYKELLKKGNVLVDTEYLKTLTEEELKEISKAYISEEEGMIILEFKGTYNGYLVIKHKDVETSEEVREEQKKLAEELKKKLEALGAEVREIEVKVKEEVKTEKEGNITKTTLTLEVEIDGEKVTLKLTEVEVEL",
    "AATAAALEHLEAAAAALKELAALVATEAADAAALKAKAEELAAKVREHLRAARAATGDTSLTDEDIDAFIQRILDAVDDAEAVKALYEELEAAIAAFRAAQEAAA",
    "GKLNIKVTFLSSGKEEKLAALKAHVDALVASIDTKASGAPPLKVEVKESESKETREIDGKTYEYGFTTVTYSFEGTNDILNQLANDIVTHISNTLKDLLIEIDIAATSDGDLNLTINITVNGVDTVILLNVSLTAGTNVNLTINITVTGATVTVHIIVSLTTTSAGSATVTINATAGAGATLNITLMGVFTNTAVKDVTVNVTTTATSGTVTVTLGPVTQASAAEMAAGVAAAREAAREEALREVARLTE",
    "PTLEELGAELRVEEIDGVTVVEIRDRKHGNKLVRVVAQGVSPERVEALAKELFLKYVIEKGYDGVVAAISRSGAVFVLRRNLKTGEVEFEVLEIPDEDDILSIELETYPDGHALAIVKTKEETIVLSILHNNGIIVANAVTVRTERELEVVLLLLEAVRVASKVADIKEFLVVRSPSGLWHLRLRLDFPDGETRFEYEVTVPEEEFVAAARAVVAGLLALLAEAAKEDPAVAPLAEEMAAILARLEALAA",
    "GRPPKNVKVSGVDGNSATISFDLADENDKYILIMIGPANDPNSWTSWWLPHETSYLSISNLPPGAEYQVTFMMVRPGKMSPPITQDFKC",
    "LSLKLLEKALSKELADKIITFHLLGLVLEVSKDHPEKPIFDELRERLEELEEELEEHLDLPEEEFNELVDKRLSEFIAEAFSHPAVVDAFLDLLVTLKAMADIRSARAEKADEERRANDPSGKEEEVPDSEELVTLKKASAALDRALDTLLKDPRVREMVERYLRARGVRIPEEALSLPRAEQLRLAFQRLVAREAARMTPAGKTAAEVTDEELAASFAASPNPFARAFAHRFPELAKELAEMQDLHDKL"
]  # Add potential binding sequences here

def compute_mlm_loss(protein, binder, iterations=3):
    total_loss = 0.0

    for _ in range(iterations):
        # Concatenate protein sequences with a separator
        concatenated_sequence = protein + ":" + binder

        # Mask a subset of amino acids in the concatenated sequence (excluding the separator)
        tokens = list(concatenated_sequence)
        mask_rate = 0.15  # For instance, masking 15% of the sequence
        num_mask = int(len(tokens) * mask_rate)

        # Exclude the separator from potential mask indices
        available_indices = [i for i, token in enumerate(tokens) if token != ":"]
        probs = torch.ones(len(available_indices))
        mask_indices = torch.multinomial(probs, num_mask, replacement=False)

        for idx in mask_indices:
            tokens[available_indices[idx]] = tokenizer.mask_token

        masked_sequence = "".join(tokens)
        inputs = tokenizer(masked_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')

        # Compute the MLM loss
        with torch.no_grad():
            outputs = model(**inputs, labels=inputs["input_ids"])
            loss = outputs.loss

        total_loss += loss.item()

    # Return the average loss
    return total_loss / iterations

# Compute MLM loss for each potential binder
mlm_losses = {}
for binder in potential_binders:
    loss = compute_mlm_loss(protein_of_interest, binder)
    mlm_losses[binder] = loss

# Rank binders based on MLM loss
ranked_binders = sorted(mlm_losses, key=mlm_losses.get)

print("Ranking of Potential Binders:")
for idx, binder in enumerate(ranked_binders, 1):
    print(f"{idx}. {binder} - MLM Loss: {mlm_losses[binder]}")

Ranking of Potential Binders:
1. SEEEEERERALKEIIEETRRELKAAKAKHGKVVVVLIMASSTLEPEFILELSKALIKEMKSLFPNVVLIIVVVGLAPASLLARIRDVSLELAKYAKSLGIKVIVIVGNENEAVFVPAFEALGVEVIVDRTIIEIAAEELGLSEEEVLARFAAAAELLDELFAADPSLRERYARLDVAGATELLLERLRELFGAKVERHERLITVEVERVLTPDERRRVTAILLTPEAAREVVERLVDLVVDLILEKIAEGHNVLVLVFTPTIALAREVAALFEERRPLLEEAGAAVIIRLVARDPDTFLI - MLM Loss: 7.770889759063721
2. GKLNIKVTFLSSGKEEKLAALKAHVDALVASIDTKASGAPPLKVEVKESESKETREIDGKTYEYGFTTVTYSFEGTNDILNQLANDIVTHISNTLKDLLIEIDIAATSDGDLNLTINITVNGVDTVILLNVSLTAGTNVNLTINITVTGATVTVHIIVSLTTTSAGSATVTINATAGAGATLNITLMGVFTNTAVKDVTVNVTTTATSGTVTVTLGPVTQASAAEMAAGVAAAREAAREEALREVARLTE - MLM Loss: 8.43543815612793
3. SAAAAAQARLDAALAALREWLAARAREAIERYRDAKERVVEEEAITRDFHGVLTLEAVRIEVTPTTVAISARLRHASGQTVYLSILAPHDPAALEAALAIAELATRLALEAGYDLFVAVAFEPPGPVTPERWEEFAAFLELVAEDLRALLADAAAKGRPLLVVIVIVVNDDLAAHLPLESHTDPEAAAAAVATYVAEVEAKTGRKLTLPAEIAAALAAGASVVLVVARREDIAGVPARVEAALRAALAAA - MLM Loss: 8.538489977518717
4. PTLEELGAELRVEEIDGVTVVEIRDRKHGNKLVRVVAQGVSPERVEALAKELFLKYVIEKGYDGVV

### Analysis
This section is using a masked language model to evaluate potential protein binders against a protein of interest. It will evaluate how well different protein binders might interact with a specified protein. The lower the MLM loss, the more likely the binder is a suitable candidate for binding to the protein of interest.

The protein of interest sequence was 5WB7, sourced from https://www.rcsb.org/sequence/5WB7.

My designed protein ranked third in the class with an MLM of 8.53.

Improvement: Had I defined a binder of interest when designing my protein using RF diffusion (link above), my protein would have been more specific to that binding site and ranked higher.

### Building a PPI Network

In [23]:
import networkx as nx
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import plotly.graph_objects as go
from ipywidgets import interact
from ipywidgets import widgets

# Check if CUDA is available and set the default device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pretrained (or fine-tuned) ESM-2 model and tokenizer
model_name = "facebook/esm2_t6_8M_UR50D"  # You can change this to your fine-tuned model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)

# Send the model to the device (GPU or CPU)
model.to(device)

# Ensure the model is in evaluation mode
model.eval()

# Define Protein Sequences (Replace with your list)
all_proteins = [
    "SAAAAAQARLDAALAALREWLAARAREAIERYRDAKERVVEEEAITRDFHGVLTLEAVRIEVTPTTVAISARLRHASGQTVYLSILAPHDPAALEAALAIAELATRLALEAGYDLFVAVAFEPPGPVTPERWEEFAAFLELVAEDLRALLADAAAKGRPLLVVIVIVVNDDLAAHLPLESHTDPEAAAAAVATYVAEVEAKTGRKLTLPAEIAAALAAGASVVLVVARREDIAGVPARVEAALRAALAAA",
    "CPPLENIDISGVDGDSATISFEPCREPVDYVVLHYGRAGDPGDWKTYFLPPGDTSFTLTGLEPGGWYRVELWCWRPGRCCEPQTEYFEV",
    "SEEEEERERALKEIIEETRRELKAAKAKHGKVVVVLIMASSTLEPEFILELSKALIKEMKSLFPNVVLIIVVVGLAPASLLARIRDVSLELAKYAKSLGIKVIVIVGNENEAVFVPAFEALGVEVIVDRTIIEIAAEELGLSEEEVLARFAAAAELLDELFAADPSLRERYARLDVAGATELLLERLRELFGAKVERHERLITVEVERVLTPDERRRVTAILLTPEAAREVVERLVDLVVDLILEKIAEGHNVLVLVFTPTIALAREVAALFEERRPLLEEAGAAVIIRLVARDPDTFLI",
    "LEEKKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYVQRNYDLSFLKTIQEVAGYVLIALNTVERIPLENLQIIRGNMYYENSYALAVLSNYDANKTGLKELPMRNLQEILHGAVRFSNNPALCNVESIQWRDIVSSDFLSNMSMDFQNHLGSCQKCDPSCPNGSCW",
    "SNKTQLGSSG/ELEELLAKKEELLKKLYKELLKKGNVLVDTEYLKTLTEEELKEISKAYISEEEGMIILEFKGTYNGYLVIKHKDVETSEEVREEQKKLAEELKKKLEALGAEVREIEVKVKEEVKTEKEGNITKTTLTLEVEIDGEKVTLKLTEVEVEL",
    "AATAAALEHLEAAAAALKELAALVATEAADAAALKAKAEELAAKVREHLRAARAATGDTSLTDEDIDAFIQRILDAVDDAEAVKALYEELEAAIAAFRAAQEAAA",
    "GKLNIKVTFLSSGKEEKLAALKAHVDALVASIDTKASGAPPLKVEVKESESKETREIDGKTYEYGFTTVTYSFEGTNDILNQLANDIVTHISNTLKDLLIEIDIAATSDGDLNLTINITVNGVDTVILLNVSLTAGTNVNLTINITVTGATVTVHIIVSLTTTSAGSATVTINATAGAGATLNITLMGVFTNTAVKDVTVNVTTTATSGTVTVTLGPVTQASAAEMAAGVAAAREAAREEALREVARLTE",
    "PTLEELGAELRVEEIDGVTVVEIRDRKHGNKLVRVVAQGVSPERVEALAKELFLKYVIEKGYDGVVAAISRSGAVFVLRRNLKTGEVEFEVLEIPDEDDILSIELETYPDGHALAIVKTKEETIVLSILHNNGIIVANAVTVRTERELEVVLLLLEAVRVASKVADIKEFLVVRSPSGLWHLRLRLDFPDGETRFEYEVTVPEEEFVAAARAVVAGLLALLAEAAKEDPAVAPLAEEMAAILARLEALAA",
    "GRPPKNVKVSGVDGNSATISFDLADENDKYILIMIGPANDPNSWTSWWLPHETSYLSISNLPPGAEYQVTFMMVRPGKMSPPITQDFKC",
    "LSLKLLEKALSKELADKIITFHLLGLVLEVSKDHPEKPIFDELRERLEELEEELEEHLDLPEEEFNELVDKRLSEFIAEAFSHPAVVDAFLDLLVTLKAMADIRSARAEKADEERRANDPSGKEEEVPDSEELVTLKKASAALDRALDTLLKDPRVREMVERYLRARGVRIPEEALSLPRAEQLRLAFQRLVAREAARMTPAGKTAAEVTDEELAASFAASPNPFARAFAHRFPELAKELAEMQDLHDKL"
]

def compute_average_mlm_loss(protein1, protein2, iterations=10):
    total_loss = 0.0
    connector = "G" * 25  # Connector sequence of G's
    for _ in range(iterations):
        concatenated_sequence = protein1 + connector + protein2
        inputs = tokenizer(concatenated_sequence, return_tensors="pt", padding=True, truncation=True, max_length=1024)

        mask_prob = 0.55
        mask_indices = torch.rand(inputs["input_ids"].shape, device=device) < mask_prob

        # Locate the positions of the connector 'G's and set their mask indices to False
        connector_indices = tokenizer.encode(connector, add_special_tokens=False)
        connector_length = len(connector_indices)
        start_connector = len(tokenizer.encode(protein1, add_special_tokens=False))
        end_connector = start_connector + connector_length

        # Avoid masking the connector 'G's
        mask_indices[0, start_connector:end_connector] = False

        # Apply the mask to the input IDs
        inputs["input_ids"][mask_indices] = tokenizer.mask_token_id
        inputs = {k: v.to(device) for k, v in inputs.items()}  # Send inputs to the device

        with torch.no_grad():
            outputs = model(**inputs, labels=inputs["input_ids"])

        loss = outputs.loss
        total_loss += loss.item()

    return total_loss / iterations

# Compute all average losses to determine the maximum threshold for the slider
all_losses = []
for i, protein1 in enumerate(all_proteins):
    for j, protein2 in enumerate(all_proteins[i+1:], start=i+1):
        avg_loss = compute_average_mlm_loss(protein1, protein2)
        all_losses.append(avg_loss)

# Set the maximum threshold to the maximum loss computed
max_threshold = max(all_losses)
print(f"Maximum loss (maximum threshold for slider): {max_threshold}")

def plot_graph(threshold):
    G = nx.Graph()

    # Add all protein nodes to the graph
    for i, protein in enumerate(all_proteins):
        G.add_node(f"protein {i+1}")

    # Loop through all pairs of proteins and calculate average MLM loss
    loss_idx = 0  # Index to keep track of the position in the all_losses list
    for i, protein1 in enumerate(all_proteins):
        for j, protein2 in enumerate(all_proteins[i+1:], start=i+1):
            avg_loss = all_losses[loss_idx]
            loss_idx += 1

            # Add an edge if the loss is below the threshold
            if avg_loss < threshold:
                G.add_edge(f"protein {i+1}", f"protein {j+1}", weight=round(avg_loss, 3))

    # 3D Network Plot
    # Adjust the k parameter to bring nodes closer. This might require some experimentation to find the right value.
    k_value = 2  # Lower value will bring nodes closer together
    pos = nx.spring_layout(G, dim=3, seed=42, k=k_value)

    edge_x = []
    edge_y = []
    edge_z = []
    for edge in G.edges():
        x0, y0, z0 = pos[edge[0]]
        x1, y1, z1 = pos[edge[1]]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])
        edge_z.extend([z0, z1, None])

    edge_trace = go.Scatter3d(x=edge_x, y=edge_y, z=edge_z, mode='lines', line=dict(width=0.5, color='grey'))

    node_x = []
    node_y = []
    node_z = []
    node_text = []
    for node in G.nodes():
        x, y, z = pos[node]
        node_x.append(x)
        node_y.append(y)
        node_z.append(z)
        node_text.append(node)

    node_trace = go.Scatter3d(x=node_x, y=node_y, z=node_z, mode='markers', marker=dict(size=5), hoverinfo='text', hovertext=node_text)

    layout = go.Layout(title='Protein Interaction Graph', title_x=0.5, scene=dict(xaxis=dict(showbackground=False), yaxis=dict(showbackground=False), zaxis=dict(showbackground=False)))

    fig = go.Figure(data=[edge_trace, node_trace], layout=layout)
    fig.show()

# Create an interactive slider for the threshold value with a default of 8.50
interact(plot_graph, threshold=widgets.FloatSlider(min=0.0, max=max_threshold, step=0.05, value=8.25))


`clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884



Maximum loss (maximum threshold for slider): 9.071200180053712


interactive(children=(FloatSlider(value=8.25, description='threshold', max=9.071200180053712, step=0.05), Outp…