In [None]:
import pandas as pd
import pyarrow.parquet as pq
import numpy as np

# Read the file in chunks
def process_chunk(chunk, unique_building_blocks, unique_molecules):
    # Update the unique building blocks set
    unique_building_blocks.update(chunk['buildingblock1_smiles'].unique())
    unique_building_blocks.update(chunk['buildingblock2_smiles'].unique())
    unique_building_blocks.update(chunk['buildingblock3_smiles'].unique())
    
    # Update the unique molecules set
    unique_molecules.update(chunk['molecule_smiles'].unique())

# Load the parquet file
file_path = './train.parquet'
batch_size = 100000
parquet_file = pq.ParquetFile(file_path)

# Initialize sets to keep track of unique building blocks and molecules
unique_building_blocks = set()
unique_molecules = set()
# Set pandas display option to avoid truncation of long strings
pd.set_option('display.max_colwidth', None)

# Iterate over the parquet file in batches
num_row_groups = parquet_file.num_row_groups

# Initialize variables to store the total number of rows and rows with binds = 1
total_rows = 0
binds_1_count = 0

for i in range(num_row_groups):
    # Read a batch of rows
    row_group = parquet_file.read_row_group(i).to_pandas()

    if i == 0:
        print("First few rows from the first row group:")
        print(row_group.head())

        # Print the molecule smiles column
        print(f"Row group {i} molecule smiles:")
        print(row_group['molecule_smiles'])

        # Update the total number of rows and the count of binds = 1
    total_rows += len(row_group)
    binds_1_count += row_group['binds'].sum()
    
    # Process the current chunk
    process_chunk(row_group, unique_building_blocks, unique_molecules)

# Calculate the percentage of rows with binds = 1 for the entire dataset
percentage_binds_1 = (binds_1_count / total_rows) * 100

# Print the percentage of rows with binds = 1 for the entire dataset
print(f"Percentage of rows with binds = 1 in the entire dataset: {percentage_binds_1:.2f}%")

# Output the total unique counts
print(f"Total number of unique building blocks: {len(unique_building_blocks)}")
print(f"Total number of unique molecules: {len(unique_molecules)}")

Given that only 0.54% of the data actually, contain bindings, filter the training data down to only select molecule_smiles that bind to at least one protein but not all 3. The idea is to use a model with contrastive loss function.

In [None]:
import pandas as pd
import pyarrow.parquet as pq
import numpy as np

# Read the file in chunks
def process_chunk(chunk, unique_building_blocks, unique_molecules):
    # Update the unique building blocks set
    unique_building_blocks.update(chunk['buildingblock1_smiles'].unique())
    unique_building_blocks.update(chunk['buildingblock2_smiles'].unique())
    unique_building_blocks.update(chunk['buildingblock3_smiles'].unique())
    
    # Update the unique molecules set
    unique_molecules.update(chunk['molecule_smiles'].unique())

# Load the parquet file
file_path = './train.parquet'
parquet_file = pq.ParquetFile(file_path)

# Set pandas display option to avoid truncation of long strings
pd.set_option('display.max_colwidth', None)

# Iterate over the parquet file in batches
num_row_groups = parquet_file.num_row_groups

# Initialize variables to store the total number of rows and rows with binds = 1
total_rows = 0
binds_1_count = 0

# Initialize a dictionary to keep track of binding counts for each molecule
molecule_binding_counts = {}

# First pass: Identify molecules that bind with at least one protein and track binding counts
for i in range(num_row_groups):
    # Read a batch of rows
    row_group = parquet_file.read_row_group(i).to_pandas()
    
    # Identify molecules with binds = 1 and count their bindings
    binding_molecules = row_group[row_group['binds'] == 1]['molecule_smiles'].value_counts()
    for molecule, count in binding_molecules.items():
        if molecule not in molecule_binding_counts:
            molecule_binding_counts[molecule] = count
        else:
            molecule_binding_counts[molecule] += count

# Filter molecules that bind to at least one but not all three proteins
filtered_molecules = {molecule for molecule, count in molecule_binding_counts.items() if 1 <= count < 3}

# Second pass: Filter rows with molecules that meet the criteria
filtered_data = []

for i in range(num_row_groups):
    # Read a batch of rows
    row_group = parquet_file.read_row_group(i).to_pandas()
    
    # Filter rows where molecule_smiles is in filtered_molecules
    filtered_chunk = row_group[row_group['molecule_smiles'].isin(filtered_molecules)]
    filtered_data.append(filtered_chunk)
    
    # Update the total number of rows and the count of binds = 1 for the filtered data
    total_rows += len(filtered_chunk)
    binds_1_count += filtered_chunk['binds'].sum()

# Concatenate all filtered data
filtered_data = pd.concat(filtered_data, ignore_index=True)

# Calculate the percentage of rows with binds = 1 for the filtered data
percentage_binds_1 = (binds_1_count / total_rows) * 100

# Print the percentage of rows with binds = 1 for the filtered data
print(f"Percentage of rows with binds = 1 in the filtered dataset: {percentage_binds_1:.2f}%")

# Output the total unique counts
print(f"Total number of rows in the dataset: {len(filtered_data)}")

# Save the filtered data
filtered_data.to_parquet('filtered_train.parquet')


In [3]:
import pandas as pd
import numpy as np
import pyarrow.parquet as pq
from sklearn.model_selection import train_test_split

# Load the filtered data
filtered_data = pd.read_parquet('filtered_train.parquet')

# Get unique molecule_smiles
unique_molecule_smiles = filtered_data['molecule_smiles'].unique()

# Shuffle the unique molecule_smiles
np.random.shuffle(unique_molecule_smiles)

# Calculate the number of molecules for each split
num_molecules = len(unique_molecule_smiles)
train_size = int(0.8 * num_molecules)
test_size = int(0.15 * num_molecules)
val_size = num_molecules - train_size - test_size

# Split the unique molecule_smiles
train_smiles = unique_molecule_smiles[:train_size]
test_smiles = unique_molecule_smiles[train_size:train_size + test_size]
val_smiles = unique_molecule_smiles[train_size + test_size:]

# Filter the original data based on the splits
train_data = filtered_data[filtered_data['molecule_smiles'].isin(train_smiles)]
test_data = filtered_data[filtered_data['molecule_smiles'].isin(test_smiles)]
val_data = filtered_data[filtered_data['molecule_smiles'].isin(val_smiles)]

# Save the splits to separate Parquet files
train_data.to_parquet('train_data_temp.parquet')
test_data.to_parquet('test_data_temp.parquet')
val_data.to_parquet('val_data_temp.parquet')

print(f"Train data: {len(train_data)} rows")
print(f"Test data: {len(test_data)} rows")
print(f"Validation data: {len(val_data)} rows")


Train data: 3623319 rows
Test data: 679371 rows
Validation data: 226461 rows


Update the above code to make it easier to create both positive and negative graphs.

In [2]:
import pandas as pd
import numpy as np
import pyarrow.parquet as pq
from sklearn.model_selection import train_test_split

# Load the filtered data
filtered_data = pd.read_parquet('filtered_train.parquet')

# Initialize a list to store transformed rows
transformed_rows = []

# Iterate through each group
for molecule_smiles, group in filtered_data.groupby('molecule_smiles'):
    binding_rows = group[group['binds'] == 1]
    non_binding_rows = group[group['binds'] == 0]

    if len(binding_rows) == 1 and len(non_binding_rows) == 2:
        # Duplicate the single binding row for each non-binding protein
        for _, non_bind_row in non_binding_rows.iterrows():
            bind_row = binding_rows.iloc[0].copy()
            new_row = {
                'molecule_smiles': bind_row['molecule_smiles'],
                'buildingblock1_smiles': bind_row['buildingblock1_smiles'],
                'buildingblock2_smiles': bind_row['buildingblock2_smiles'],
                'buildingblock3_smiles': bind_row['buildingblock3_smiles'],
                'binds': bind_row['protein_name'],
                'not_binds': non_bind_row['protein_name']
            }
            transformed_rows.append(new_row)
    elif len(binding_rows) == 2 and len(non_binding_rows) == 1:
        # Duplicate the single non-binding row for each binding protein
        for _, bind_row in binding_rows.iterrows():
            non_bind_row = non_binding_rows.iloc[0].copy()
            new_row = {
                'molecule_smiles': non_bind_row['molecule_smiles'],
                'buildingblock1_smiles': non_bind_row['buildingblock1_smiles'],
                'buildingblock2_smiles': non_bind_row['buildingblock2_smiles'],
                'buildingblock3_smiles': non_bind_row['buildingblock3_smiles'],
                'binds': non_bind_row['protein_name'],
                'not_binds': bind_row['protein_name']
            }
            transformed_rows.append(new_row)
    else:
        # If the group does not meet the above criteria, keep it unchanged
        for _, row in group.iterrows():
            new_row = {
                'molecule_smiles': row['molecule_smiles'],
                'buildingblock1_smiles': row['buildingblock1_smiles'],
                'buildingblock2_smiles': row['buildingblock2_smiles'],
                'buildingblock3_smiles': row['buildingblock3_smiles'],
                'binds': row['protein_name'],
                'not_binds': None
            }
            transformed_rows.append(new_row)

# Create a DataFrame from the transformed rows
transformed_data = pd.DataFrame(transformed_rows)

# Get unique molecule_smiles
unique_molecule_smiles = transformed_data['molecule_smiles'].unique()

# Shuffle the unique molecule_smiles
np.random.shuffle(unique_molecule_smiles)

# Calculate the number of molecules for each split
num_molecules = len(unique_molecule_smiles)
train_size = int(0.8 * num_molecules)
test_size = int(0.15 * num_molecules)
val_size = num_molecules - train_size - test_size

# Split the unique molecule_smiles
train_smiles = unique_molecule_smiles[:train_size]
test_smiles = unique_molecule_smiles[train_size:train_size + test_size]
val_smiles = unique_molecule_smiles[train_size + test_size:]

# Filter the transformed data based on the splits
train_data = transformed_data[transformed_data['molecule_smiles'].isin(train_smiles)]
test_data = transformed_data[transformed_data['molecule_smiles'].isin(test_smiles)]
val_data = transformed_data[transformed_data['molecule_smiles'].isin(val_smiles)]

# Save the splits to separate Parquet files
train_data.to_parquet('train_data.parquet')
test_data.to_parquet('test_data.parquet')
val_data.to_parquet('val_data.parquet')

print(f"Train data: {len(train_data)} rows")
print(f"Test data: {len(test_data)} rows")
print(f"Validation data: {len(val_data)} rows")


Train data: 2415546 rows
Test data: 452914 rows
Validation data: 150974 rows


Do some verification of the generated data in the above files

In [6]:
import pandas as pd
import pyarrow.parquet as pq
import numpy as np

def verify_data(file_path, original_file_path, num_samples=10, chunk_size=100000):
    # Load the split data
    data = pd.read_parquet(file_path)

    # Get unique molecule_smiles
    unique_molecule_smiles = data['molecule_smiles'].unique()

    # Randomly sample 10 molecule_smiles
    sampled_smiles = np.random.choice(unique_molecule_smiles, num_samples, replace=False)

    print(f"Sampled molecule_smiles from {file_path}:")
    print(sampled_smiles)

    # Print rows corresponding to sampled molecule_smiles from the split file
    print(f"\nRows from {file_path}:")
    sampled_data = data[data['molecule_smiles'].isin(sampled_smiles)]
    print(sampled_data[['molecule_smiles', 'binds', 'not_binds']])

    # Read the original file in chunks
    print(f"\nCorresponding rows from the original file:")
    parquet_file = pq.ParquetFile(original_file_path)
    num_row_groups = parquet_file.num_row_groups

    for i in range(num_row_groups):
        # Read a batch of rows
        row_group = parquet_file.read_row_group(i).to_pandas()

        # Filter rows corresponding to sampled molecule_smiles
        original_sampled_data = row_group[row_group['molecule_smiles'].isin(sampled_smiles)]
        if not original_sampled_data.empty:
            print(original_sampled_data[['molecule_smiles', 'protein_name', 'binds']])

# File paths
train_file_path = 'train_data.parquet'
test_file_path = 'test_data.parquet'
val_file_path = 'val_data.parquet'
original_file_path = 'filtered_train.parquet'

# Verify train data
print("Verifying train data:")
verify_data(train_file_path, original_file_path)

# Verify test data
print("\nVerifying test data:")
verify_data(test_file_path, original_file_path)

# Verify validation data
print("\nVerifying validation data:")
verify_data(val_file_path, original_file_path)


Verifying train data:
Sampled molecule_smiles from train_data.parquet:
['O=C(N[Dy])C1CCC(CNc2nc(NCCOc3ccc(F)c(F)c3)nc(NCC(F)(F)C(F)(F)F)n2)CC1'
 'CCSCCNc1nc(NCC2CCC(C(=O)N[Dy])CC2)nc(Nc2c(F)cccc2F)n1'
 'Cn1cncc1C(CNc1nc(NCC=Cc2cccnc2)nc(Nc2cc(F)c(Br)cc2C(=O)N[Dy])n1)N1CCCC1'
 'CC(C)(CCC#N)CNc1nc(NCCC2CCOC2)nc(N[C@@H](CC(=O)N[Dy])Cc2ccc(Br)cc2)n1'
 'CCOC(=O)c1cnc(Nc2nc(NCC3CCCCC(F)(F)C3)nc(NC(CC(C)C)C(=O)N[Dy])n2)cn1'
 'O=C(N[Dy])C1CCC(CNc2nc(Nc3ccc(F)c([N+](=O)[O-])c3)nc(Nc3nc4ccccc4o3)n2)CC1'
 'CN(c1nc(NCc2ccc[n+]([O-])c2)nc(NCC(C)(C)CCC#N)n1)[C@@H](CC1CCCCC1)C(=O)N[Dy]'
 'O=C1CCC(Nc2nc(NCC3CCC(C(=O)N[Dy])CC3)nc(Nc3ccc(-c4ncc[nH]4)cc3)n2)CC1'
 'O=C(N[Dy])c1ccc(Nc2nc(NCc3ccc(CN4CCCC4=O)cc3)nc(NCC3(Cc4ccccc4)CC3)n2)cc1'
 'CNC(=O)c1cc(Oc2ccc(Nc3nc(NCc4nc5c(s4)CCC5)nc(N[C@@H](CC(=O)N[Dy])c4cccc(Cl)c4Cl)n3)cc2)ccn1']

Rows from train_data.parquet:
                                                                                     molecule_smiles  \
199660                         CC(C)(CCC

# Exhaustive List of Features for Small Molecules

## Molecular Descriptors:
- Molecular weight
- Number of atoms
- Number of bonds
- Number of aromatic rings
- Number of rotatable bonds
- Topological polar surface area (TPSA)
- LogP (octanol-water partition coefficient)

## Atom-Level Features:
- Atom types (e.g., C, H, O, N, S)
- Hybridization states (sp, sp2, sp3)
- Formal charge
- Aromaticity
- Degree (number of bonds to the atom)
- Implicit and explicit hydrogen counts
- Chirality

## Bond-Level Features:
- Bond types (single, double, triple, aromatic)
- Conjugation
- Ring membership
- Stereo configuration (cis/trans)

## Graph-Based Features:
- Adjacency matrix
- Distance matrix
- Graph Laplacian

## Physicochemical Properties:
- Hydrogen bond donors and acceptors
- Molecular refractivity
- Molar volume
- Electronegativity
- Electron affinity

## Structural Fingerprints:
- MACCS keys
- Morgan fingerprints
- ECFP (Extended Connectivity Fingerprints)
- RDKIT fingerprints


In [9]:
from rdkit import Chem
from rdkit.Chem import Descriptors, rdMolDescriptors, AllChem, rdmolops
from rdkit.DataStructs import ConvertToNumpyArray
import numpy as np

# Define encoding schemes outside the class
ATOM_TYPES = ['C', 'H', 'O', 'N', 'S', 'F', 'Cl', 'Br', 'I', 'P', 'B']
HYBRIDIZATION_STATES = ['SP', 'SP2', 'SP3', 'SP3D', 'SP3D2']
CHIRAL_TAGS = ['CHI_UNSPECIFIED', 'CHI_TETRAHEDRAL_CW', 'CHI_TETRAHEDRAL_CCW', 'CHI_OTHER']
BOND_TYPES = ['SINGLE', 'DOUBLE', 'TRIPLE', 'AROMATIC']
STEREO_CONFIGURATIONS = ['STEREONONE', 'STEREOZ', 'STEREOE', 'STEREOCIS', 'STEREOTRANS']

class SmallMoleculeFeatureExtractor:
    def __init__(self, smiles):
        self.smiles = smiles
        self.mol = Chem.MolFromSmiles(smiles)

    def get_molecular_descriptors(self):
        descriptors = {
            'molecular_weight': Descriptors.MolWt(self.mol),
            'num_atoms': self.mol.GetNumAtoms(),
            'num_bonds': self.mol.GetNumBonds(),
            'num_aromatic_rings': rdMolDescriptors.CalcNumAromaticRings(self.mol),
            'num_rotatable_bonds': Descriptors.NumRotatableBonds(self.mol),
            'tpsa': Descriptors.TPSA(self.mol),
            'logp': Descriptors.MolLogP(self.mol)
        }
        return descriptors

    def one_hot_encode(self, value, categories):
        encoding = [0] * len(categories)
        if value in categories:
            encoding[categories.index(value)] = 1
        return encoding

    def get_atom_level_features(self):
        atom_features = []
        for atom in self.mol.GetAtoms():
            atom_features.append([
                self.one_hot_encode(atom.GetSymbol(), ATOM_TYPES),
                self.one_hot_encode(str(atom.GetHybridization()), HYBRIDIZATION_STATES),
                atom.GetFormalCharge(),
                atom.GetIsAromatic(),
                atom.GetDegree(),
                atom.GetImplicitValence(),
                atom.GetTotalNumHs(),
                self.one_hot_encode(str(atom.GetChiralTag()), CHIRAL_TAGS)
            ])
        return atom_features

    def get_bond_level_features(self):
        bond_features = []
        for bond in self.mol.GetBonds():
            bond_features.append([
                self.one_hot_encode(str(bond.GetBondType()), BOND_TYPES),
                bond.GetIsConjugated(),
                bond.IsInRing(),
                self.one_hot_encode(str(bond.GetStereo()), STEREO_CONFIGURATIONS)
            ])
        return bond_features

    def get_graph_based_features(self):
        adj_matrix = rdmolops.GetAdjacencyMatrix(self.mol)
        dist_matrix = rdmolops.GetDistanceMatrix(self.mol)
        return {
            'adjacency_matrix': adj_matrix,
            'distance_matrix': dist_matrix,
        }

    def get_physicochemical_properties(self):
        properties = {
            'h_bond_donors': Descriptors.NumHDonors(self.mol),
            'h_bond_acceptors': Descriptors.NumHAcceptors(self.mol),
            'molecular_refractivity': Descriptors.MolMR(self.mol),
            'molar_volume': Descriptors.MolLogP(self.mol) / Descriptors.MolWt(self.mol)
        }
        return properties

    def get_structural_fingerprints(self):
        maccs_keys = AllChem.GetMACCSKeysFingerprint(self.mol)
        morgan_fp = AllChem.GetMorganFingerprintAsBitVect(self.mol, 2)
        rdk_fp = Chem.RDKFingerprint(self.mol)

        maccs_keys_np = np.zeros((1,))
        ConvertToNumpyArray(maccs_keys, maccs_keys_np)

        morgan_fp_np = np.zeros((1,))
        ConvertToNumpyArray(morgan_fp, morgan_fp_np)

        rdk_fp_np = np.zeros((1,))
        ConvertToNumpyArray(rdk_fp, rdk_fp_np)
        
        return {
            'maccs_keys': maccs_keys_np,
            'morgan_fp': morgan_fp_np,
            'rdkit_fp': rdk_fp_np
        }

    def extract_features(self):
        features = {
            'molecular_descriptors': self.get_molecular_descriptors(),
            'atom_level_features': self.get_atom_level_features(),
            'bond_level_features': self.get_bond_level_features(),
            'graph_based_features': self.get_graph_based_features(),
            'physicochemical_properties': self.get_physicochemical_properties(),
            'structural_fingerprints': self.get_structural_fingerprints()
        }
        return features

    def flatten_features(self):
        # Extract individual features
        molecular_descriptors = self.get_molecular_descriptors()
        physicochemical_properties = self.get_physicochemical_properties()
        structural_fingerprints = self.get_structural_fingerprints()
        graph_based_features = self.get_graph_based_features()
    
        # Flatten the structural fingerprints
        flattened_structural_fingerprints = np.concatenate([
            structural_fingerprints['maccs_keys'],
            structural_fingerprints['morgan_fp'],
            structural_fingerprints['rdkit_fp']
        ])
    
        # Convert molecular descriptors and physicochemical properties to arrays
        molecular_descriptors_array = np.array(list(molecular_descriptors.values()))
        physicochemical_properties_array = np.array(list(physicochemical_properties.values()))

        # Extract adjacency and distance matrices
        adjacency_matrix = graph_based_features['adjacency_matrix']
        distance_matrix = graph_based_features['distance_matrix']
    
        return {
            'molecular_descriptors': molecular_descriptors_array,
            'physicochemical_properties': physicochemical_properties_array,
            'structural_fingerprints': flattened_structural_fingerprints,
            'adjacency_matrix': adjacency_matrix,
            'distance_matrix': distance_matrix
        }


In [24]:
smiles = "C#CC[C@@H](CC(=O)O)NC(=O)OCC1c2ccccc2-c2ccccc21"
extractor = SmallMoleculeFeatureExtractor(smiles)
features = extractor.flatten_features()
for feature, value in features.items():
    print(f"{feature}: {value}")

molecular_descriptors: [349.386   26.      28.       2.       6.      75.63     3.3917]
physicochemical_properties: [2.00000000e+00 3.00000000e+00 9.76955000e+01 9.70760133e-03]
structural_fingerprints: [0. 0. 0. ... 1. 1. 1.]
adjacency_matrix: [[0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 1 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 1 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 1 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 

# Feature Extraction from Protein Structure PDB File

## Structural Features

### Amino Acid Composition:
- Frequency of each amino acid type in the binding site.
- Frequency of amino acid types in the entire protein.

### Secondary Structure:
- Percentage of alpha-helices, beta-sheets, and random coils in the binding site.
- Secondary structure elements around the binding site.

### Tertiary Structure:
- 3D coordinates of the binding site.
- Distance between key residues in the binding site.

### Binding Site Characteristics:
- Volume and surface area of the binding site.
- Shape descriptors (e.g., sphericity, elongation).

## Physicochemical Properties

### Hydrophobicity:
- Hydrophobic and hydrophilic residue distribution in the binding site.
- Hydrophobic surface area.

### Charge Distribution:
- Number and type of charged residues (positive and negative).
- Electrostatic potential distribution.

### Polarity:
- Number of polar residues.
- Polar surface area.

### Solvent Accessibility:
- Solvent-accessible surface area (SASA) of residues in the binding site.

### Hydrogen Bonding:
- Number of potential hydrogen bond donors and acceptors.
- Hydrogen bond network in the binding site.

### Van der Waals Interactions:
- Van der Waals interaction potential of the binding site.

## Geometric Features

### Distance Metrics:
- Pairwise distances between all residues in the binding site.
- Distance to the nearest surface residue.

### Angles and Dihedrals:
- Angles and dihedral angles between residues in the binding site.

## Chemical Environment

### Residue Environment:
- Local chemical environment of each residue (e.g., neighboring residues within a certain radius).

### Ligand Interaction Sites:
- Specific interaction sites for known ligands (if available).

## Dynamic Properties

### Flexibility:
- B-factors or temperature factors indicating residue flexibility.

### Molecular Dynamics Simulations:
- Root mean square fluctuation (RMSF) of residues in the binding site.
- Conformational changes over time.

## Topological Features

### Graph-based Features:
- Protein structure represented as a graph with nodes (residues) and edges (interactions).
- Degree centrality, betweenness centrality, and clustering coefficient of residues in the binding site.

## Energy-based Features

### Binding Energy:
- Estimated binding free energy of known ligands.
- Energy components (van der Waals, electrostatic, solvation) from docking simulations.

## Protein-Ligand Interaction Features

### Docking Scores:
- Scores from molecular docking simulations with various ligands.

### Interaction Profiles:
- Interaction fingerprints summarizing the types and strengths of interactions with ligands.

## Evolutionary Features

### Conservation:
- Sequence conservation of residues in the binding site (e.g., from multiple sequence alignment).

### Mutational Impact:
- Predicted impact of mutations on binding site residues.

## Experimental Data

### Experimental Binding Data:
- Known binding affinities (e.g., Kd, Ki, IC50) for small molecules.

## Contextual Features

### Functional Annotations:
- Biological function and pathway involvement of the protein.
- Known protein-protein interactions.

## Integration and Representation

### Feature Scaling and Normalization:
- Standardize and normalize features for input into the deep learning model.


In [34]:
from Bio.PDB import PDBParser, is_aa, NeighborSearch
import numpy as np
import networkx as nx

class ProteinFeatureExtractor:
    def __init__(self, pdb_file):
        self.pdb_file = pdb_file
        self.structure = self.load_structure()
        self.ligand_resnames = self.detect_ligands()
        self.graph = self.construct_graph()

    def load_structure(self):
        # Load the PDB structure
        parser = PDBParser()
        structure = parser.get_structure('protein', self.pdb_file)
        return structure

    def detect_ligands(self):
        # Detect ligand residue names by excluding standard amino acids and water
        ligands = set()
        standard_amino_acids = {'ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE', 
                                'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL'}
        water_residues = {'HOH'}
        for residue in self.structure.get_residues():
            resname = residue.resname
            if resname not in standard_amino_acids and resname not in water_residues:
                ligands.add(resname)
        return list(ligands)

    def get_amino_acid_composition(self):
        # Get the composition of amino acids in the protein
        amino_acids = [residue.resname for residue in self.structure.get_residues() if residue.id[0] == ' ']
        aa_counts = {aa: amino_acids.count(aa) for aa in set(amino_acids)}
        return aa_counts


    def get_flexibility(self):
        # Calculate the flexibility of the protein based on B-factors
        flexibility = []
        for atom in self.structure.get_atoms():
            flexibility.append(atom.bfactor)
        return np.mean(flexibility)

    def get_distance_metrics(self):
        # Calculate distance metrics between residues in the protein
        distances = []
        for chain in self.structure.get_chains():
            print(f"Processing chain: {chain.id}")
            residues = [res for res in chain if 'CA' in res.child_dict]  # Filter residues with 'CA' atom
            for i, res1 in enumerate(residues):
                ca1 = res1.child_dict.get('CA')
                if ca1 is None:
                    print(f"Residue {res1} does not have a CA atom.")
                    continue
                for j, res2 in enumerate(residues):
                    if i < j:
                        ca2 = res2.child_dict.get('CA')
                        if ca2 is None:
                            print(f"Residue {res2} does not have a CA atom.")
                            continue
                        try:
                            distance = ca1 - ca2
                            distances.append(distance)
                        except KeyError as e:
                            print(f"Error calculating distance: {e}")
        return distances

    def construct_graph(self, cutoff=4.0):
        # Initialize an undirected graph
        G = nx.Graph()

        # Add nodes for each residue
        for chain in self.structure.get_chains():
            for residue in chain:
                if is_aa(residue):
                    G.add_node(residue.id, residue=residue)

        # Add edges based on distance cutoff
        atoms = list(self.structure.get_atoms())
        ns = NeighborSearch(atoms)
        for atom in atoms:
            if atom.element == 'H':  # Skip hydrogen atoms
                continue
            neighbors = ns.search(atom.coord, cutoff)
            for neighbor in neighbors:
                if neighbor.element == 'H':  # Skip hydrogen atoms
                    continue
                res1 = atom.get_parent()
                res2 = neighbor.get_parent()
                if res1 != res2:
                    G.add_edge(res1.id, res2.id, weight=atom - neighbor)

        return G

    def extract_graph_features(self):
        # Adjacency matrix
        adjacency_matrix = nx.adjacency_matrix(self.graph).todense()

        # Distance matrix (Floyd-Warshall algorithm)
        distance_matrix = nx.floyd_warshall_numpy(self.graph)

        # Degree centrality
        degree_centrality = nx.degree_centrality(self.graph)

        # Betweenness centrality
        betweenness_centrality = nx.betweenness_centrality(self.graph)

        # Clustering coefficient
        clustering_coefficient = nx.clustering(self.graph)

        # Ensure features are in a consistent order
        nodes = list(self.graph.nodes)
        degree_centrality = np.array([degree_centrality[node] for node in nodes])
        betweenness_centrality = np.array([betweenness_centrality[node] for node in nodes])
        clustering_coefficient = np.array([clustering_coefficient[node] for node in nodes])

        # Aggregate features into a dictionary
        features = {
            'adjacency_matrix': adjacency_matrix,
            'distance_matrix': distance_matrix,
            'degree_centrality': degree_centrality,
            'betweenness_centrality': betweenness_centrality,
            'clustering_coefficient': clustering_coefficient
        }

        return features

    def extract_features(self):
        # Extract various features from the protein structure
        amino_acid_composition = self.get_amino_acid_composition()
        flexibility = self.get_flexibility()
        distance_metrics = self.get_distance_metrics()
        graph_features = self.extract_graph_features()

        features = {
            "amino_acid_composition": amino_acid_composition,
            "flexibility": flexibility,
            "distance_metrics": distance_metrics,
            "graph_features": graph_features
        }
        return features

    def extract_and_aggregate_features(self):
        # Extract various features from the protein structure
        amino_acid_composition = self.get_amino_acid_composition()
        flexibility = self.get_flexibility()
        distance_metrics = self.get_distance_metrics()
        graph_features = self.extract_graph_features()
    
        # Combine amino acid composition and flexibility into a single array
        amino_acid_comp_values = list(amino_acid_composition.values())
        combined_features = amino_acid_comp_values + [flexibility]
    
        features = {
            "protein_combined_features": np.array(combined_features),
            "distance_metrics": distance_metrics,
            "distance_matrix": graph_features["distance_matrix"],
            "adjacency_matrix": graph_features["adjacency_matrix"],
            "degree_centrality": graph_features["degree_centrality"],
            "betweenness_centrality": graph_features["betweenness_centrality"],
            "clustering_coefficient": graph_features["clustering_coefficient"],
        }
        return features


In [35]:
pdb_file = "./ALB.pdb"
extractor = ProteinFeatureExtractor(pdb_file)
aggregated_features = extractor.extract_and_aggregate_features()
print(aggregated_features["adjacency_matrix"])




Processing chain: A
Processing chain: B
[[0.        3.6402905 3.8100557 ... 0.        0.        0.       ]
 [3.6402905 0.        3.7282476 ... 0.        0.        0.       ]
 [3.8100557 3.7282476 0.        ... 0.        0.        0.       ]
 ...
 [0.        0.        0.        ... 0.        0.        0.       ]
 [0.        0.        0.        ... 0.        0.        0.       ]
 [0.        0.        0.        ... 0.        0.        0.       ]]


In [39]:
import dgl
import torch
import pandas as pd
import pyarrow.parquet as pq
import pyarrow as pa
import time
from tqdm import tqdm
from torch.nn.functional import pad

# Assume SmallMoleculeFeatureExtractor and ProteinFeatureExtractor classes are already defined.

def create_heterogeneous_graphs(parquet_file_path, protein_files, variant):
    # Load parquet file and initialize unique sets
    parquet_file = pq.ParquetFile(parquet_file_path)
    num_row_groups = parquet_file.num_row_groups

    all_data = []
    print('Loading all data')
    for i in range(num_row_groups):
        row_group = parquet_file.read_row_group(i).to_pandas()
        all_data.append(row_group)
    all_data = pd.concat(all_data, ignore_index=True)
    print('Completed loading all data')

    print(f'Final sample size {len(all_data)}')

    # Initialize node and edge lists for positive and negative graphs
    node_data = {
        'building_block': [],
        'small_molecule': [],
        'molecular_descriptor': [],
        'physicochemical_properties': [],
        'adjacency_matrix': [],
        'distance_matrix': [],
        'protein': [],
        'distance_metrics': [],
        'degree_centrality': [],
        'betweenness_centrality': [],
        'clustering_coefficient': []
    }

    edge_data_pos = {
        ('building_block', 'has', 'molecular_descriptor'): [],
        ('building_block', 'has', 'physicochemical_properties'): [],
        ('building_block', 'has', 'adjacency_matrix'): [],
        ('building_block', 'has', 'distance_matrix'): [],
        ('small_molecule', 'has', 'molecular_descriptor'): [],
        ('small_molecule', 'has', 'physicochemical_properties'): [],
        ('small_molecule', 'has', 'adjacency_matrix'): [],
        ('small_molecule', 'has', 'distance_matrix'): [],
        ('small_molecule', 'contains', 'building_block'): [],
        ('small_molecule', 'binds', 'protein'): [],
        ('protein', 'has', 'distance_metrics'): [],
        ('protein', 'has', 'degree_centrality'): [],
        ('protein', 'has', 'betweenness_centrality'): [],
        ('protein', 'has', 'clustering_coefficient'): []
    }

    edge_data_neg = {
        ('building_block', 'has', 'molecular_descriptor'): [],
        ('building_block', 'has', 'physicochemical_properties'): [],
        ('building_block', 'has', 'adjacency_matrix'): [],
        ('building_block', 'has', 'distance_matrix'): [],
        ('small_molecule', 'has', 'molecular_descriptor'): [],
        ('small_molecule', 'has', 'physicochemical_properties'): [],
        ('small_molecule', 'has', 'adjacency_matrix'): [],
        ('small_molecule', 'has', 'distance_matrix'): [],
        ('small_molecule', 'contains', 'building_block'): [],
        ('small_molecule', 'binds', 'protein'): [],
        ('protein', 'has', 'distance_metrics'): [],
        ('protein', 'has', 'degree_centrality'): [],
        ('protein', 'has', 'betweenness_centrality'): [],
        ('protein', 'has', 'clustering_coefficient'): []
    }

    # Dictionaries to keep track of indices and for saving later
    node_indices = {ntype: 0 for ntype in node_data.keys()}
    building_block_index_map = {}
    protein_index_map = {}
    
    print('Starting protein node creation')
    # Load protein data
    for protein_name, pdb_file in protein_files.items():
        extractor = ProteinFeatureExtractor(pdb_file)
        features = extractor.extract_and_aggregate_features()
        protein_idx = node_indices['protein']
        node_data['protein'].append((protein_idx, {'feature': torch.tensor(features['protein_combined_features'])}))
        node_data['distance_metrics'].append((protein_idx, {'feature': torch.tensor(features['distance_metrics'])}))
        node_data['degree_centrality'].append((protein_idx, {'feature': torch.tensor(features['degree_centrality'])}))
        node_data['betweenness_centrality'].append((protein_idx, {'feature': torch.tensor(features['betweenness_centrality'])}))
        node_data['clustering_coefficient'].append((protein_idx, {'feature': torch.tensor(features['clustering_coefficient'])}))

        #Add edges for protein features
        edge_data_pos[('protein', 'has', 'distance_metrics')].append((protein_idx, protein_idx))
        edge_data_pos[('protein', 'has', 'degree_centrality')].append((protein_idx, protein_idx))
        edge_data_pos[('protein', 'has', 'betweenness_centrality')].append((protein_idx, protein_idx))
        edge_data_pos[('protein', 'has', 'clustering_coefficient')].append((protein_idx, protein_idx))

        edge_data_neg[('protein', 'has', 'distance_metrics')].append((protein_idx, protein_idx))
        edge_data_neg[('protein', 'has', 'degree_centrality')].append((protein_idx, protein_idx))
        edge_data_neg[('protein', 'has', 'betweenness_centrality')].append((protein_idx, protein_idx))
        edge_data_neg[('protein', 'has', 'clustering_coefficient')].append((protein_idx, protein_idx))
        
        protein_index_map[protein_name] = protein_idx
        node_indices['protein'] += 1

    print('Completed protein node creation')

    # Process the sample
    print('Starting processing samples')
    start_time = time.time()
    log_interval = 120
    total_rows = all_data.shape[0]

    for sampleIndex, row in tqdm(all_data.iterrows(), total=len(all_data)):
        # Check elapsed time
        elapsed_time = time.time() - start_time
        if elapsed_time >= log_interval:
            print(f"Processing index: {sampleIndex} of {total_rows}")
            start_time = time.time()  # Reset start time

        building_blocks = [
            row['buildingblock1_smiles'],
            row['buildingblock2_smiles'],
            row['buildingblock3_smiles']
        ]
        molecule_smiles = row['molecule_smiles']
        binds_protein = row['binds']
        not_binds_protein = row['not_binds']

        # Process small molecule
        mol_extractor = SmallMoleculeFeatureExtractor(molecule_smiles)
        mol_features = mol_extractor.flatten_features()
        mol_idx = node_indices['small_molecule']
        mol_desc_index = node_indices['molecular_descriptor']
        phys_index = node_indices['physicochemical_properties']
        adj_index = node_indices['adjacency_matrix']
        dist_index = node_indices['distance_matrix']

        node_data['small_molecule'].append((mol_idx, {'feature': torch.tensor(mol_features['structural_fingerprints'])}))
        node_data['molecular_descriptor'].append((mol_desc_index, {'feature': torch.tensor(mol_features['molecular_descriptors'])}))
        node_data['physicochemical_properties'].append((phys_index, {'feature': torch.tensor(mol_features['physicochemical_properties'])}))
        node_data['adjacency_matrix'].append((adj_index, {'feature': torch.tensor(mol_features['adjacency_matrix'])}))
        node_data['distance_matrix'].append((dist_index, {'feature': torch.tensor(mol_features['distance_matrix'])}))

        node_indices['small_molecule'] += 1
        node_indices['molecular_descriptor'] += 1
        node_indices['physicochemical_properties'] += 1
        node_indices['adjacency_matrix'] += 1
        node_indices['distance_matrix'] += 1

        # Create 'has' edges for small molecule
        edge_data_pos[('small_molecule', 'has', 'molecular_descriptor')].append((mol_idx, mol_desc_index))
        edge_data_pos[('small_molecule', 'has', 'physicochemical_properties')].append((mol_idx, phys_index))
        edge_data_pos[('small_molecule', 'has', 'adjacency_matrix')].append((mol_idx, adj_index))
        edge_data_pos[('small_molecule', 'has', 'distance_matrix')].append((mol_idx, dist_index))

        edge_data_neg[('small_molecule', 'has', 'molecular_descriptor')].append((mol_idx, mol_desc_index))
        edge_data_neg[('small_molecule', 'has', 'physicochemical_properties')].append((mol_idx, phys_index))
        edge_data_neg[('small_molecule', 'has', 'adjacency_matrix')].append((mol_idx, adj_index))
        edge_data_neg[('small_molecule', 'has', 'distance_matrix')].append((mol_idx, dist_index))

        # Process building blocks
        for bb_smiles in building_blocks:
            if bb_smiles not in building_block_index_map:
                bb_extractor = SmallMoleculeFeatureExtractor(bb_smiles)
                bb_features = bb_extractor.flatten_features()
                bb_idx = node_indices['building_block']
                mol_desc_index = node_indices['molecular_descriptor']
                phys_index = node_indices['physicochemical_properties']
                adj_index = node_indices['adjacency_matrix']
                dist_index = node_indices['distance_matrix']
    
                node_data['building_block'].append((bb_idx, {'feature': torch.tensor(bb_features['structural_fingerprints'])}))
                node_data['molecular_descriptor'].append((mol_desc_index, {'feature': torch.tensor(bb_features['molecular_descriptors'])}))
                node_data['physicochemical_properties'].append((phys_index, {'feature': torch.tensor(bb_features['physicochemical_properties'])}))
                node_data['adjacency_matrix'].append((adj_index, {'feature': torch.tensor(bb_features['adjacency_matrix'])}))
                node_data['distance_matrix'].append((dist_index, {'feature': torch.tensor(bb_features['distance_matrix'])}))

                node_indices['building_block'] += 1
                node_indices['molecular_descriptor'] += 1
                node_indices['physicochemical_properties'] += 1
                node_indices['adjacency_matrix'] += 1
                node_indices['distance_matrix'] += 1

                # Create 'has' edges for building block
                edge_data_pos[('building_block', 'has', 'molecular_descriptor')].append((bb_idx, mol_desc_index))
                edge_data_pos[('building_block', 'has', 'physicochemical_properties')].append((bb_idx, phys_index))
                edge_data_pos[('building_block', 'has', 'adjacency_matrix')].append((bb_idx, adj_index))
                edge_data_pos[('building_block', 'has', 'distance_matrix')].append((bb_idx, dist_index))

                edge_data_neg[('building_block', 'has', 'molecular_descriptor')].append((bb_idx, mol_desc_index))
                edge_data_neg[('building_block', 'has', 'physicochemical_properties')].append((bb_idx, phys_index))
                edge_data_neg[('building_block', 'has', 'adjacency_matrix')].append((bb_idx, adj_index))
                edge_data_neg[('building_block', 'has', 'distance_matrix')].append((bb_idx, dist_index))

                building_block_index_map[bb_smiles] = bb_idx
            else:
                bb_idx = building_block_index_map[bb_smiles]

            # Add 'contains' edge from small molecule to building block
            edge_data_pos[('small_molecule', 'contains', 'building_block')].append((mol_idx, bb_idx))
            edge_data_neg[('small_molecule', 'contains', 'building_block')].append((mol_idx, bb_idx))

        # Add 'binds' edge from small molecule to protein
        protein_idx = protein_index_map[binds_protein]
        edge_data_pos[('small_molecule', 'binds', 'protein')].append((mol_idx, protein_idx))

        protein_idx = protein_index_map[not_binds_protein]
        edge_data_neg[('small_molecule', 'binds', 'protein')].append((mol_idx, protein_idx))

    print('Completed processing samples')

    # Create graphs
    g_pos = dgl.heterograph(edge_data_pos)
    g_neg = dgl.heterograph(edge_data_neg)

    # Find the maximum feature shape for each node type
    max_feature_shapes = {}
    for ntype, features in tqdm(node_data.items(), total=len(node_data.items())):
        max_shape = torch.Size([0])
        for _, feat in features:
            if feat['feature'].shape > max_shape:
                max_shape = feat['feature'].shape
        max_feature_shapes[ntype] = max_shape
    
    # Function to pad features to the maximum shape
    def pad_feature(feature, max_shape):
        padding = [0] * (2 * len(max_shape))
        for i in range(len(max_shape)):
            padding[2 * i + 1] = max_shape[i] - feature.shape[i]
        return pad(feature, padding)

    # Assign padded features to nodes
    for ntype, features in tqdm(node_data.items(), total=len(node_data.items())):
        indices, feats = zip(*features)
        max_shape = max_feature_shapes[ntype]
        padded_feats = [pad_feature(feat['feature'], max_shape) for feat in feats]
        g_pos.nodes[ntype].data['feature'] = torch.stack(padded_feats)
        g_neg.nodes[ntype].data['feature'] = torch.stack(padded_feats)

    # Save the graphs
    dgl.save_graphs(f"./heterogeneous_graph_pos_{variant}.dgl", [g_pos])
    dgl.save_graphs(f"./heterogeneous_graph_neg_{variant}.dgl", [g_neg])

    # Save building block and protein indices to Parquet
    building_block_df = pd.DataFrame.from_dict(building_block_index_map, orient='index', columns=['smiles'])
    protein_df = pd.DataFrame.from_dict(protein_index_map, orient='index', columns=['protein_name'])

    building_block_df.to_parquet(f"./building_block_indices_{variant}.parquet")
    protein_df.to_parquet(f"./protein_indices_{variant}.parquet")

    return g_pos, g_neg

def load_heterogeneous_graphs():
    g_pos, _ = dgl.load_graphs("./heterogeneous_graph_pos.dgl")
    g_neg, _ = dgl.load_graphs("./heterogeneous_graph_neg.dgl")

    # Load building block and protein indices
    building_block_df = pd.read_parquet("./building_block_indices.parquet")
    protein_df = pd.read_parquet("./protein_indices.parquet")

    building_block_index_map = building_block_df.to_dict(orient='index')
    protein_index_map = protein_df.to_dict(orient='index')

    return g_pos[0], g_neg[0], building_block_index_map, protein_index_map

# Usage:
# File paths
train_file_path = 'train_data.parquet'
test_file_path = 'test_data.parquet'
val_file_path = 'val_data.parquet'
protein_files = {
    'BRD4': './BRD4.pdb',
    'HSA': './ALB.pdb',
    'sEH': './EPH.pdb'
}

# train_pos_g, train_neg_g = create_heterogeneous_graphs(train_file_path, protein_files, 'train')
# test_pos_g, test_neg_g = create_heterogeneous_graphs(test_file_path, protein_files, 'test')
val_pos_g, val_neg_g = create_heterogeneous_graphs(val_file_path, protein_files, 'val')


Loading all data
Completed loading all data
Final sample size 150974
Starting protein node creation
Processing chain: A




Processing chain: A
Processing chain: B
Processing chain: A
Completed protein node creation
Starting processing samples


 35%|████████████▎                      | 53103/150974 [01:36<02:57, 549.95it/s]


KeyboardInterrupt: 

In [None]:
import math

import dgl
import dgl.function as fn

import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.functional import edge_softmax


class HGTLayer(nn.Module):
    def __init__(
        self,
        in_dim,
        out_dim,
        node_dict,
        edge_dict,
        n_heads,
        dropout=0.2,
        use_norm=False,
    ):
        super(HGTLayer, self).__init__()

        self.in_dim = in_dim
        self.out_dim = out_dim
        self.node_dict = node_dict
        self.edge_dict = edge_dict
        self.num_types = len(node_dict)
        self.num_relations = len(edge_dict)
        self.total_rel = self.num_types * self.num_relations * self.num_types
        self.n_heads = n_heads
        self.d_k = out_dim // n_heads
        self.sqrt_dk = math.sqrt(self.d_k)
        self.att = None

        self.k_linears = nn.ModuleList()
        self.q_linears = nn.ModuleList()
        self.v_linears = nn.ModuleList()
        self.a_linears = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.use_norm = use_norm

        for t in range(self.num_types):
            self.k_linears.append(nn.Linear(in_dim, out_dim))
            self.q_linears.append(nn.Linear(in_dim, out_dim))
            self.v_linears.append(nn.Linear(in_dim, out_dim))
            self.a_linears.append(nn.Linear(out_dim, out_dim))
            if use_norm:
                self.norms.append(nn.LayerNorm(out_dim))

        self.relation_pri = nn.Parameter(
            torch.ones(self.num_relations, self.n_heads)
        )
        self.relation_att = nn.Parameter(
            torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k)
        )
        self.relation_msg = nn.Parameter(
            torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k)
        )
        self.skip = nn.Parameter(torch.ones(self.num_types))
        self.drop = nn.Dropout(dropout)

        nn.init.xavier_uniform_(self.relation_att)
        nn.init.xavier_uniform_(self.relation_msg)

    def forward(self, G, h):
        with G.local_scope():
            node_dict, edge_dict = self.node_dict, self.edge_dict
            for srctype, etype, dsttype in G.canonical_etypes:
                sub_graph = G[srctype, etype, dsttype]

                k_linear = self.k_linears[node_dict[srctype]]
                v_linear = self.v_linears[node_dict[srctype]]
                q_linear = self.q_linears[node_dict[dsttype]]

                k = k_linear(h[srctype]).view(-1, self.n_heads, self.d_k)
                v = v_linear(h[srctype]).view(-1, self.n_heads, self.d_k)
                q = q_linear(h[dsttype]).view(-1, self.n_heads, self.d_k)

                e_id = self.edge_dict[etype]

                relation_att = self.relation_att[e_id]
                relation_pri = self.relation_pri[e_id]
                relation_msg = self.relation_msg[e_id]

                k = torch.einsum("bij,ijk->bik", k, relation_att)
                v = torch.einsum("bij,ijk->bik", v, relation_msg)

                sub_graph.srcdata["k"] = k
                sub_graph.dstdata["q"] = q
                sub_graph.srcdata["v_%d" % e_id] = v

                sub_graph.apply_edges(fn.v_dot_u("q", "k", "t"))
                attn_score = (
                    sub_graph.edata.pop("t").sum(-1)
                    * relation_pri
                    / self.sqrt_dk
                )
                attn_score = edge_softmax(sub_graph, attn_score, norm_by="dst")

                sub_graph.edata["t"] = attn_score.unsqueeze(-1)

            G.multi_update_all(
                {
                    etype: (
                        fn.u_mul_e("v_%d" % e_id, "t", "m"),
                        fn.sum("m", "t"),
                    )
                    for etype, e_id in edge_dict.items()
                },
                cross_reducer="mean",
            )

            new_h = {}
            for ntype in G.ntypes:
                """
                Step 3: Target-specific Aggregation
                x = norm( W[node_type] * gelu( Agg(x) ) + x )
                """
                n_id = node_dict[ntype]
                alpha = torch.sigmoid(self.skip[n_id])
                t = G.nodes[ntype].data["t"].view(-1, self.out_dim)
                trans_out = self.drop(self.a_linears[n_id](t))
                trans_out = trans_out * alpha + h[ntype] * (1 - alpha)
                if self.use_norm:
                    new_h[ntype] = self.norms[n_id](trans_out)
                else:
                    new_h[ntype] = trans_out
            return new_h

class HeteroDotProductPredictor(nn.Module):
    def forward(self, graph, h, etype):
        with graph.local_scope():
            for ntype in graph.ntypes:
                graph[ntype].data['h'] = h[ntype]

            graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
            return graph.edges[etype].data['score']


class HGT(nn.Module):
    def __init__(
        self,
        node_dict,
        edge_dict,
        in_dim_dict,
        n_hid,
        n_out,
        n_layers,
        n_heads,
        use_norm=True,
    ):
        super(HGT, self).__init__()
        self.node_dict = node_dict
        self.edge_dict = edge_dict
        self.gcs = nn.ModuleList()
        self.n_hid = n_hid
        self.n_out = n_out
        self.n_layers = n_layers

        self.adapt_ws = nn.ModuleList()
        for ntype, in_dim in in_dim_dict.items():
            if len(in_dim) == 1:
                self.adapt_ws[ntype] = nn.Linear(in_dim[0], n_hid)
            elif len(in_dim) == 2:
                self.adapt_ws[ntype] = nn.Sequential(
                    nn.Conv1d(in_dim[0], n_hid, kernel_size=1),
                    nn.Conv1d(n_hid, n_hid, kernel_size=1),
                    nn.Flatten(),
                    nn.Linear(n_hid * in_dim[1], n_hid)
                )
            elif len(in_dim) == 3:
                self.adapt_ws[ntype] = nn.Sequential(
                    nn.Conv2d(in_dim[0], n_hid, kernel_size=1),
                    nn.Conv2d(n_hid, n_hid, kernel_size=1),
                    nn.Flatten(),
                    nn.Linear(n_hid * in_dim[1] * in_dim[2], n_hid)
                )
            elif len(in_dim) == 4:
                self.adapt_ws[ntype] = nn.Sequential(
                    nn.Conv3d(in_dim[0], n_hid, kernel_size=1),
                    nn.Conv3d(n_hid, n_hid, kernel_size=1),
                    nn.Flatten(),
                    nn.Linear(n_hid * in_dim[1] * in_dim[2] * in_dim[3], n_hid)
                )

        for _ in range(n_layers):
            self.gcs.append(
                HGTLayer(
                    n_hid,
                    n_hid,
                    node_dict,
                    edge_dict,
                    n_heads,
                    use_norm=use_norm,
                )
            )
        self.pred = HeteroDotProductPredictor()

    def forward(self, G, neg_g, etype):
        h = {}
        for ntype in G.ntypes:
            n_id = self.node_dict[ntype]
            h[ntype] = F.gelu(self.adapt_ws[n_id](G.nodes[ntype].data["feature"]))
        for i in range(self.n_layers):
            h = self.gcs[i](G, h)
        return self.pred(g, h, etype), self.pred(neg_g, h, etype)

In [None]:
from dgl.dataloading import EdgeDataLoader
from dgl.sampling import global_uniform_negative_sampling

def compute_loss(pos_score, neg_score):
    n_edges = pos_score.shape[0]
    return (1 - pos_score + neg_score.view(n_edges, -1)).clamp(min=0).mean()

def train(model, pos_g, neg_g, edge_type, optimizer, scheduler, num_epochs):
    for epoch in range(num_epochs):
        pos_score, neg_score = model(pos_g, neg_g, edge_type)
        loss = compute_loss(pos_score, neg_score)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss / len(dataloader)}')

# Initialize the model and optimizer
node_dict = {}
edge_dict = {}
in_dim_dict = {ntype: pos_g.nodes[ntype].data['feature'].shape[1] for ntype in pos_g.ntypes}
for ntype in pos_g.ntypes:
    node_dict[ntype] = len(node_dict)
for etype in pos_g.etypes:
    edge_dict[etype] = len(edge_dict)
    pos_g.edges[etype].data["id"] = (
        torch.ones(pos_g.num_edges(etype), dtype=torch.long) * edge_dict[etype]
    )

model = HGT(
    node_dict=node_dict,
    edge_dict=edge_dict,
    in_dim_dict = in_dim_dict,
    n_hid=128,
    n_out=32,
    n_layers=2,
    n_heads=4,
    use_norm=True
)
epochs = 100
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, total_steps = epochs, max_lr = 1e-3
)

# Train the model
train(model, pos_g, neg_g, 'binds', optimizer, scheduler, epochs)
