# 04 — Prepare R-GCN Link Prediction (disease–gene)
# 
#### This notebook:
#### loads the heterogeneous graph from `data/processed/hetero_graph.pt`
#### - builds a homogeneous graph for R-GCN
#### - assigns relation IDs to each edge type
#### - selects the ('disease', 'assoc_gene', 'gene') edges as the target relation
#### - creates train/val/test splits for that relation
#### - saves everything to `data/processed/rgcn_linkpred_assoc_gene.pt`.


In [1]:
import os
from pathlib import Path

import torch
from torch_geometric.data import HeteroData

# Paths
PROJECT_ROOT = Path.cwd().parent if Path.cwd().name == "notebooks" else Path.cwd()
DATA_PROCESSED = PROJECT_ROOT / "data" / "processed"
GRAPH_PATH = DATA_PROCESSED / "hetero_graph.pt"

print("PROJECT_ROOT:", PROJECT_ROOT)
print("DATA_PROCESSED:", DATA_PROCESSED)
print("GRAPH_PATH:", GRAPH_PATH)

obj = torch.load(GRAPH_PATH, weights_only=False)
data: HeteroData = obj["data"]
node_maps = obj["node_maps"]

data

PROJECT_ROOT: /Users/domenicschmidt/Documents/master thesis/lung-cancer-gnn
DATA_PROCESSED: /Users/domenicschmidt/Documents/master thesis/lung-cancer-gnn/data/processed
GRAPH_PATH: /Users/domenicschmidt/Documents/master thesis/lung-cancer-gnn/data/processed/hetero_graph.pt


HeteroData(
  disease={
    num_nodes=58,
    x=[58, 1],
  },
  gene={
    num_nodes=14779,
    x=[14779, 1],
  },
  variant={
    num_nodes=285747,
    x=[285747, 1],
  },
  gene_fusion={
    num_nodes=4266,
    x=[4266, 1],
  },
  chrom_rearr={
    num_nodes=1993,
    x=[1993, 1],
  },
  pathway={
    num_nodes=1488,
    x=[1488, 1],
  },
  biomarker={
    num_nodes=24,
    x=[24, 1],
  },
  chemical={
    num_nodes=160,
    x=[160, 1],
  },
  evidence={
    num_nodes=144,
    x=[144, 1],
  },
  city={
    num_nodes=368,
    x=[368, 1],
  },
  demographic_group={
    num_nodes=204,
    x=[204, 1],
  },
  (disease, assoc_gene, gene)={ edge_index=[2, 82461] },
  (disease, assoc_gene_fusion, gene_fusion)={ edge_index=[2, 4289] },
  (disease, assoc_chrom_rearr, chrom_rearr)={ edge_index=[2, 2309] },
  (disease, assoc_variant, variant)={ edge_index=[2, 500000] },
  (disease, assoc_pathway, pathway)={ edge_index=[2, 168] },
  (gene, participates_in, pathway)={ edge_index=[2, 50651] },
  (d

# Inspect edge types & build relation ID mapping


In [2]:
print("Heterogeneous edge types:")
for i, et in enumerate(data.edge_types):
    print(f"  {i}: {et}")

edge_type_to_id = {et: i for i, et in enumerate(data.edge_types)}
edge_type_to_id

Heterogeneous edge types:
  0: ('disease', 'assoc_gene', 'gene')
  1: ('disease', 'assoc_gene_fusion', 'gene_fusion')
  2: ('disease', 'assoc_chrom_rearr', 'chrom_rearr')
  3: ('disease', 'assoc_variant', 'variant')
  4: ('disease', 'assoc_pathway', 'pathway')
  5: ('gene', 'participates_in', 'pathway')
  6: ('disease', 'assoc_biomarker', 'biomarker')
  7: ('chemical', 'has_evidence', 'evidence')
  8: ('chemical', 'measured_in', 'city')
  9: ('disease', 'has_demographic_stats', 'demographic_group')


{('disease', 'assoc_gene', 'gene'): 0,
 ('disease', 'assoc_gene_fusion', 'gene_fusion'): 1,
 ('disease', 'assoc_chrom_rearr', 'chrom_rearr'): 2,
 ('disease', 'assoc_variant', 'variant'): 3,
 ('disease', 'assoc_pathway', 'pathway'): 4,
 ('gene', 'participates_in', 'pathway'): 5,
 ('disease', 'assoc_biomarker', 'biomarker'): 6,
 ('chemical', 'has_evidence', 'evidence'): 7,
 ('chemical', 'measured_in', 'city'): 8,
 ('disease', 'has_demographic_stats', 'demographic_group'): 9}

# Select target relation: (disease, assoc_gene, gene)


In [3]:
target_edge_type = ("disease", "assoc_gene", "gene")
target_rel_id = edge_type_to_id[target_edge_type]

print("Target edge type:", target_edge_type)
print("Target relation ID:", target_rel_id)


Target edge type: ('disease', 'assoc_gene', 'gene')
Target relation ID: 0


# Convert to homogeneous graph (for R-GCN)


In [4]:
homo = data.to_homogeneous(
    add_edge_type=True,  # adds `edge_type` tensor (relation ID for each edge)
    add_node_type=True,  # adds `node_type` tensor (node type ID for each node)
)

print(homo)
print("Homogeneous num_nodes:", homo.num_nodes)
print("Homogeneous num_edges:", homo.edge_index.shape[1])
print("Edge_type shape:", homo.edge_type.shape)


Data(edge_index=[2, 640677], x=[309231, 1], node_type=[309231], edge_type=[640677])
Homogeneous num_nodes: 309231
Homogeneous num_edges: 640677
Edge_type shape: torch.Size([640677])


# Extract indices of edges of the target relation


In [5]:
edge_type = homo.edge_type  # [num_edges]
is_target = (edge_type == target_rel_id)

target_edge_indices = is_target.nonzero(as_tuple=False).view(-1)
num_target_edges = target_edge_indices.numel()

print("Number of edges for target relation:", num_target_edges)


Number of edges for target relation: 82461


# Sanity check: matches hetero edge count?


In [6]:
expected = data[target_edge_type].edge_index.shape[1]
print("Expected edges from hetero:", expected)
assert expected == num_target_edges, "Mismatch between hetero and homogeneous!"


Expected edges from hetero: 82461


# Create train/val/test split for target edges


In [7]:
import math

torch.manual_seed(42)  # reproducibility

num = num_target_edges
perm = torch.randperm(num)

train_frac = 0.8
val_frac = 0.1
test_frac = 0.1

num_train = int(num * train_frac)
num_val = int(num * val_frac)
num_test = num - num_train - num_val  # Rest

train_idx = target_edge_indices[perm[:num_train]]
val_idx   = target_edge_indices[perm[num_train:num_train + num_val]]
test_idx  = target_edge_indices[perm[num_train + num_val:]]

print("Train edges:", train_idx.numel())
print("Val edges:  ", val_idx.numel())
print("Test edges: ", test_idx.numel())
print("Sum:", train_idx.numel() + val_idx.numel() + test_idx.numel())


Train edges: 65968
Val edges:   8246
Test edges:  8247
Sum: 82461


# Build boolean masks on all edges


In [8]:
num_edges = homo.edge_index.shape[1]

edge_train_mask = torch.zeros(num_edges, dtype=torch.bool)
edge_val_mask   = torch.zeros(num_edges, dtype=torch.bool)
edge_test_mask  = torch.zeros(num_edges, dtype=torch.bool)

edge_train_mask[train_idx] = True
edge_val_mask[val_idx]     = True
edge_test_mask[test_idx]   = True

# Sanity checks: no overlap, only target edges betroffen
assert not (edge_train_mask & edge_val_mask).any()
assert not (edge_train_mask & edge_test_mask).any()
assert not (edge_val_mask & edge_test_mask).any()

# Alle True-Werte müssen target edges sein
assert (edge_type[edge_train_mask] == target_rel_id).all()
assert (edge_type[edge_val_mask]   == target_rel_id).all()
assert (edge_type[edge_test_mask]  == target_rel_id).all()

homo.edge_train_mask = edge_train_mask
homo.edge_val_mask   = edge_val_mask
homo.edge_test_mask  = edge_test_mask

print("Masks assigned.")
print("Train/Val/Test sums:", edge_train_mask.sum().item(), edge_val_mask.sum().item(), edge_test_mask.sum().item())


Masks assigned.
Train/Val/Test sums: 65968 8246 8247


# Package everything for R-GCN training


In [9]:
rgcn_obj = {
    "hetero_data": data,
    "homo_data": homo,
    "node_maps": node_maps,
    "edge_type_to_id": edge_type_to_id,
    "target_edge_type": target_edge_type,
    "target_rel_id": target_rel_id,
}

out_path = DATA_PROCESSED / "rgcn_linkpred_assoc_gene.pt"
torch.save(rgcn_obj, out_path)

print("Saved R-GCN link prediction dataset to:")
print("   ", out_path)


Saved R-GCN link prediction dataset to:
    /Users/domenicschmidt/Documents/master thesis/lung-cancer-gnn/data/processed/rgcn_linkpred_assoc_gene.pt


# Reload test

In [10]:
loaded = torch.load(out_path, weights_only=False)
homo2 = loaded["homo_data"]

print(homo2)
print("edge_type shape:", homo2.edge_type.shape)
print("edge_train_mask sum:", homo2.edge_train_mask.sum().item())
print("edge_val_mask   sum:", homo2.edge_val_mask.sum().item())
print("edge_test_mask  sum:", homo2.edge_test_mask.sum().item())


Data(edge_index=[2, 640677], x=[309231, 1], node_type=[309231], edge_type=[640677], edge_train_mask=[640677], edge_val_mask=[640677], edge_test_mask=[640677])
edge_type shape: torch.Size([640677])
edge_train_mask sum: 65968
edge_val_mask   sum: 8246
edge_test_mask  sum: 8247
