<a href="https://colab.research.google.com/github/katyachemistry/PLI_prediction/blob/main/Make_prediction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Acknowledgements


To create protein embeddings this notebook uses code from the following source:

ProtTrans. Available at: [https://github.com/agemagician/ProtTrans](https://github.com/agemagician/ProtTrans).

A. Elnaggar et al., "ProtTrans: Toward Understanding the Language of Life Through Self-Supervised Learning," doi: 10.1109/TPAMI.2021.3095381.



# Load dependencies and ProtT5 model

In [None]:
!pip install rdkit-pypi
import pandas as pd
import sentencepiece
import numpy as np
import transformers
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import h5py
from rdkit.Chem import AllChem
from rdkit import Chem
from IPython.display import clear_output

Collecting rdkit-pypi
  Downloading rdkit_pypi-2022.9.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (29.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.4/29.4 MB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: rdkit-pypi
Successfully installed rdkit-pypi-2022.9.5


In [None]:
!mkdir protT5 # root directory for storing checkpoints, results etc
!mkdir protT5/protT5_checkpoint # directory holding the ProtT5 checkpoint
!mkdir protT5/output # directory for storing your embeddings & predictions

In [None]:
from transformers import T5EncoderModel, T5Tokenizer
import time
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
per_residue = False
# --> only one 1024-d vector per protein, irrespective of its length
per_protein = True

sec_struct = False

In [None]:
# Load ProtT5 in half-precision (more specifically: the encoder-part of ProtT5-XL-U50)
def get_T5_model():
    model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
    model = model.to(device) # move model to GPU
    model = model.eval() # set model to evaluation model
    tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)

    return model, tokenizer

prott5_model, tokenizer = get_T5_model()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/656 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.42G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/238k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/1.79k [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


# Make prediction

In [None]:
#@markdown - SMILES string of your molecule. Citric acid SMILES is given as an example:
SMILES = "C(C(=O)O)C(CC(=O)O)(C(=O)O)O" #@param {type:"string"}
#@markdown - Aminoacid sequence of your protein (only one chain is available)
sequence = 'MSALLRLLRTGAPAAACLRLGTSAGTGSRRAMALYHTEERGQPCSQNYRLFFKNVTGHYISPFHDIPLKVNSKEENGIPMKKARNDEYENLFNMIVEIPRWTNAKMEIATKEPMNPIKQYVKDGKLRYVANIFPYKGYIWNYGTLPQTWEDPHEKDKSTNCFGDNDPIDVCEIGSKILSCGEVIHVKILGILALIDEGETDWKLIAINANDPEASKFHDIDDVKKFKPGYLEATLNWFRLYKVPDGKPENQFAFNGEFKNKAFALEVIKSTHQCWKALLMKKCNGGAINCTNVQISDSPFRCTQEEARSLVESVSSSPNKESNEEEQVWHFLGK' #@param {type:"string"}
#@markdown - Protein name
protein_name = "ppa2" #@param {type:"string"}

protein_seq_dict = {protein_name: sequence}


# Generate embeddings via batch-processing
# per_residue indicates that embeddings for each residue in a protein should be returned.
# per_protein indicates that embeddings for a whole protein should be returned (average-pooling)
# max_residues gives the upper limit of residues within one batch
# max_seq_len gives the upper sequences length for applying batch-processing
# max_batch gives the upper number of sequences per batch
def get_embeddings( model, tokenizer, seqs, per_residue, per_protein, sec_struct,
                   max_residues=4000, max_seq_len=1000, max_batch=100 ):


    results = {"protein_embs" : dict()}

    # sort sequences according to length (reduces unnecessary padding --> speeds up embedding)
    seq_dict   = sorted( seqs.items(), key=lambda kv: len( seqs[kv[0]] ), reverse=True )
    start = time.time()
    batch = list()
    for seq_idx, (pdb_id, seq) in enumerate(seq_dict,1):
        seq = seq
        seq_len = len(seq)
        seq = ' '.join(list(seq))
        batch.append((pdb_id,seq,seq_len))

        # count residues in current batch and add the last sequence length to
        # avoid that batches with (n_res_batch > max_residues) get processed
        n_res_batch = sum([ s_len for  _, _, s_len in batch ]) + seq_len
        if len(batch) >= max_batch or n_res_batch>=max_residues or seq_idx==len(seq_dict) or seq_len>max_seq_len:
            pdb_ids, seqs, seq_lens = zip(*batch)
            batch = list()

            # add_special_tokens adds extra token at the end of each sequence
            token_encoding = tokenizer.batch_encode_plus(seqs, add_special_tokens=True, padding="longest")
            input_ids      = torch.tensor(token_encoding['input_ids']).to(device)
            attention_mask = torch.tensor(token_encoding['attention_mask']).to(device)

            try:
                with torch.no_grad():
                    # returns: ( batch-size x max_seq_len_in_minibatch x embedding_dim )
                    embedding_repr = model(input_ids, attention_mask=attention_mask)
            except RuntimeError:
                print("RuntimeError during embedding for {} (L={})".format(pdb_id, seq_len))
                continue


            for batch_idx, identifier in enumerate(pdb_ids): # for each protein in the current mini-batch
                s_len = seq_lens[batch_idx]
                # slice off padding --> batch-size x seq_len x embedding_dim
                emb = embedding_repr.last_hidden_state[batch_idx,:s_len]

                if per_protein: # apply average-pooling to derive per-protein embeddings (1024-d)
                    protein_emb = emb.mean(dim=0)
                    results["protein_embs"][identifier] = protein_emb.detach().cpu().numpy().squeeze()


    passed_time=time.time()-start
    avg_time = passed_time/len(results["residue_embs"]) if per_residue else passed_time/len(results["protein_embs"])
    return results




!gdown 1J0Ve8cw-DZBgTBs2CLVWFnoVSlgJdVay

class InteractionClassifier_ProtT5_based(nn.Module):
    '''
    Interaction/non-interaction classification model for using ProtT5 protein embeddings

    Args:
        input_size_protein (int): Size of the input feature vector for proteins.
        input_size_molecule (int): Size of the input feature vector for molecules.
        fc1_layer_size_factor (int): Factor to reduce the size of the first fully connected layer.
        fc2_layer_size_factor (int): Factor to reduce the size of the second fully connected layer.
        dropout_rate (float): Dropout rate to apply after each layer. Default is 0.

    Attributes:
        protein_fc1 (nn.Linear): First fully connected layer for protein features.
        protein_fc2 (nn.Linear): Second fully connected layer for protein features.
        molecule_fc1 (nn.Linear): First fully connected layer for molecule features.
        molecule_fc2 (nn.Linear): Second fully connected layer for molecule features.
        dropout (nn.Dropout): Dropout layer.
        fc1 (nn.Linear): Fully connected layer combining protein and molecule features.
        fc2 (nn.Linear): Output layer.
        norm_prot1 (nn.BatchNorm1d): Batch normalization for the first protein layer.
        norm_prot2 (nn.BatchNorm1d): Batch normalization for the second protein layer.
        norm_mol1 (nn.BatchNorm1d): Batch normalization for the first molecule layer.
        norm_mol2 (nn.BatchNorm1d): Batch normalization for the second molecule layer.
        norm_all (nn.BatchNorm1d): Batch normalization for the combined features layer.
    '''

    def __init__(self, input_size_protein, input_size_molecule, fc1_layer_size_factor, fc2_layer_size_factor, dropout_rate=0):
        super().__init__()

        output_size_protein_1 = int(input_size_protein / fc1_layer_size_factor)
        self.protein_fc1 = nn.Linear(input_size_protein, output_size_protein_1)

        output_size_protein_2 = int(output_size_protein_1 / fc2_layer_size_factor)
        self.protein_fc2 = nn.Linear(output_size_protein_1, output_size_protein_2)

        output_size_molecule_1 = int(input_size_molecule / fc1_layer_size_factor)
        self.molecule_fc1 = nn.Linear(input_size_molecule, output_size_molecule_1)

        output_size_molecule_2 = int(output_size_molecule_1 / fc2_layer_size_factor)
        self.molecule_fc2 = nn.Linear(output_size_molecule_1, output_size_molecule_2)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout_rate)

        self.fc1 = nn.Linear(output_size_protein_2 + output_size_molecule_2, 64)
        self.fc2 = nn.Linear(64, 1)

        self.norm_prot1 = nn.BatchNorm1d(output_size_protein_1)
        self.norm_prot2 = nn.BatchNorm1d(output_size_protein_2)
        self.norm_mol1 = nn.BatchNorm1d(output_size_molecule_1)
        self.norm_mol2 = nn.BatchNorm1d(output_size_molecule_2)
        self.norm_all = nn.BatchNorm1d(64)

    def forward(self, protein, molecule):
        molecule = molecule.view(molecule.size(0), -1).to(torch.float32)

        protein = self.relu(self.norm_prot1(self.protein_fc1(protein)))
        protein = self.dropout(protein)
        protein = self.relu(self.norm_prot2(self.protein_fc2(protein)))
        protein = self.dropout(protein)

        molecule = self.relu(self.norm_mol1(self.molecule_fc1(molecule)))
        molecule = self.dropout(molecule)
        molecule = self.relu(self.norm_mol2(self.molecule_fc2(molecule)))
        molecule = self.dropout(molecule)

        combined = torch.cat((protein, molecule), dim=1)

        x = self.relu(self.norm_all(self.fc1(combined)))
        x = self.dropout(x)

        x = self.fc2(x)

        return x


checkpoint_path = 'ProtT5_Morgan7.ckpt'
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
model = InteractionClassifier_ProtT5_based(1024, 1024, 2, 2, 0)
new_state_dict = {}
for key in list(checkpoint['state_dict'].keys())[:-1]:
  new_state_dict[key[6:]] = checkpoint['state_dict'][key]
model.load_state_dict(new_state_dict)
model.eval()
clear_output()


results = get_embeddings( prott5_model, tokenizer, protein_seq_dict,
                         per_residue, per_protein, sec_struct)

# protein = torch.tensor(h5py.File(path_to_ProtT5_h5,'r')[protein_name][:]).to(torch.float32).view(1, 1024)
print('\n############# Generating embedding #############')
protein = torch.tensor(results["protein_embs"][protein_name][:]).to(torch.float32).view(1, 1024)

print('\n############# Generating Morgan Fingerprint #############')
mol = Chem.MolFromSmiles(SMILES)
fpts =  AllChem.GetMorganFingerprintAsBitVect(mol,2,1024)
mfpts = torch.tensor(fpts).to(torch.float32).view(1, 1024)

print('\n############# Making prediction #############')

with torch.no_grad():
  proba = float(nn.functional.sigmoid(model(protein, mfpts))[0][0])
print(f'The probability of this interaction is {proba:.3f}')



############# Generating embedding #############

############# Generating Morgan Fingerprint #############

############# Making prediction #############
The probability of this interaction is 0.912
