# Create the graph for DC GNN

In this notebook, I will create the graph for the DC GNN model.

In [1]:
import pandas as pd
import dotenv
import os

from sqlalchemy import create_engine, text

from IPython.display import display

dotenv.load_dotenv()

user = os.getenv("DB_USER")
password = os.getenv("DB_PASSWORD")
database = os.getenv("DB_NAME")
host = os.getenv("DB_HOST", "127.0.0.1")
port = os.getenv("DB_PORT", "3306")

connection_str = f"mysql+pymysql://{user}:{password}@{host}:{port}/{database}"
engine = create_engine(connection_str)

In [2]:
display(pd.read_sql("SHOW COLUMNS FROM cell_line", con=engine))
display(pd.read_sql("SHOW COLUMNS FROM disease", con=engine))
display(pd.read_sql("SHOW COLUMNS FROM drug", con=engine))
display(pd.read_sql("SHOW COLUMNS FROM gnn_experiments", con=engine))
display(pd.read_sql("SHOW COLUMNS FROM drug_comb_drug", con=engine))

Unnamed: 0,Field,Type,Null,Key,Default,Extra
0,cell_line_id,char(9),NO,PRI,,
1,cell_line_name,varchar(100),NO,UNI,,
2,source_id,int,YES,MUL,,
3,tissue,varchar(100),YES,,,
4,disease_id,varchar(50),YES,MUL,,


Unnamed: 0,Field,Type,Null,Key,Default,Extra
0,disease_id,varchar(15),NO,PRI,,
1,disease_name,varchar(350),NO,,,


Unnamed: 0,Field,Type,Null,Key,Default,Extra
0,drug_id,varchar(25),NO,PRI,,
1,source_id,int,NO,MUL,,
2,drug_name,varchar(300),YES,,,
3,molecular_type,varchar(50),YES,,,
4,chemical_structure,varchar(5000),YES,,,
5,inchi_key,varchar(250),YES,,,


Unnamed: 0,Field,Type,Null,Key,Default,Extra
0,dc_id,int,NO,PRI,,
1,cell_line_id,char(9),NO,PRI,,
2,score_value,float,NO,,,


Unnamed: 0,Field,Type,Null,Key,Default,Extra
0,dc_id,int,NO,PRI,,
1,drug_id,varchar(25),NO,PRI,,


In [3]:
# (DC - Disease) nodes
experiments = pd.read_sql(
    """
    SELECT e.dc_id, c.disease_id, e.score_value
    FROM gnn_experiments e
    JOIN cell_line c ON e.cell_line_id = c.cell_line_id
    WHERE c.disease_id IS NOT NULL
    """, con=engine)

# Drug nodes
drugs = pd.read_sql(
    """
    SELECT DISTINCT d.drug_id, d.chemical_structure
    FROM drug d
    JOIN drug_comb_drug dd ON d.drug_id = dd.drug_id
    JOIN gnn_experiments e ON dd.dc_id = e.dc_id
    """, con=engine
)

# Drug -> DC edges
drug_comb_drug = pd.read_sql(
    """
    SELECT DISTINCT dd.dc_id, dd.drug_id
    FROM drug_comb_drug dd
    JOIN gnn_experiments e ON dd.dc_id = e.dc_id
    JOIN cell_line c ON e.cell_line_id = c.cell_line_id
    WHERE c.disease_id IS NOT NULL
    """, con=engine
)

# Disease nodes
diseases = pd.read_sql(
    """
    SELECT DISTINCT d.disease_id
    FROM disease d
    JOIN cell_line c ON d.disease_id = c.disease_id
    JOIN gnn_experiments e ON c.cell_line_id = e.cell_line_id
    """, con=engine
)

In [4]:
print(len(drugs))
print(len(experiments))

32
6771


## Node creation
Create the mappings from the tables IDs to nodes IDs.

In [5]:
import torch
from torch_geometric.data import HeteroData

# Create a mapping from original IDs to contiguous integer IDs for each node
drug_mapping = {drug_id: idx for idx, drug_id in enumerate(drugs['drug_id'].unique())}
dc_mapping = {dc_id: idx for idx, dc_id in enumerate(experiments['dc_id'].unique())}
disease_mapping = {disease_id: idx for idx, disease_id in enumerate(diseases['disease_id'].unique())}

# Create HeteroData object
data = HeteroData()

### Node features
Create the drugs' features.

In [6]:
from rdkit import Chem
from rdkit.Chem import rdFingerprintGenerator
import numpy as np

mfpgen = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=512)
def smiles_to_fp(smiles, nBits=512):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return np.zeros(nBits)  # Return zero vector for invalid SMILES
    fp = mfpgen.GetFingerprintAsNumPy(mol)
    return fp

fps = np.stack(drugs['chemical_structure'].apply(smiles_to_fp))
x_drugs = torch.tensor(fps, dtype=torch.float)

# Check for zero vectors
empty_vectors = (x_drugs.sum(dim=1) == 0).sum().item()
print(f"Empty vectors: {empty_vectors} of {len(x_drugs)}")

Empty vectors: 0 of 32


Create the DC's features as the sum of the features of the drugs that are in the DC.

In [7]:
num_dcs = len(dc_mapping)
num_features = x_drugs.shape[1]

# Initialize DC features as zero vectors
x_dc = torch.zeros((num_dcs, num_features), dtype=torch.float)

# Get the mapped drug and DC IDs for the nodes
mapped_drugs = [drug_mapping[d] for d in drug_comb_drug['drug_id']]
mapped_dcs = [dc_mapping[dc] for dc in drug_comb_drug['dc_id']]

idx_drugs = torch.tensor(mapped_drugs, dtype=torch.long)
idx_dcs = torch.tensor(mapped_dcs, dtype=torch.long)

# Add the drug features to the corresponding DC features
x_dc.index_add_(0, idx_dcs, x_drugs[idx_drugs])

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 1.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        ...,
        [1., 1., 0.,  ..., 0., 0., 0.],
        [2., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]])

Create the disease's features as random embeddings.

In [8]:
x_disease = torch.randn((len(disease_mapping), num_features), dtype=torch.float)

Add the features to the graph.

In [9]:
data["drug"].x = x_drugs
data["dc"].x = x_dc
data["disease"].x = x_disease

## Edge creation

### Drug -> DC

In [10]:
src_drug = drug_comb_drug['drug_id'].map(drug_mapping).values
dst_dc = drug_comb_drug['dc_id'].map(dc_mapping).values

edge_index = torch.from_numpy(
    np.vstack((src_drug, dst_dc))
).long()
data["drug", "interacts", "dc"].edge_index = edge_index

### DC -> Disease

In [11]:
positive_experiments = experiments[experiments['score_value'] > 1]

src_dc = positive_experiments['dc_id'].map(dc_mapping).values
dst_disease = positive_experiments['disease_id'].map(disease_mapping).values

edge_index = torch.from_numpy(
    np.vstack((src_dc, dst_disease))
).long()
data["dc", "treats", "disease"].edge_index = edge_index

In [13]:
# ==========================================
# GRAPH INTEGRITY TEST (SANITY CHECKS)
# ==========================================

print("--- 1. GENERAL STRUCTURE ---")
# This prints the graph schema: node types, feature dimensions,
# and the shape of edge_index matrices.
print(data)
print("\n")

print("--- 2. NATIVE PyG VALIDATION ---")
# PyG has an internal validator that checks whether any index in edge_index
# exceeds the number of nodes, or if feature dimensions are inconsistent.
try:
    data.validate(raise_on_error=True)
    print("✅ The graph passed PyG strict validation (No out-of-range indices).")
except Exception as e:
    print(f"❌ Error detected by PyG: {e}")
print("\n")

print("--- 3. CHECK FOR NaNs OR INFINITIES ---")
# Neural networks will break if there are NaNs in the features.
all_clean = True
for node_type in data.node_types:
    has_nans = torch.isnan(data[node_type].x).any().item()
    if has_nans:
        print(f"❌ Warning! Features of node type '{node_type}' contain NaNs.")
        all_clean = False

for edge_type in data.edge_types:
    has_nans = torch.isnan(data[edge_type].edge_index).any().item()
    if has_nans:
        print(f"❌ Warning! edge_index of {edge_type} contains NaNs.")
        all_clean = False

if all_clean:
    print("✅ No NaNs found in node features or edges.")
print("\n")

print("--- 4. DATASET STATISTICS ---")
print(f"'drug' nodes:      {data['drug'].num_nodes} (Feature dimension: {data['drug'].x.shape[1]})")
print(f"'dc' nodes:        {data['dc'].num_nodes} (Feature dimension: {data['dc'].x.shape[1]})")
print(f"'disease' nodes:   {data['disease'].num_nodes} (Feature dimension: {data['disease'].x.shape[1]})")
print(f"'interacts' edges (drug -> dc):   {data['drug', 'interacts', 'dc'].num_edges}")
print(f"'treats' edges     (dc -> disease): {data['dc', 'treats', 'disease'].num_edges}")

# Biological check: are there enough synergies to learn from?
sparsity = (
    data['dc', 'treats', 'disease'].num_edges
    / (data['dc'].num_nodes * data['disease'].num_nodes)
    * 100
)
print(f"\nSynergy density (Sparsity): {sparsity:.4f}%")

--- 1. GENERAL STRUCTURE ---
HeteroData(
  drug={ x=[32, 512] },
  dc={ x=[255, 512] },
  disease={ x=[37, 512] },
  (drug, interacts, dc)={ edge_index=[2, 510] },
  (dc, treats, disease)={ edge_index=[2, 2029] }
)


--- 2. NATIVE PyG VALIDATION ---
✅ The graph passed PyG strict validation (No out-of-range indices).


--- 3. CHECK FOR NaNs OR INFINITIES ---
✅ No NaNs found in node features or edges.


--- 4. DATASET STATISTICS ---
'drug' nodes:      32 (Feature dimension: 512)
'dc' nodes:        255 (Feature dimension: 512)
'disease' nodes:   37 (Feature dimension: 512)
'interacts' edges (drug -> dc):   510
'treats' edges     (dc -> disease): 2029

Synergy density (Sparsity): 21.5050%


In [15]:
torch.save(data, '../data/graph.pt')