In [None]:
import pyarrow.parquet as pq

parquet_file = pq.ParquetFile('train.parquet')
print(parquet_file.schema.names)


In [None]:
import dask.dataframe as dd

# Read train and test datasets
train_df = dd.read_parquet('train.parquet')
test_df = dd.read_parquet('test.parquet')

# Total number of rows in train dataset
total_rows = train_df.map_partitions(len).compute().sum()
print(f"Total number of rows: {total_rows}")

# Number of positive bindings
num_positive_bindings = train_df['binds'].sum().compute()
print(f"Number of positive bindings: {num_positive_bindings}")

# Number of negative bindings
num_negative_bindings = total_rows - num_positive_bindings
print(f"Number of negative bindings: {num_negative_bindings}")

# Percentage calculations
percent_positive = (num_positive_bindings / total_rows) * 100
percent_negative = (num_negative_bindings / total_rows) * 100
print(f"Percentage of positive bindings: {percent_positive:.2f}%")
print(f"Percentage of negative bindings: {percent_negative:.2f}%")

# Total unique proteins in train dataset
unique_proteins_train = train_df['protein_name'].dropna().unique().compute()
total_unique_proteins_train = len(unique_proteins_train)
print(f"Total unique proteins in train dataset: {total_unique_proteins_train}")

# Total unique proteins in test dataset
unique_proteins_test = test_df['protein_name'].dropna().unique().compute()
total_unique_proteins_test = len(unique_proteins_test)
print(f"Total unique proteins in test dataset: {total_unique_proteins_test}")

# Total unique proteins in both datasets
unique_proteins_all = dd.concat([
    train_df['protein_name'],
    test_df['protein_name']
]).dropna().unique().compute()
total_unique_proteins_all = len(unique_proteins_all)
print(f"Total unique proteins in both datasets: {total_unique_proteins_all}")

# Concatenate building block columns from both datasets
train_building_blocks = dd.concat([
    train_df['buildingblock1_smiles'],
    train_df['buildingblock2_smiles'],
    train_df['buildingblock3_smiles']
])

test_building_blocks = dd.concat([
    test_df['buildingblock1_smiles'],
    test_df['buildingblock2_smiles'],
    test_df['buildingblock3_smiles']
])

all_building_blocks = dd.concat([train_building_blocks, test_building_blocks])

# Compute unique building blocks
unique_building_blocks = all_building_blocks.dropna().unique().compute()
total_unique_building_blocks = len(unique_building_blocks)
print(f"Total unique building blocks (train and test): {total_unique_building_blocks}")

# Compute unique small molecules from train and test
train_small_molecules = train_df['molecule_smiles'].dropna()
test_small_molecules = test_df['molecule_smiles'].dropna()
all_small_molecules = dd.concat([train_small_molecules, test_small_molecules])

unique_small_molecules = all_small_molecules.unique().compute()
total_unique_small_molecules = len(unique_small_molecules)
print(f"Total unique small molecules (train and test): {total_unique_small_molecules}")

In [None]:
import pyarrow.parquet as pq
import pyarrow as pa
import pandas as pd
import os
from tqdm import tqdm
from collections import defaultdict

# Paths to input and output Parquet files
input_file = 'train.parquet'
output_file = 'filtered_train.parquet'

# Remove output file if it exists
if os.path.exists(output_file):
    os.remove(output_file)

# Open the Parquet file
pf = pq.ParquetFile(input_file)

# Get total number of row groups (batches)
total_row_groups = pf.num_row_groups

# First Pass: Build mapping of molecule_smiles to the set of proteins it binds to
print("First Pass: Building molecule to proteins mapping...")

# Initialize a dictionary to store mappings
molecule_binds = defaultdict(set)
all_proteins = set()

for rg in tqdm(range(total_row_groups), desc="Processing Batches"):
    # Read a row group with necessary columns
    batch = pf.read_row_group(rg, columns=['molecule_smiles', 'protein_name', 'binds'])
    df = batch.to_pandas()
    
    # Update all_proteins set
    all_proteins.update(df['protein_name'].unique())
    
    # Filter rows where 'binds' == 1
    df_binds_1 = df[df['binds'] == 1]
    
    # Update molecule_binds mapping
    for idx, row in df_binds_1.iterrows():
        molecule = row['molecule_smiles']
        protein = row['protein_name']
        molecule_binds[molecule].add(protein)
    
    # Clear variables to free memory
    del df, df_binds_1, batch

# Convert all_proteins to a list
all_proteins = list(all_proteins)

# Second Pass: Process data and write to output
print("Second Pass: Filtering dataset and writing to new Parquet file...")

# Initialize Parquet writer
writer = None

for rg in tqdm(range(total_row_groups), desc="Processing Batches"):
    # Read the row group
    batch = pf.read_row_group(rg)
    df = batch.to_pandas()
    
    # Filter molecules that have at least one binds == 1
    df = df[df['molecule_smiles'].isin(molecule_binds.keys())]
    
    if not df.empty:
        # Prepare to select rows to include
        rows_to_include = []
        
        # Process each molecule in the batch
        for molecule, group in df.groupby('molecule_smiles'):
            binds_1_proteins = molecule_binds[molecule]
            num_binds_1 = len(binds_1_proteins)
            
            if num_binds_1 > 1:
                # Include all rows for this molecule
                rows_to_include.append(group)
            elif num_binds_1 == 1:
                # Include the positive binding row
                positive_row = group[group['binds'] == 1]
                
                # Include one negative binding row
                unbound_proteins = set(all_proteins) - binds_1_proteins
                # Select one unbound protein
                unbound_protein = unbound_proteins.pop()
                negative_row = group[(group['protein_name'] == unbound_protein) & (group['binds'] == 0)]
                
                # If negative_row is empty, skp it
                if  not negative_row.empty:
                    # Append the positive and negative rows
                    rows_to_include.append(pd.concat([positive_row, negative_row]))

        
        if rows_to_include:
            filtered_df = pd.concat(rows_to_include)
            
            # Convert to PyArrow Table
            table = pa.Table.from_pandas(filtered_df)
            
            # Initialize the Parquet writer if not already done
            if writer is None:
                writer = pq.ParquetWriter(output_file, table.schema)
            
            # Write the table to the Parquet file
            writer.write_table(table)
            
            # Clear variables to free memory
            del table, filtered_df
    
    # Clear variables to free memory
    del df, batch

# Close the Parquet writer
if writer is not None:
    writer.close()

print("Filtering completed. Filtered dataset saved to 'filtered_train.parquet'.")

In [None]:
import dask.dataframe as dd

# Read train and test datasets
train_df = dd.read_parquet('filtered_train.parquet')
test_df = dd.read_parquet('test.parquet')

# Total number of rows in train dataset
total_rows = train_df.map_partitions(len).compute().sum()
print(f"Total number of rows: {total_rows}")

# Number of positive bindings
num_positive_bindings = train_df['binds'].sum().compute()
print(f"Number of positive bindings: {num_positive_bindings}")

# Number of negative bindings
num_negative_bindings = total_rows - num_positive_bindings
print(f"Number of negative bindings: {num_negative_bindings}")

# Percentage calculations
percent_positive = (num_positive_bindings / total_rows) * 100
percent_negative = (num_negative_bindings / total_rows) * 100
print(f"Percentage of positive bindings: {percent_positive:.2f}%")
print(f"Percentage of negative bindings: {percent_negative:.2f}%")

# Total unique proteins in train dataset
unique_proteins_train = train_df['protein_name'].dropna().unique().compute()
total_unique_proteins_train = len(unique_proteins_train)
print(f"Total unique proteins in train dataset: {total_unique_proteins_train}")

# Total unique proteins in test dataset
unique_proteins_test = test_df['protein_name'].dropna().unique().compute()
total_unique_proteins_test = len(unique_proteins_test)
print(f"Total unique proteins in test dataset: {total_unique_proteins_test}")

# Total unique proteins in both datasets
unique_proteins_all = dd.concat([
    train_df['protein_name'],
    test_df['protein_name']
]).dropna().unique().compute()
total_unique_proteins_all = len(unique_proteins_all)
print(f"Total unique proteins in both datasets: {total_unique_proteins_all}")
print(f"Unique proteins in both datasets: {unique_proteins_all.values}")

# Concatenate building block columns from both datasets
train_building_blocks = dd.concat([
    train_df['buildingblock1_smiles'],
    train_df['buildingblock2_smiles'],
    train_df['buildingblock3_smiles']
])

test_building_blocks = dd.concat([
    test_df['buildingblock1_smiles'],
    test_df['buildingblock2_smiles'],
    test_df['buildingblock3_smiles']
])

all_building_blocks = dd.concat([train_building_blocks, test_building_blocks])

# Compute unique building blocks
unique_building_blocks = all_building_blocks.dropna().unique().compute()
total_unique_building_blocks = len(unique_building_blocks)
print(f"Total unique building blocks (train and test): {total_unique_building_blocks}")

# Compute unique small molecules from train and test
train_small_molecules = train_df['molecule_smiles'].dropna()
test_small_molecules = test_df['molecule_smiles'].dropna()
all_small_molecules = dd.concat([train_small_molecules, test_small_molecules])

unique_small_molecules = all_small_molecules.unique().compute()
total_unique_small_molecules = len(unique_small_molecules)
print(f"Total unique small molecules (train and test): {total_unique_small_molecules}")

In [1]:
# Import necessary libraries
import pandas as pd
from sklearn.model_selection import train_test_split
from rdkit import Chem
from rdkit.Chem import AllChem
import torch
from torch_geometric.data import Dataset, DataLoader, Data
from torch_geometric.utils import from_networkx
import networkx as nx
import os
from tqdm import tqdm
import numpy as np

# Set random seed for reproducibility
import random
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)


<torch._C.Generator at 0x31b7426b0>

In [2]:
# Load the dataset
# Assuming the dataset is in a Parquet file named 'filtered_train.parquet'
df = pd.read_parquet('filtered_train.parquet')

# Check the distribution of 'binds'
print("Distribution of 'binds' in the dataset:")
print(df['binds'].value_counts())

# Stratified splitting
train_df, temp_df = train_test_split(
    df,
    test_size=0.3,
    stratify=df['binds'],
    random_state=42
)

val_df, test_df = train_test_split(
    temp_df,
    test_size=0.5,
    stratify=temp_df['binds'],
    random_state=42
)

print("\nSplit completed.")
print(f"Training set size: {len(train_df)}")
print(f"Validation set size: {len(val_df)}")
print(f"Test set size: {len(test_df)}")


Distribution of 'binds' in the dataset:
binds
1    1589903
0    1509717
Name: count, dtype: int64

Split completed.
Training set size: 2169734
Validation set size: 464943
Test set size: 464943


In [3]:
# Import necessary libraries
import pandas as pd
from sklearn.model_selection import train_test_split
from rdkit import Chem
from rdkit.Chem import AllChem
import torch
from torch_geometric.data import Dataset, HeteroData, DataLoader
from Bio.PDB import PDBParser, is_aa
from Bio.PDB.Polypeptide import three_to_index, index_to_one
import os
from tqdm import tqdm
import numpy as np
import random
from torch_geometric.nn import GCNConv, global_mean_pool
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score

# Set random seed and device
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
def residue_name_to_idx(res_name_one):
    amino_acids = ['A', 'R', 'N', 'D', 'C', 'E', 'Q', 'G',
                   'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S',
                   'T', 'W', 'Y', 'V']
    if res_name_one in amino_acids:
        return amino_acids.index(res_name_one)
    else:
        return len(amino_acids)  # Unknown amino acid

def process_protein(pdb_file, threshold=5.0):
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure('protein', pdb_file)

    amino_acids = []
    for model in structure:
        for chain in model:
            for residue in chain:
                if is_aa(residue):
                    amino_acids.append(residue)

    amino_acid_types = [index_to_one(three_to_index(residue.get_resname())) for residue in amino_acids]
    unique_amino_acids = list(set(amino_acid_types))

    data = HeteroData()

    node_features = {}
    node_positions = {}
    for aa_type in unique_amino_acids:
        node_features[aa_type] = []
        node_positions[aa_type] = []

    for residue, aa_type in zip(amino_acids, amino_acid_types):
        try:
            ca_atom = residue['CA']
            pos = ca_atom.get_coord()
        except KeyError:
            pos = [0.0, 0.0, 0.0]
        node_features[aa_type].append([residue_name_to_idx(aa_type)])
        node_positions[aa_type].append(pos)

    for aa_type in unique_amino_acids:
        data[aa_type].x = torch.tensor(node_features[aa_type], dtype=torch.float)
        data[aa_type].pos = torch.tensor(node_positions[aa_type], dtype=torch.float)

    # Build edges based on proximity
    contact_edge_index = {}
    for src_aa in unique_amino_acids:
        for tgt_aa in unique_amino_acids:
            contact_edge_index[(src_aa, 'contact', tgt_aa)] = []

    for i, residue_i in enumerate(amino_acids):
        try:
            ca_i = residue_i['CA']
            pos_i = ca_i.get_coord()
        except KeyError:
            continue
        for j, residue_j in enumerate(amino_acids):
            if i == j:
                continue
            try:
                ca_j = residue_j['CA']
                pos_j = ca_j.get_coord()
            except KeyError:
                continue
            distance = np.linalg.norm(pos_i - pos_j)
            if distance <= threshold:
                aa_i = amino_acid_types[i]
                aa_j = amino_acid_types[j]
                contact_edge_index[(aa_i, 'contact', aa_j)].append([i, j])

    for edge_type, edges in contact_edge_index.items():
        if len(edges) > 0:
            edge_tensor = torch.tensor(edges, dtype=torch.long).t().contiguous()
            data[edge_type].edge_index = edge_tensor


    return data

In [5]:
class MoleculeDataset(Dataset):
    def __init__(self, dataframe, protein_graphs, transform=None, pre_transform=None):
        self.dataframe = dataframe.reset_index(drop=True)
        self.protein_graphs = protein_graphs
        super(MoleculeDataset, self).__init__(None, transform, pre_transform)

    def len(self):
        return len(self.dataframe)

    def get(self, idx):
        row = self.dataframe.iloc[idx]
        smiles = row['molecule_smiles']
        binds = row['binds']
        protein_name = row['protein_name']

        # Convert SMILES to molecular graph
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None  # Skip invalid SMILES

        mol = Chem.AddHs(mol)
        AllChem.EmbedMolecule(mol, randomSeed=42)

        atom_types = [atom.GetSymbol() for atom in mol.GetAtoms()]
        unique_atom_types = list(set(atom_types))

        data = HeteroData()

        atom_type_to_indices = {atype: [] for atype in unique_atom_types}
        atom_features = []
        atom_positions = []
        for i, atom in enumerate(mol.GetAtoms()):
            atype = atom.GetSymbol()
            atom_type_to_indices[atype].append(i)
            atom_features.append(self.get_atom_features(atom))
            pos = mol.GetConformer().GetAtomPosition(i)
            atom_positions.append([pos.x, pos.y, pos.z])

        # Assign node features and positions per atom type
        for atype in unique_atom_types:
            idx = atom_type_to_indices[atype]
            x = torch.tensor([atom_features[i] for i in idx], dtype=torch.float)
            pos = torch.tensor([atom_positions[i] for i in idx], dtype=torch.float)
            data[atype].x = x
            data[atype].pos = pos

        # Precompute mapping from global atom index to local index within atom type
        atom_type_to_local_idx = {
            atype: {global_idx: local_idx for local_idx, global_idx in enumerate(idxs)}
            for atype, idxs in atom_type_to_indices.items()
        }

        # Assign bond edges to specific edge types based on atom types
        bond_edges = {}
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            atype_i = atom_types[i]
            atype_j = atom_types[j]
            edge_type = (atype_i, 'bond', atype_j)
            
            if edge_type not in bond_edges:
                bond_edges[edge_type] = {'edge_index': [], 'edge_attr': []}
            
            # Retrieve local indices using precomputed mapping
            src_local = atom_type_to_local_idx[atype_i][i]
            tgt_local = atom_type_to_local_idx[atype_j][j]
            
            # Append both directions for undirected bonds
            bond_edges[edge_type]['edge_index'].append([src_local, tgt_local])
            bond_edges[edge_type]['edge_index'].append([tgt_local, src_local])
            
            # Append bond features for both directions
            bond_feature = self.get_bond_features(bond)
            bond_edges[edge_type]['edge_attr'].append(bond_feature)
            bond_edges[edge_type]['edge_attr'].append(bond_feature)

        # Assign bond edges to HeteroData
        for edge_type, attrs in bond_edges.items():
            edge_index = torch.tensor(attrs['edge_index'], dtype=torch.long).t().contiguous()
            edge_attr = torch.tensor(attrs['edge_attr'], dtype=torch.float)
            data[edge_type].edge_index = edge_index
            data[edge_type].edge_attr = edge_attr

        # Add binding label and metadata
        data['smolecule'].y = torch.tensor([binds], dtype=torch.float)
        data['smolecule'].smiles = smiles
        data['smolecule'].protein_name = protein_name

        return data

    @staticmethod
    def get_atom_features(atom):
        return [
            atom.GetAtomicNum(),
            atom.GetTotalDegree(),
            atom.GetFormalCharge(),
            atom.GetHybridization().real,
            int(atom.GetIsAromatic())
        ]

    @staticmethod
    def get_bond_features(bond):
        bond_type = bond.GetBondType()
        bond_dict = {
            Chem.rdchem.BondType.SINGLE: 0,
            Chem.rdchem.BondType.DOUBLE: 1,
            Chem.rdchem.BondType.TRIPLE: 2,
            Chem.rdchem.BondType.AROMATIC: 3
        }

        # Additional features
        bond_length = bond.GetBondLength()

        return [
            bond_dict.get(bond_type, -1),
            bond_length
        ]

In [6]:
class CombinedDataset(Dataset):
    def __init__(self, dataframe, protein_graphs, transform=None, pre_transform=None):
        self.dataframe = dataframe.reset_index(drop=True)
        self.protein_graphs = protein_graphs
        super(CombinedDataset, self).__init__(None, transform, pre_transform)

    
    def len(self):
        return len(self.dataframe)
    
    def get(self, idx):
        row = self.dataframe.iloc[idx]
        smiles = row['molecule_smiles']
        binds = row['binds']
        protein_name = row['protein_name']
        
        # Process molecule
        mol_dataset = MoleculeDataset(pd.DataFrame([row]), self.protein_graphs)
        molecule_data = mol_dataset.get(0)
        
        # Process protein
        protein_data = self.protein_graphs.get(protein_name, HeteroData())
        
        # Combine molecule and protein data
        data = HeteroData()
        
        # Add molecule node types
        for node_type in molecule_data.node_types:
            data[node_type].x = molecule_data[node_type].x
            data[node_type].pos = molecule_data[node_type].pos
        
        # Add molecule bond edges
        for edge_type in molecule_data.edge_types:
            data[edge_type].edge_index = molecule_data[edge_type].edge_index
            data[edge_type].edge_attr = molecule_data[edge_type].edge_attr
        
        # Add protein node types
        for node_type in protein_data.node_types:
            if node_type in data.node_types:
                data[node_type].x = torch.cat([data[node_type].x, protein_data[node_type].x], dim=0)
                data[node_type].pos = torch.cat([data[node_type].pos, protein_data[node_type].pos], dim=0)
            else:
                data[node_type].x = protein_data[node_type].x
                data[node_type].pos = protein_data[node_type].pos
        
        # Add protein contact edges
        for edge_type in protein_data.edge_types:
            data[edge_type].edge_index = protein_data[edge_type].edge_index
        
        # Add binding label and metadata
        data['smolecule'].y = torch.tensor([binds], dtype=torch.float)
        data['smolecule'].smiles = smiles
        data['smolecule'].protein_name = protein_name
        
        return data

In [7]:
def collect_molecule_node_and_edge_types(df):
    molecule_node_types = set()
    molecule_edge_types = set()
    for idx, row in df.iterrows():
        smiles = row['molecule_smiles']
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            continue
        mol = Chem.AddHs(mol)
        atom_types = [atom.GetSymbol() for atom in mol.GetAtoms()]
        molecule_node_types.update(atom_types)
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            atype_i = atom_types[i]
            atype_j = atom_types[j]
            edge_type = (atype_i, 'bond', atype_j)
            molecule_edge_types.add(edge_type)
    return sorted(molecule_node_types), sorted(molecule_edge_types)

def collect_protein_node_and_edge_types(protein_graphs):
    protein_node_types = set()
    protein_edge_types = set()
    for protein_data in protein_graphs.values():
        protein_node_types.update(protein_data.node_types)
        protein_edge_types.update(protein_data.edge_types)
    return sorted(protein_node_types), sorted(protein_edge_types)


In [None]:
import torch
from torch_geometric.nn import HeteroConv, GCNConv, Linear, global_mean_pool

class CrossGraphAttentionModel(torch.nn.Module):
    def __init__(self, molecule_node_types, protein_node_types, edge_types, hidden_dim=64):
        super(CrossGraphAttentionModel, self).__init__()

        self.molecule_node_types = molecule_node_types
        self.protein_node_types = protein_node_types
        self.edge_types = edge_types

        # Create HeteroConv layers
        self.conv1 = HeteroConv({
            edge_type: GCNConv((-1, -1), hidden_dim)
            for edge_type in self.edge_types
        }, aggr='mean')

        self.conv2 = HeteroConv({
            edge_type: GCNConv((-1, -1), hidden_dim)
            for edge_type in self.edge_types
        }, aggr='mean')

        # Classification layers
        self.fc1 = Linear(hidden_dim * 2, hidden_dim)
        self.fc2 = Linear(hidden_dim, 1)

    def forward(self, data):
        x_dict = data.x_dict
        edge_index_dict = data.edge_index_dict

        # First convolution layer
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {key: torch.relu(x) for key, x in x_dict.items()}

        # Second convolution layer
        x_dict = self.conv2(x_dict, edge_index_dict)
        x_dict = {key: torch.relu(x) for key, x in x_dict.items()}

        # Global pooling for molecule and protein nodes separately
        mol_embeddings = []
        for nt in self.molecule_node_types:
            if nt in x_dict:
                x = x_dict[nt]
                batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
                mol_embeddings.append(global_mean_pool(x, batch))

        if mol_embeddings:
            mol_rep = torch.mean(torch.stack(mol_embeddings), dim=0)
        else:
            mol_rep = torch.zeros((1, self.fc1.in_features // 2), device=data['smolecule'].y.device)

        prot_embeddings = []
        for nt in self.protein_node_types:
            if nt in x_dict:
                x = x_dict[nt]
                batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
                prot_embeddings.append(global_mean_pool(x, batch))

        if prot_embeddings:
            prot_rep = torch.mean(torch.stack(prot_embeddings), dim=0)
        else:
            prot_rep = torch.zeros((1, self.fc1.in_features // 2), device=data['smolecule'].y.device)

        # Combine representations
        combined = torch.cat([mol_rep, prot_rep], dim=1)

        # Classification
        x = torch.relu(self.fc1(combined))
        out = torch.sigmoid(self.fc2(x))

        return out.squeeze()


In [None]:
# Load the dataset
df = pd.read_parquet('filtered_train.parquet')

# Stratified splitting
train_df, temp_df = train_test_split(
    df,
    test_size=0.3,
    stratify=df['binds'],
    random_state=42
)

val_df, test_df = train_test_split(
    temp_df,
    test_size=0.5,
    stratify=temp_df['binds'],
    random_state=42
)

# Process and store protein graphs
protein_graphs = {}
protein_pdb_files = {
    'BRD4': './BRD4.pdb',
    'HSA': './ALB.pdb',
    'sEH': './EPH.pdb'
}

for protein_name, pdb_file in protein_pdb_files.items():
    if os.path.exists(pdb_file):
        protein_data = process_protein(pdb_file)
        protein_graphs[protein_name] = protein_data
    else:
        print(f"PDB file {pdb_file} for {protein_name} does not exist.")

# Create datasets
train_dataset = CombinedDataset(train_df, protein_graphs)
val_dataset = CombinedDataset(val_df, protein_graphs)
test_dataset = CombinedDataset(test_df, protein_graphs)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

# Instantiate the model
hidden_dim = 64
model = CrossGraphAttentionModel(hidden_dim=hidden_dim)
model = model.to(device)

# Loss function and optimizer
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Training loop
num_epochs = 20

def train_epoch():
    model.train()
    total_loss = 0
    for data in tqdm(train_loader, desc="Training"):
        optimizer.zero_grad()
        data = data.to(device)
        out = model(data)
        y = data['smolecule'].y
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

def evaluate(loader, mode='Validation'):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for data in tqdm(loader, desc=mode):
            data = data.to(device)
            out = model(data)
            y = data['smolecule'].y
            loss = criterion(out, y)
            total_loss += loss.item()
    return total_loss / len(loader)

for epoch in range(1, num_epochs + 1):
    train_loss = train_epoch()
    val_loss = evaluate(val_loader, mode='Validation')
    print(f'Epoch: {epoch:02d}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')

# Prediction on test data
def predict(loader):
    model.eval()
    predictions = []
    true_labels = []
    with torch.no_grad():
        for data in tqdm(loader, desc="Testing"):
            data = data.to(device)
            out = model(data)
            predictions.extend(out.cpu().numpy())
            true_labels.extend(data['smolecule'].y.cpu().numpy())
    return predictions, true_labels

test_predictions, test_true = predict(test_loader)

# Apply a threshold to obtain binary predictions
threshold = 0.5
test_pred_binary = [1 if p >= threshold else 0 for p in test_predictions]

# Evaluate performance
accuracy = accuracy_score(test_true, test_pred_binary)
roc_auc = roc_auc_score(test_true, test_predictions)
precision = precision_score(test_true, test_pred_binary)
recall = recall_score(test_true, test_pred_binary)
f1 = f1_score(test_true, test_pred_binary)

print(f"\nTest Accuracy: {accuracy:.4f}")
print(f"Test ROC-AUC: {roc_auc:.4f}")
print(f"Test Precision: {precision:.4f}")
print(f"Test Recall: {recall:.4f}")
print(f"Test F1-Score: {f1:.4f}")

