In [None]:
import pyarrow.parquet as pq

parquet_file = pq.ParquetFile('train.parquet')
print(parquet_file.schema.names)


In [None]:
import dask.dataframe as dd

# Read train and test datasets
train_df = dd.read_parquet('train.parquet')
test_df = dd.read_parquet('test.parquet')

# Total number of rows in train dataset
total_rows = train_df.map_partitions(len).compute().sum()
print(f"Total number of rows: {total_rows}")

# Number of positive bindings
num_positive_bindings = train_df['binds'].sum().compute()
print(f"Number of positive bindings: {num_positive_bindings}")

# Number of negative bindings
num_negative_bindings = total_rows - num_positive_bindings
print(f"Number of negative bindings: {num_negative_bindings}")

# Percentage calculations
percent_positive = (num_positive_bindings / total_rows) * 100
percent_negative = (num_negative_bindings / total_rows) * 100
print(f"Percentage of positive bindings: {percent_positive:.2f}%")
print(f"Percentage of negative bindings: {percent_negative:.2f}%")

# Total unique proteins in train dataset
unique_proteins_train = train_df['protein_name'].dropna().unique().compute()
total_unique_proteins_train = len(unique_proteins_train)
print(f"Total unique proteins in train dataset: {total_unique_proteins_train}")

# Total unique proteins in test dataset
unique_proteins_test = test_df['protein_name'].dropna().unique().compute()
total_unique_proteins_test = len(unique_proteins_test)
print(f"Total unique proteins in test dataset: {total_unique_proteins_test}")

# Total unique proteins in both datasets
unique_proteins_all = dd.concat([
    train_df['protein_name'],
    test_df['protein_name']
]).dropna().unique().compute()
total_unique_proteins_all = len(unique_proteins_all)
print(f"Total unique proteins in both datasets: {total_unique_proteins_all}")

# Concatenate building block columns from both datasets
train_building_blocks = dd.concat([
    train_df['buildingblock1_smiles'],
    train_df['buildingblock2_smiles'],
    train_df['buildingblock3_smiles']
])

test_building_blocks = dd.concat([
    test_df['buildingblock1_smiles'],
    test_df['buildingblock2_smiles'],
    test_df['buildingblock3_smiles']
])

all_building_blocks = dd.concat([train_building_blocks, test_building_blocks])

# Compute unique building blocks
unique_building_blocks = all_building_blocks.dropna().unique().compute()
total_unique_building_blocks = len(unique_building_blocks)
print(f"Total unique building blocks (train and test): {total_unique_building_blocks}")

# Compute unique small molecules from train and test
train_small_molecules = train_df['molecule_smiles'].dropna()
test_small_molecules = test_df['molecule_smiles'].dropna()
all_small_molecules = dd.concat([train_small_molecules, test_small_molecules])

unique_small_molecules = all_small_molecules.unique().compute()
total_unique_small_molecules = len(unique_small_molecules)
print(f"Total unique small molecules (train and test): {total_unique_small_molecules}")

In [None]:
import pyarrow.parquet as pq
import pyarrow as pa
import pandas as pd
import os
from tqdm import tqdm
from collections import defaultdict

# Paths to input and output Parquet files
input_file = 'train.parquet'
output_file = 'filtered_train.parquet'

# Remove output file if it exists
if os.path.exists(output_file):
    os.remove(output_file)

# Open the Parquet file
pf = pq.ParquetFile(input_file)

# Get total number of row groups (batches)
total_row_groups = pf.num_row_groups

# First Pass: Build mapping of molecule_smiles to the set of proteins it binds to
print("First Pass: Building molecule to proteins mapping...")

# Initialize a dictionary to store mappings
molecule_binds = defaultdict(set)
all_proteins = set()

for rg in tqdm(range(total_row_groups), desc="Processing Batches"):
    # Read a row group with necessary columns
    batch = pf.read_row_group(rg, columns=['molecule_smiles', 'protein_name', 'binds'])
    df = batch.to_pandas()
    
    # Update all_proteins set
    all_proteins.update(df['protein_name'].unique())
    
    # Filter rows where 'binds' == 1
    df_binds_1 = df[df['binds'] == 1]
    
    # Update molecule_binds mapping
    for idx, row in df_binds_1.iterrows():
        molecule = row['molecule_smiles']
        protein = row['protein_name']
        molecule_binds[molecule].add(protein)
    
    # Clear variables to free memory
    del df, df_binds_1, batch

# Convert all_proteins to a list
all_proteins = list(all_proteins)

# Second Pass: Process data and write to output
print("Second Pass: Filtering dataset and writing to new Parquet file...")

# Initialize Parquet writer
writer = None

for rg in tqdm(range(total_row_groups), desc="Processing Batches"):
    # Read the row group
    batch = pf.read_row_group(rg)
    df = batch.to_pandas()
    
    # Filter molecules that have at least one binds == 1
    df = df[df['molecule_smiles'].isin(molecule_binds.keys())]
    
    if not df.empty:
        # Prepare to select rows to include
        rows_to_include = []
        
        # Process each molecule in the batch
        for molecule, group in df.groupby('molecule_smiles'):
            binds_1_proteins = molecule_binds[molecule]
            num_binds_1 = len(binds_1_proteins)
            
            if num_binds_1 > 1:
                # Include all rows for this molecule
                rows_to_include.append(group)
            elif num_binds_1 == 1:
                # Include the positive binding row
                positive_row = group[group['binds'] == 1]
                
                # Include one negative binding row
                unbound_proteins = set(all_proteins) - binds_1_proteins
                # Select one unbound protein
                unbound_protein = unbound_proteins.pop()
                negative_row = group[(group['protein_name'] == unbound_protein) & (group['binds'] == 0)]
                
                # If negative_row is empty, skp it
                if  not negative_row.empty:
                    # Append the positive and negative rows
                    rows_to_include.append(pd.concat([positive_row, negative_row]))

        
        if rows_to_include:
            filtered_df = pd.concat(rows_to_include)
            
            # Convert to PyArrow Table
            table = pa.Table.from_pandas(filtered_df)
            
            # Initialize the Parquet writer if not already done
            if writer is None:
                writer = pq.ParquetWriter(output_file, table.schema)
            
            # Write the table to the Parquet file
            writer.write_table(table)
            
            # Clear variables to free memory
            del table, filtered_df
    
    # Clear variables to free memory
    del df, batch

# Close the Parquet writer
if writer is not None:
    writer.close()

print("Filtering completed. Filtered dataset saved to 'filtered_train.parquet'.")

In [None]:
import dask.dataframe as dd

# Read train and test datasets
train_df = dd.read_parquet('filtered_train.parquet')
test_df = dd.read_parquet('test.parquet')

# Total number of rows in train dataset
total_rows = train_df.map_partitions(len).compute().sum()
print(f"Total number of rows: {total_rows}")

# Number of positive bindings
num_positive_bindings = train_df['binds'].sum().compute()
print(f"Number of positive bindings: {num_positive_bindings}")

# Number of negative bindings
num_negative_bindings = total_rows - num_positive_bindings
print(f"Number of negative bindings: {num_negative_bindings}")

# Percentage calculations
percent_positive = (num_positive_bindings / total_rows) * 100
percent_negative = (num_negative_bindings / total_rows) * 100
print(f"Percentage of positive bindings: {percent_positive:.2f}%")
print(f"Percentage of negative bindings: {percent_negative:.2f}%")

# Total unique proteins in train dataset
unique_proteins_train = train_df['protein_name'].dropna().unique().compute()
total_unique_proteins_train = len(unique_proteins_train)
print(f"Total unique proteins in train dataset: {total_unique_proteins_train}")

# Total unique proteins in test dataset
unique_proteins_test = test_df['protein_name'].dropna().unique().compute()
total_unique_proteins_test = len(unique_proteins_test)
print(f"Total unique proteins in test dataset: {total_unique_proteins_test}")

# Total unique proteins in both datasets
unique_proteins_all = dd.concat([
    train_df['protein_name'],
    test_df['protein_name']
]).dropna().unique().compute()
total_unique_proteins_all = len(unique_proteins_all)
print(f"Total unique proteins in both datasets: {total_unique_proteins_all}")
print(f"Unique proteins in both datasets: {unique_proteins_all.values}")

# Concatenate building block columns from both datasets
train_building_blocks = dd.concat([
    train_df['buildingblock1_smiles'],
    train_df['buildingblock2_smiles'],
    train_df['buildingblock3_smiles']
])

test_building_blocks = dd.concat([
    test_df['buildingblock1_smiles'],
    test_df['buildingblock2_smiles'],
    test_df['buildingblock3_smiles']
])

all_building_blocks = dd.concat([train_building_blocks, test_building_blocks])

# Compute unique building blocks
unique_building_blocks = all_building_blocks.dropna().unique().compute()
total_unique_building_blocks = len(unique_building_blocks)
print(f"Total unique building blocks (train and test): {total_unique_building_blocks}")

# Compute unique small molecules from train and test
train_small_molecules = train_df['molecule_smiles'].dropna()
test_small_molecules = test_df['molecule_smiles'].dropna()
all_small_molecules = dd.concat([train_small_molecules, test_small_molecules])

unique_small_molecules = all_small_molecules.unique().compute()
total_unique_small_molecules = len(unique_small_molecules)
print(f"Total unique small molecules (train and test): {total_unique_small_molecules}")

In [None]:
# Import necessary libraries
import pandas as pd
from sklearn.model_selection import train_test_split
from rdkit import Chem
from rdkit.Chem import AllChem
import torch
from torch_geometric.data import Dataset, DataLoader, Data
from torch_geometric.utils import from_networkx
import networkx as nx
import os
from tqdm import tqdm
import numpy as np

# Set random seed for reproducibility
import random
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)


In [None]:
# Load the dataset
# Assuming the dataset is in a Parquet file named 'filtered_train.parquet'
df = pd.read_parquet('filtered_train.parquet')

# Check the distribution of 'binds'
print("Distribution of 'binds' in the dataset:")
print(df['binds'].value_counts())

# Stratified splitting
train_df, temp_df = train_test_split(
    df,
    test_size=0.3,
    stratify=df['binds'],
    random_state=42
)

val_df, test_df = train_test_split(
    temp_df,
    test_size=0.5,
    stratify=temp_df['binds'],
    random_state=42
)


In [28]:
# Import necessary libraries
import pandas as pd
from sklearn.model_selection import train_test_split
from rdkit import Chem
from rdkit.Chem import AllChem
import torch
from torch_geometric.data import Dataset, HeteroData, DataLoader
from Bio.PDB import PDBParser, is_aa
from Bio.PDB.Polypeptide import three_to_index, index_to_one
import os
from tqdm import tqdm
import numpy as np
import random
from torch_geometric.nn import SAGEConv, global_mean_pool
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score

# Set random seed and device
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [29]:
import os
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Dataset, HeteroData
from torch_geometric.loader import DataLoader
from torch_geometric.nn import HeteroConv, GCNConv, Linear, global_mean_pool
from torch.utils.data import random_split
from rdkit import Chem
from rdkit.Chem import AllChem
from Bio.PDB import PDBParser, is_aa
from Bio.PDB.Polypeptide import three_to_index, index_to_one
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score
from tqdm import tqdm

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Helper functions
def residue_name_to_idx(res_name_one):
    amino_acids = ['A', 'R', 'N', 'D', 'C', 'E', 'Q', 'G',
                   'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S',
                   'T', 'W', 'Y', 'V']
    if res_name_one in amino_acids:
        return amino_acids.index(res_name_one)
    else:
        return len(amino_acids)  # Unknown amino acid

def process_protein(pdb_file, threshold=5.0):
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure('protein', pdb_file)

    amino_acids = []
    for model in structure:
        for chain in model:
            for residue in chain:
                if is_aa(residue):
                    amino_acids.append(residue)

    amino_acid_types = [index_to_one(three_to_index(residue.get_resname())) for residue in amino_acids]
    unique_amino_acids = list(set(amino_acid_types))

    data = HeteroData()

    node_features = {}
    node_positions = {}
    node_counter = 0

    # Initialize node features and positions
    for aa_type in unique_amino_acids:
        node_features[aa_type] = []
        node_positions[aa_type] = []

    for idx, (residue, aa_type) in enumerate(zip(amino_acids, amino_acid_types)):
        try:
            ca_atom = residue['CA']
            pos = ca_atom.get_coord()
        except KeyError:
            pos = [0.0, 0.0, 0.0]
        node_features[aa_type].append([residue_name_to_idx(aa_type)])
        node_positions[aa_type].append(pos)
        node_counter += 1

    for aa_type in unique_amino_acids:
        data[aa_type].x = torch.tensor(node_features[aa_type], dtype=torch.float)
        data[aa_type].pos = torch.tensor(np.array(node_positions[aa_type]), dtype=torch.float)

    # Build edges based on proximity
    contact_edge_index = {}
    edge_types = set()

    # Mapping from global index to local index within node type
    global_to_local_idx = {}
    current_idx = {aa_type: 0 for aa_type in unique_amino_acids}

    for aa_type in amino_acid_types:
        global_to_local_idx[aa_type] = {}

    for idx, aa_type in enumerate(amino_acid_types):
        global_idx = idx
        local_idx = current_idx[aa_type]
        global_to_local_idx[aa_type][global_idx] = local_idx
        current_idx[aa_type] += 1

    num_residues = len(amino_acids)
    for i in range(num_residues):
        residue_i = amino_acids[i]
        aa_i = amino_acid_types[i]
        try:
            ca_i = residue_i['CA']
            pos_i = ca_i.get_coord()
        except KeyError:
            continue
        for j in range(i + 1, num_residues):  # Ensure j > i to avoid duplicates
            residue_j = amino_acids[j]
            aa_j = amino_acid_types[j]
            try:
                ca_j = residue_j['CA']
                pos_j = ca_j.get_coord()
            except KeyError:
                continue

            distance = np.linalg.norm(pos_i - pos_j)
            if distance <= threshold:
                # Define edge type in consistent order
                if aa_i <= aa_j:
                    edge_type = (aa_i, 'contact', aa_j)
                    src_aa, tgt_aa = aa_i, aa_j
                    src_global, tgt_global = i, j
                else:
                    edge_type = (aa_j, 'contact', aa_i)
                    src_aa, tgt_aa = aa_j, aa_i
                    src_global, tgt_global = j, i

                # Initialize edge list if not present
                if edge_type not in contact_edge_index:
                    contact_edge_index[edge_type] = []

                # Get local indices within their respective node types
                src_local = global_to_local_idx[src_aa][src_global]
                tgt_local = global_to_local_idx[tgt_aa][tgt_global]

                # Append edge
                contact_edge_index[edge_type].append([src_local, tgt_local])
                edge_types.add(edge_type)

    # Assign edges to HeteroData
    for edge_type, edges in contact_edge_index.items():
        if len(edges) > 0:
            # Since the graph is undirected, add reverse edges
            reverse_edges = [[tgt, src] for src, tgt in edges]
            all_edges = edges + reverse_edges
            edge_tensor = torch.tensor(all_edges, dtype=torch.long).t().contiguous()
            data[edge_type].edge_index = edge_tensor

    data.node_types = set(unique_amino_acids)
    data.edge_types = edge_types

    return data

In [3]:
class MoleculeDataset(Dataset):
    def __init__(self, dataframe, transform=None, pre_transform=None):
        self.dataframe = dataframe.reset_index(drop=True)
        super(MoleculeDataset, self).__init__(None, transform, pre_transform)

    def len(self):
        return len(self.dataframe)

    def get(self, idx):
        row = self.dataframe.iloc[idx]
        smiles = row['molecule_smiles']
        binds = row['binds']
        protein_name = row['protein_name']

        # Convert SMILES to molecular graph
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None  # Skip invalid SMILES

        mol = Chem.AddHs(mol)
        AllChem.EmbedMolecule(mol, randomSeed=42)

        atom_types = [atom.GetSymbol() for atom in mol.GetAtoms()]
        unique_atom_types = list(set(atom_types))

        data = HeteroData()

        atom_type_to_indices = {atype: [] for atype in unique_atom_types}
        atom_features = []
        atom_positions = []
        for i, atom in enumerate(mol.GetAtoms()):
            atype = atom.GetSymbol()
            atom_type_to_indices[atype].append(i)
            atom_features.append(self.get_atom_features(atom))
            pos = mol.GetConformer().GetAtomPosition(i)
            atom_positions.append(np.array([pos.x, pos.y, pos.z], dtype=np.float32))

        # Assign node features and positions per atom type
        for atype in unique_atom_types:
            idx = atom_type_to_indices[atype]
            x = torch.tensor([atom_features[i] for i in idx], dtype=torch.float)
            pos = torch.tensor(np.array([atom_positions[i] for i in idx]), dtype=torch.float)
            data[atype].x = x
            data[atype].pos = pos

        # Precompute mapping from global atom index to local index within atom type
        atom_type_to_local_idx = {
            atype: {global_idx: local_idx for local_idx, global_idx in enumerate(idxs)}
            for atype, idxs in atom_type_to_indices.items()
        }

        # Assign bond edges to specific edge types based on atom types
        bond_edges = {}
        edge_types = set()
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            atype_i = atom_types[i]
            atype_j = atom_types[j]
            
            # Define edge type in consistent order
            if atype_i <= atype_j:
                edge_type = (atype_i, 'bond', atype_j)
                src_atype, tgt_atype = atype_i, atype_j
                src_idx, tgt_idx = i, j
            else:
                edge_type = (atype_j, 'bond', atype_i)
                src_atype, tgt_atype = atype_j, atype_i
                src_idx, tgt_idx = j, i
            edge_types.add(edge_type)
            if edge_type not in bond_edges:
                bond_edges[edge_type] = {'edge_index': [], 'edge_attr': []}

            # Retrieve local indices using precomputed mapping
            src_local = atom_type_to_local_idx[src_atype][src_idx]
            tgt_local = atom_type_to_local_idx[tgt_atype][tgt_idx]

            # Append both directions for undirected bonds
            bond_edges[edge_type]['edge_index'].append([src_local, tgt_local])
            bond_edges[edge_type]['edge_index'].append([tgt_local, src_local])

            # Append bond features for both directions
            bond_feature = self.get_bond_features(bond)
            bond_edges[edge_type]['edge_attr'].append(bond_feature)
            bond_edges[edge_type]['edge_attr'].append(bond_feature)

        # Assign bond edges to HeteroData
        for edge_type, attrs in bond_edges.items():
            edge_index = torch.tensor(attrs['edge_index'], dtype=torch.long).t().contiguous()
            edge_attr = torch.tensor(attrs['edge_attr'], dtype=torch.float)
            data[edge_type].edge_index = edge_index
            data[edge_type].edge_attr = edge_attr

        # Add binding label and metadata
        data['smolecule'].y = torch.tensor([binds], dtype=torch.float)
        data['smolecule'].smiles = smiles
        data['smolecule'].protein_name = protein_name

        data.node_types = set(unique_atom_types)
        data.edge_types = set(bond_edges.keys())

        return data

    @staticmethod
    def get_atom_features(atom):
        return [
            atom.GetAtomicNum(),
            atom.GetTotalDegree(),
            atom.GetFormalCharge(),
            atom.GetHybridization().real,
            int(atom.GetIsAromatic())
        ]

    @staticmethod
    def get_bond_features(bond):
        bond_type = bond.GetBondType()
        bond_dict = {
            Chem.rdchem.BondType.SINGLE: 0,
            Chem.rdchem.BondType.DOUBLE: 1,
            Chem.rdchem.BondType.TRIPLE: 2,
            Chem.rdchem.BondType.AROMATIC: 3
        }

        return [
            bond_dict.get(bond_type, -1)
        ]


In [8]:
class CombinedDataset(Dataset):
    def __init__(self, dataframe, protein_graphs, transform=None, pre_transform=None, cache_dir='./processed'):
        self.dataframe = dataframe.reset_index(drop=True)
        self.protein_graphs = protein_graphs
        self.cache_dir = cache_dir
        os.makedirs(self.cache_dir, exist_ok=True)
        super(CombinedDataset, self).__init__(None, transform, pre_transform)

    def len(self):
        return len(self.dataframe)

    def get(self, idx):
        processed_file = os.path.join(self.cache_dir, f'data_{idx}.pt')
        if os.path.exists(processed_file):
            molecule_data, protein_data = torch.load(processed_file)
        else:
            row = self.dataframe.iloc[idx]
            smiles = row['molecule_smiles']
            binds = row['binds']
            protein_name = row['protein_name']

            mol_dataset = MoleculeDataset(pd.DataFrame([row]))
            molecule_data = mol_dataset.get(0)
            if molecule_data is None:
                return None

            protein_data = self.protein_graphs.get(protein_name, HeteroData())

            molecule_data.y = torch.tensor([binds], dtype=torch.float)
            molecule_data.smiles = smiles
            molecule_data.protein_name = protein_name

            torch.save((molecule_data, protein_data), processed_file)

        return molecule_data, protein_data

In [9]:
from tqdm import tqdm

def collect_molecule_node_and_edge_types(df):
    molecule_node_types = set()
    molecule_edge_types = set()
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Collecting molecule types"):
        smiles = row['molecule_smiles']
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            continue
        mol = Chem.AddHs(mol)
        atom_types = [atom.GetSymbol() for atom in mol.GetAtoms()]
        molecule_node_types.update(atom_types)
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            atype_i = atom_types[i]
            atype_j = atom_types[j]
            if atype_i <= atype_j:
                edge_type = (atype_i, 'bond', atype_j)
            else:
                edge_type = (atype_j, 'bond', atype_i)
            molecule_edge_types.add(edge_type)
    return sorted(molecule_node_types), sorted(molecule_edge_types)

def collect_protein_node_and_edge_types(protein_graphs):
    protein_node_types = set()
    protein_edge_types = set()
    for protein_data in protein_graphs.values():
        protein_node_types.update(protein_data.node_types)
        protein_edge_types.update(protein_data.edge_types)
    return sorted(protein_node_types), sorted(protein_edge_types)

In [30]:
class CrossAttentionLayer(torch.nn.Module):
    def __init__(self, hidden_dim, num_heads=4):
        super(CrossAttentionLayer, self).__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads

        assert self.head_dim * num_heads == hidden_dim, "hidden_dim must be divisible by num_heads"

        self.W_Q = torch.nn.Linear(hidden_dim, hidden_dim)
        self.W_K = torch.nn.Linear(hidden_dim, hidden_dim)
        self.W_V = torch.nn.Linear(hidden_dim, hidden_dim)

        self.scale = self.head_dim ** 0.5

    def forward(self, query_nodes, key_nodes):
        # query_nodes: [N_q, hidden_dim]
        # key_nodes: [N_k, hidden_dim]

        N_q = query_nodes.size(0)
        N_k = key_nodes.size(0)

        Q = self.W_Q(query_nodes)  # [N_q, hidden_dim]
        K = self.W_K(key_nodes)    # [N_k, hidden_dim]
        V = self.W_V(key_nodes)    # [N_k, hidden_dim]

        # Reshape for multi-head attention
        Q = Q.view(N_q, self.num_heads, self.head_dim).permute(1, 0, 2)  # [num_heads, N_q, head_dim]
        K = K.view(N_k, self.num_heads, self.head_dim).permute(1, 0, 2)  # [num_heads, N_k, head_dim]
        V = V.view(N_k, self.num_heads, self.head_dim).permute(1, 0, 2)  # [num_heads, N_k, head_dim]

        # Compute attention scores
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale  # [num_heads, N_q, N_k]
        attn_weights = torch.softmax(attn_scores, dim=-1)                # [num_heads, N_q, N_k]

        # Compute attended values
        out = torch.matmul(attn_weights, V)  # [num_heads, N_q, head_dim]
        out = out.permute(1, 0, 2).contiguous().view(N_q, -1)  # [N_q, hidden_dim]

        return out

class CrossGraphAttentionModel(torch.nn.Module):
    def __init__(self, hidden_dim=64, num_attention_heads=4):
        super(CrossGraphAttentionModel, self).__init__()

        # Molecule GNN Encoder
        self.mol_conv1 = HeteroConv({
            edge_type: SAGEConv((-1, -1), hidden_dim)
            for edge_type in molecule_edge_types
        }, aggr='mean')

        self.mol_conv2 = HeteroConv({
            edge_type: SAGEConv((-1, -1), hidden_dim)
            for edge_type in molecule_edge_types
        }, aggr='mean')

        # Protein GNN Encoder
        self.prot_conv1 = HeteroConv({
            edge_type: SAGEConv((-1, -1), hidden_dim)
            for edge_type in protein_edge_types
        }, aggr='mean')

        self.prot_conv2 = HeteroConv({
            edge_type: SAGEConv((-1, -1), hidden_dim)
            for edge_type in protein_edge_types
        }, aggr='mean')

        # Cross-Attention Layers
        self.cross_attn_mol_to_prot = CrossAttentionLayer(hidden_dim, num_attention_heads)
        self.cross_attn_prot_to_mol = CrossAttentionLayer(hidden_dim, num_attention_heads)

        # Fully Connected Layers
        self.fc1 = Linear(hidden_dim * 2, hidden_dim)
        self.fc2 = Linear(hidden_dim, 1)

    def forward(self, mol_data, prot_data):
        # Molecule GNN Encoding
        x_mol_dict = mol_data.x_dict
        edge_index_mol_dict = mol_data.edge_index_dict

        x_mol_dict = self.mol_conv1(x_mol_dict, edge_index_mol_dict)
        x_mol_dict = {key: F.relu(x) for key, x in x_mol_dict.items()}

        x_mol_dict = self.mol_conv2(x_mol_dict, edge_index_mol_dict)
        x_mol_dict = {key: F.relu(x) for key, x in x_mol_dict.items()}

        # Concatenate molecule node embeddings
        mol_node_embeddings = []
        for nt in molecule_node_types:
            if nt in x_mol_dict:
                mol_node_embeddings.append(x_mol_dict[nt])
        H_mol = torch.cat(mol_node_embeddings, dim=0)

        # Protein GNN Encoding
        x_prot_dict = prot_data.x_dict
        edge_index_prot_dict = prot_data.edge_index_dict

        x_prot_dict = self.prot_conv1(x_prot_dict, edge_index_prot_dict)
        x_prot_dict = {key: F.relu(x) for key, x in x_prot_dict.items()}

        x_prot_dict = self.prot_conv2(x_prot_dict, edge_index_prot_dict)
        x_prot_dict = {key: F.relu(x) for key, x in x_prot_dict.items()}

        # Concatenate protein node embeddings
        prot_node_embeddings = []
        for nt in protein_node_types:
            if nt in x_prot_dict:
                prot_node_embeddings.append(x_prot_dict[nt])
        H_prot = torch.cat(prot_node_embeddings, dim=0)

        # Cross-Attention
        H_mol_attn = self.cross_attn_mol_to_prot(H_mol, H_prot)
        H_prot_attn = self.cross_attn_prot_to_mol(H_prot, H_mol)

        # Combine original and attended embeddings
        H_mol_combined = H_mol + H_mol_attn
        H_prot_combined = H_prot + H_prot_attn

        # Global Pooling
        mol_batch = mol_data.batch if hasattr(mol_data, 'batch') else torch.zeros(H_mol_combined.size(0), dtype=torch.long, device=H_mol_combined.device)
        prot_batch = prot_data.batch if hasattr(prot_data, 'batch') else torch.zeros(H_prot_combined.size(0), dtype=torch.long, device=H_prot_combined.device)

        z_mol = global_mean_pool(H_mol_combined, mol_batch)
        z_prot = global_mean_pool(H_prot_combined, prot_batch)

        # Joint Representation
        z_joint = torch.cat([z_mol, z_prot], dim=1)

        # Prediction
        x = F.relu(self.fc1(z_joint))
        out = torch.sigmoid(self.fc2(x))

        return out.squeeze()

In [None]:
# Load the dataset
df = pd.read_parquet('filtered_train.parquet')

# Stratified splitting
train_df, temp_df = train_test_split(
    df,
    test_size=0.3,
    stratify=df['binds'],
    random_state=42
)

val_df, test_df = train_test_split(
    temp_df,
    test_size=0.5,
    stratify=temp_df['binds'],
    random_state=42
)

print("\nSplit completed.")
print(f"Training set size: {len(train_df)}")
print(f"Validation set size: {len(val_df)}")
print(f"Test set size: {len(test_df)}")

# Process and store protein graphs
protein_graphs = {}
protein_pdb_files = {
    'BRD4': './BRD4.pdb',
    'HSA': './ALB.pdb',
    'sEH': './EPH.pdb'
}

for protein_name, pdb_file in protein_pdb_files.items():
    if os.path.exists(pdb_file):
        protein_data = process_protein(pdb_file)
        protein_graphs[protein_name] = protein_data
    else:
        print(f"PDB file {pdb_file} for {protein_name} does not exist.")

# Collect node and edge types with progress tracking
print("Collecting molecule node and edge types...")
molecule_node_types, molecule_edge_types = collect_molecule_node_and_edge_types(df)

print("Collecting protein node and edge types...")
protein_node_types, protein_edge_types = collect_protein_node_and_edge_types(protein_graphs)




Split completed.
Training set size: 2169734
Validation set size: 464943
Test set size: 464943
Collecting molecule node and edge types...


Collecting molecule types:   0%|        | 796/3099620 [00:00<18:54, 2730.99it/s]

In [None]:
from datasets import CombinedDataset, MoleculeDataset

# Create datasets
train_dataset = CombinedDataset(train_df, protein_graphs)
val_dataset = CombinedDataset(val_df, protein_graphs)
test_dataset = CombinedDataset(test_df, protein_graphs)

In [None]:
from torch_geometric.loader import DataLoader as GeoDataLoader

# Custom collate function
def collate_fn(batch):
    mol_batch = [item[0] for item in batch if item is not None and item[0] is not None]
    prot_batch = [item[1] for item in batch if item is not None and item[0] is not None]

    mol_batch = Batch.from_data_list(mol_batch)
    prot_batch = Batch.from_data_list(prot_batch)

    return mol_batch, prot_batch

# Create data loaders
train_loader = GeoDataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, collate_fn=collate_fn)
val_loader = GeoDataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4, collate_fn=collate_fn)
test_loader = GeoDataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4, collate_fn=collate_fn)

In [None]:
# Instantiate the model
hidden_dim = 64
num_attention_heads = 4

model = CrossGraphAttentionModel(
    hidden_dim=hidden_dim,
    num_attention_heads=num_attention_heads
)
model = model.to(device)

# Loss function and optimizer
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Training and evaluation functions
def train_epoch():
    model.train()
    total_loss = 0
    for mol_data, prot_data in tqdm(train_loader, desc="Training"):
        optimizer.zero_grad()
        mol_data = mol_data.to(device)
        prot_data = prot_data.to(device)
        out = model(mol_data, prot_data)
        y = mol_data.y.to(device)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

def evaluate(loader, mode='Validation'):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for mol_data, prot_data in tqdm(loader, desc=mode):
            mol_data = mol_data.to(device)
            prot_data = prot_data.to(device)
            out = model(mol_data, prot_data)
            y = mol_data.y.to(device)
            loss = criterion(out, y)
            total_loss += loss.item()
    return total_loss / len(loader)

# Training loop
num_epochs = 20

for epoch in range(1, num_epochs + 1):
    train_loss = train_epoch()
    val_loss = evaluate(val_loader, mode='Validation')
    print(f'Epoch: {epoch:02d}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')

# Save the model weights
torch.save(model.state_dict(), 'cross_graph_attention_model.pth')

# Prediction and evaluation on test data
def predict(loader):
    model.eval()
    predictions = []
    true_labels = []
    with torch.no_grad():
        for mol_data, prot_data in tqdm(loader, desc="Testing"):
            mol_data = mol_data.to(device)
            prot_data = prot_data.to(device)
            out = model(mol_data, prot_data)
            predictions.extend(out.cpu().numpy())
            true_labels.extend(mol_data.y.cpu().numpy())
    return predictions, true_labels

test_predictions, test_true = predict(test_loader)

# Apply a threshold to obtain binary predictions
threshold = 0.5
test_pred_binary = [1 if p >= threshold else 0 for p in test_predictions]

# Evaluate performance
accuracy = accuracy_score(test_true, test_pred_binary)
roc_auc = roc_auc_score(test_true, test_predictions)
precision = precision_score(test_true, test_pred_binary)
recall = recall_score(test_true, test_pred_binary)
f1 = f1_score(test_true, test_pred_binary)

print(f"\nTest Accuracy: {accuracy:.4f}")
print(f"Test ROC-AUC: {roc_auc:.4f}")
print(f"Test Precision: {precision:.4f}")
print(f"Test Recall: {recall:.4f}")
print(f"Test F1-Score: {f1:.4f}")