<a href="https://colab.research.google.com/github/katyachemistry/PLI_prediction/blob/main/Make_prediction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Dependencies and example data

In [6]:
!pip install rdkit-pypi
import pandas as pd
import numpy as np
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import h5py
from rdkit.Chem import AllChem
from rdkit import Chem
from IPython.display import clear_output

In [None]:
# Load example protein embeddings
!gdown 1qLNO7zeSA_0Hw5htwl3q85PSXP-cm2-W

# Make prediction

In [28]:
#@markdown - SMILES string of your molecule. We use ATP as an example
SMILES = "C1=NC(=C2C(=N1)N(C=N2)C3C(C(C(O3)COP(=O)(O)OP(=O)(O)OP(=O)(O)O)O)O)N" #@param {type:"string"}
#@markdown - path to the obtained ProtT5 embeddings. For details see README of [this repo](https://github.com/katyachemistry/PLI_prediction). As an example, we use [this protein's](https://www.uniprot.org/uniprotkb/P0A7A9/) embedding
path_to_ProtT5_h5 = "./ppa_prott5_embedding.h5" #@param {type:"string"}
#@markdown - your protein name, as it was written in FASTA-file you submitted to ProtT5 notebook
protein_name = "ppa" #@param {type:"string"}

mol = Chem.MolFromSmiles(SMILES)
fpts =  AllChem.GetMorganFingerprintAsBitVect(mol,2,1024)
mfpts = torch.tensor(fpts).to(torch.float32).view(1, 1024)

!gdown 1J0Ve8cw-DZBgTBs2CLVWFnoVSlgJdVay

class InteractionClassifier_ProtT5_based(nn.Module):
    '''
    Interaction/non-interaction classification model for using ProtT5 protein embeddings

    Args:
        input_size_protein (int): Size of the input feature vector for proteins.
        input_size_molecule (int): Size of the input feature vector for molecules.
        fc1_layer_size_factor (int): Factor to reduce the size of the first fully connected layer.
        fc2_layer_size_factor (int): Factor to reduce the size of the second fully connected layer.
        dropout_rate (float): Dropout rate to apply after each layer. Default is 0.

    Attributes:
        protein_fc1 (nn.Linear): First fully connected layer for protein features.
        protein_fc2 (nn.Linear): Second fully connected layer for protein features.
        molecule_fc1 (nn.Linear): First fully connected layer for molecule features.
        molecule_fc2 (nn.Linear): Second fully connected layer for molecule features.
        dropout (nn.Dropout): Dropout layer.
        fc1 (nn.Linear): Fully connected layer combining protein and molecule features.
        fc2 (nn.Linear): Output layer.
        norm_prot1 (nn.BatchNorm1d): Batch normalization for the first protein layer.
        norm_prot2 (nn.BatchNorm1d): Batch normalization for the second protein layer.
        norm_mol1 (nn.BatchNorm1d): Batch normalization for the first molecule layer.
        norm_mol2 (nn.BatchNorm1d): Batch normalization for the second molecule layer.
        norm_all (nn.BatchNorm1d): Batch normalization for the combined features layer.
    '''

    def __init__(self, input_size_protein, input_size_molecule, fc1_layer_size_factor, fc2_layer_size_factor, dropout_rate=0):
        super().__init__()

        output_size_protein_1 = int(input_size_protein / fc1_layer_size_factor)
        self.protein_fc1 = nn.Linear(input_size_protein, output_size_protein_1)

        output_size_protein_2 = int(output_size_protein_1 / fc2_layer_size_factor)
        self.protein_fc2 = nn.Linear(output_size_protein_1, output_size_protein_2)

        output_size_molecule_1 = int(input_size_molecule / fc1_layer_size_factor)
        self.molecule_fc1 = nn.Linear(input_size_molecule, output_size_molecule_1)

        output_size_molecule_2 = int(output_size_molecule_1 / fc2_layer_size_factor)
        self.molecule_fc2 = nn.Linear(output_size_molecule_1, output_size_molecule_2)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout_rate)

        self.fc1 = nn.Linear(output_size_protein_2 + output_size_molecule_2, 64)
        self.fc2 = nn.Linear(64, 1)

        self.norm_prot1 = nn.BatchNorm1d(output_size_protein_1)
        self.norm_prot2 = nn.BatchNorm1d(output_size_protein_2)
        self.norm_mol1 = nn.BatchNorm1d(output_size_molecule_1)
        self.norm_mol2 = nn.BatchNorm1d(output_size_molecule_2)
        self.norm_all = nn.BatchNorm1d(64)

    def forward(self, protein, molecule):
        molecule = molecule.view(molecule.size(0), -1).to(torch.float32)

        protein = self.relu(self.norm_prot1(self.protein_fc1(protein)))
        protein = self.dropout(protein)
        protein = self.relu(self.norm_prot2(self.protein_fc2(protein)))
        protein = self.dropout(protein)

        molecule = self.relu(self.norm_mol1(self.molecule_fc1(molecule)))
        molecule = self.dropout(molecule)
        molecule = self.relu(self.norm_mol2(self.molecule_fc2(molecule)))
        molecule = self.dropout(molecule)

        combined = torch.cat((protein, molecule), dim=1)

        x = self.relu(self.norm_all(self.fc1(combined)))
        x = self.dropout(x)

        x = self.fc2(x)

        return x


checkpoint_path = 'ProtT5_Morgan7.ckpt'
checkpoint = torch.load(checkpoint_path)
model = InteractionClassifier_ProtT5_based(1024, 1024, 2, 2, 0)
new_state_dict = {}
for key in list(checkpoint['state_dict'].keys())[:-1]:
  new_state_dict[key[6:]] = checkpoint['state_dict'][key]
model.load_state_dict(new_state_dict)
model.eval()
clear_output()

protein = torch.tensor(h5py.File(path_to_ProtT5_h5,'r')[protein_name][:]).to(torch.float32).view(1, 1024)


with torch.no_grad():
  proba = float(nn.functional.sigmoid(model(protein, mfpts))[0][0])
print(f'The probability of this interaction is {proba:.3f}')


The probability of this interaction is 0.961
