In [2]:
import pandas as pd
import pyarrow.parquet as pq
import numpy as np

# Read the file in chunks
def process_chunk(chunk, unique_building_blocks, unique_molecules):
    # Update the unique building blocks set
    unique_building_blocks.update(chunk['buildingblock1_smiles'].unique())
    unique_building_blocks.update(chunk['buildingblock2_smiles'].unique())
    unique_building_blocks.update(chunk['buildingblock3_smiles'].unique())
    
    # Update the unique molecules set
    unique_molecules.update(chunk['molecule_smiles'].unique())

# Load the parquet file
file_path = './train.parquet'
batch_size = 100000
parquet_file = pq.ParquetFile(file_path)

# Initialize sets to keep track of unique building blocks and molecules
unique_building_blocks = set()
unique_molecules = set()

# Iterate over the parquet file in batches
num_row_groups = parquet_file.num_row_groups

for i in range(num_row_groups):
    # Read a batch of rows
    row_group = parquet_file.read_row_group(i).to_pandas()

    if i == 0:
        print("First few rows from the first row group:")
        print(row_group.head())
    
    # Process the current chunk
    process_chunk(row_group, unique_building_blocks, unique_molecules)

# Output the total unique counts
print(f"Total number of unique building blocks: {len(unique_building_blocks)}")
print(f"Total number of unique molecules: {len(unique_molecules)}")

First few rows from the first row group:
   id                            buildingblock1_smiles buildingblock2_smiles  \
0   0  C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21  C#CCOc1ccc(CN)cc1.Cl   
1   1  C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21  C#CCOc1ccc(CN)cc1.Cl   
2   2  C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21  C#CCOc1ccc(CN)cc1.Cl   
3   3  C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21  C#CCOc1ccc(CN)cc1.Cl   
4   4  C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21  C#CCOc1ccc(CN)cc1.Cl   

     buildingblock3_smiles                                    molecule_smiles  \
0  Br.Br.NCC1CCCN1c1cccnn1  C#CCOc1ccc(CNc2nc(NCC3CCCN3c3cccnn3)nc(N[C@@H]...   
1  Br.Br.NCC1CCCN1c1cccnn1  C#CCOc1ccc(CNc2nc(NCC3CCCN3c3cccnn3)nc(N[C@@H]...   
2  Br.Br.NCC1CCCN1c1cccnn1  C#CCOc1ccc(CNc2nc(NCC3CCCN3c3cccnn3)nc(N[C@@H]...   
3        Br.NCc1cccc(Br)n1  C#CCOc1ccc(CNc2nc(NCc3cccc(Br)n3)nc(N[C@@H](CC...   
4        Br.NCc1cccc(Br)n1  C#CCOc1ccc(CNc2nc(NCc3cccc(Br)n3)nc(N[C@@H](C

# Exhaustive List of Features for Small Molecules

## Molecular Descriptors:
- Molecular weight
- Number of atoms
- Number of bonds
- Number of aromatic rings
- Number of rotatable bonds
- Topological polar surface area (TPSA)
- LogP (octanol-water partition coefficient)

## Atom-Level Features:
- Atom types (e.g., C, H, O, N, S)
- Hybridization states (sp, sp2, sp3)
- Formal charge
- Aromaticity
- Degree (number of bonds to the atom)
- Implicit and explicit hydrogen counts
- Chirality

## Bond-Level Features:
- Bond types (single, double, triple, aromatic)
- Conjugation
- Ring membership
- Stereo configuration (cis/trans)

## Graph-Based Features:
- Adjacency matrix
- Distance matrix
- Graph Laplacian

## Physicochemical Properties:
- Hydrogen bond donors and acceptors
- Molecular refractivity
- Molar volume
- Electronegativity
- Electron affinity

## Structural Fingerprints:
- MACCS keys
- Morgan fingerprints
- ECFP (Extended Connectivity Fingerprints)
- RDKIT fingerprints


In [7]:
from rdkit import Chem
from rdkit.Chem import Descriptors, rdMolDescriptors, AllChem, rdmolops
from rdkit.DataStructs import ConvertToNumpyArray
import numpy as np

# Define encoding schemes outside the class
ATOM_TYPES = ['C', 'H', 'O', 'N', 'S', 'F', 'Cl', 'Br', 'I', 'P', 'B']
HYBRIDIZATION_STATES = ['SP', 'SP2', 'SP3', 'SP3D', 'SP3D2']
CHIRAL_TAGS = ['CHI_UNSPECIFIED', 'CHI_TETRAHEDRAL_CW', 'CHI_TETRAHEDRAL_CCW', 'CHI_OTHER']
BOND_TYPES = ['SINGLE', 'DOUBLE', 'TRIPLE', 'AROMATIC']
STEREO_CONFIGURATIONS = ['STEREONONE', 'STEREOZ', 'STEREOE', 'STEREOCIS', 'STEREOTRANS']

class SmallMoleculeFeatureExtractor:
    def __init__(self, smiles):
        self.smiles = smiles
        self.mol = Chem.MolFromSmiles(smiles)

    def get_molecular_descriptors(self):
        descriptors = {
            'molecular_weight': Descriptors.MolWt(self.mol),
            'num_atoms': self.mol.GetNumAtoms(),
            'num_bonds': self.mol.GetNumBonds(),
            'num_aromatic_rings': rdMolDescriptors.CalcNumAromaticRings(self.mol),
            'num_rotatable_bonds': Descriptors.NumRotatableBonds(self.mol),
            'tpsa': Descriptors.TPSA(self.mol),
            'logp': Descriptors.MolLogP(self.mol)
        }
        return descriptors

    def one_hot_encode(self, value, categories):
        encoding = [0] * len(categories)
        if value in categories:
            encoding[categories.index(value)] = 1
        return encoding

    def get_atom_level_features(self):
        atom_features = []
        for atom in self.mol.GetAtoms():
            atom_features.append([
                self.one_hot_encode(atom.GetSymbol(), ATOM_TYPES),
                self.one_hot_encode(str(atom.GetHybridization()), HYBRIDIZATION_STATES),
                atom.GetFormalCharge(),
                atom.GetIsAromatic(),
                atom.GetDegree(),
                atom.GetImplicitValence(),
                atom.GetTotalNumHs(),
                self.one_hot_encode(str(atom.GetChiralTag()), CHIRAL_TAGS)
            ])
        return atom_features

    def get_bond_level_features(self):
        bond_features = []
        for bond in self.mol.GetBonds():
            bond_features.append([
                self.one_hot_encode(str(bond.GetBondType()), BOND_TYPES),
                bond.GetIsConjugated(),
                bond.IsInRing(),
                self.one_hot_encode(str(bond.GetStereo()), STEREO_CONFIGURATIONS)
            ])
        return bond_features

    def get_graph_based_features(self):
        adj_matrix = rdmolops.GetAdjacencyMatrix(self.mol)
        dist_matrix = rdmolops.GetDistanceMatrix(self.mol)
        return {
            'adjacency_matrix': adj_matrix,
            'distance_matrix': dist_matrix,
        }

    def get_physicochemical_properties(self):
        properties = {
            'h_bond_donors': Descriptors.NumHDonors(self.mol),
            'h_bond_acceptors': Descriptors.NumHAcceptors(self.mol),
            'molecular_refractivity': Descriptors.MolMR(self.mol),
            'molar_volume': Descriptors.MolLogP(self.mol) / Descriptors.MolWt(self.mol)
        }
        return properties

    def get_structural_fingerprints(self):
        maccs_keys = AllChem.GetMACCSKeysFingerprint(self.mol)
        morgan_fp = AllChem.GetMorganFingerprintAsBitVect(self.mol, 2)
        rdk_fp = Chem.RDKFingerprint(self.mol)

        maccs_keys_np = np.zeros((1,))
        ConvertToNumpyArray(maccs_keys, maccs_keys_np)

        morgan_fp_np = np.zeros((1,))
        ConvertToNumpyArray(morgan_fp, morgan_fp_np)

        rdk_fp_np = np.zeros((1,))
        ConvertToNumpyArray(rdk_fp, rdk_fp_np)
        
        return {
            'maccs_keys': maccs_keys_np,
            'morgan_fp': morgan_fp_np,
            'rdkit_fp': rdk_fp_np
        }

    def extract_features(self):
        features = {
            'molecular_descriptors': self.get_molecular_descriptors(),
            'atom_level_features': self.get_atom_level_features(),
            'bond_level_features': self.get_bond_level_features(),
            'graph_based_features': self.get_graph_based_features(),
            'physicochemical_properties': self.get_physicochemical_properties(),
            'structural_fingerprints': self.get_structural_fingerprints()
        }
        return features

    def flatten_features(self):
        # Extract individual features
        molecular_descriptors = self.get_molecular_descriptors()
        physicochemical_properties = self.get_physicochemical_properties()
        structural_fingerprints = self.get_structural_fingerprints()
        graph_based_features = self.get_graph_based_features()
    
        # Flatten the structural fingerprints
        flattened_structural_fingerprints = np.concatenate([
            structural_fingerprints['maccs_keys'],
            structural_fingerprints['morgan_fp'],
            structural_fingerprints['rdkit_fp']
        ])
    
        # Convert molecular descriptors and physicochemical properties to arrays
        molecular_descriptors_array = np.array(list(molecular_descriptors.values()))
        physicochemical_properties_array = np.array(list(physicochemical_properties.values()))

        # Extract adjacency and distance matrices
        adjacency_matrix = graph_based_features['adjacency_matrix']
        distance_matrix = graph_based_features['distance_matrix']
    
        return {
            'molecular_descriptors': molecular_descriptors_array,
            'physicochemical_properties': physicochemical_properties_array,
            'structural_fingerprints': flattened_structural_fingerprints,
            'adjacency_matrix': adjacency_matrix,
            'distance_matrix': distance_matrix
        }


In [8]:
smiles = "C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21"
extractor = SmallMoleculeFeatureExtractor(smiles)
features = extractor.flatten_features()
for feature, value in features.items():
    print(f"{feature}: {value}")

molecular_descriptors: [349.386   26.      28.       2.       6.      75.63     3.3917]
physicochemical_properties: [2.00000000e+00 3.00000000e+00 9.76955000e+01 9.70760133e-03]
structural_fingerprints: [0. 0. 0. ... 1. 1. 1.]
adjacency_matrix: [[0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 1 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 1 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 1 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 

# Feature Extraction from Protein Structure PDB File

## Structural Features

### Amino Acid Composition:
- Frequency of each amino acid type in the binding site.
- Frequency of amino acid types in the entire protein.

### Secondary Structure:
- Percentage of alpha-helices, beta-sheets, and random coils in the binding site.
- Secondary structure elements around the binding site.

### Tertiary Structure:
- 3D coordinates of the binding site.
- Distance between key residues in the binding site.

### Binding Site Characteristics:
- Volume and surface area of the binding site.
- Shape descriptors (e.g., sphericity, elongation).

## Physicochemical Properties

### Hydrophobicity:
- Hydrophobic and hydrophilic residue distribution in the binding site.
- Hydrophobic surface area.

### Charge Distribution:
- Number and type of charged residues (positive and negative).
- Electrostatic potential distribution.

### Polarity:
- Number of polar residues.
- Polar surface area.

### Solvent Accessibility:
- Solvent-accessible surface area (SASA) of residues in the binding site.

### Hydrogen Bonding:
- Number of potential hydrogen bond donors and acceptors.
- Hydrogen bond network in the binding site.

### Van der Waals Interactions:
- Van der Waals interaction potential of the binding site.

## Geometric Features

### Distance Metrics:
- Pairwise distances between all residues in the binding site.
- Distance to the nearest surface residue.

### Angles and Dihedrals:
- Angles and dihedral angles between residues in the binding site.

## Chemical Environment

### Residue Environment:
- Local chemical environment of each residue (e.g., neighboring residues within a certain radius).

### Ligand Interaction Sites:
- Specific interaction sites for known ligands (if available).

## Dynamic Properties

### Flexibility:
- B-factors or temperature factors indicating residue flexibility.

### Molecular Dynamics Simulations:
- Root mean square fluctuation (RMSF) of residues in the binding site.
- Conformational changes over time.

## Topological Features

### Graph-based Features:
- Protein structure represented as a graph with nodes (residues) and edges (interactions).
- Degree centrality, betweenness centrality, and clustering coefficient of residues in the binding site.

## Energy-based Features

### Binding Energy:
- Estimated binding free energy of known ligands.
- Energy components (van der Waals, electrostatic, solvation) from docking simulations.

## Protein-Ligand Interaction Features

### Docking Scores:
- Scores from molecular docking simulations with various ligands.

### Interaction Profiles:
- Interaction fingerprints summarizing the types and strengths of interactions with ligands.

## Evolutionary Features

### Conservation:
- Sequence conservation of residues in the binding site (e.g., from multiple sequence alignment).

### Mutational Impact:
- Predicted impact of mutations on binding site residues.

## Experimental Data

### Experimental Binding Data:
- Known binding affinities (e.g., Kd, Ki, IC50) for small molecules.

## Contextual Features

### Functional Annotations:
- Biological function and pathway involvement of the protein.
- Known protein-protein interactions.

## Integration and Representation

### Feature Scaling and Normalization:
- Standardize and normalize features for input into the deep learning model.


In [14]:
from Bio.PDB import PDBParser, is_aa, NeighborSearch
import numpy as np
import networkx as nx

class ProteinFeatureExtractor:
    def __init__(self, pdb_file):
        self.pdb_file = pdb_file
        self.structure = self.load_structure()
        self.ligand_resnames = self.detect_ligands()
        self.graph = self.construct_graph()

    def load_structure(self):
        # Load the PDB structure
        parser = PDBParser()
        structure = parser.get_structure('protein', self.pdb_file)
        return structure

    def detect_ligands(self):
        # Detect ligand residue names by excluding standard amino acids and water
        ligands = set()
        standard_amino_acids = {'ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE', 
                                'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL'}
        water_residues = {'HOH'}
        for residue in self.structure.get_residues():
            resname = residue.resname
            if resname not in standard_amino_acids and resname not in water_residues:
                ligands.add(resname)
        return list(ligands)

    def get_amino_acid_composition(self):
        # Get the composition of amino acids in the protein
        amino_acids = [residue.resname for residue in self.structure.get_residues() if residue.id[0] == ' ']
        aa_counts = {aa: amino_acids.count(aa) for aa in set(amino_acids)}
        return aa_counts


    def get_flexibility(self):
        # Calculate the flexibility of the protein based on B-factors
        flexibility = []
        for atom in self.structure.get_atoms():
            flexibility.append(atom.bfactor)
        return np.mean(flexibility)

    def get_distance_metrics(self):
        # Calculate distance metrics between residues in the protein
        distances = []
        for chain in self.structure.get_chains():
            print(f"Processing chain: {chain.id}")
            residues = [res for res in chain if 'CA' in res.child_dict]  # Filter residues with 'CA' atom
            for i, res1 in enumerate(residues):
                ca1 = res1.child_dict.get('CA')
                if ca1 is None:
                    print(f"Residue {res1} does not have a CA atom.")
                    continue
                for j, res2 in enumerate(residues):
                    if i < j:
                        ca2 = res2.child_dict.get('CA')
                        if ca2 is None:
                            print(f"Residue {res2} does not have a CA atom.")
                            continue
                        try:
                            distance = ca1 - ca2
                            distances.append(distance)
                        except KeyError as e:
                            print(f"Error calculating distance: {e}")
        return distances

    def construct_graph(self, cutoff=4.0):
        # Initialize an undirected graph
        G = nx.Graph()

        # Add nodes for each residue
        for chain in self.structure.get_chains():
            for residue in chain:
                if is_aa(residue):
                    G.add_node(residue.id, residue=residue)

        # Add edges based on distance cutoff
        atoms = list(self.structure.get_atoms())
        ns = NeighborSearch(atoms)
        for atom in atoms:
            if atom.element == 'H':  # Skip hydrogen atoms
                continue
            neighbors = ns.search(atom.coord, cutoff)
            for neighbor in neighbors:
                if neighbor.element == 'H':  # Skip hydrogen atoms
                    continue
                res1 = atom.get_parent()
                res2 = neighbor.get_parent()
                if res1 != res2:
                    G.add_edge(res1.id, res2.id, weight=atom - neighbor)

        return G

    def extract_graph_features(self):
        # Adjacency matrix
        adjacency_matrix = nx.adjacency_matrix(self.graph).todense()

        # Distance matrix (Floyd-Warshall algorithm)
        distance_matrix = nx.floyd_warshall_numpy(self.graph)

        # Degree centrality
        degree_centrality = nx.degree_centrality(self.graph)

        # Betweenness centrality
        betweenness_centrality = nx.betweenness_centrality(self.graph)

        # Clustering coefficient
        clustering_coefficient = nx.clustering(self.graph)

        # Ensure features are in a consistent order
        nodes = list(self.graph.nodes)
        degree_centrality = np.array([degree_centrality[node] for node in nodes])
        betweenness_centrality = np.array([betweenness_centrality[node] for node in nodes])
        clustering_coefficient = np.array([clustering_coefficient[node] for node in nodes])

        # Aggregate features into a dictionary
        features = {
            'adjacency_matrix': adjacency_matrix,
            'distance_matrix': distance_matrix,
            'degree_centrality': degree_centrality,
            'betweenness_centrality': betweenness_centrality,
            'clustering_coefficient': clustering_coefficient
        }

        return features

    def extract_features(self):
        # Extract various features from the protein structure
        amino_acid_composition = self.get_amino_acid_composition()
        flexibility = self.get_flexibility()
        distance_metrics = self.get_distance_metrics()
        graph_features = self.extract_graph_features()

        features = {
            "amino_acid_composition": amino_acid_composition,
            "flexibility": flexibility,
            "distance_metrics": distance_metrics,
            "graph_features": graph_features
        }
        return features

    def extract_and_aggregate_features(self):
        # Extract various features from the protein structure
        amino_acid_composition = self.get_amino_acid_composition()
        flexibility = self.get_flexibility()
        distance_metrics = self.get_distance_metrics()
        graph_features = self.extract_graph_features()
    
        # Combine amino acid composition and flexibility into a single array
        amino_acid_comp_values = list(amino_acid_composition.values())
        combined_features = amino_acid_comp_values + [flexibility]
    
        features = {
            "protein_combined_features": np.array(combined_features),
            "distance_metrics": distance_metrics,
            "degree_centrality": graph_features["degree_centrality"],
            "betweenness_centrality": graph_features["betweenness_centrality"],
            "clustering_coefficient": graph_features["clustering_coefficient"],
        }
        return features


In [None]:
pdb_file = "./ALB.pdb"
extractor = ProteinFeatureExtractor(pdb_file)
aggregated_features = extractor.extract_and_aggregate_features()
print(aggregated_features)


In [16]:
import dgl
import torch
import pandas as pd
import pyarrow.parquet as pq
from sklearn.model_selection import train_test_split

# Import your feature extractors here
# from your_feature_extractor_module import SmallMoleculeFeatureExtractor, ProteinFeatureExtractor

def process_chunk(chunk, G, node_features, unique_building_blocks, unique_molecules, protein_features, protein_dict):
    for index, row in chunk.iterrows():
        # Building block features extraction
        bb1_extractor = SmallMoleculeFeatureExtractor(row['buildingblock1_smiles'])
        bb1_features = bb1_extractor.flatten_features()
        
        bb2_extractor = SmallMoleculeFeatureExtractor(row['buildingblock2_smiles'])
        bb2_features = bb2_extractor.flatten_features()

        bb3_extractor = SmallMoleculeFeatureExtractor(row['buildingblock3_smiles'])
        bb3_features = bb3_extractor.flatten_features()

        # Molecule features extraction
        mol_extractor = SmallMoleculeFeatureExtractor(row['molecule_smiles'])
        mol_features = mol_extractor.flatten_features()

        # Add nodes and features
        for bb, features in zip([row['buildingblock1_smiles'], row['buildingblock2_smiles'], row['buildingblock3_smiles']], [bb1_features, bb2_features, bb3_features]):
            if bb not in unique_building_blocks:
                node_id = len(unique_building_blocks)
                unique_building_blocks[bb] = node_id
                G.add_nodes(1, ntype='building_block')
                node_features['building_block'][node_id] = torch.tensor(features['structural_fingerprints'])
                
                # Add molecular descriptors, physicochemical properties, adjacency matrix, distance matrix nodes
                G.add_nodes(1, ntype='molecular_descriptor')
                node_features['molecular_descriptor'][node_id] = torch.tensor(features['molecular_descriptors'])
                G.add_edge(node_id, node_id + len(unique_building_blocks), etype='has')
                
                G.add_nodes(1, ntype='physicochemical_properties')
                node_features['physicochemical_properties'][node_id] = torch.tensor(features['physicochemical_properties'])
                G.add_edge(node_id, node_id + 2 * len(unique_building_blocks), etype='has')
                
                G.add_nodes(1, ntype='adjacency_matrix')
                node_features['adjacency_matrix'][node_id] = torch.tensor(features['adjacency_matrix'])
                G.add_edge(node_id, node_id + 3 * len(unique_building_blocks), etype='has')
                
                G.add_nodes(1, ntype='distance_matrix')
                node_features['distance_matrix'][node_id] = torch.tensor(features['distance_matrix'])
                G.add_edge(node_id, node_id + 4 * len(unique_building_blocks), etype='has')

        if row['molecule_smiles'] not in unique_molecules:
            node_id = len(unique_molecules)
            unique_molecules[row['molecule_smiles']] = node_id
            G.add_nodes(1, ntype='molecule')
            node_features['molecule'][node_id] = torch.tensor(mol_features['structural_fingerprints'])

            # Add molecular descriptors, physicochemical properties, adjacency matrix, distance matrix nodes
            G.add_nodes(1, ntype='molecular_descriptor')
            node_features['molecular_descriptor'][node_id] = torch.tensor(mol_features['molecular_descriptors'])
            G.add_edge(node_id, node_id + len(unique_molecules), etype='has')
            
            G.add_nodes(1, ntype='physicochemical_properties')
            node_features['physicochemical_properties'][node_id] = torch.tensor(mol_features['physicochemical_properties'])
            G.add_edge(node_id, node_id + 2 * len(unique_molecules), etype='has')
            
            G.add_nodes(1, ntype='adjacency_matrix')
            node_features['adjacency_matrix'][node_id] = torch.tensor(mol_features['adjacency_matrix'])
            G.add_edge(node_id, node_id + 3 * len(unique_molecules), etype='has')
            
            G.add_nodes(1, ntype='distance_matrix')
            node_features['distance_matrix'][node_id] = torch.tensor(mol_features['distance_matrix'])
            G.add_edge(node_id, node_id + 4 * len(unique_molecules), etype='has')

        # Create contains edges
        G.add_edge(unique_molecules[row['molecule_smiles']], unique_building_blocks[row['buildingblock1_smiles']], etype='contains')
        G.add_edge(unique_molecules[row['molecule_smiles']], unique_building_blocks[row['buildingblock2_smiles']], etype='contains')
        G.add_edge(unique_molecules[row['molecule_smiles']], unique_building_blocks[row['buildingblock3_smiles']], etype='contains')

        # Protein features
        protein_name = row['protein_name']
        if protein_name not in protein_features:
            pdb_file = f"./{protein_name}.pdb"
            protein_extractor = ProteinFeatureExtractor(pdb_file)
            protein_features[protein_name] = protein_extractor.extract_and_aggregate_features()
            protein_dict[protein_name] = len(protein_dict)

        # Add protein node and its features
        protein_id = protein_dict[protein_name]
        if protein_name not in protein_features:
            G.add_nodes(1, ntype='protein')
            node_features['protein'][protein_id] = torch.tensor(protein_features[protein_name]['protein_combined_features'])

            # Add protein's distance metrics, degree centrality, betweenness centrality, clustering coefficient
            G.add_nodes(1, ntype='distance_metrics')
            node_features['distance_metrics'][protein_id] = torch.tensor(protein_features[protein_name]['distance_metrics'])
            G.add_edge(protein_id, protein_id + len(protein_dict), etype='has')

            G.add_nodes(1, ntype='degree_centrality')
            node_features['degree_centrality'][protein_id] = torch.tensor(protein_features[protein_name]['degree_centrality'])
            G.add_edge(protein_id, protein_id + 2 * len(protein_dict), etype='has')

            G.add_nodes(1, ntype='betweenness_centrality')
            node_features['betweenness_centrality'][protein_id] = torch.tensor(protein_features[protein_name]['betweenness_centrality'])
            G.add_edge(protein_id, protein_id + 3 * len(protein_dict), etype='has')

            G.add_nodes(1, ntype='clustering_coefficient')
            node_features['clustering_coefficient'][protein_id] = torch.tensor(protein_features[protein_name]['clustering_coefficient'])
            G.add_edge(protein_id, protein_id + 4 * len(protein_dict), etype='has')

        # Create binds edge
        G.add_edge(unique_molecules[row['molecule_smiles']], protein_id, etype='binds')


def create_and_save_graph(file_path, chunk_size=100000):
    # Initialize graph and features
    G = dgl.heterograph({})

    node_features = {
        'building_block': {},
        'molecule': {},
        'molecular_descriptor': {},
        'physicochemical_properties': {},
        'adjacency_matrix': {},
        'distance_matrix': {},
        'protein': {},
        'distance_metrics': {},
        'degree_centrality': {},
        'betweenness_centrality': {},
        'clustering_coefficient': {}
    }

    unique_building_blocks = {}
    unique_molecules = {}
    protein_features = {}
    protein_dict = {}

    # Load the parquet file in chunks
    parquet_file = pq.ParquetFile(file_path)
    num_row_groups = parquet_file.num_row_groups

    # Perform stratified sampling to ensure good coverage of the three proteins
    all_data = []
    for i in range(num_row_groups):
        chunk = parquet_file.read_row_group(i).to_pandas()
        all_data.append(chunk)
    all_data = pd.concat(all_data)

    # Combine protein_name and binds columns to ensure stratification on both
    all_data['stratify_col'] = all_data['protein_name'].astype(str) + '_' + all_data['binds'].astype(str)
    
    # Perform stratified sampling
    stratified_sample, _ = train_test_split(all_data, test_size=0.98, stratify=all_data['stratify_col'])

    for i in range(0, len(stratified_sample), chunk_size):
        chunk = stratified_sample.iloc[i:i + chunk_size]
        process_chunk(chunk, G, node_features, unique_building_blocks, unique_molecules, protein_features, protein_dict)

    # Convert lists to tensors and ensure each node feature is named "feature"
    for ntype in G.ntypes:
        features = [node_features[ntype][i] for i in range(len(node_features[ntype]))]
        G.nodes[ntype].data['feature'] = torch.stack(features)

    # Save the graph and node features
    dgl.save_graphs('heterograph.dgl', [G], node_data=G.nodes())

    bb_df = pd.DataFrame(unique_building_blocks.items(), columns=['smiles', 'id'])
    bb_df.to_parquet('building_blocks.parquet', index=False)
    
    protein_df = pd.DataFrame(protein_dict.items(), columns=['name', 'id'])
    protein_df.to_parquet('proteins.parquet', index=False)


def load_graph_and_features(graph_path, bb_path, protein_path):
    # Load the graph and node features
    graphs, node_data = dgl.load_graphs(graph_path)
    G = graphs[0]

    bb_df = pd.read_parquet(bb_path)
    protein_df = pd.read_parquet(protein_path)

    return G, bb_df, protein_df


# Create and save the graph
create_and_save_graph('./train.parquet')

  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError: No module named 'sklearn'