In [1]:
# Import necessary libraries
import pandas as pd
from sklearn.model_selection import train_test_split
from rdkit import Chem
from rdkit.Chem import AllChem
import torch
from torch_geometric.data import Dataset, HeteroData, DataLoader
from Bio.PDB import PDBParser, is_aa
from Bio.PDB.Polypeptide import three_to_index, index_to_one
import os
from tqdm import tqdm
import numpy as np
import random
from torch_geometric.nn import SAGEConv, global_mean_pool
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score

# Set random seed and device
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
import os
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Dataset, HeteroData
from torch_geometric.loader import DataLoader
from torch_geometric.nn import HeteroConv, GCNConv, Linear, global_mean_pool
from torch.utils.data import random_split
from rdkit import Chem
from rdkit.Chem import AllChem
from Bio.PDB import PDBParser, is_aa
from Bio.PDB.Polypeptide import three_to_index, index_to_one
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score
from tqdm import tqdm

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Helper functions
def residue_name_to_idx(res_name_one):
    amino_acids = ['A', 'R', 'N', 'D', 'C', 'E', 'Q', 'G',
                   'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S',
                   'T', 'W', 'Y', 'V']
    if res_name_one in amino_acids:
        return amino_acids.index(res_name_one)
    else:
        return len(amino_acids)  # Unknown amino acid

def process_protein(pdb_file, threshold=5.0):
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure('protein', pdb_file)

    amino_acids = []
    for model in structure:
        for chain in model:
            for residue in chain:
                if is_aa(residue):
                    amino_acids.append(residue)

    amino_acid_types = [index_to_one(three_to_index(residue.get_resname())) for residue in amino_acids]
    unique_amino_acids = list(set(amino_acid_types))

    data = HeteroData()

    node_features = {}
    node_positions = {}
    node_counter = 0

    # Initialize node features and positions
    for aa_type in unique_amino_acids:
        node_features[aa_type] = []
        node_positions[aa_type] = []

    for idx, (residue, aa_type) in enumerate(zip(amino_acids, amino_acid_types)):
        try:
            ca_atom = residue['CA']
            pos = ca_atom.get_coord()
        except KeyError:
            pos = [0.0, 0.0, 0.0]
        node_features[aa_type].append([residue_name_to_idx(aa_type)])
        node_positions[aa_type].append(pos)
        node_counter += 1

    for aa_type in unique_amino_acids:
        data[aa_type].x = torch.tensor(node_features[aa_type], dtype=torch.float)
        data[aa_type].pos = torch.tensor(np.array(node_positions[aa_type]), dtype=torch.float)

    # Build edges based on proximity
    contact_edge_index = {}
    edge_types = set()
    reverse_edge_types = set()

    # Mapping from global index to local index within node type
    global_to_local_idx = {}
    current_idx = {aa_type: 0 for aa_type in unique_amino_acids}

    for aa_type in amino_acid_types:
        global_to_local_idx[aa_type] = {}

    for idx, aa_type in enumerate(amino_acid_types):
        global_idx = idx
        local_idx = current_idx[aa_type]
        global_to_local_idx[aa_type][global_idx] = local_idx
        current_idx[aa_type] += 1

    num_residues = len(amino_acids)
    for i in range(num_residues):
        residue_i = amino_acids[i]
        aa_i = amino_acid_types[i]
        try:
            ca_i = residue_i['CA']
            pos_i = ca_i.get_coord()
        except KeyError:
            continue
        for j in range(i + 1, num_residues):  # Ensure j > i to avoid duplicates
            residue_j = amino_acids[j]
            aa_j = amino_acid_types[j]
            try:
                ca_j = residue_j['CA']
                pos_j = ca_j.get_coord()
            except KeyError:
                continue

            distance = np.linalg.norm(pos_i - pos_j)
            if distance <= threshold:
                # Define edge type in consistent order
                if aa_i <= aa_j:
                    edge_type = (aa_i, 'contact', aa_j)
                    src_aa, tgt_aa = aa_i, aa_j
                    src_global, tgt_global = i, j
                else:
                    edge_type = (aa_j, 'contact', aa_i)
                    src_aa, tgt_aa = aa_j, aa_i
                    src_global, tgt_global = j, i

                # Initialize edge list if not present
                if edge_type not in contact_edge_index:
                    contact_edge_index[edge_type] = []

                # Get local indices within their respective node types
                src_local = global_to_local_idx[src_aa][src_global]
                tgt_local = global_to_local_idx[tgt_aa][tgt_global]

                # Append edge
                contact_edge_index[edge_type].append([src_local, tgt_local])
                edge_types.add(edge_type)

    # Assign edges to HeteroData
    for edge_type, edges in contact_edge_index.items():
        if len(edges) > 0:
            # Extract the original source and target types
            src_type, relation, tgt_type = edge_type
    
            # Create reverse edge type
            reverse_edge_type = (tgt_type, relation, src_type)
            reverse_edge_types.add(reverse_edge_type)
    
            # Convert original edges to tensor
            edge_tensor = torch.tensor(edges, dtype=torch.long).t().contiguous()
    
            # Assign original edges to original edge type
            data[edge_type].edge_index = edge_tensor
    
            # Create reverse edges
            reverse_edges = [[tgt, src] for src, tgt in edges]
            reverse_edge_tensor = torch.tensor(reverse_edges, dtype=torch.long).t().contiguous()
    
            # Assign reverse edges to reverse edge type
            data[reverse_edge_type].edge_index = reverse_edge_tensor

    data.node_types = set(unique_amino_acids)
    data.edge_types = edge_types
    data.reverse_edge_types = reverse_edge_types

    return data

In [3]:
from tqdm import tqdm
from joblib import Parallel, delayed
import pandas as pd

def collect_protein_node_and_edge_types(protein_graphs):
    protein_node_types = set()
    protein_edge_types = set()
    for protein_data in protein_graphs.values():
        protein_node_types.update(protein_data.node_types)
        protein_edge_types.update(protein_data.edge_types)
        protein_edge_types.update(protein_data.reverse_edge_types)
    return sorted(protein_node_types), sorted(protein_edge_types)

In [4]:
# Load the dataset
df = pd.read_parquet('filtered_train.parquet')
import json

# Stratified splitting
train_df, temp_df = train_test_split(
    df,
    test_size=0.3,
    stratify=df['binds'],
    random_state=42
)

val_df, test_df = train_test_split(
    temp_df,
    test_size=0.5,
    stratify=temp_df['binds'],
    random_state=42
)

print("\nSplit completed.")
print(f"Training set size: {len(train_df)}")
print(f"Validation set size: {len(val_df)}")
print(f"Test set size: {len(test_df)}")

# Process and store protein graphs
protein_graphs = {}
protein_pdb_files = {
    'BRD4': './BRD4.pdb',
    'HSA': './ALB.pdb',
    'sEH': './EPH.pdb'
}

for protein_name, pdb_file in protein_pdb_files.items():
    if os.path.exists(pdb_file):
        protein_data = process_protein(pdb_file)
        protein_graphs[protein_name] = protein_data
    else:
        print(f"PDB file {pdb_file} for {protein_name} does not exist.")

# Load the unique atom and edge types from the JSON file
with open('unique_atom_and_edge_types.json', 'r') as f:
    unique_types = json.load(f)

# Extract molecule node and edge types
molecule_node_types = unique_types['molecule_node_types']
molecule_edge_types = [tuple(edge) for edge in unique_types['molecule_edge_types']]

# Now molecule_node_types and molecule_edge_types can be used in your code
print("Collected molecule node and edge types successfully.")

print("Collecting protein node and edge types...")
protein_node_types, protein_edge_types = collect_protein_node_and_edge_types(protein_graphs)


Split completed.
Training set size: 2169734
Validation set size: 464943
Test set size: 464943
Collected molecule node and edge types successfully.
Collecting protein node and edge types...


In [5]:
import torch.nn.functional as F

class CrossAttentionLayer(torch.nn.Module):
    def __init__(self, hidden_dim, num_heads=4):
        super(CrossAttentionLayer, self).__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads

        assert self.head_dim * num_heads == hidden_dim, "hidden_dim must be divisible by num_heads"

        self.W_Q = torch.nn.Linear(hidden_dim, hidden_dim)
        self.W_K = torch.nn.Linear(hidden_dim, hidden_dim)
        self.W_V = torch.nn.Linear(hidden_dim, hidden_dim)

        self.scale = self.head_dim ** 0.5

    def forward(self, query_nodes, key_nodes):
        # query_nodes: [N_q, hidden_dim]
        # key_nodes: [N_k, hidden_dim]

        N_q = query_nodes.size(0)
        N_k = key_nodes.size(0)

        Q = self.W_Q(query_nodes)  # [N_q, hidden_dim]
        K = self.W_K(key_nodes)    # [N_k, hidden_dim]
        V = self.W_V(key_nodes)    # [N_k, hidden_dim]

        # Reshape for multi-head attention
        Q = Q.view(N_q, self.num_heads, self.head_dim).permute(1, 0, 2)  # [num_heads, N_q, head_dim]
        K = K.view(N_k, self.num_heads, self.head_dim).permute(1, 0, 2)  # [num_heads, N_k, head_dim]
        V = V.view(N_k, self.num_heads, self.head_dim).permute(1, 0, 2)  # [num_heads, N_k, head_dim]

        # Compute attention scores
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale  # [num_heads, N_q, N_k]
        attn_weights = torch.softmax(attn_scores, dim=-1)                # [num_heads, N_q, N_k]

        # Compute attended values
        out = torch.matmul(attn_weights, V)  # [num_heads, N_q, head_dim]
        out = out.permute(1, 0, 2).contiguous().view(N_q, -1)  # [N_q, hidden_dim]

        return out

class CrossGraphAttentionModel(torch.nn.Module):
    def __init__(self, hidden_dim=64, num_attention_heads=4):
        super(CrossGraphAttentionModel, self).__init__()

        # print(molecule_edge_types)

        # Molecule GNN Encoder
        self.mol_conv1 = HeteroConv({
            edge_type: SAGEConv((-1, -1), hidden_dim)
            for edge_type in molecule_edge_types
        }, aggr='mean')

        self.mol_conv2 = HeteroConv({
            edge_type: SAGEConv((-1, -1), hidden_dim)
            for edge_type in molecule_edge_types
        }, aggr='mean')

        # Protein GNN Encoder
        self.prot_conv1 = HeteroConv({
            edge_type: SAGEConv((-1, -1), hidden_dim)
            for edge_type in protein_edge_types
        }, aggr='mean')

        self.prot_conv2 = HeteroConv({
            edge_type: SAGEConv((-1, -1), hidden_dim)
            for edge_type in protein_edge_types
        }, aggr='mean')

        # Cross-Attention Layers
        self.cross_attn_mol_to_prot = CrossAttentionLayer(hidden_dim, num_attention_heads)
        self.cross_attn_prot_to_mol = CrossAttentionLayer(hidden_dim, num_attention_heads)

        # Fully Connected Layers
        self.fc1 = Linear(hidden_dim * 2, hidden_dim)
        self.fc2 = Linear(hidden_dim, 1)

    def forward(self, mol_data, prot_data):
        # Molecule GNN Encoding
        x_mol_dict = mol_data.x_dict
        edge_index_mol_dict = mol_data.edge_index_dict

        x_mol_dict = self.mol_conv1(x_mol_dict, edge_index_mol_dict)
        x_mol_dict = {key: F.relu(x) for key, x in x_mol_dict.items()}

        x_mol_dict = self.mol_conv2(x_mol_dict, edge_index_mol_dict)
        x_mol_dict = {key: F.relu(x) for key, x in x_mol_dict.items()}

        # Concatenate molecule node embeddings
        mol_node_embeddings = []
        for nt in molecule_node_types:
            if nt in x_mol_dict:
                mol_node_embeddings.append(x_mol_dict[nt])
        H_mol = torch.cat(mol_node_embeddings, dim=0)

        # Protein GNN Encoding
        x_prot_dict = prot_data.x_dict
        edge_index_prot_dict = prot_data.edge_index_dict

        x_prot_dict = self.prot_conv1(x_prot_dict, edge_index_prot_dict)
        x_prot_dict = {key: F.relu(x) for key, x in x_prot_dict.items()}

        x_prot_dict = self.prot_conv2(x_prot_dict, edge_index_prot_dict)
        x_prot_dict = {key: F.relu(x) for key, x in x_prot_dict.items()}

        # Concatenate protein node embeddings
        prot_node_embeddings = []
        for nt in protein_node_types:
            if nt in x_prot_dict:
                prot_node_embeddings.append(x_prot_dict[nt])
        H_prot = torch.cat(prot_node_embeddings, dim=0)

        # Cross-Attention
        H_mol_attn = self.cross_attn_mol_to_prot(H_mol, H_prot)
        H_prot_attn = self.cross_attn_prot_to_mol(H_prot, H_mol)

        # Combine original and attended embeddings
        H_mol_combined = H_mol + H_mol_attn
        H_prot_combined = H_prot + H_prot_attn

        # # Global Pooling
        # mol_batch = mol_data.batch if hasattr(mol_data, 'batch') else torch.zeros(H_mol_combined.size(0), dtype=torch.long, device=H_mol_combined.device)
        # prot_batch = prot_data.batch if hasattr(prot_data, 'batch') else torch.zeros(H_prot_combined.size(0), dtype=torch.long, device=H_prot_combined.device)

        # Global Pooling
        # Use batch_dict to get batch information per node type
        mol_batches = torch.cat([mol_data.batch_dict[nt] for nt in molecule_node_types if nt in mol_data.batch_dict])
        prot_batches = torch.cat([prot_data.batch_dict[nt] for nt in protein_node_types if nt in prot_data.batch_dict])

        # z_mol = global_mean_pool(H_mol_combined, mol_batch)
        # z_prot = global_mean_pool(H_prot_combined, prot_batch)

        z_mol = global_mean_pool(H_mol_combined, mol_batches)
        z_prot = global_mean_pool(H_prot_combined, prot_batches)

        # Joint Representation
        z_joint = torch.cat([z_mol, z_prot], dim=1)

        # Prediction
        x = F.relu(self.fc1(z_joint))
        out = torch.sigmoid(self.fc2(x))

        return out.squeeze()

In [None]:
from datasets import CombinedDataset, MoleculeDataset

# Create datasets
train_dataset = CombinedDataset(train_df, protein_graphs)
val_dataset = CombinedDataset(val_df, protein_graphs)
test_dataset = CombinedDataset(test_df, protein_graphs)

In [None]:
from torch_geometric.loader import DataLoader as GeoDataLoader

# Custom collate function
def collate_fn(batch):
    mol_batch = [item[0] for item in batch if item is not None and item[0] is not None and item[0]['invalid'] is False]
    prot_batch = [item[1] for item in batch if item is not None and item[0] is not None and item[0]['invalid'] is False]

    mol_batch = Batch.from_data_list(mol_batch)
    prot_batch = Batch.from_data_list(prot_batch)

    return mol_batch, prot_batch

# Create data loaders
train_loader = GeoDataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=14)
val_loader = GeoDataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=14)
test_loader = GeoDataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=14)

In [None]:
def print_node_and_edge_info(data):
    print("\nNode and Edge Information:")
    
    # Iterate over all node types and print their counts
    for node_type in data.node_types:
        if node_type == 'smolecule':  # Ignore smolecule
            continue
        
        num_nodes = data[node_type].x.size(0)
        print(f"Node type: {node_type}, Number of nodes: {num_nodes}")
    
    # Iterate over all edge types and print relevant information
    for edge_type, edge_index in data.edge_index_dict.items():
        if 'smolecule' in edge_type:  # Ignore any edges involving smolecule
            continue
        
        src_type, _, tgt_type = edge_type
        max_src_idx = edge_index[0].max().item()
        max_tgt_idx = edge_index[1].max().item()

        num_src_nodes = data[src_type].x.size(0)
        num_tgt_nodes = data[tgt_type].x.size(0)

        print(f"Edge type: {edge_type}, Edge index shape: {edge_index.shape}")
        print(f"Max index in edge_index: src = {max_src_idx}, tgt = {max_tgt_idx}")
        print(f"Num src nodes ({src_type}): {num_src_nodes}, Num tgt nodes ({tgt_type}): {num_tgt_nodes}")

        # Validation check to identify if there are out-of-bound indices
        if max_src_idx >= num_src_nodes or max_tgt_idx >= num_tgt_nodes:
            print(f"Warning: Edge indices out of bounds for edge type {edge_type}")
            print(f"  Max src index: {max_src_idx} (Num src nodes: {num_src_nodes})")
            print(f"  Max tgt index: {max_tgt_idx} (Num tgt nodes: {num_tgt_nodes})")



In [None]:
# Instantiate the model
import warnings

warnings.filterwarnings("ignore", message="You are using `torch.load` with `weights_only=False`")

hidden_dim = 64
num_attention_heads = 4

model = CrossGraphAttentionModel(
    hidden_dim=hidden_dim,
    num_attention_heads=num_attention_heads
)
model = model.to(device)

# Loss function and optimizer
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Training and evaluation functions
def train_epoch():
    model.train()
    total_loss = 0
    for mol_data, prot_data in tqdm(train_loader, desc="Training"):
        optimizer.zero_grad()
        mol_data = mol_data.to(device)
        prot_data = prot_data.to(device)
        # print_node_and_edge_info(mol_data)
        out = model(mol_data, prot_data)
        y = mol_data['smolecule'].y.to(device)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

def evaluate(loader, mode='Validation'):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for mol_data, prot_data in tqdm(loader, desc=mode):
            mol_data = mol_data.to(device)
            prot_data = prot_data.to(device)
            out = model(mol_data, prot_data)
            y = mol_data['smolecule'].y.to(device)
            loss = criterion(out, y)
            total_loss += loss.item()
    return total_loss / len(loader)

# Training loop
num_epochs = 5

for epoch in range(1, num_epochs + 1):
    train_loss = train_epoch()
    val_loss = evaluate(val_loader, mode='Validation')
    print(f'Epoch: {epoch:02d}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')

# Save the model weights
torch.save(model.state_dict(), 'cross_graph_attention_model.pth')

# Prediction and evaluation on test data
def predict(loader):
    model.eval()
    predictions = []
    true_labels = []
    with torch.no_grad():
        for mol_data, prot_data in tqdm(loader, desc="Testing"):
            mol_data = mol_data.to(device)
            prot_data = prot_data.to(device)
            out = model(mol_data, prot_data)
            predictions.extend(out.cpu().numpy())
            true_labels.extend(mol_data['smolecule'].y.cpu().numpy())
    return predictions, true_labels

test_predictions, test_true = predict(test_loader)

# Apply a threshold to obtain binary predictions
threshold = 0.5
test_pred_binary = [1 if p >= threshold else 0 for p in test_predictions]

# Evaluate performance
accuracy = accuracy_score(test_true, test_pred_binary)
roc_auc = roc_auc_score(test_true, test_predictions)
precision = precision_score(test_true, test_pred_binary)
recall = recall_score(test_true, test_pred_binary)
f1 = f1_score(test_true, test_pred_binary)

print(f"\nTest Accuracy: {accuracy:.4f}")
print(f"Test ROC-AUC: {roc_auc:.4f}")
print(f"Test Precision: {precision:.4f}")
print(f"Test Recall: {recall:.4f}")
print(f"Test F1-Score: {f1:.4f}")

In [6]:
from torch_geometric.loader import DataLoader as GeoDataLoader
import warnings
from datasets import CombinedDataset, MoleculeDataset
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)
import csv

external_test_df = pd.read_parquet('cleaned_test.parquet')
# Specify the device
device = torch.device('cpu')  # Use 'cuda' if you have a GPU available

# Load the state dictionary
state_dict = torch.load('cross_graph_attention_model.pth', map_location=device, weights_only=True)
model = CrossGraphAttentionModel(hidden_dim=64, num_attention_heads=4)
state_dict = torch.load('cross_graph_attention_model.pth', map_location=device)
model.to(device)
model.load_state_dict(state_dict)
model.eval()

external_test_dataset = CombinedDataset(external_test_df, protein_graphs)

external_test_loader = GeoDataLoader(external_test_dataset, batch_size=16, shuffle=False, num_workers=14)

def external_predict(loader, output_csv_path):
    with open(output_csv_path, mode='w', newline='') as csv_file:
        fieldnames = ['id', 'binds']
        writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
        writer.writeheader()
        
        with torch.no_grad():
            for mol_data, prot_data in tqdm(loader, desc="Predicting"):
                node_types, _ = mol_data.metadata()
                if node_types[0] == 'dummy_data':
                    continue

                mol_data = mol_data.to(device)
                prot_data = prot_data.to(device)
                out = model(mol_data, prot_data)

                # Extract the 'id' from mol_data['smolecule']
                ids = mol_data['smolecule'].id.cpu().numpy()
                predictions = out.cpu().numpy()

                # Write the results directly to CSV
                for id_val, pred in zip(ids, predictions):
                    writer.writerow({'id': int(id_val), 'binds': float(pred)})

    print(f"Predictions have been written to {output_csv_path}")

external_predict(external_test_loader, output_csv_path='submissions.csv')

Predicting:   3%|█████▎                                                                                                                                                                | 3326/104681 [11:25<5:41:39,  4.94it/s]

Skipping index 51 due to error: Bad Conformer Id
Skipping index 51 due to error: Bad Conformer Id
Skipping index 51 due to error: Bad Conformer Id
Skipping index 54 due to error: Bad Conformer Id
Skipping index 54 due to error: Bad Conformer Id
Skipping index 54 due to error: Bad Conformer Id


Predicting:   3%|█████▎                                                                                                                                                                | 3326/104681 [11:26<5:48:46,  4.84it/s]


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 13 but got size 16 for tensor number 1 in the list.