# Imports

In [1]:
import pandas as pd
import torch
from torch_geometric.data import HeteroData
from torch_geometric.nn import SAGEConv, HeteroConv
from torch_geometric.transforms import RandomLinkSplit
from sklearn.metrics import roc_auc_score, accuracy_score
import numpy as np
import torch.nn.functional as F
from torch_geometric.utils import negative_sampling
import networkx as nx
from torch.nn import BatchNorm1d
from rdkit import Chem
from rdkit.Chem import rdFingerprintGenerator
from transformers import T5Tokenizer, T5Model
import requests

  from .autonotebook import tqdm as notebook_tqdm


# Step 1 -> Load Dataset

In [2]:
df = pd.read_csv('new_chembl_inhibit_drug_target_1.csv')  # Replace with your CSV path
df.drop_duplicates(inplace=True)

# Step 2 -> Construct the graph

In [3]:
# Set up the device (use CUDA if available)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')

# Initialize HeteroData object
data = HeteroData()

# Add compound nodes
compounds = df['compound_chembl_id'].unique()
compound_map = {compound: i for i, compound in enumerate(compounds)}

# Add target nodes
targets = df['target_uniprot_id'].unique()
target_map = {target: i for i, target in enumerate(targets)}

# Add edges (compound <-> target) ensuring undirected edges
compound_indices = np.array(df['compound_chembl_id'].map(compound_map).values)
target_indices = np.array(df['target_uniprot_id'].map(target_map).values)

# Convert to PyTorch tensors and move to GPU
edge_index = torch.tensor(np.vstack([compound_indices, target_indices]), dtype=torch.long).to(device)
edge_index_rev = torch.tensor(np.vstack([target_indices, compound_indices]), dtype=torch.long).to(device)

# Assign undirected edges to the graph
data['target', 'interacts', 'compound'].edge_index = edge_index_rev
data['compound', 'interacts', 'target'].edge_index = edge_index

# Load PPI data
cancer_ppi_data = pd.read_csv('cancer_ppi_combined.csv')

# Add Protein Nodes
existing_proteins = set(target_map.keys())

# Filter only valid PPI interactions
cancer_ppi_data_filtered = cancer_ppi_data[
    (cancer_ppi_data['node1_uniprot_id'].isin(existing_proteins)) | 
    (cancer_ppi_data['node2_uniprot_id'].isin(existing_proteins))
]

# Map PPI edges to indices
ppi_edges = []
len_target_map = len(target_map)

for _, row in cancer_ppi_data_filtered.iterrows():
    node1, node2 = row['node1_uniprot_id'], row['node2_uniprot_id']
    
    # Assign indices if not already mapped
    if node1 not in target_map:
        target_map[node1] = len_target_map
        len_target_map += 1
    if node2 not in target_map:
        target_map[node2] = len_target_map
        len_target_map += 1

    # Append undirected edges
    ppi_edges.append([target_map[node1], target_map[node2]])
    ppi_edges.append([target_map[node2], target_map[node1]])  # Add reverse edge

# Convert PPI edges to tensor
ppi_edge_index = torch.tensor(ppi_edges, dtype=torch.long).T.to(device)

# Assign undirected PPI edges to the graph
data['target', 'interacts', 'target'].edge_index = ppi_edge_index

data['compound'].x = torch.load('compound_features.pt')

import pickle

embeddings_dict_merged = pickle.load(open("embeddings_dict.pkl", "rb"))

# Add protein features
target_features = []

# Load target features in order of the target_map value (index)
for idx in range(len(target_map)):
    protein = [protein for protein, i in target_map.items() if i == idx][0]
    if protein in embeddings_dict_merged:
        target_features.append(embeddings_dict_merged[protein])
    else:
        target_features.append(np.zeros(1024))

data['target'].x = torch.tensor(target_features, dtype=torch.float).to(device)

  data['target'].x = torch.tensor(target_features, dtype=torch.float).to(device)


In [18]:
# Reconstruct targets from the target_map
targets = [protein for protein, _ in sorted(target_map.items(), key=lambda x: x[1])]

In [19]:
print(len(targets))

2123


In [5]:
print(len(embeddings_dict_merged))

2123


In [5]:
print(cancer_ppi_data_filtered)

      node1_uniprot_id node2_uniprot_id
113             Q6GPI1           P09683
114             Q6GPI1           P01298
201             P60709           Q99835
221             P60709           Q9NYK1
222             P60709           Q9NZQ7
...                ...              ...
39749           Q99759           P00533
39761           P46109           P00533
39788           Q9Y2R2           P00533
39799           Q15047           Q9Y6K1
39801           Q9Y6K1           Q15047

[3570 rows x 2 columns]


In [23]:
print(len(df['compound_chembl_id']))

70902


In [6]:
print(len(target_map))
print(len(compound_map))

2123
46660


In [26]:
print(len(compounds))

46660


In [23]:
print(target_map)

{'CHEMBL340': 0, 'CHEMBL4302': 1, 'CHEMBL2046258': 2, 'CHEMBL2835': 3, 'CHEMBL2363062': 4, 'CHEMBL3632452': 5, 'CHEMBL4523999': 6, 'CHEMBL4296661': 7, 'CHEMBL4005': 8, 'CHEMBL2842': 9, 'CHEMBL1163125': 10, 'CHEMBL2599': 11, 'CHEMBL4523988': 12, 'CHEMBL1937': 13, 'CHEMBL1829': 14, 'CHEMBL325': 15, 'CHEMBL333': 16, 'CHEMBL321': 17, 'CHEMBL2973': 18, 'CHEMBL3231': 19, 'CHEMBL4835': 20, 'CHEMBL4940': 21, 'CHEMBL2364162': 22, 'CHEMBL1947': 23, 'CHEMBL1163101': 24, 'CHEMBL5145': 25, 'CHEMBL2111432': 26, 'CHEMBL3130': 27, 'CHEMBL2292': 28, 'CHEMBL1865': 29, 'CHEMBL5432': 30, 'CHEMBL3399911': 31, 'CHEMBL267': 32, 'CHEMBL2002': 33, 'CHEMBL3024': 34, 'CHEMBL4523993': 35, 'CHEMBL6006': 36, 'CHEMBL3085620': 37, 'CHEMBL2362979': 38, 'CHEMBL1795184': 39, 'CHEMBL1973': 40, 'CHEMBL5608': 41, 'CHEMBL4898': 42, 'CHEMBL2815': 43, 'CHEMBL4072': 44, 'CHEMBL3105': 45, 'CHEMBL2111389': 46, 'CHEMBL3116': 47, 'CHEMBL4527': 48, 'CHEMBL1907611': 49, 'CHEMBL1871': 50, 'CHEMBL6166': 51, 'CHEMBL4895': 52, 'CHEMBL35

In [61]:
def fetch_smiles_set(chembl_ids):
    url = f"https://www.ebi.ac.uk/chembl/api/data/molecule/set/{';'.join(chembl_ids)}"
    
    # Add the 'Accept' header to explicitly request JSON format
    headers = {
        'Accept': 'application/json',
    }

    try:
        # Set a timeout of 10 seconds (you can adjust this as needed)
        response = requests.get(url, headers=headers, timeout=10)
        
        if response.status_code == 200:
            try:
                data = response.json()  # Parse JSON response
                smiles_dict = {}
                for molecule in data.get('molecules', []):
                    smiles = molecule.get('molecule_structures', {}).get('canonical_smiles', None)
                    chembl_id = molecule['molecule_chembl_id']
                    smiles_dict[chembl_id] = smiles
                return smiles_dict
            except:
                print("Error")
                return {}
        else:
            print(f"Error: Received status code {response.status_code}")
            return {}
    except requests.exceptions.ConnectTimeout:
        print("Connection timed out. Please check your network or try again later.")
        return {}
    except requests.exceptions.RequestException as e:
        print(f"An error occurred: {e}")
        return {}

def fetch_all_smiles(df, batch_size=100):
    unique_chembl_ids = df['compound_chembl_id'].unique()
    smiles_dict = {}

    # Fetch SMILES in batches using the /set/ endpoint
    for i in range(0, len(unique_chembl_ids), batch_size):
        batch_ids = unique_chembl_ids[i:i+batch_size]
        batch_smiles = fetch_smiles_set(batch_ids)
        smiles_dict.update(batch_smiles)
        print(f"Batch number: {i/batch_size}")

    return smiles_dict

# Example usage
smiles_dict = fetch_all_smiles(df, batch_size=250)

df['compound_smiles'] = df['compound_chembl_id'].map(smiles_dict)

Error
Batch number: 0.0
Batch number: 1.0
Batch number: 2.0
Batch number: 3.0
Batch number: 4.0
Batch number: 5.0
Batch number: 6.0
Batch number: 7.0
Batch number: 8.0
Error
Batch number: 9.0
Batch number: 10.0
Batch number: 11.0
Error
Batch number: 12.0
Batch number: 13.0
Batch number: 14.0
Batch number: 15.0
Batch number: 16.0
Batch number: 17.0
Batch number: 18.0
Batch number: 19.0
Batch number: 20.0
Batch number: 21.0
Batch number: 22.0
Batch number: 23.0
Batch number: 24.0
Error
Batch number: 25.0
Batch number: 26.0
Batch number: 27.0
Error
Batch number: 28.0
Batch number: 29.0
Batch number: 30.0
Error
Batch number: 31.0
Batch number: 32.0
Batch number: 33.0
Error
Batch number: 34.0
Batch number: 35.0
Batch number: 36.0
Error
Batch number: 37.0
Batch number: 38.0
Batch number: 39.0
Batch number: 40.0
Batch number: 41.0
Batch number: 42.0
Batch number: 43.0
Error
Batch number: 44.0
Batch number: 45.0
Batch number: 46.0
Batch number: 47.0
Batch number: 48.0
Error
Batch number: 49.0


In [62]:
df.to_csv('new_chembl_inhibit_drug_target_1.csv', index=False)

In [93]:
nan_smiles_df = df[df['compound_smiles'].isna()]
print(len(nan_smiles_df))

missing_smiles_dict = fetch_all_smiles(nan_smiles_df, batch_size=1)

75
Batch number: 0.0
Error
Batch number: 1.0
Error
Batch number: 2.0
Batch number: 3.0
Batch number: 4.0
Error
Batch number: 5.0
Error
Batch number: 6.0
Batch number: 7.0
Batch number: 8.0
Error
Batch number: 9.0
Error
Batch number: 10.0
Batch number: 11.0
Batch number: 12.0
Error
Batch number: 13.0
Batch number: 14.0
Error
Batch number: 15.0
Error
Batch number: 16.0
Batch number: 17.0
Error
Batch number: 18.0
Batch number: 19.0
Error
Batch number: 20.0
Batch number: 21.0
Error
Batch number: 22.0
Batch number: 23.0
Batch number: 24.0
Error
Batch number: 25.0
Batch number: 26.0
Error
Batch number: 27.0
Batch number: 28.0
Error
Batch number: 29.0
Error
Batch number: 30.0
Batch number: 31.0
Error
Batch number: 32.0
Batch number: 33.0
Error
Batch number: 34.0
Batch number: 35.0
Batch number: 36.0
Error
Batch number: 37.0
Error
Batch number: 38.0
Batch number: 39.0
Batch number: 40.0
Error
Batch number: 41.0
Error
Batch number: 42.0
Batch number: 43.0
Error
Batch number: 44.0
Error
Batch nu

In [94]:
def fill_missing_smiles(row):
    if pd.isna(row['compound_smiles']):  # Check if the current value is NaN
        chembl_id = row['compound_chembl_id']  # Get the ChEMBL ID from the row
        return missing_smiles_dict.get(chembl_id, row['compound_smiles'])  # Get the SMILES from the dict or keep NaN if not found
    else:
        return row['compound_smiles']  # Return the existing SMILES if it's not NaN

# Apply the function to fill missing SMILES
df['compound_smiles'] = df.apply(fill_missing_smiles, axis=1)

In [95]:
nan_smiles_df = df[df['compound_smiles'].isna()]
print(len(nan_smiles_df))

37


In [96]:
df.to_csv('new_chembl_inhibit_drug_target_1.csv', index=False)

In [4]:
import math

# Function to extract molecular features using RDKit
def extract_compound_features(smiles_list):
    features = []
    generator = rdFingerprintGenerator.GetMorganGenerator(radius=2,fpSize=1024)
    for smiles in smiles_list:
        if isinstance(smiles, float) and math.isnan(smiles):
            features.append([0] * 1024)
            continue
        mol = Chem.MolFromSmiles(smiles)
        if mol is not None:
            # Example: Using molecular weight and LogP as features
            fingerprint = generator.GetFingerprint(mol)
            fingerprint_list = list(fingerprint)  # Convert to list if needed
            features.append(fingerprint_list)
        else:
            features.append([0] * 1024)  # Handle missing molecule cases
    return torch.tensor(features, dtype=torch.float)

# Get the list of smiles from compound_map and df
smiles_list = [df[df['compound_chembl_id'] == compound].iloc[0]['compound_smiles'] for compound in compounds]

# Extract compound features
compound_features = extract_compound_features(smiles_list)

# Add compound features to the graph
data['compound'].x = compound_features.to(device)

KeyboardInterrupt: 

In [28]:
print(data['compound'].x.shape)

torch.Size([46660, 1024])


In [29]:
# Store data['compound'].x
torch.save(data['compound'].x, 'compound_features.pt')

In [4]:
# Load data['compound'].x
data['compound'].x = torch.load('compound_features.pt')

In [6]:
print(data['compound'].x)

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.]], device='cuda:0')


In [8]:
def fetch_uniprot_set(chembl_ids):
    url = f"https://www.ebi.ac.uk/chembl/api/data/target/set/{';'.join(chembl_ids)}"
    
    # Add the 'Accept' header to explicitly request JSON format
    headers = {
        'Accept': 'application/json',
    }

    try:
        # Set a timeout of 10 seconds (you can adjust this as needed)
        response = requests.get(url, headers=headers, timeout=10)
        
        if response.status_code == 200:
            try:
                data = response.json()  # Parse JSON response
                uniprot_dict = {}
                for target in data.get('targets', []):
                    target_components = target.get('target_components', [])
                    for component in target_components:
                        for xref in component.get('target_component_xrefs', []):
                            if xref.get('xref_src_db') == 'UniProt':
                                chembl_id = target['target_chembl_id']
                                uniprot_id = xref.get('xref_id')
                                uniprot_dict[chembl_id] = uniprot_id
                return uniprot_dict
            except Exception as e:
                print(f"Error processing response: {e}")
                return {}
        else:
            print(f"Error: Received status code {response.status_code}")
            return {}
    except requests.exceptions.ConnectTimeout:
        print("Connection timed out. Please check your network or try again later.")
        return {}
    except requests.exceptions.RequestException as e:
        print(f"An error occurred: {e}")
        return {}

# Function to fetch all UniProt IDs in batches, similar to fetch_all_smiles
def fetch_all_uniprots(df, batch_size=100):
    unique_chembl_ids = df['target_chembl_id'].unique()  # Assuming 'target_chembl_id' column exists
    uniprot_dict = {}

    # Fetch UniProt IDs in batches using the /target/set/ endpoint
    for i in range(0, len(unique_chembl_ids), batch_size):
        batch_ids = unique_chembl_ids[i:i+batch_size]
        batch_uniprots = fetch_uniprot_set(batch_ids)
        uniprot_dict.update(batch_uniprots)
        print(f"Processed batch number: {i // batch_size + 1}")

    return uniprot_dict

# Example usage
uniprot_dict = fetch_all_uniprots(df, batch_size=250)

# Map the fetched UniProt IDs back to the original DataFrame
df['target_uniprot_id'] = df['target_chembl_id'].map(uniprot_dict)

Processed batch number: 1
Processed batch number: 2
Processed batch number: 3
Processed batch number: 4
Processed batch number: 5


In [12]:
df.to_csv('new_chembl_inhibit_drug_target_1.csv', index=False)

In [4]:
# Load the Prot5 model and tokenizer
tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50")
print("Tokenizer loaded")

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Tokenizer loaded


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
model = T5Model.from_pretrained("Rostlab/prot_t5_xl_uniref50").half().to(device)
print("Model loaded")

Model loaded


In [12]:
# Function to fetch protein sequence from UniProt
def fetch_uniprot_sequence(uniprot_id):
    url = f"https://www.uniprot.org/uniprot/{uniprot_id}.fasta"
    response = requests.get(url)
    if response.status_code == 200:
        sequence = ''.join(response.text.splitlines()[1:])  # Skip the FASTA header
        return sequence
    else:
        print(f"Error fetching sequence for {uniprot_id}")
        return None

# Function to generate embeddings for a list of protein sequences
def get_protein_embeddings(uniprot_ids):
    embeddings_dict = {}  # Dictionary to cache embeddings
    num = 0

    for uniprot_id in uniprot_ids:
        num += 1
        if num%20 == 0:
            print(f"Processing uniprot id : {uniprot_id} number {num}")

        try:
            sequence = fetch_uniprot_sequence(uniprot_id)  # Fetch the protein sequence
            if sequence:
                # Tokenize the input sequence
                inputs = tokenizer(sequence, return_tensors="pt", padding=True).to(device)

                # Add decoder_input_ids: initialize with the pad token id
                decoder_input_ids = torch.tensor([[tokenizer.pad_token_id]]).to(device)

                # Forward pass with encoder input and decoder input
                with torch.no_grad():
                    outputs = model(input_ids=inputs['input_ids'], decoder_input_ids=decoder_input_ids)

                # Extract embeddings (e.g., from encoder output or mean pooling)
                embedding = outputs.last_hidden_state.mean(dim=1)  # Mean pooling for simplicity
                embeddings_dict[uniprot_id] = embedding.squeeze(0).cpu().numpy()  # Cache the embedding
            else:
                embeddings_dict[uniprot_id] = [0] * 1024  # Default zero embedding for missing sequences
        except:
            print(f"Error processing uniprot id : {uniprot_id} number {num}")

    return embeddings_dict


# Generate the embeddings for the unique UniProt IDs
unique_uniprot_ids = protein_map.keys()
print(len(unique_uniprot_ids))
embeddings_dict = get_protein_embeddings(unique_uniprot_ids)

2123
Processing uniprot id : Q59GZ4 number 20
Processing uniprot id : Q9BZX1 number 40
Processing uniprot id : Q96RG3 number 60
Processing uniprot id : Q8N7V3 number 80
Processing uniprot id : Q9UMG5 number 100
Processing uniprot id : Q9BVY6 number 120
Processing uniprot id : Q9Y618 number 140
Processing uniprot id : Q9NYL2 number 160
Processing uniprot id : Q5T489 number 180
Processing uniprot id : Q9UCC0 number 200
Processing uniprot id : Q9HAR0 number 220
Processing uniprot id : Q9UHD2 number 240
Processing uniprot id : Q9P0B8 number 260
Processing uniprot id : Q9Y265 number 280
Processing uniprot id : Q2M1P8 number 300
Processing uniprot id : Q9UQ96 number 320
Processing uniprot id : Q99957 number 340
Processing uniprot id : Q969U4 number 360
Processing uniprot id : Q5T7T8 number 380
Processing uniprot id : Q96KA8 number 400
Processing uniprot id : Q9HD26 number 420
Processing uniprot id : Q96CY8 number 440
Processing uniprot id : Q9Y616 number 460
Processing uniprot id : Q68DZ3 nu

In [13]:
missing_ids = [m_id for m_id in unique_uniprot_ids if m_id not in embeddings_dict]

In [14]:
print(missing_ids)

[]


In [21]:
new_dict = get_protein_embeddings(missing_ids)
print(new_dict)

{'Q9UQ95': array([ 0.1744,  0.2167, -0.1216, ...,  0.1583, -0.0855, -0.5703],
      dtype=float16), 'Q9NQ14': array([ 0.1744,  0.2167, -0.1216, ...,  0.1583, -0.0855, -0.5703],
      dtype=float16)}


In [32]:
# embeddings_dict_merged = {**embeddings_dict, **new_dict}
# print(len(embeddings_dict_merged))

NameError: name 'embeddings_dict' is not defined

In [15]:
import pickle

with open("embeddings_dict.pkl", "wb") as f:
    pickle.dump(embeddings_dict, f)

In [5]:
import pickle

embeddings_dict_merged = pickle.load(open("embeddings_dict.pkl", "rb"))

In [9]:
print(len(embeddings_dict_merged))

2123


In [6]:
# Create a list of embeddings that matches the size of the original dataset
target_features = [embeddings_dict_merged[uniprot_id] for uniprot_id in df['target_uniprot_id']]

data['target'].x = torch.tensor(target_features, dtype=torch.float).to(device)

In [7]:
# Check the new feature shapes
print("Compound node features shape", data['compound'].x.shape)
print("Target node features shape", data['target'].x.shape)

# Check shapes to verify
print("Reverse edge index shape:", data['compound', 'interacts', 'target'].edge_index.shape)
print("Edge index shape:", data['target', 'interacts', 'compound'].edge_index.shape)
print("PPI edge index shape:", data['target', 'interacts', 'target'].edge_index.shape)

Compound node features shape torch.Size([46660, 1024])
Target node features shape torch.Size([70902, 1024])
Reverse edge index shape: torch.Size([2, 70902])
Edge index shape: torch.Size([2, 70902])
PPI edge index shape: torch.Size([2, 7140])


In [8]:
# Apply RandomLinkSplit to generate train, validation, and test sets
transform = RandomLinkSplit(
    num_val=0.1,  # 10% validation edges
    num_test=0.2,  # 20% test edges
    is_undirected=False,  # Keep it directed, set to True for undirected
    add_negative_train_samples=True,  # Generate negative samples for training
    edge_types=('target', 'interacts', 'compound'),  # Specify the edge type
    rev_edge_types=('compound', 'interacts', 'target')  # Specify the reverse edge type
)

# Perform train, val, and test split and move the splits to the GPU
train_data, val_data, test_data = transform(data)

# Move split datasets to the same device
train_data = train_data.to(device)
val_data = val_data.to(device)
test_data = test_data.to(device)

In [22]:
print(len(train_data))
print(len(val_data))
print(len(test_data))

4
4
4


In [23]:
print(train_data)

HeteroData(
  compound={ x=[46660, 1024] },
  target={ x=[70902, 1024] },
  (target, interacts, compound)={
    edge_index=[2, 49632],
    edge_label=[99264],
    edge_label_index=[2, 99264],
  },
  (compound, interacts, target)={ edge_index=[2, 49632] },
  (target, interacts, target)={ edge_index=[2, 7140] }
)


# Define GNN Model

In [13]:
from torch_geometric.nn import GATv2Conv

In [10]:
# Define a simple MLP for link prediction
class LinkPredictor(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(LinkPredictor, self).__init__()
        self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, 1)  # Output is a single scalar score
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        return torch.sigmoid(self.fc2(x))  # Output between 0 and 1 (probability)

class GATModel(torch.nn.Module):
    def __init__(self, hidden_channels, heads=1):
        super(GATModel, self).__init__()
        self.attn_weights = {}  # To store attention weights

        self.conv1 = HeteroConv({
            ('target', 'interacts', 'compound'): GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False),
            ('compound', 'interacts', 'target'): GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False),
            ('target', 'interacts', 'target'): GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False)
        }, aggr='mean')

        # Batch normalization after first layer
        self.batchnorm1_compound = BatchNorm1d(hidden_channels * heads)
        self.batchnorm1_target = BatchNorm1d(hidden_channels * heads)

        self.conv2 = HeteroConv({
            ('target', 'interacts', 'compound'): GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False),
            ('compound', 'interacts', 'target'): GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False),
            ('target', 'interacts', 'target'): GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False)
        }, aggr='mean')

        # Batch normalization after second layer
        self.batchnorm2_compound = BatchNorm1d(hidden_channels * heads)
        self.batchnorm2_target = BatchNorm1d(hidden_channels * heads)

        # Define the MLP for link prediction
        input_dim = hidden_channels * heads * 2  # Because we will concatenate embeddings of two nodes
        self.link_predictor = LinkPredictor(input_dim, hidden_dim=hidden_channels)

    def forward(self, x_dict, edge_index_dict, dir=True):
        self.attn_weights = {}  # Clear attention weights at each forward pass
        
        # Forward pass through the first HeteroConv layer
        x_dict, attn1 = self._apply_conv_and_extract_attention(self.conv1, x_dict, edge_index_dict)

        x_dict['compound'] = F.relu(self.batchnorm1_compound(x_dict['compound']))
        x_dict['target'] = F.relu(self.batchnorm1_target(x_dict['target']))

        # Forward pass through the second HeteroConv layer
        x_dict, attn2 = self._apply_conv_and_extract_attention(self.conv2, x_dict, edge_index_dict)

        x_dict['compound'] = self.batchnorm2_compound(x_dict['compound'])
        x_dict['target'] = self.batchnorm2_target(x_dict['target'])

        if dir:
            edge_index = edge_index_dict[('target', 'interacts', 'compound')]
            source_x = x_dict['target'][edge_index[0]]  # Source node embeddings (compounds)
            target_x = x_dict['compound'][edge_index[1]]    # Target node embeddings (targets)
        else:
            edge_index = edge_index_dict[('compound', 'interacts', 'target')]
            source_x = x_dict['compound'][edge_index[0]]
            target_x = x_dict['target'][edge_index[1]]

        # Concatenate the embeddings of source and target nodes
        x_concat = torch.cat([source_x, target_x], dim=-1).to(device)  # Concatenate along the feature dimension
        
        # Pass the concatenated embeddings through the neural network
        return self.link_predictor(x_concat)  # Output a probability or score for the link

    def _apply_conv_and_extract_attention(self, conv, x_dict, edge_index_dict):
        # Initialize attention storage for this convolution layer
        attn_weights = {}

        # Iterate over each edge type in the heterogeneous graph
        for edge_type, gat_conv in conv.convs.items():
            edge_index = edge_index_dict[edge_type]

            # Perform GATConv and retrieve attention weights
            x_dict[edge_type[2]], attn = gat_conv((x_dict[edge_type[0]], x_dict[edge_type[2]]), edge_index, return_attention_weights=True)
            attn_weights[edge_type] = attn[1]  # Store attention weights

        # Save attention weights for this layer
        self.attn_weights.update(attn_weights)
        return x_dict, attn_weights


# Train

In [11]:
model = GATModel(
    hidden_channels=64,
    heads=8
)

# Check for available device (GPU or CPU)
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
model = model.to(device)

In [12]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCELoss()

def train(data, device):
    model.train()

    optimizer.zero_grad()

    # Move the data to the GPU (if it's not already on the same device)
    data = data.to(device)

    # Positive edges (compound -> target)
    pos_link_pred = model(data.x_dict, data.edge_index_dict).squeeze()

    # Initialize the neg_edge_index_dict
    neg_edge_index_dict = {}

    # Iterate through edge types in edge_index_dict
    for edge_type, edge_index in data.edge_index_dict.items():
        source_type, _, target_type = edge_type

        # If edge type is 'target -> compound', generate negative samples
        if edge_type == ('target', 'interacts', 'compound'):
            neg_edge_index = negative_sampling(
                edge_index=edge_index.to(device),
                num_nodes=(data[source_type].num_nodes, data[target_type].num_nodes),
                num_neg_samples=edge_index.size(1)  # 1:1 ratio of positive to negative samples
            )
            neg_edge_index_dict[edge_type] = neg_edge_index

        # If edge type is 'target -> target', use the original edges (no sampling)
        elif edge_type == ('target', 'interacts', 'target'):
            neg_edge_index_dict[edge_type] = edge_index.to(device)
        
        else:
            neg_edge_index_rev = negative_sampling(
                edge_index=edge_index.to(device),
                num_nodes=(data[source_type].num_nodes, data[target_type].num_nodes),
                num_neg_samples=edge_index.size(1)  # 1:1 ratio of positive to negative samples
            )
            neg_edge_index_dict[edge_type] = neg_edge_index_rev

    neg_link_pred = model(data.x_dict, neg_edge_index_dict).squeeze()

    # Combine positive and negative samples for the loss
    link_preds = torch.cat([pos_link_pred, neg_link_pred], dim=0)
    link_labels = torch.cat([
        torch.ones(pos_link_pred.size(0)).to(device),  # Positive samples on the same device
        torch.zeros(neg_link_pred.size(0)).to(device)  # Negative samples on the same device
    ], dim=0)

    # Compute loss
    loss = criterion(link_preds, link_labels)
    loss.backward()
    optimizer.step()

    return loss.item()


# Test

In [13]:
def test(data, edge_label_index, edge_label, device):
    model.eval()
    
    # Move data to GPU if available
    edge_label_index = edge_label_index.to(device)
    edge_label = edge_label.to(device)

    with torch.no_grad():
        # Positive edges for test set
        pos_link_logits = model(data.x_dict, data.edge_index_dict).squeeze()

        neg_edge_index_dict = {}

        # Iterate through edge types in edge_index_dict
        for edge_type, edge_index in data.edge_index_dict.items():
            source_type, _, target_type = edge_type

            # If edge type is 'target -> compound', generate negative samples
            if edge_type == ('target', 'interacts', 'compound'):
                neg_edge_index = negative_sampling(
                    edge_index=edge_index.to(device),
                    num_nodes=(data[source_type].num_nodes, data[target_type].num_nodes),
                    num_neg_samples=edge_index.size(1)  # 1:1 ratio of positive to negative samples
                )
                neg_edge_index_dict[edge_type] = neg_edge_index

            # If edge type is 'target -> target', use the original edges (no sampling)
            elif edge_type == ('target', 'interacts', 'target'):
                neg_edge_index_dict[edge_type] = edge_index.to(device)
            
            else:
                neg_edge_index_rev = negative_sampling(
                    edge_index=edge_index.to(device),
                    num_nodes=(data[source_type].num_nodes, data[target_type].num_nodes),
                    num_neg_samples=edge_index.size(1)  # 1:1 ratio of positive to negative samples
                )
                neg_edge_index_dict[edge_type] = neg_edge_index_rev

        neg_link_logits = model(data.x_dict, neg_edge_index_dict, dir=False).squeeze()

        all_link_logits = torch.cat([pos_link_logits, neg_link_logits], dim=0)

        all_link_labels = torch.cat([
            torch.ones(pos_link_logits.size(0), device=device),
            torch.zeros(neg_link_logits.size(0), device=device)
        ], dim=0)

        # Compute probabilities and predictions
        link_probs = all_link_logits.cpu().numpy()
        link_labels = all_link_labels.cpu().numpy()
        
        # Compute AUC
        auc = roc_auc_score(link_labels, link_probs)
        
        # Convert probabilities to binary predictions
        link_pred = (link_probs > 0.5).astype(int)
        
        # Compute accuracy
        acc = accuracy_score(link_labels, link_pred)
    
    return auc, acc


# Driver Code

In [14]:
%env CUDA_LAUNCH_BLOCKING=1

env: CUDA_LAUNCH_BLOCKING=1


In [31]:
print(data.edge_index_dict)

{('target', 'interacts', 'compound'): tensor([[    0,     1,     1,  ...,   119,    53,    53],
        [    0,     1,     2,  ..., 46657, 46658, 46659]]), ('compound', 'interacts', 'target'): tensor([[    0,     1,     2,  ..., 46657, 46658, 46659],
        [    0,     1,     1,  ...,   119,    53,    53]]), ('target', 'interacts', 'target'): tensor([[ 912,  919,  912,  ...,  731,  731, 2122],
        [ 919,  912,  920,  ..., 2122, 2122,  731]])}


In [17]:
for edge_type, edge_index in data.edge_index_dict.items():
    max_source = edge_index[0].max().item()
    max_target = edge_index[1].max().item()
    print(f"Edge type: {edge_type}, Max Source Index: {max_source}, Max Target Index: {max_target}")

Edge type: ('target', 'interacts', 'compound'), Max Source Index: 1022, Max Target Index: 46659
Edge type: ('compound', 'interacts', 'target'), Max Source Index: 46659, Max Target Index: 1022
Edge type: ('target', 'interacts', 'target'), Max Source Index: 2122, Max Target Index: 2122


In [15]:
from sklearn.model_selection import KFold
import torch

# Define number of epochs and folds
epochs = 50
k_folds = 5

# Extract edges for prediction
edge_label_index = train_data['target', 'interacts', 'compound'].edge_label_index
edge_labels = train_data['target', 'interacts', 'compound'].edge_label

# Convert to NumPy for easier indexing
edge_indices = edge_label_index.t().cpu().numpy()  # Shape [N, 2]
edge_labels = edge_labels.cpu().numpy()  # Shape [N]

kf = KFold(n_splits=k_folds, shuffle=True, random_state=42)

fold_results = []

for fold, (train_idx, val_idx) in enumerate(kf.split(edge_indices), 1):
    print(f'\nFold {fold}/{k_folds}')

    # Create train and validation subsets
    train_edges = torch.tensor(edge_indices[train_idx]).T.to(device)
    train_labels = torch.tensor(edge_labels[train_idx]).to(device)
    
    val_edges = torch.tensor(edge_indices[val_idx]).T.to(device)
    val_labels = torch.tensor(edge_labels[val_idx]).to(device)

    # Create a copy of train_data and update the edge_label sets
    train_subset = train_data.clone()
    train_subset['target', 'interacts', 'compound'].edge_label_index = train_edges
    train_subset['target', 'interacts', 'compound'].edge_label = train_labels

    val_subset = train_data.clone()
    val_subset['target', 'interacts', 'compound'].edge_label_index = val_edges
    val_subset['target', 'interacts', 'compound'].edge_label = val_labels

    # Training loop
    for epoch in range(1, epochs + 1):
        loss = train(train_subset, device)
        
        if epoch % 10 == 0:
            val_auc, val_acc = test(val_subset, val_subset['target', 'interacts', 'compound'].edge_label_index,
                                    val_subset['target', 'interacts', 'compound'].edge_label, device)
            print(f'Fold {fold}, Epoch {epoch}, Loss: {loss:.4f}, Val AUC: {val_auc:.4f}, Val Acc: {val_acc:.4f}')

    # Evaluate on test data for this fold
    test_auc, test_acc = test(test_data, test_data['target', 'interacts', 'compound'].edge_label_index,
                              test_data['target', 'interacts', 'compound'].edge_label, device)
    
    print(f'Fold {fold} Test AUC: {test_auc:.4f}, Test Accuracy: {test_acc:.4f}')
    fold_results.append((test_auc, test_acc))

# Compute average results
avg_auc = sum([res[0] for res in fold_results]) / k_folds
avg_acc = sum([res[1] for res in fold_results]) / k_folds

print(f'\nAverage Test AUC: {avg_auc:.4f}, Average Test Accuracy: {avg_acc:.4f}')



Fold 1/5
Fold 1, Epoch 10, Loss: 0.2265, Val AUC: 0.1523, Val Acc: 0.2547
Fold 1, Epoch 20, Loss: 0.0203, Val AUC: 0.1520, Val Acc: 0.2496
Fold 1, Epoch 30, Loss: 0.1469, Val AUC: 0.0702, Val Acc: 0.2068
Fold 1, Epoch 40, Loss: 0.0115, Val AUC: 0.0701, Val Acc: 0.2089
Fold 1, Epoch 50, Loss: 0.0220, Val AUC: 0.0683, Val Acc: 0.2059
Fold 1 Test AUC: 0.0722, Test Accuracy: 0.1812

Fold 2/5
Fold 2, Epoch 10, Loss: 0.0092, Val AUC: 0.0665, Val Acc: 0.2054
Fold 2, Epoch 20, Loss: 0.0051, Val AUC: 0.0664, Val Acc: 0.2078
Fold 2, Epoch 30, Loss: 0.0185, Val AUC: 0.0548, Val Acc: 0.2002
Fold 2, Epoch 40, Loss: 0.0000, Val AUC: 0.0214, Val Acc: 0.1873
Fold 2, Epoch 50, Loss: 0.0031, Val AUC: 0.1058, Val Acc: 0.2304
Fold 2 Test AUC: 0.1135, Test Accuracy: 0.2097

Fold 3/5
Fold 3, Epoch 10, Loss: 0.0012, Val AUC: 0.1060, Val Acc: 0.2259
Fold 3, Epoch 20, Loss: 0.0020, Val AUC: 0.1035, Val Acc: 0.2265
Fold 3, Epoch 30, Loss: 0.0000, Val AUC: 0.1013, Val Acc: 0.2264
Fold 3, Epoch 40, Loss: 0.0000,

In [16]:
print("Target node count:", data['target'].num_nodes)
print("Compound node count:", data['compound'].num_nodes)

Target node count: 70902
Compound node count: 46660


In [20]:
# Retrieve attention weights for 'target' -> 'target' edges
edge_type = ('target', 'interacts', 'target')
attn_weights = model.attn_weights[edge_type]  # Shape: [num_edges, heads]
mean_attn_weights = attn_weights.mean(dim=-1)  # Average across attention heads

# Select a specific interaction to explain using CHEMBL IDs
drug_chembl_id = 'CHEMBL4752635'  # Example CHEMBL ID for drug
protein_uniprot_id = 'P14780'  # Example Protein ID for protein

# Get the node indices for the drug and protein
compound_index = compound_map[drug_chembl_id]
protein_idx = target_map[protein_uniprot_id]

# Get the edge index for the protein -> protein interaction
edge_index = data['target', 'interacts', 'target'].edge_index

# Get the attention weights and corresponding proteins for protein -> protein interactions involving the target protein
attn_weights, protein_idx = mean_attn_weights[edge_index[1] == protein_idx], edge_index[0][edge_index[1] == protein_idx]

# Get the top 5 proteins with the highest attention weights
top_k = 5
top_k_indices = protein_idx[attn_weights.argsort(descending=True)][:top_k]

# Map the indices back to the protein uniprot IDs
target_map_inv = {v: k for k, v in target_map.items()}
top_k_proteins = [target_map_inv[i.item()] for i in top_k_indices]

print(f"Top 5 proteins interacting with the protein-drug pair {protein_uniprot_id} <-> {drug_chembl_id} :")

for i, protein in enumerate(top_k_proteins):
    print(f"{i + 1}. {protein}")

Top 5 proteins interacting with the protein-drug pair Q6GPI1 <-> CHEMBL4112929 :
1. P15144
2. P06307
3. P01375
4. P01011
5. Q8TF68


In [21]:
print(len(targets))

2123


# Variational Graph Autoencoder

In [9]:
from torch_geometric.nn import VGAE
from torch_geometric.utils import train_test_split_edges

In [10]:
# Define Encoder for VGAE
class HeteroEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, heads=1):
        super(HeteroEncoder, self).__init__()

        self.conv1 = HeteroConv({
            ('target', 'interacts', 'compound'): GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False),
            ('compound', 'interacts', 'target'): GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False),
            ('target', 'interacts', 'target'): GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False)
        }, aggr='mean')

        self.conv_mu = HeteroConv({
            ('target', 'interacts', 'compound'): GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False),
            ('compound', 'interacts', 'target'): GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False),
            ('target', 'interacts', 'target'): GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False)
        }, aggr='mean')

        self.conv_logstd = HeteroConv({
            ('target', 'interacts', 'compound'): GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False),
            ('compound', 'interacts', 'target'): GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False),
            ('target', 'interacts', 'target'): GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False)
        }, aggr='mean')

    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}  # Apply ReLU activation
        
        mu_dict = self.conv_mu(x_dict, edge_index_dict)  # Mean
        logstd_dict = self.conv_logstd(x_dict, edge_index_dict)  # Log Standard Deviation 

        return mu_dict, logstd_dict

class HeteroVGAE(VGAE):
    def encode(self, x_dict, edge_index_dict):
        mu_dict, logstd_dict = self.encoder(x_dict, edge_index_dict)

        # Concatenate features of all node types into single tensors
        mu = torch.cat([mu_dict[key] for key in mu_dict], dim=0)
        logstd = torch.cat([logstd_dict[key] for key in logstd_dict], dim=0)

        # Clamp logstd to prevent numerical instability
        logstd = logstd.clamp(min=-10, max=10)

        z = self.reparametrize(mu, logstd)

        return z, mu, logstd

In [14]:
# VGAE Model
model = HeteroVGAE(HeteroEncoder(hidden_channels=64, heads=8)).to(device)

# Define Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [14]:
# Training Function
def train(train_data):
    model.train()
    optimizer.zero_grad()
    z, mu, logstd = model.encode(train_data.x_dict, train_data.edge_index_dict)
    
    # Initialize loss
    loss = 0

    for edge_type, edge_index in train_data.edge_index_dict.items():
        loss += model.recon_loss(z, edge_index) + (1 / (train_data[edge_type[0]].num_nodes + train_data[edge_type[2]].num_nodes)) * model.kl_loss(mu, logstd)
        
    loss.backward()
    optimizer.step()
    return loss.item()

# K-Fold Cross Validation
k_folds = 5
epochs = 50

kf = KFold(n_splits=k_folds, shuffle=True, random_state=42)

for fold, (train_idx, val_idx) in enumerate(kf.split(targets), 1):
    print(f'\nFold {fold}/{k_folds}')

    # Create train and validation subsets
    train_data_fold = train_data.clone()
    val_data_fold = val_data.clone()

    # Update edge_label_index and edge_label for train and val subsets
    for edge_type, edge_index in train_data_fold.edge_index_dict.items():
        train_data_fold[edge_type].edge_label_index = edge_index[:, train_idx]
        train_data_fold[edge_type].edge_label = edge_index[:, train_idx]

        val_data_fold[edge_type].edge_label_index = edge_index[:, val_idx]
        val_data_fold[edge_type].edge_label = edge_index[:, val_idx]

    for epoch in range(1, epochs + 1):
        loss = train()
        
        if epoch % 10 == 0:
            print(f'Fold {fold}, Epoch {epoch}, Loss: {loss:.4f}')

Epoch 0, Loss: 54.587955
Epoch 20, Loss: 31.094702
Epoch 40, Loss: 27.135923
Epoch 60, Loss: 23.926014
Epoch 80, Loss: 21.270927
Epoch 100, Loss: 20.115210
Epoch 120, Loss: 19.330423
Epoch 140, Loss: 18.822206
Epoch 160, Loss: 18.393713
Epoch 180, Loss: 18.253078
Epoch 200, Loss: 18.065180


In [31]:
# Save the model
torch.save(model.state_dict(), 'hetero_vgae_model.pth')

In [15]:
# Load the model
model.load_state_dict(torch.load('hetero_vgae_model.pth'))

<All keys matched successfully>

In [17]:
import gc

gc.collect()

torch.cuda.empty_cache()

In [26]:
# Create all possible pairs (cartesian product)

target_nodes = torch.arange(test_data['target'].num_nodes).to(device)
compound_nodes = torch.arange(test_data['compound'].num_nodes).to(device)

# possible_edges = torch.cartesian_prod(target_nodes, compound_nodes).t().contiguous().to(device)

model.eval()
with torch.no_grad():
    z, _, _ = model.encode(test_data.x_dict, test_data.edge_index_dict)
    print(z)
    new_edges = model.decode(z, test_data.edge_index_dict['target', 'interacts', 'compound'])

# Convert Predicted Edges to a List
print(len(test_data.edge_index_dict[('target', 'interacts', 'compound')][0]))
predicted_edges = new_edges.cpu().numpy()
print(predicted_edges, len(predicted_edges))
threshold = 0.5
new_interactions = (predicted_edges > threshold).astype(int)

# Compare Old and New Graph
print(f"Original Edges: {len(test_data['target', 'interacts', 'compound'].edge_label)}")
print(f"Newly Predicted Interactions: {sum(new_interactions)}")

tensor([[-0.2029,  0.3498,  0.2685,  ..., -0.2318, -0.2427, -0.3927],
        [-0.2453,  0.5501,  0.5172,  ..., -0.5384, -0.5338, -1.0964],
        [-0.2453,  0.5501,  0.5172,  ..., -0.5384, -0.5338, -1.0964],
        ...,
        [ 0.0220, -0.0260, -0.0344,  ...,  0.0141,  0.0285,  0.0295],
        [ 0.0220, -0.0260, -0.0344,  ...,  0.0141,  0.0285,  0.0295],
        [ 0.0220, -0.0260, -0.0344,  ...,  0.0141,  0.0285,  0.0295]],
       device='cuda:0')
56722
[1. 1. 1. ... 1. 1. 1.] 56722
Original Edges: 28360
Newly Predicted Interactions: 56722
