In [None]:
import pandas as pd
import pyarrow.parquet as pq
import numpy as np

# Read the file in chunks
def process_chunk(chunk, unique_building_blocks, unique_molecules):
    # Update the unique building blocks set
    unique_building_blocks.update(chunk['buildingblock1_smiles'].unique())
    unique_building_blocks.update(chunk['buildingblock2_smiles'].unique())
    unique_building_blocks.update(chunk['buildingblock3_smiles'].unique())
    
    # Update the unique molecules set
    unique_molecules.update(chunk['molecule_smiles'].unique())

# Load the parquet file
file_path = './train.parquet'
batch_size = 100000
parquet_file = pq.ParquetFile(file_path)

# Initialize sets to keep track of unique building blocks and molecules
unique_building_blocks = set()
unique_molecules = set()

# Iterate over the parquet file in batches
num_row_groups = parquet_file.num_row_groups

for i in range(num_row_groups):
    # Read a batch of rows
    row_group = parquet_file.read_row_group(i).to_pandas()

    if i == 0:
        print("First few rows from the first row group:")
        print(row_group.head())
    
    # Process the current chunk
    process_chunk(row_group, unique_building_blocks, unique_molecules)

# Output the total unique counts
print(f"Total number of unique building blocks: {len(unique_building_blocks)}")
print(f"Total number of unique molecules: {len(unique_molecules)}")

# Exhaustive List of Features for Small Molecules

## Molecular Descriptors:
- Molecular weight
- Number of atoms
- Number of bonds
- Number of aromatic rings
- Number of rotatable bonds
- Topological polar surface area (TPSA)
- LogP (octanol-water partition coefficient)

## Atom-Level Features:
- Atom types (e.g., C, H, O, N, S)
- Hybridization states (sp, sp2, sp3)
- Formal charge
- Aromaticity
- Degree (number of bonds to the atom)
- Implicit and explicit hydrogen counts
- Chirality

## Bond-Level Features:
- Bond types (single, double, triple, aromatic)
- Conjugation
- Ring membership
- Stereo configuration (cis/trans)

## Graph-Based Features:
- Adjacency matrix
- Distance matrix
- Graph Laplacian

## Physicochemical Properties:
- Hydrogen bond donors and acceptors
- Molecular refractivity
- Molar volume
- Electronegativity
- Electron affinity

## Structural Fingerprints:
- MACCS keys
- Morgan fingerprints
- ECFP (Extended Connectivity Fingerprints)
- RDKIT fingerprints


In [1]:
from rdkit import Chem
from rdkit.Chem import Descriptors, rdMolDescriptors, AllChem, rdmolops
from rdkit.DataStructs import ConvertToNumpyArray
import numpy as np

# Define encoding schemes outside the class
ATOM_TYPES = ['C', 'H', 'O', 'N', 'S', 'F', 'Cl', 'Br', 'I', 'P', 'B']
HYBRIDIZATION_STATES = ['SP', 'SP2', 'SP3', 'SP3D', 'SP3D2']
CHIRAL_TAGS = ['CHI_UNSPECIFIED', 'CHI_TETRAHEDRAL_CW', 'CHI_TETRAHEDRAL_CCW', 'CHI_OTHER']
BOND_TYPES = ['SINGLE', 'DOUBLE', 'TRIPLE', 'AROMATIC']
STEREO_CONFIGURATIONS = ['STEREONONE', 'STEREOZ', 'STEREOE', 'STEREOCIS', 'STEREOTRANS']

class SmallMoleculeFeatureExtractor:
    def __init__(self, smiles):
        self.smiles = smiles
        self.mol = Chem.MolFromSmiles(smiles)

    def get_molecular_descriptors(self):
        descriptors = {
            'molecular_weight': Descriptors.MolWt(self.mol),
            'num_atoms': self.mol.GetNumAtoms(),
            'num_bonds': self.mol.GetNumBonds(),
            'num_aromatic_rings': rdMolDescriptors.CalcNumAromaticRings(self.mol),
            'num_rotatable_bonds': Descriptors.NumRotatableBonds(self.mol),
            'tpsa': Descriptors.TPSA(self.mol),
            'logp': Descriptors.MolLogP(self.mol)
        }
        return descriptors

    def one_hot_encode(self, value, categories):
        encoding = [0] * len(categories)
        if value in categories:
            encoding[categories.index(value)] = 1
        return encoding

    def get_atom_level_features(self):
        atom_features = []
        for atom in self.mol.GetAtoms():
            atom_features.append([
                self.one_hot_encode(atom.GetSymbol(), ATOM_TYPES),
                self.one_hot_encode(str(atom.GetHybridization()), HYBRIDIZATION_STATES),
                atom.GetFormalCharge(),
                atom.GetIsAromatic(),
                atom.GetDegree(),
                atom.GetImplicitValence(),
                atom.GetTotalNumHs(),
                self.one_hot_encode(str(atom.GetChiralTag()), CHIRAL_TAGS)
            ])
        return atom_features

    def get_bond_level_features(self):
        bond_features = []
        for bond in self.mol.GetBonds():
            bond_features.append([
                self.one_hot_encode(str(bond.GetBondType()), BOND_TYPES),
                bond.GetIsConjugated(),
                bond.IsInRing(),
                self.one_hot_encode(str(bond.GetStereo()), STEREO_CONFIGURATIONS)
            ])
        return bond_features

    def get_graph_based_features(self):
        adj_matrix = rdmolops.GetAdjacencyMatrix(self.mol)
        dist_matrix = rdmolops.GetDistanceMatrix(self.mol)
        return {
            'adjacency_matrix': adj_matrix,
            'distance_matrix': dist_matrix,
        }

    def get_physicochemical_properties(self):
        properties = {
            'h_bond_donors': Descriptors.NumHDonors(self.mol),
            'h_bond_acceptors': Descriptors.NumHAcceptors(self.mol),
            'molecular_refractivity': Descriptors.MolMR(self.mol),
            'molar_volume': Descriptors.MolLogP(self.mol) / Descriptors.MolWt(self.mol)
        }
        return properties

    def get_structural_fingerprints(self):
        maccs_keys = AllChem.GetMACCSKeysFingerprint(self.mol)
        morgan_fp = AllChem.GetMorganFingerprintAsBitVect(self.mol, 2)
        rdk_fp = Chem.RDKFingerprint(self.mol)

        maccs_keys_np = np.zeros((1,))
        ConvertToNumpyArray(maccs_keys, maccs_keys_np)

        morgan_fp_np = np.zeros((1,))
        ConvertToNumpyArray(morgan_fp, morgan_fp_np)

        rdk_fp_np = np.zeros((1,))
        ConvertToNumpyArray(rdk_fp, rdk_fp_np)
        
        return {
            'maccs_keys': maccs_keys_np,
            'morgan_fp': morgan_fp_np,
            'rdkit_fp': rdk_fp_np
        }

    def extract_features(self):
        features = {
            'molecular_descriptors': self.get_molecular_descriptors(),
            'atom_level_features': self.get_atom_level_features(),
            'bond_level_features': self.get_bond_level_features(),
            'graph_based_features': self.get_graph_based_features(),
            'physicochemical_properties': self.get_physicochemical_properties(),
            'structural_fingerprints': self.get_structural_fingerprints()
        }
        return features

    def flatten_features(self):
        # Extract individual features
        molecular_descriptors = self.get_molecular_descriptors()
        physicochemical_properties = self.get_physicochemical_properties()
        structural_fingerprints = self.get_structural_fingerprints()
        graph_based_features = self.get_graph_based_features()
    
        # Flatten the structural fingerprints
        flattened_structural_fingerprints = np.concatenate([
            structural_fingerprints['maccs_keys'],
            structural_fingerprints['morgan_fp'],
            structural_fingerprints['rdkit_fp']
        ])
    
        # Convert molecular descriptors and physicochemical properties to arrays
        molecular_descriptors_array = np.array(list(molecular_descriptors.values()))
        physicochemical_properties_array = np.array(list(physicochemical_properties.values()))

        # Extract adjacency and distance matrices
        adjacency_matrix = graph_based_features['adjacency_matrix']
        distance_matrix = graph_based_features['distance_matrix']
    
        return {
            'molecular_descriptors': molecular_descriptors_array,
            'physicochemical_properties': physicochemical_properties_array,
            'structural_fingerprints': flattened_structural_fingerprints,
            'adjacency_matrix': adjacency_matrix,
            'distance_matrix': distance_matrix
        }


In [None]:
smiles = "C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21"
extractor = SmallMoleculeFeatureExtractor(smiles)
features = extractor.flatten_features()
for feature, value in features.items():
    print(f"{feature}: {value}")

# Feature Extraction from Protein Structure PDB File

## Structural Features

### Amino Acid Composition:
- Frequency of each amino acid type in the binding site.
- Frequency of amino acid types in the entire protein.

### Secondary Structure:
- Percentage of alpha-helices, beta-sheets, and random coils in the binding site.
- Secondary structure elements around the binding site.

### Tertiary Structure:
- 3D coordinates of the binding site.
- Distance between key residues in the binding site.

### Binding Site Characteristics:
- Volume and surface area of the binding site.
- Shape descriptors (e.g., sphericity, elongation).

## Physicochemical Properties

### Hydrophobicity:
- Hydrophobic and hydrophilic residue distribution in the binding site.
- Hydrophobic surface area.

### Charge Distribution:
- Number and type of charged residues (positive and negative).
- Electrostatic potential distribution.

### Polarity:
- Number of polar residues.
- Polar surface area.

### Solvent Accessibility:
- Solvent-accessible surface area (SASA) of residues in the binding site.

### Hydrogen Bonding:
- Number of potential hydrogen bond donors and acceptors.
- Hydrogen bond network in the binding site.

### Van der Waals Interactions:
- Van der Waals interaction potential of the binding site.

## Geometric Features

### Distance Metrics:
- Pairwise distances between all residues in the binding site.
- Distance to the nearest surface residue.

### Angles and Dihedrals:
- Angles and dihedral angles between residues in the binding site.

## Chemical Environment

### Residue Environment:
- Local chemical environment of each residue (e.g., neighboring residues within a certain radius).

### Ligand Interaction Sites:
- Specific interaction sites for known ligands (if available).

## Dynamic Properties

### Flexibility:
- B-factors or temperature factors indicating residue flexibility.

### Molecular Dynamics Simulations:
- Root mean square fluctuation (RMSF) of residues in the binding site.
- Conformational changes over time.

## Topological Features

### Graph-based Features:
- Protein structure represented as a graph with nodes (residues) and edges (interactions).
- Degree centrality, betweenness centrality, and clustering coefficient of residues in the binding site.

## Energy-based Features

### Binding Energy:
- Estimated binding free energy of known ligands.
- Energy components (van der Waals, electrostatic, solvation) from docking simulations.

## Protein-Ligand Interaction Features

### Docking Scores:
- Scores from molecular docking simulations with various ligands.

### Interaction Profiles:
- Interaction fingerprints summarizing the types and strengths of interactions with ligands.

## Evolutionary Features

### Conservation:
- Sequence conservation of residues in the binding site (e.g., from multiple sequence alignment).

### Mutational Impact:
- Predicted impact of mutations on binding site residues.

## Experimental Data

### Experimental Binding Data:
- Known binding affinities (e.g., Kd, Ki, IC50) for small molecules.

## Contextual Features

### Functional Annotations:
- Biological function and pathway involvement of the protein.
- Known protein-protein interactions.

## Integration and Representation

### Feature Scaling and Normalization:
- Standardize and normalize features for input into the deep learning model.


In [2]:
from Bio.PDB import PDBParser, is_aa, NeighborSearch
import numpy as np
import networkx as nx

class ProteinFeatureExtractor:
    def __init__(self, pdb_file):
        self.pdb_file = pdb_file
        self.structure = self.load_structure()
        self.ligand_resnames = self.detect_ligands()
        self.graph = self.construct_graph()

    def load_structure(self):
        # Load the PDB structure
        parser = PDBParser()
        structure = parser.get_structure('protein', self.pdb_file)
        return structure

    def detect_ligands(self):
        # Detect ligand residue names by excluding standard amino acids and water
        ligands = set()
        standard_amino_acids = {'ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE', 
                                'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL'}
        water_residues = {'HOH'}
        for residue in self.structure.get_residues():
            resname = residue.resname
            if resname not in standard_amino_acids and resname not in water_residues:
                ligands.add(resname)
        return list(ligands)

    def get_amino_acid_composition(self):
        # Get the composition of amino acids in the protein
        amino_acids = [residue.resname for residue in self.structure.get_residues() if residue.id[0] == ' ']
        aa_counts = {aa: amino_acids.count(aa) for aa in set(amino_acids)}
        return aa_counts


    def get_flexibility(self):
        # Calculate the flexibility of the protein based on B-factors
        flexibility = []
        for atom in self.structure.get_atoms():
            flexibility.append(atom.bfactor)
        return np.mean(flexibility)

    def get_distance_metrics(self):
        # Calculate distance metrics between residues in the protein
        distances = []
        for chain in self.structure.get_chains():
            print(f"Processing chain: {chain.id}")
            residues = [res for res in chain if 'CA' in res.child_dict]  # Filter residues with 'CA' atom
            for i, res1 in enumerate(residues):
                ca1 = res1.child_dict.get('CA')
                if ca1 is None:
                    print(f"Residue {res1} does not have a CA atom.")
                    continue
                for j, res2 in enumerate(residues):
                    if i < j:
                        ca2 = res2.child_dict.get('CA')
                        if ca2 is None:
                            print(f"Residue {res2} does not have a CA atom.")
                            continue
                        try:
                            distance = ca1 - ca2
                            distances.append(distance)
                        except KeyError as e:
                            print(f"Error calculating distance: {e}")
        return distances

    def construct_graph(self, cutoff=4.0):
        # Initialize an undirected graph
        G = nx.Graph()

        # Add nodes for each residue
        for chain in self.structure.get_chains():
            for residue in chain:
                if is_aa(residue):
                    G.add_node(residue.id, residue=residue)

        # Add edges based on distance cutoff
        atoms = list(self.structure.get_atoms())
        ns = NeighborSearch(atoms)
        for atom in atoms:
            if atom.element == 'H':  # Skip hydrogen atoms
                continue
            neighbors = ns.search(atom.coord, cutoff)
            for neighbor in neighbors:
                if neighbor.element == 'H':  # Skip hydrogen atoms
                    continue
                res1 = atom.get_parent()
                res2 = neighbor.get_parent()
                if res1 != res2:
                    G.add_edge(res1.id, res2.id, weight=atom - neighbor)

        return G

    def extract_graph_features(self):
        # Adjacency matrix
        adjacency_matrix = nx.adjacency_matrix(self.graph).todense()

        # Distance matrix (Floyd-Warshall algorithm)
        distance_matrix = nx.floyd_warshall_numpy(self.graph)

        # Degree centrality
        degree_centrality = nx.degree_centrality(self.graph)

        # Betweenness centrality
        betweenness_centrality = nx.betweenness_centrality(self.graph)

        # Clustering coefficient
        clustering_coefficient = nx.clustering(self.graph)

        # Ensure features are in a consistent order
        nodes = list(self.graph.nodes)
        degree_centrality = np.array([degree_centrality[node] for node in nodes])
        betweenness_centrality = np.array([betweenness_centrality[node] for node in nodes])
        clustering_coefficient = np.array([clustering_coefficient[node] for node in nodes])

        # Aggregate features into a dictionary
        features = {
            'adjacency_matrix': adjacency_matrix,
            'distance_matrix': distance_matrix,
            'degree_centrality': degree_centrality,
            'betweenness_centrality': betweenness_centrality,
            'clustering_coefficient': clustering_coefficient
        }

        return features

    def extract_features(self):
        # Extract various features from the protein structure
        amino_acid_composition = self.get_amino_acid_composition()
        flexibility = self.get_flexibility()
        distance_metrics = self.get_distance_metrics()
        graph_features = self.extract_graph_features()

        features = {
            "amino_acid_composition": amino_acid_composition,
            "flexibility": flexibility,
            "distance_metrics": distance_metrics,
            "graph_features": graph_features
        }
        return features

    def extract_and_aggregate_features(self):
        # Extract various features from the protein structure
        amino_acid_composition = self.get_amino_acid_composition()
        flexibility = self.get_flexibility()
        distance_metrics = self.get_distance_metrics()
        graph_features = self.extract_graph_features()
    
        # Combine amino acid composition and flexibility into a single array
        amino_acid_comp_values = list(amino_acid_composition.values())
        combined_features = amino_acid_comp_values + [flexibility]
    
        features = {
            "protein_combined_features": np.array(combined_features),
            "distance_metrics": distance_metrics,
            "degree_centrality": graph_features["degree_centrality"],
            "betweenness_centrality": graph_features["betweenness_centrality"],
            "clustering_coefficient": graph_features["clustering_coefficient"],
        }
        return features


In [None]:
pdb_file = "./ALB.pdb"
extractor = ProteinFeatureExtractor(pdb_file)
aggregated_features = extractor.extract_and_aggregate_features()
print(aggregated_features)


In [3]:
import dgl
import torch
import pandas as pd
import pyarrow.parquet as pq
import pyarrow as pa
from sklearn.model_selection import train_test_split

# Assume SmallMoleculeFeatureExtractor and ProteinFeatureExtractor classes are already defined.

def create_heterogeneous_graph(parquet_file_path, protein_files, chunk_size=100000, test_size=0.99):
    # Load parquet file and initialize unique sets
    parquet_file = pq.ParquetFile(parquet_file_path)
    num_row_groups = parquet_file.num_row_groups

    all_data = []
    print('Loading all data')
    for i in range(num_row_groups):
        row_group = parquet_file.read_row_group(i).to_pandas()
        all_data.append(row_group)

    all_data = pd.concat(all_data, ignore_index=True)
    print('Completed loading all data')

    # Stratified sampling
    print('Setting up sampling')
    all_data['stratify_col'] = all_data['protein_name'].astype(str) + '_' + all_data['binds'].astype(str)
    stratified_sample, _ = train_test_split(all_data, test_size=test_size, stratify=all_data['stratify_col'])
    stratified_sample = stratified_sample.drop(columns=['stratify_col'])
    print('Finished sampling data')

    # Initialize node and edge lists
    node_data = {
        'building_block': [],
        'small_molecule': [],
        'molecular_descriptor': [],
        'physicochemical_properties': [],
        'adjacency_matrix': [],
        'distance_matrix': [],
        'protein': [],
        'distance_metrics': [],
        'degree_centrality': [],
        'betweenness_centrality': [],
        'clustering_coefficient': []
    }

    edge_data = {
        ('building_block', 'has', 'molecular_descriptor'): [],
        ('building_block', 'has', 'physicochemical_properties'): [],
        ('building_block', 'has', 'adjacency_matrix'): [],
        ('building_block', 'has', 'distance_matrix'): [],
        ('small_molecule', 'has', 'molecular_descriptor'): [],
        ('small_molecule', 'has', 'physicochemical_properties'): [],
        ('small_molecule', 'has', 'adjacency_matrix'): [],
        ('small_molecule', 'has', 'distance_matrix'): [],
        ('small_molecule', 'contains', 'building_block'): [],
        ('small_molecule', 'binds', 'protein'): [],
        ('protein', 'has', 'distance_metric'): [],
        ('protein', 'has', 'degree_centrality'): [],
        ('protein', 'has', 'betweenness_centrality'): [],
        ('protein', 'has', 'clustering_coefficient'): []
    }

    # Dictionaries to keep track of indices and for saving later
    node_indices = {ntype: 0 for ntype in node_data.keys()}
    building_block_index_map = {}
    protein_index_map = {}
    
    print('Starting protein node creation')
    # Load protein data
    for protein_name, pdb_file in protein_files.items():
        extractor = ProteinFeatureExtractor(pdb_file)
        features = extractor.extract_and_aggregate_features()
        protein_idx = node_indices['protein']
        node_data['protein'].append((protein_idx, {'feature': torch.tensor(features['protein_combined_features'])}))
        node_data['distance_metrics'].append((protein_idx, {'feature': torch.tensor(features['distance_metrics'])}))
        node_data['degree_centrality'].append((protein_idx, {'feature': torch.tensor(features['degree_centrality'])}))
        node_data['betweenness_centrality'].append((protein_idx, {'feature': torch.tensor(features['betweenness_centrality'])}))
        node_data['clustering_coefficient'].append((protein_idx, {'feature': torch.tensor(features['clustering_coefficient'])}))
        protein_index_map[protein_name] = protein_idx
        node_indices['protein'] += 1

    print('Completed protein node creation')

    # Process the stratified sample
    print('Starting processing samples')
    for _, row in stratified_sample.iterrows():
        building_blocks = [
            row['buildingblock1_smiles'],
            row['buildingblock2_smiles'],
            row['buildingblock3_smiles']
        ]
        molecule_smiles = row['molecule_smiles']
        protein_name = row['protein_name']
        binds = row['binds']

        # Process small molecule
        mol_extractor = SmallMoleculeFeatureExtractor(molecule_smiles)
        mol_features = mol_extractor.flatten_features()
        mol_idx = node_indices['small_molecule']
        mol_desc_index = node_indices['molecular_descriptor']
        phys_index = node_indices['physicochemical_properties']
        adj_index = node_indices['adjacency_matrix']
        dist_index = node_indices['distance_matrix']

        node_data['small_molecule'].append((mol_idx, {'feature': torch.tensor(mol_features['structural_fingerprints'])}))
        node_data['molecular_descriptor'].append((mol_desc_index, {'feature': torch.tensor(mol_features['molecular_descriptors'])}))
        node_data['physicochemical_properties'].append((phys_index, {'feature': torch.tensor(mol_features['physicochemical_properties'])}))
        node_data['adjacency_matrix'].append((adj_index, {'feature': torch.tensor(mol_features['adjacency_matrix'])}))
        node_data['distance_matrix'].append((dist_index, {'feature': torch.tensor(mol_features['distance_matrix'])}))

        node_indices['small_molecule'] += 1
        node_indices['molecular_descriptor'] += 1
        node_indices['physicochemical_properties'] += 1
        node_indices['adjacency_matrix'] += 1
        node_indices['distance_matrix'] += 1

        # Create 'has' edges for small molecule
        edge_data[('small_molecule', 'has', 'molecular_descriptor')].append((mol_idx, mol_desc_index))
        edge_data[('small_molecule', 'has', 'physicochemical_properties')].append((mol_idx, phys_index))
        edge_data[('small_molecule', 'has', 'adjacency_matrix')].append((mol_idx, adj_index))
        edge_data[('small_molecule', 'has', 'distance_matrix')].append((mol_idx, dist_index))

        # Process building blocks
        for bb_smiles in building_blocks:
            if bb_smiles not in building_block_index_map:
                bb_extractor = SmallMoleculeFeatureExtractor(bb_smiles)
                bb_features = bb_extractor.flatten_features()
                bb_idx = node_indices['building_block']
                mol_desc_index = node_indices['molecular_descriptor']
                phys_index = node_indices['physicochemical_properties']
                adj_index = node_indices['adjacency_matrix']
                dist_index = node_indices['distance_matrix']
    
                node_data['building_block'].append((bb_idx, {'feature': torch.tensor(bb_features['structural_fingerprints'])}))
                node_data['molecular_descriptor'].append((mol_desc_index, {'feature': torch.tensor(bb_features['molecular_descriptors'])}))
                node_data['physicochemical_properties'].append((phys_index, {'feature': torch.tensor(bb_features['physicochemical_properties'])}))
                node_data['adjacency_matrix'].append((adj_index, {'feature': torch.tensor(bb_features['adjacency_matrix'])}))
                node_data['distance_matrix'].append((dist_index, {'feature': torch.tensor(bb_features['distance_matrix'])}))

                node_indices['building_block'] += 1
                node_indices['molecular_descriptor'] += 1
                node_indices['physicochemical_properties'] += 1
                node_indices['adjacency_matrix'] += 1
                node_indices['distance_matrix'] += 1

                # Create 'has' edges for building block
                edge_data[('building_block', 'has', 'molecular_descriptor')].append((bb_idx, mol_desc_index))
                edge_data[('building_block', 'has', 'physicochemical_properties')].append((bb_idx, phys_index))
                edge_data[('building_block', 'has', 'adjacency_matrix')].append((bb_idx, adj_index))
                edge_data[('building_block', 'has', 'distance_matrix')].append((bb_idx, dist_index))

                building_block_index_map[bb_smiles] = bb_idx
            else:
                bb_idx = building_block_index_map[bb_smiles]

            # Add 'contains' edge from small molecule to building block
            edge_data[('small_molecule', 'contains', 'building_block')].append((mol_idx, bb_idx))

        # Add 'binds' edge from small molecule to protein
        if binds:
            edge_data[('small_molecule', 'binds', 'protein')].append((mol_idx, protein_idx))
    print('Completed processing samples')

    # Create graph
    g = dgl.heterograph(edge_data)

    # Assign features to nodes
    for ntype, features in node_data.items():
        indices, feats = zip(*features)
        g.nodes[ntype].data['feature'] = torch.stack([feat['feature'] for feat in feats])

    # Save the graph
    dgl.save_graphs("heterogeneous_graph.dgl", [g])

    # Save building block and protein indices to Parquet
    building_block_df = pd.DataFrame.from_dict(building_block_index_map, orient='index', columns=['smiles'])
    protein_df = pd.DataFrame.from_dict(protein_index_map, orient='index', columns=['protein_name'])

    building_block_df.to_parquet("building_block_indices.parquet")
    protein_df.to_parquet("protein_indices.parquet")

    return g

def load_heterogeneous_graph():
    g, _ = dgl.load_graphs("heterogeneous_graph.dgl")

    # Load building block and protein indices
    building_block_df = pd.read_parquet("building_block_indices.parquet")
    protein_df = pd.read_parquet("protein_indices.parquet")

    building_block_index_map = building_block_df.to_dict(orient='index')
    protein_index_map = protein_df.to_dict(orient='index')

    return g[0], building_block_index_map, protein_index_map

# Usage example:
parquet_file_path = './train.parquet'
protein_files = {
    'BRD4': './BRD4.pdb',
    'HSA': './ALB.pdb',
    'sEH': './EPH.pdb'
}

g = create_heterogeneous_graph(parquet_file_path, protein_files)
loaded_g, building_block_index_map, protein_index_map = load_heterogeneous_graph()


  from .autonotebook import tqdm as notebook_tqdm


Loading all data
Completed loading all data
Setting up sampling
Finished sampling data
Starting protein node creation
Processing chain: A




Processing chain: A
Processing chain: B
Processing chain: A
Completed protein node creation
Starting processing samples


TypeError: 'int' object is not iterable