In [26]:
!pip install --upgrade --no-cache-dir biopython
!pip install rdkit-pypi
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-2.2.0+cu118.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-2.2.0+cu118.html
!pip install -q torch-geometric
!pip install fair-esm


Collecting fair-esm
  Downloading fair_esm-2.0.0-py3-none-any.whl.metadata (37 kB)
Downloading fair_esm-2.0.0-py3-none-any.whl (93 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fair-esm
Successfully installed fair-esm-2.0.0


In [None]:
import numpy as np
import pandas as pd 
from Bio.Align import substitution_matrices
from rdkit import Chem
from tqdm import tqdm
from joblib import Parallel, delayed
import pickle
import esm
import torch




In [None]:
def one_hot_encode(value, valid_values):
    if value not in valid_values:
        value = valid_values[-1]
    return [value == item for item in valid_values]


def get_atom_features(atom):
    atom_symbols = [
        'C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca',
        'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn',
        'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au',
        'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb', 'X'
    ]
    degrees = list(range(11))
    hydrogen_counts = list(range(11))
    valences = list(range(11))

    features = (
        one_hot_encode(atom.GetSymbol(), atom_symbols) +
        one_hot_encode(atom.GetDegree(), degrees) +
        one_hot_encode(atom.GetTotalNumHs(), hydrogen_counts) +
        one_hot_encode(atom.GetImplicitValence(), valences) +
        [atom.GetIsAromatic()]
    )

    return np.array(features)


In [None]:
def smile_graph(smile):
    nodes=[]
    edges=[]
    edges_type=[]
    mol = Chem.MolFromSmiles(smile)
    mol_size = mol.GetNumAtoms()
    for atom in mol.GetAtoms():
        nodes.append(get_atom_features(atom))
    
    for bond in mol.GetBonds():
        start = bond.GetBeginAtomIdx()
        end = bond.GetEndAtomIdx()
        bond_type = bond.GetBondTypeAsDouble()

        # Since molecular graphs are undirected, add both directions
        edges.append([start, end])
        edges.append([end, start])

        edges_type.append(bond_type)
        edges_type.append(bond_type)
        

    return mol_size,nodes,edges,edges_type

    # for atom in mol.GetAtoms():
    #     print(atom.GetIdx(), atom.GetSymbol())
    # return  Draw.MolToImage(mol)
    
   
    

In [None]:
def esm_model(model,alphabet,seq):
   
    batch_converter = alphabet.get_batch_converter()
    batch_labels, batch_strs, batch_tokens = batch_converter([("protein", seq)])
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[16], return_contacts=True)
    
    contact_map = results["contacts"]  # Shape: [1, L, L]
    
   
    return contact_map

In [None]:
def split_sequence(seq, window_size=1000, stride=500):
    windows = []
    for start in range(0, len(seq), stride):
        end = min(start + window_size, len(seq))
        if end - start < 2:  # skip too-short fragments
            break
        windows.append((start, seq[start:end]))
        if end == len(seq):
            break
    return windows

In [None]:
def protein_graph(model, alphabet, seq, threshold=0.5, window_size=1000, stride=500):
    aa_dict = {aa: i for i, aa in enumerate("ACDEFGHIKLMNPQRSTVWY")}
    L = len(seq)
    
    # Build node features (one-hot encoding for the full sequence)
    node_features = torch.eye(20)[[aa_dict.get(aa, 0) for aa in seq]]  # [L, 20]

    # Containers for merged edges
    edge_index = []
    edge_attr = []

    windows = split_sequence(seq, window_size, stride)

    for start_idx, subseq in windows:
        contact_map = esm_model(model, alphabet, subseq)[0]  # shape: [L_window, L_window]
        L_win = len(subseq)

        for i in range(L_win):
            for j in range(L_win):
                prob = contact_map[i, j].item()
                if prob > threshold:
                    global_i = start_idx + i
                    global_j = start_idx + j
                    if global_i < L and global_j < L:
                        edge_index.append([global_i, global_j])
                        edge_attr.append(prob)

    return node_features, edge_index, edge_attr

In [None]:
# Read the CSV file
df = pd.read_csv("/kaggle/input/virus-drug/virus_drug_interactions.csv")
df = df.drop(df.columns[0], axis=1)
# Display the first few rows
print(df.head())

                                    Protein_Sequence  \
0  PISPIETVPVKLKPGMDGPKVKQWPLTEEKIKALVEICTEMEKEGK...   
1  MTMDEQQSQAVAPVYVGGFLARYDQSPDEAELLLPRDVVEHWLHAQ...   
2  PQVTLWQRPLVTIKIGGQLKEALLDTGADDTVLEEMSLPGRWKPKM...   
3  PQVTLWQRPLVTIKIGGQLKEALLDTGADDTVLEEMSLPGRWKPKM...   
4  PQVTLWQRPLVTIKIGGQLKEALLDTGADDTVLEEMSLPGRWKPKM...   

                                              SMILES     pIC50  
0                       S=C(NCN1CCOCC1)Nc1ccc(Br)cn1  5.000000  
1                  CC(=O)O[C@@H]1CC(=O)N1C(=O)NC(C)C  4.000000  
2  CCC(C)[C@H](NC(=O)[C@@H]1CCCN1[P@@](=O)(OC)[C@...  7.522879  
3  CCC(C)[C@H](NC(=O)[C@@H]1CCCN1[P@@](=O)(OC)[C@...  7.031517  
4  COC(=O)N[C@H](C(=O)N[C@@H](Cc1ccccc1)C(O)CN(Cc...  7.376751  


In [None]:
drugs = df['SMILES']
drug_graphs=[]
for i,drug in tqdm(enumerate(drugs)):
    graph=smile_graph(drug)
    drug_graphs.append(graph)
print(len(drug_graphs),len(drug_graphs[0]))



19451it [00:15, 1279.32it/s]

19451 4





In [None]:
model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
model.eval()
protein_sequences = df['Protein_Sequence']
unique_sequences = list(set(protein_sequences))

Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm1b_t33_650M_UR50S.pt" to /root/.cache/torch/hub/checkpoints/esm1b_t33_650M_UR50S.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm1b_t33_650M_UR50S-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm1b_t33_650M_UR50S-contact-regression.pt


In [None]:
def compute_graph(protein):
    return protein_graph(model, alphabet, protein)

# Run in parallel using all CPU cores
unique_graphs = Parallel(n_jobs=8)(
    delayed(compute_graph)(protein) for protein in tqdm(unique_sequences, desc="Processing proteins")
)



Processing proteins:   0%|          | 0/72 [00:00<?, ?it/s][A[A

Processing proteins:  11%|█         | 8/72 [00:18<02:26,  2.28s/it][A[A

In [None]:
sequence_to_graph = dict(zip(unique_sequences, unique_graphs))
protein_graphs = [sequence_to_graph[seq] for seq in protein_sequences]
print(len(protein_graphs), len(protein_graphs[0]))

In [None]:
with open("drug_graphs.pkl", "wb") as f:
    pickle.dump(drug_graphs, f)
with open("protein_graphs.pkl", "wb") as f:
    pickle.dump(protein_graphs, f)