**Mini Project 2 - Comparing Protein Designs**

Author: Cole Richardson

Code modified from Amelie Schreiber as the protein sequences were changed to match binding sequences created by BME 5990 Applied Deep Learning class
https://huggingface.co/blog/AmelieSchreiber/protein-binding-partners-with-esm2



Protein design made through RF Diffusion notebook: http://github.com/sokrypton/ColabDesign


Protein Design Paramaters:

250 contigs
5WB7 pdb

My protein design can be found at ....


In [4]:
# Importing Necessary Libraries
import numpy as np
from scipy.optimize import linear_sum_assignment
from transformers import AutoTokenizer, EsmForMaskedLM
import torch

In [5]:
# Initializing the Model and Tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")

In [6]:
# Setting up the GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [7]:
# Defining the Protein Sequences
# Protein sequences made from fellow classmates
all_proteins = ["CPPLENIDISGVDGDSATISFEPCREPVDYVVLHYGRAGDPGDWKTYFLPPGDTSFTLTGLEPGGWYRVELWCWRPGRCCEPQTEYFEV","SEEEEERERALKEIIEETRRELKAAKAKHGKVVVVLIMASSTLEPEFILELSKALIKEMKSLFPNVVLIIVVVGLAPASLLARIRDVSLELAKYAKSLGIKVIVIVGNENEAVFVPAFEALGVEVIVDRTIIEIAAEELGLSEEEVLARFAAAAELLDELFAADPSLRERYARLDVAGATELLLERLRELFGAKVERHERLITVEVERVLTPDERRRVTAILLTPEAAREVVERLVDLVVDLILEKIAEGHNVLVLVFTPTIALAREVAALFEERRPLLEEAGAAVIIRLVARDPDTFLI","LEEKKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYVQRNYDLSFLKTIQEVAGYVLIALNTVERIPLENLQIIRGNMYYENSYALAVLSNYDANKTGLKELPMRNLQEILHGAVRFSNNPALCNVESIQWRDIVSSDFLSNMSMDFQNHLGSCQKCDPSCPNGSCW","SNKTQLGSSG/ELEELLAKKEELLKKLYKELLKKGNVLVDTEYLKTLTEEELKEISKAYISEEEGMIILEFKGTYNGYLVIKHKDVETSEEVREEQKKLAEELKKKLEALGAEVREIEVKVKEEVKTEKEGNITKTTLTLEVEIDGEKVTLKLTEVEVEL","AATAAALEHLEAAAAALKELAALVATEAADAAALKAKAEELAAKVREHLRAARAATGDTSLTDEDIDAFIQRILDAVDDAEAVKALYEELEAAIAAFRAAQEAAA","GKLNIKVTFLSSGKEEKLAALKAHVDALVASIDTKASGAPPLKVEVKESESKETREIDGKTYEYGFTTVTYSFEGTNDILNQLANDIVTHISNTLKDLLIEIDIAATSDGDLNLTINITVNGVDTVILLNVSLTAGTNVNLTINITVTGATVTVHIIVSLTTTSAGSATVTINATAGAGATLNITLMGVFTNTAVKDVTVNVTTTATSGTVTVTLGPVTQASAAEMAAGVAAAREAAREEALREVARLTE","PTLEELGAELRVEEIDGVTVVEIRDRKHGNKLVRVVAQGVSPERVEALAKELFLKYVIEKGYDGVVAAISRSGAVFVLRRNLKTGEVEFEVLEIPDEDDILSIELETYPDGHALAIVKTKEETIVLSILHNNGIIVANAVTVRTERELEVVLLLLEAVRVASKVADIKEFLVVRSPSGLWHLRLRLDFPDGETRFEYEVTVPEEEFVAAARAVVAGLLALLAEAAKEDPAVAPLAEEMAAILARLEALAA","GRPPKNVKVSGVDGNSATISFDLADENDKYILIMIGPANDPNSWTSWWLPHETSYLSISNLPPGAEYQVTFMMVRPGKMSPPITQDFKC","KKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYVQRNYDLSFLKTIQEVAGYVLIALNTVERIPLENLQIIRGNMYYENSYALAVLSNYDANKTGLKELPMRNLQEILHGAVRFSNNPALCNVESIQWRDIVSSDFLSNMSMDFQNHLGSCQKCDPSCPNGSCWGAGEENCQKLTKIICAQQCSGRCRGKSPSDCCHNQCAAGCTGPRESDCLVCRKFRDEATCKDTCPPLMLYNPTTYQM","SAAAAAQARLDAALAALREWLAARAREAIERYRDAKERVVEEEAITRDFHGVLTLEAVRIEVTPTTVAISARLRHASGQTVYLSILAPHDPAALEAALAIAELATRLALEAGYDLFVAVAFEPPGPVTPERWEEFAAFLELVAEDLRALLADAAAKGRPLLVVIVIVVNDDLAAHLPLESHTDPEAAAAAVATYVAEVEAKTGRKLTLPAEIAAALAAGASVVLVVARREDIAGVPARVEAALRAALAAA","LSLKLLEKALSKELADKIITFHLLGLVLEVSKDHPEKPIFDELRERLEELEEELEEHLDLPEEEFNELVDKRLSEFIAEAFSHPAVVDAFLDLLVTLKAMADIRSARAEKADEERRANDPSGKEEEVPDSEELVTLKKASAALDRALDTLLKDPRVREMVERYLRARGVRIPEEALSLPRAEQLRLAFQRLVAREAARMTPAGKTAAEVTDEELAASFAASPNPFARAFAHRFPELAKELAEMQDLHDKL","AAAARLAELEALLAEAEALAKKIRKGEDLKVLEKLRATVEAIAALLAASGVDGPVAELVARLKEIAALLAEAIAKRKEAEKIKKEVEEKEELLKTYKEYNELLKKMKELKKKIAELKKKLEEKRKELKKAIEKEKLPAEVVAKIEKFLEKMEELLKRTEEEAKAWREELDKKEKKFIEDFVKDLSKEMKESKEEEKKEILKKKKEEVKKFVEEWEKENKPKLEEFRKKVEERREELKKLAEEAAKLPNAEVAKLLQEALKLASELIEAIAEYTELYYRRSALGTLVLTEKMFRAVEEGKDPEEKAKEFLELEKKGEVDEELLEKLEKEAEKTKEKRAKAEKLEKEAEKLLKEALALLEEALAAVRALLAAWAAAA]"]

In [9]:
# Defining the MLM Loss Function
BATCH_SIZE = 2
NUM_MASKS = 10
P_MASK = 0.15

# Function to compute MLM loss for a batch of protein pairs
def compute_mlm_loss_batch(pairs):
    avg_losses = []
    for _ in range(NUM_MASKS):
        # Tokenize the concatenated protein pairs
        inputs = tokenizer(pairs, return_tensors="pt", truncation=True, padding=True, max_length=1022)

        # Move input tensors to GPU if available
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Get the mask token ID
        mask_token_id = tokenizer.mask_token_id

        # Clone input IDs for labels
        labels = inputs["input_ids"].clone()

        # Randomly mask 15% of the residues for each sequence in the batch
        for idx in range(inputs["input_ids"].shape[0]):
            mask_indices = np.random.choice(inputs["input_ids"].shape[1], size=int(P_MASK * inputs["input_ids"].shape[1]), replace=False)
            inputs["input_ids"][idx, mask_indices] = mask_token_id
            labels[idx, [i for i in range(inputs["input_ids"].shape[1]) if i not in mask_indices]] = -100

        # Compute the MLM loss
        outputs = model(**inputs, labels=labels)
        avg_losses.append(outputs.loss.item())

    # Return the average loss for the batch
    return sum(avg_losses) / NUM_MASKS

In [10]:
# Constructing the Loss Matrix
# Compute loss matrix
loss_matrix = np.zeros((len(all_proteins), len(all_proteins)))

for i in range(len(all_proteins)):
    for j in range(i+1, len(all_proteins), BATCH_SIZE):  # to avoid self-pairing and use batches
        pairs = [all_proteins[i] + all_proteins[k] for k in range(j, min(j+BATCH_SIZE, len(all_proteins)))]
        batch_loss = compute_mlm_loss_batch(pairs)
        for k in range(len(pairs)):
            loss_matrix[i, j+k] = batch_loss
            loss_matrix[j+k, i] = batch_loss  # the matrix is symmetric

# Set the diagonal of the loss matrix to a large value to prevent self-pairings
np.fill_diagonal(loss_matrix, np.inf)

In [11]:
# Finding Optimal Pairs
# Use the linear assignment problem to find the optimal pairing based on MLM loss
rows, cols = linear_sum_assignment(loss_matrix)
optimal_pairs = list(zip(rows, cols))

print(optimal_pairs)

[(0, 11), (1, 3), (2, 5), (3, 9), (4, 10), (5, 4), (6, 2), (7, 8), (8, 1), (9, 7), (10, 6), (11, 0)]


In [13]:
# Next Steps
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch

# Load the base model and tokenizer
base_model_path = "facebook/esm2_t12_35M_UR50D"
model = AutoModelForMaskedLM.from_pretrained(base_model_path)
tokenizer = AutoTokenizer.from_pretrained(base_model_path)

# Ensure the model is in evaluation mode
model.eval()

# Define the protein of interest and its potential binders
protein_of_interest = "LEEKKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYVQRNYDLSFLKTIQEVAGYVLIALNTVERIPLENLQIIRGNMYYENSYALAVLSNYDANKTGLKELPMRNLQEILHGAVRFSNNPALCNVESIQWRDIVSSDFLSNMSMDFQNHLGSCQKCDPSCPNGSCWGAGEENCQKLTKIICAQQCSGRCRGKSPSDCCHNQCAAGCTGPRESDCLVCRKFRDEATCKDTCPPLMLYNPTTYQMDVNPEGKYSFGATCVKKCPRNYVVTDHGSCVRACGADSYEMEEDGVRKCKKCEGPCRKVCNGIGIGEFKDSLSINATNIKHFKNCTSISGDLHILPVAFRGDSFTHTPPLDPQELDILKTVKEITGFLLIQAWPENRTDLHAFENLEIIRGRTKQHGQFSLAVVSLNITSLGLRSLKEISDGDVIISGNKNLCYANTINWKKLFGTSGQKTKIISNRGENSCKATGQVCHALCSPEGCWGPEPRDCVSHHHHHH"
# Add potential binding sequences here
potential_binders = ["CPPLENIDISGVDGDSATISFEPCREPVDYVVLHYGRAGDPGDWKTYFLPPGDTSFTLTGLEPGGWYRVELWCWRPGRCCEPQTEYFEV","SEEEEERERALKEIIEETRRELKAAKAKHGKVVVVLIMASSTLEPEFILELSKALIKEMKSLFPNVVLIIVVVGLAPASLLARIRDVSLELAKYAKSLGIKVIVIVGNENEAVFVPAFEALGVEVIVDRTIIEIAAEELGLSEEEVLARFAAAAELLDELFAADPSLRERYARLDVAGATELLLERLRELFGAKVERHERLITVEVERVLTPDERRRVTAILLTPEAAREVVERLVDLVVDLILEKIAEGHNVLVLVFTPTIALAREVAALFEERRPLLEEAGAAVIIRLVARDPDTFLI","LEEKKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYVQRNYDLSFLKTIQEVAGYVLIALNTVERIPLENLQIIRGNMYYENSYALAVLSNYDANKTGLKELPMRNLQEILHGAVRFSNNPALCNVESIQWRDIVSSDFLSNMSMDFQNHLGSCQKCDPSCPNGSCW","SNKTQLGSSG/ELEELLAKKEELLKKLYKELLKKGNVLVDTEYLKTLTEEELKEISKAYISEEEGMIILEFKGTYNGYLVIKHKDVETSEEVREEQKKLAEELKKKLEALGAEVREIEVKVKEEVKTEKEGNITKTTLTLEVEIDGEKVTLKLTEVEVEL","AATAAALEHLEAAAAALKELAALVATEAADAAALKAKAEELAAKVREHLRAARAATGDTSLTDEDIDAFIQRILDAVDDAEAVKALYEELEAAIAAFRAAQEAAA","GKLNIKVTFLSSGKEEKLAALKAHVDALVASIDTKASGAPPLKVEVKESESKETREIDGKTYEYGFTTVTYSFEGTNDILNQLANDIVTHISNTLKDLLIEIDIAATSDGDLNLTINITVNGVDTVILLNVSLTAGTNVNLTINITVTGATVTVHIIVSLTTTSAGSATVTINATAGAGATLNITLMGVFTNTAVKDVTVNVTTTATSGTVTVTLGPVTQASAAEMAAGVAAAREAAREEALREVARLTE","PTLEELGAELRVEEIDGVTVVEIRDRKHGNKLVRVVAQGVSPERVEALAKELFLKYVIEKGYDGVVAAISRSGAVFVLRRNLKTGEVEFEVLEIPDEDDILSIELETYPDGHALAIVKTKEETIVLSILHNNGIIVANAVTVRTERELEVVLLLLEAVRVASKVADIKEFLVVRSPSGLWHLRLRLDFPDGETRFEYEVTVPEEEFVAAARAVVAGLLALLAEAAKEDPAVAPLAEEMAAILARLEALAA","GRPPKNVKVSGVDGNSATISFDLADENDKYILIMIGPANDPNSWTSWWLPHETSYLSISNLPPGAEYQVTFMMVRPGKMSPPITQDFKC","KKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYVQRNYDLSFLKTIQEVAGYVLIALNTVERIPLENLQIIRGNMYYENSYALAVLSNYDANKTGLKELPMRNLQEILHGAVRFSNNPALCNVESIQWRDIVSSDFLSNMSMDFQNHLGSCQKCDPSCPNGSCWGAGEENCQKLTKIICAQQCSGRCRGKSPSDCCHNQCAAGCTGPRESDCLVCRKFRDEATCKDTCPPLMLYNPTTYQM","SAAAAAQARLDAALAALREWLAARAREAIERYRDAKERVVEEEAITRDFHGVLTLEAVRIEVTPTTVAISARLRHASGQTVYLSILAPHDPAALEAALAIAELATRLALEAGYDLFVAVAFEPPGPVTPERWEEFAAFLELVAEDLRALLADAAAKGRPLLVVIVIVVNDDLAAHLPLESHTDPEAAAAAVATYVAEVEAKTGRKLTLPAEIAAALAAGASVVLVVARREDIAGVPARVEAALRAALAAA","LSLKLLEKALSKELADKIITFHLLGLVLEVSKDHPEKPIFDELRERLEELEEELEEHLDLPEEEFNELVDKRLSEFIAEAFSHPAVVDAFLDLLVTLKAMADIRSARAEKADEERRANDPSGKEEEVPDSEELVTLKKASAALDRALDTLLKDPRVREMVERYLRARGVRIPEEALSLPRAEQLRLAFQRLVAREAARMTPAGKTAAEVTDEELAASFAASPNPFARAFAHRFPELAKELAEMQDLHDKL","AAAARLAELEALLAEAEALAKKIRKGEDLKVLEKLRATVEAIAALLAASGVDGPVAELVARLKEIAALLAEAIAKRKEAEKIKKEVEEKEELLKTYKEYNELLKKMKELKKKIAELKKKLEEKRKELKKAIEKEKLPAEVVAKIEKFLEKMEELLKRTEEEAKAWREELDKKEKKFIEDFVKDLSKEMKESKEEEKKEILKKKKEEVKKFVEEWEKENKPKLEEFRKKVEERREELKKLAEEAAKLPNAEVAKLLQEALKLASELIEAIAEYTELYYRRSALGTLVLTEKMFRAVEEGKDPEEKAKEFLELEKKGEVDEELLEKLEKEAEKTKEKRAKAEKLEKEAEKLLKEALALLEEALAAVRALLAAWAAAA]"]

def compute_mlm_loss(protein, binder, iterations=3):
    total_loss = 0.0

    for _ in range(iterations):
        # Concatenate protein sequences with a separator
        concatenated_sequence = protein + ":" + binder

        # Mask a subset of amino acids in the concatenated sequence (excluding the separator)
        tokens = list(concatenated_sequence)
        mask_rate = 0.15  # For instance, masking 15% of the sequence
        num_mask = int(len(tokens) * mask_rate)

        # Exclude the separator from potential mask indices
        available_indices = [i for i, token in enumerate(tokens) if token != ":"]
        probs = torch.ones(len(available_indices))
        mask_indices = torch.multinomial(probs, num_mask, replacement=False)

        for idx in mask_indices:
            tokens[available_indices[idx]] = tokenizer.mask_token

        masked_sequence = "".join(tokens)
        inputs = tokenizer(masked_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')

        # Compute the MLM loss
        with torch.no_grad():
            outputs = model(**inputs, labels=inputs["input_ids"])
            loss = outputs.loss

        total_loss += loss.item()

    # Return the average loss
    return total_loss / iterations

# Compute MLM loss for each potential binder
mlm_losses = {}
for binder in potential_binders:
    loss = compute_mlm_loss(protein_of_interest, binder)
    mlm_losses[binder] = loss

# Rank binders based on MLM loss
ranked_binders = sorted(mlm_losses, key=mlm_losses.get)

print("Ranking of Potential Binders:")
for idx, binder in enumerate(ranked_binders, 1):
    print(f"{idx}. {binder} - MLM Loss: {mlm_losses[binder]}")

Ranking of Potential Binders:
1. AAAARLAELEALLAEAEALAKKIRKGEDLKVLEKLRATVEAIAALLAASGVDGPVAELVARLKEIAALLAEAIAKRKEAEKIKKEVEEKEELLKTYKEYNELLKKMKELKKKIAELKKKLEEKRKELKKAIEKEKLPAEVVAKIEKFLEKMEELLKRTEEEAKAWREELDKKEKKFIEDFVKDLSKEMKESKEEEKKEILKKKKEEVKKFVEEWEKENKPKLEEFRKKVEERREELKKLAEEAAKLPNAEVAKLLQEALKLASELIEAIAEYTELYYRRSALGTLVLTEKMFRAVEEGKDPEEKAKEFLELEKKGEVDEELLEKLEKEAEKTKEKRAKAEKLEKEAEKLLKEALALLEEALAAVRALLAAWAAAA] - MLM Loss: 6.426949183146159
2. KKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYVQRNYDLSFLKTIQEVAGYVLIALNTVERIPLENLQIIRGNMYYENSYALAVLSNYDANKTGLKELPMRNLQEILHGAVRFSNNPALCNVESIQWRDIVSSDFLSNMSMDFQNHLGSCQKCDPSCPNGSCWGAGEENCQKLTKIICAQQCSGRCRGKSPSDCCHNQCAAGCTGPRESDCLVCRKFRDEATCKDTCPPLMLYNPTTYQM - MLM Loss: 7.616223017374675
3. SEEEEERERALKEIIEETRRELKAAKAKHGKVVVVLIMASSTLEPEFILELSKALIKEMKSLFPNVVLIIVVVGLAPASLLARIRDVSLELAKYAKSLGIKVIVIVGNENEAVFVPAFEALGVEVIVDRTIIEIAAEELGLSEEEVLARFAAAAELLDELFAADPSLRERYARLDVAGATELLLERLRELFGAKVERHERLITVEVERVLTPDERRRVTAILLTPEAAREVVERLVDLVVDLILEKIAEGHNVLVLVFTPTIALAREVAALFE

In [14]:
# Building a PPI Network
import networkx as nx
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import plotly.graph_objects as go
from ipywidgets import interact
from ipywidgets import widgets

# Check if CUDA is available and set the default device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pretrained (or fine-tuned) ESM-2 model and tokenizer
model_name = "facebook/esm2_t6_8M_UR50D"  # You can change this to your fine-tuned model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)

# Send the model to the device (GPU or CPU)
model.to(device)

# Ensure the model is in evaluation mode
model.eval()

# Define Protein Sequences (Replace with your list)
all_proteins = ["CPPLENIDISGVDGDSATISFEPCREPVDYVVLHYGRAGDPGDWKTYFLPPGDTSFTLTGLEPGGWYRVELWCWRPGRCCEPQTEYFEV","SEEEEERERALKEIIEETRRELKAAKAKHGKVVVVLIMASSTLEPEFILELSKALIKEMKSLFPNVVLIIVVVGLAPASLLARIRDVSLELAKYAKSLGIKVIVIVGNENEAVFVPAFEALGVEVIVDRTIIEIAAEELGLSEEEVLARFAAAAELLDELFAADPSLRERYARLDVAGATELLLERLRELFGAKVERHERLITVEVERVLTPDERRRVTAILLTPEAAREVVERLVDLVVDLILEKIAEGHNVLVLVFTPTIALAREVAALFEERRPLLEEAGAAVIIRLVARDPDTFLI","LEEKKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYVQRNYDLSFLKTIQEVAGYVLIALNTVERIPLENLQIIRGNMYYENSYALAVLSNYDANKTGLKELPMRNLQEILHGAVRFSNNPALCNVESIQWRDIVSSDFLSNMSMDFQNHLGSCQKCDPSCPNGSCW","SNKTQLGSSG/ELEELLAKKEELLKKLYKELLKKGNVLVDTEYLKTLTEEELKEISKAYISEEEGMIILEFKGTYNGYLVIKHKDVETSEEVREEQKKLAEELKKKLEALGAEVREIEVKVKEEVKTEKEGNITKTTLTLEVEIDGEKVTLKLTEVEVEL","AATAAALEHLEAAAAALKELAALVATEAADAAALKAKAEELAAKVREHLRAARAATGDTSLTDEDIDAFIQRILDAVDDAEAVKALYEELEAAIAAFRAAQEAAA","GKLNIKVTFLSSGKEEKLAALKAHVDALVASIDTKASGAPPLKVEVKESESKETREIDGKTYEYGFTTVTYSFEGTNDILNQLANDIVTHISNTLKDLLIEIDIAATSDGDLNLTINITVNGVDTVILLNVSLTAGTNVNLTINITVTGATVTVHIIVSLTTTSAGSATVTINATAGAGATLNITLMGVFTNTAVKDVTVNVTTTATSGTVTVTLGPVTQASAAEMAAGVAAAREAAREEALREVARLTE","PTLEELGAELRVEEIDGVTVVEIRDRKHGNKLVRVVAQGVSPERVEALAKELFLKYVIEKGYDGVVAAISRSGAVFVLRRNLKTGEVEFEVLEIPDEDDILSIELETYPDGHALAIVKTKEETIVLSILHNNGIIVANAVTVRTERELEVVLLLLEAVRVASKVADIKEFLVVRSPSGLWHLRLRLDFPDGETRFEYEVTVPEEEFVAAARAVVAGLLALLAEAAKEDPAVAPLAEEMAAILARLEALAA","GRPPKNVKVSGVDGNSATISFDLADENDKYILIMIGPANDPNSWTSWWLPHETSYLSISNLPPGAEYQVTFMMVRPGKMSPPITQDFKC","KKVCQGTSNKLTQLGTFEDHFLSLQRMFNNCEVVLGNLEITYVQRNYDLSFLKTIQEVAGYVLIALNTVERIPLENLQIIRGNMYYENSYALAVLSNYDANKTGLKELPMRNLQEILHGAVRFSNNPALCNVESIQWRDIVSSDFLSNMSMDFQNHLGSCQKCDPSCPNGSCWGAGEENCQKLTKIICAQQCSGRCRGKSPSDCCHNQCAAGCTGPRESDCLVCRKFRDEATCKDTCPPLMLYNPTTYQM","SAAAAAQARLDAALAALREWLAARAREAIERYRDAKERVVEEEAITRDFHGVLTLEAVRIEVTPTTVAISARLRHASGQTVYLSILAPHDPAALEAALAIAELATRLALEAGYDLFVAVAFEPPGPVTPERWEEFAAFLELVAEDLRALLADAAAKGRPLLVVIVIVVNDDLAAHLPLESHTDPEAAAAAVATYVAEVEAKTGRKLTLPAEIAAALAAGASVVLVVARREDIAGVPARVEAALRAALAAA","LSLKLLEKALSKELADKIITFHLLGLVLEVSKDHPEKPIFDELRERLEELEEELEEHLDLPEEEFNELVDKRLSEFIAEAFSHPAVVDAFLDLLVTLKAMADIRSARAEKADEERRANDPSGKEEEVPDSEELVTLKKASAALDRALDTLLKDPRVREMVERYLRARGVRIPEEALSLPRAEQLRLAFQRLVAREAARMTPAGKTAAEVTDEELAASFAASPNPFARAFAHRFPELAKELAEMQDLHDKL","AAAARLAELEALLAEAEALAKKIRKGEDLKVLEKLRATVEAIAALLAASGVDGPVAELVARLKEIAALLAEAIAKRKEAEKIKKEVEEKEELLKTYKEYNELLKKMKELKKKIAELKKKLEEKRKELKKAIEKEKLPAEVVAKIEKFLEKMEELLKRTEEEAKAWREELDKKEKKFIEDFVKDLSKEMKESKEEEKKEILKKKKEEVKKFVEEWEKENKPKLEEFRKKVEERREELKKLAEEAAKLPNAEVAKLLQEALKLASELIEAIAEYTELYYRRSALGTLVLTEKMFRAVEEGKDPEEKAKEFLELEKKGEVDEELLEKLEKEAEKTKEKRAKAEKLEKEAEKLLKEALALLEEALAAVRALLAAWAAAA]"]


def compute_average_mlm_loss(protein1, protein2, iterations=10):
    total_loss = 0.0
    connector = "G" * 25  # Connector sequence of G's
    for _ in range(iterations):
        concatenated_sequence = protein1 + connector + protein2
        inputs = tokenizer(concatenated_sequence, return_tensors="pt", padding=True, truncation=True, max_length=1024)

        mask_prob = 0.55
        mask_indices = torch.rand(inputs["input_ids"].shape, device=device) < mask_prob

        # Locate the positions of the connector 'G's and set their mask indices to False
        connector_indices = tokenizer.encode(connector, add_special_tokens=False)
        connector_length = len(connector_indices)
        start_connector = len(tokenizer.encode(protein1, add_special_tokens=False))
        end_connector = start_connector + connector_length

        # Avoid masking the connector 'G's
        mask_indices[0, start_connector:end_connector] = False

        # Apply the mask to the input IDs
        inputs["input_ids"][mask_indices] = tokenizer.mask_token_id
        inputs = {k: v.to(device) for k, v in inputs.items()}  # Send inputs to the device

        with torch.no_grad():
            outputs = model(**inputs, labels=inputs["input_ids"])

        loss = outputs.loss
        total_loss += loss.item()

    return total_loss / iterations

# Compute all average losses to determine the maximum threshold for the slider
all_losses = []
for i, protein1 in enumerate(all_proteins):
    for j, protein2 in enumerate(all_proteins[i+1:], start=i+1):
        avg_loss = compute_average_mlm_loss(protein1, protein2)
        all_losses.append(avg_loss)

# Set the maximum threshold to the maximum loss computed
max_threshold = max(all_losses)
print(f"Maximum loss (maximum threshold for slider): {max_threshold}")

def plot_graph(threshold):
    G = nx.Graph()

    # Add all protein nodes to the graph
    for i, protein in enumerate(all_proteins):
        G.add_node(f"protein {i+1}")

    # Loop through all pairs of proteins and calculate average MLM loss
    loss_idx = 0  # Index to keep track of the position in the all_losses list
    for i, protein1 in enumerate(all_proteins):
        for j, protein2 in enumerate(all_proteins[i+1:], start=i+1):
            avg_loss = all_losses[loss_idx]
            loss_idx += 1

            # Add an edge if the loss is below the threshold
            if avg_loss < threshold:
                G.add_edge(f"protein {i+1}", f"protein {j+1}", weight=round(avg_loss, 3))

    # 3D Network Plot
    # Adjust the k parameter to bring nodes closer. This might require some experimentation to find the right value.
    k_value = 2  # Lower value will bring nodes closer together
    pos = nx.spring_layout(G, dim=3, seed=42, k=k_value)

    edge_x = []
    edge_y = []
    edge_z = []
    for edge in G.edges():
        x0, y0, z0 = pos[edge[0]]
        x1, y1, z1 = pos[edge[1]]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])
        edge_z.extend([z0, z1, None])

    edge_trace = go.Scatter3d(x=edge_x, y=edge_y, z=edge_z, mode='lines', line=dict(width=0.5, color='grey'))

    node_x = []
    node_y = []
    node_z = []
    node_text = []
    for node in G.nodes():
        x, y, z = pos[node]
        node_x.append(x)
        node_y.append(y)
        node_z.append(z)
        node_text.append(node)

    node_trace = go.Scatter3d(x=node_x, y=node_y, z=node_z, mode='markers', marker=dict(size=5), hoverinfo='text', hovertext=node_text)

    layout = go.Layout(title='Protein Interaction Graph', title_x=0.5, scene=dict(xaxis=dict(showbackground=False), yaxis=dict(showbackground=False), zaxis=dict(showbackground=False)))

    fig = go.Figure(data=[edge_trace, node_trace], layout=layout)
    fig.show()

# Create an interactive slider for the threshold value with a default of 8.50
interact(plot_graph, threshold=widgets.FloatSlider(min=0.0, max=max_threshold, step=0.05, value=8.25))

Maximum loss (maximum threshold for slider): 9.211536598205566


interactive(children=(FloatSlider(value=8.25, description='threshold', max=9.211536598205566, step=0.05), Outp…