In [None]:
# Cell 1: Setup and Dependencies

# Install necessary libraries
!pip install pandas numpy torch torch_geometric tqdm

# --- Create required directories ---
# The scripts assume these folders exist for input files.
import os
os.makedirs("vocab", exist_ok=True)
os.makedirs("SnomedCT/Snapshot/Terminology", exist_ok=True)

print("Environment setup complete. Folders created.")

Collecting torch_geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m84.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.7.0
Environment setup complete. Folders created.


In [None]:
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

# --- User Configuration: Set your base path for data on Google Drive ---
# Make sure your files are organized relative to this path, matching the original structure.
# Example: If 'mimic-iv-3.1' is in 'MyDrive/KnowledgeGraphData/', set path below.
BASE_DRIVE_PATH = '/content/drive/MyDrive/knowledgegraphdata'

# --- Placeholder creation removed ---
# You should place your actual data files in your Google Drive
# under the structure:
# {BASE_DRIVE_PATH}/mimic-iv-3.1/hosp/diagnoses_icd.csv.gz
# {BASE_DRIVE_PATH}/vocab/CONCEPT.zip
# {BASE_DRIVE_PATH}/vocab/CONCEPT_RELATIONSHIP.zip
# {BASE_DRIVE_PATH}/SnomedCT/Snapshot/Terminology/sct2_Relationship_Snapshot_US1000124_20250901.txt
# {BASE_DRIVE_PATH}/SnomedCT/Snapshot/Terminology/sct2_Description_Snapshot-en_US1000124_20250901.txt

print(f"Google Drive mounted. Please ensure your data files are in '{BASE_DRIVE_PATH}' and update relevant paths in subsequent cells.")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Google Drive mounted. Please ensure your data files are in '/content/drive/MyDrive/knowledgegraphdata' and update relevant paths in subsequent cells.


In [None]:
import pandas as pd

# === File paths ===
# Use BASE_DRIVE_PATH for Google Drive files
rels_path = os.path.join(BASE_DRIVE_PATH, "/content/drive/MyDrive/knowledgegraphdata/sct2_Relationship_Snapshot_US1000124_20250901.txt")
desc_path = os.path.join(BASE_DRIVE_PATH, "/content/drive/MyDrive/knowledgegraphdata/sct2_Description_Snapshot-en_US1000124_20250901.txt")

print("\n Loading SNOMED CT raw files...")
snomed_rels = pd.read_csv(rels_path, sep="\t", dtype=str)
snomed_descs = pd.read_csv(desc_path, sep="\t", dtype=str)
print(f"\u2192 Relationships: {snomed_rels.shape}, Descriptions: {snomed_descs.shape}")

# ===  Keep only active relationships ===
before_active = len(snomed_rels)
snomed_rels = snomed_rels[snomed_rels["active"] == "1"]
after_active = len(snomed_rels)
print(f" Active filter: {before_active:,} \u2192 {after_active:,} (removed {before_active - after_active:,})")

# ===  Extract preferred terms ===
before_pref = len(snomed_descs)
pref_terms = snomed_descs[snomed_descs["typeId"] == "900000000000013009"][["conceptId", "term"]]
pref_terms.dropna(subset=["term"], inplace=True)
after_pref = len(pref_terms)
print(f" Preferred terms: {before_pref:,} \u2192 {after_pref:,} (removed {before_pref - after_pref:,})")

# ===  Build full relation map (typeId \u2192 term) ===
relation_ids = snomed_rels["typeId"].unique()
relation_map_df = pref_terms[pref_terms["conceptId"].isin(relation_ids)]
relation_map = dict(zip(relation_map_df["conceptId"], relation_map_df["term"]))
print(f" Found {len(relation_map)} unique SNOMED relation types")

# === Add human-readable relation names ===
snomed_rels["relation_term"] = snomed_rels["typeId"].map(relation_map).fillna("other_relation")

# ===  Merge source & destination names ===
print("\n Merging concept names...")
rels_before_merge = snomed_rels.shape

snomed_rels = snomed_rels.merge(
    pref_terms, left_on="sourceId", right_on="conceptId", how="left"
).rename(columns={"term": "source_term"}).drop(columns=["conceptId"])

snomed_rels = snomed_rels.merge(
    pref_terms, left_on="destinationId", right_on="conceptId", how="left"
).rename(columns={"term": "destination_term"}).drop(columns=["conceptId"])

rels_after_merge = snomed_rels.shape
print(f"\u2192 Merge completed: {rels_before_merge} \u2192 {rels_after_merge}")

print("\n Cleaning data...")
before_clean = len(snomed_rels)

# Drop empty fields
snomed_rels = snomed_rels.dropna(subset=["source_term", "destination_term", "relation_term"])
after_dropna = len(snomed_rels)
print(f"Removed empty rows: {before_clean - after_dropna:,}")

# Drop rows with blank strings
mask_blank = ((snomed_rels["source_term"].str.strip() == "") |  (snomed_rels["destination_term"].str.strip() == ""))
removed_blank = mask_blank.sum()
snomed_rels = snomed_rels[~mask_blank]
print(f"Removed blank-term rows: {removed_blank:,}")


# Reorder columns
snomed_rels = snomed_rels[["sourceId", "source_term", "relation_term", "destinationId", "destination_term"]]

#Save final dataset
out_path = "snomed_relations_full.tsv"
snomed_rels.to_csv(out_path, sep="\t", index=False)

#Summary report
print("\n====== SUMMARY REPORT ======")
print(f"Final dataset shape: {snomed_rels.shape}")
print(f"Saved to: {out_path}")
print(snomed_rels.sample(min(5, len(snomed_rels)), random_state=42).to_string(index=False))



 Loading SNOMED CT raw files...
→ Relationships: (3581598, 10), Descriptions: (1696627, 9)
 Active filter: 3,581,598 → 1,336,381 (removed 2,245,217)
 Preferred terms: 1,696,627 → 994,002 (removed 702,625)
 Found 101 unique SNOMED relation types

 Merging concept names...
→ Merge completed: (1336381, 11) → (7161453, 13)

 Cleaning data...
Removed empty rows: 1
Removed blank-term rows: 0

Final dataset shape: (7161452, 5)
Saved to: snomed_relations_full.tsv
  sourceId                                                               source_term   relation_term destinationId      destination_term
1231766002                                     Hemorrhage after thrombolytic therapy            Is a     131148009           Haemorrhage
  27520001                                 Pustular psoriasis of the palms and soles      Morphology      48055004         Pustular rash
 766940004                                                                      Role            Is a     362981000       Qualifi

In [None]:
import pandas as pd
from itertools import combinations
import zipfile # Import zipfile
import os      # Import os
from tqdm import tqdm # For progress bar

# 1. Load and Clean MIMIC-IV Diagnoses
mimic_path = os.path.join(BASE_DRIVE_PATH, "/content/drive/MyDrive/knowledgegraphdata/diagnoses_icd.csv.gz")
print("Loading MIMIC-IV diagnoses data...")
diagnoses = pd.read_csv(mimic_path, dtype=str)
print(f"Loaded: {diagnoses.shape[0]:,} rows \u00d7 {diagnoses.shape[1]} columns\n")

cols_to_keep = ["subject_id", "hadm_id", "seq_num", "icd_code", "icd_version"]
diagnoses = diagnoses[cols_to_keep]

before_filter = len(diagnoses)
diagnoses = diagnoses[diagnoses["icd_version"] == "10"]
after_filter = len(diagnoses)
print(f"ICD-10 filter applied: {before_filter:,} \u2192 {after_filter:,} rows")

before_clean = len(diagnoses)
diagnoses = diagnoses.dropna(subset=["icd_code"])
diagnoses = diagnoses[diagnoses["icd_code"].str.strip() != ""]
after_clean = len(diagnoses)
print(f"Cleaned empty ICD codes: {before_clean:,} \u2192 {after_clean:,} rows")

before_dedup = len(diagnoses)
diagnoses.drop_duplicates(subset=["subject_id", "hadm_id", "icd_code"], inplace=True)
after_dedup = len(diagnoses)
print(f"Removed duplicates: {before_dedup - after_dedup:,}")

diagnoses.reset_index(drop=True, inplace=True)

out_path = "mimic_icd10_clean.csv"
diagnoses.to_csv(out_path, index=False)
print(f"\nCleaned ICD-10 data saved \u2192 {out_path}")

print("\nSample of cleaned dataset:")
print(diagnoses.sample(min(10, len(diagnoses)), random_state=42).to_string(index=False))

# 2. Extract ICD \u2192 SNOMED CT Mapping from Athena
print("\nLoading OHDSI Athena vocabularies...")

# --- Load CSVs directly from ZIPs into memory ---
zip_file_concept = os.path.join(BASE_DRIVE_PATH, "/content/drive/MyDrive/knowledgegraphdata/CONCEPT.zip")
with zipfile.ZipFile(zip_file_concept, 'r') as zf:
    with zf.open('CONCEPT.csv') as f:
        concepts = pd.read_csv(f, sep="\t", dtype=str, low_memory=False)
print(f"Loaded CONCEPT.csv from {zip_file_concept}")

zip_file_rel = os.path.join(BASE_DRIVE_PATH, "/content/drive/MyDrive/knowledgegraphdata/CONCEPT_RELATIONSHIP.zip")
with zipfile.ZipFile(zip_file_rel, 'r') as zf:
    with zf.open('CONCEPT_RELATIONSHIP.csv') as f:
        rels = pd.read_csv(f, sep="\t", dtype=str, low_memory=False)
print(f"Loaded CONCEPT_RELATIONSHIP.csv from {zip_file_rel}")
# ---------------------------------------------------

# Ensure vocabulary_id is clean right after loading
concepts['vocabulary_id'] = concepts['vocabulary_id'].str.strip()

# Ensure relationship_id is clean right after loading
rels['relationship_id'] = rels['relationship_id'].str.strip()

print(f"Concepts: {len(concepts):,} rows, Relationships: {len(rels):,} rows")
print("\n--- Debug: concepts DataFrame head ---")
display(concepts.head())
print("\n--- Debug: rels DataFrame head ---")
display(rels.head())

mapping = rels[rels["relationship_id"].isin(["Maps to", "Maps to value"])]

concepts_small = concepts[["concept_id", "concept_code", "concept_name", "vocabulary_id"]]
merged = (
    mapping
    .merge(concepts_small, left_on="concept_id_1", right_on="concept_id")
    .merge(concepts_small, left_on="concept_id_2", right_on="concept_id", suffixes=('_source', '_destination'))
)
print("\n--- Debug: merged DataFrame head (relevant columns) ---")
display(merged[["concept_code_source", "vocabulary_id_source", "concept_id_2", "concept_name_destination", "vocabulary_id_destination"]].head())

# 3. Create MIMIC-to-SNOMED mapping
print("\nCreating MIMIC-to-SNOMED mapping...")
# Filter mapping for ICD10CM source and SNOMED destination
icd_snomed_map_df = merged[
    (merged["vocabulary_id_source"] == "ICD10CM") &
    (merged["vocabulary_id_destination"] == "SNOMED")
]
# Select and rename columns for clarity
icd_snomed_map_df = icd_snomed_map_df[[
    "concept_code_source", "concept_id_2", "concept_name_destination"
]].rename(columns={
    "concept_code_source": "icd_code",
    "concept_id_2": "snomed_concept_id",
    "concept_name_destination": "snomed_concept_name"
}).drop_duplicates(subset=["icd_code", "snomed_concept_id"])
print(f"ICD10CM to SNOMED mapping entries: {len(icd_snomed_map_df):,}")

# --- Debugging print statements ---
print(f"ICD codes in diagnoses: {diagnoses['icd_code'].unique()}")
print(f"Is 'A000' in icd_snomed_map_df? {'A000' in icd_snomed_map_df['icd_code'].values}")
if 'A000' in icd_snomed_map_df['icd_code'].values:
    print("Found 'A000' in icd_snomed_map_df. Sample entry:")
    print(icd_snomed_map_df[icd_snomed_map_df['icd_code'] == 'A000'].head().to_string(index=False))
else:
    print(" 'A000' not found in icd_snomed_map_df.")
# -----------------------------------

# Merge MIMIC diagnoses with SNOMED concepts
diagnoses_snomed = diagnoses.merge(
    icd_snomed_map_df,
    on="icd_code",
    how="inner" # Keep only diagnoses that can be mapped to SNOMED
)
print(f"MIMIC diagnoses after SNOMED mapping: {len(diagnoses_snomed):,} rows")
diagnoses_snomed.dropna(subset=["snomed_concept_id"], inplace=True)
diagnoses_snomed.drop_duplicates(subset=["subject_id", "hadm_id", "snomed_concept_id"], inplace=True)
print(f"MIMIC diagnoses (unique SNOMED per admission): {len(diagnoses_snomed):,} rows")


# 4. Extract Co-occurrence Pairs
print("\nExtracting co-occurrence pairs...")
co_occurrence_list = []
# Group by admission and iterate
for (subject_id, hadm_id), group in tqdm(diagnoses_snomed.groupby(["subject_id", "hadm_id"]), desc="Processing admissions"):
    snomed_codes = sorted(group["snomed_concept_id"].unique().tolist())
    snomed_names = {row["snomed_concept_id"]: row["snomed_concept_name"] for idx, row in group.iterrows()}

    # Generate all unique pairs of SNOMED codes within this admission
    for code1, code2 in combinations(snomed_codes, 2):
        # Ensure consistent order for pairs
        if code1 < code2:
            co_occurrence_list.append({
                "sourceId": code1,
                "source_term": snomed_names.get(code1, "UNKNOWN"),
                "relation_term": "co_occurs_with",
                "destinationId": code2,
                "destination_term": snomed_names.get(code2, "UNKNOWN")
            })
        else:
             co_occurrence_list.append({
                "sourceId": code2,
                "source_term": snomed_names.get(code2, "UNKNOWN"),
                "relation_term": "co_occurs_with",
                "destinationId": code1,
                "destination_term": snomed_names.get(code1, "UNKNOWN")
            })


# Create DataFrame from co-occurrence list
mimic_snomed_pairs = pd.DataFrame(co_occurrence_list)

# If no pairs are generated (e.g., due to limited placeholder data), create an empty df with correct columns
if mimic_snomed_pairs.empty:
    mimic_snomed_pairs = pd.DataFrame(columns=[
        "sourceId", "source_term", "relation_term", "destinationId", "destination_term"
    ])


print(f"Generated {len(mimic_snomed_pairs):,} co-occurrence pairs.")

# 5. Save the co-occurrence pairs
mimic_snomed_pairs_out_path = "mimic_snomed_pairs.tsv"
mimic_snomed_pairs.to_csv(mimic_snomed_pairs_out_path, sep="\t", index=False)
print(f"MIMIC co-occurrence pairs saved to \u2192 {mimic_snomed_pairs_out_path}")

Loading MIMIC-IV diagnoses data...
Loaded: 6,364,488 rows × 5 columns

ICD-10 filter applied: 6,364,488 → 3,455,747 rows
Cleaned empty ICD codes: 3,455,747 → 3,455,747 rows
Removed duplicates: 58

Cleaned ICD-10 data saved → mimic_icd10_clean.csv

Sample of cleaned dataset:
subject_id  hadm_id seq_num icd_code icd_version
  11979534 27415150       5     N185          10
  11130556 28389443       8   F10239          10
  16475227 21680259      14     M353          10
  19645775 21606532       8    J9811          10
  12448853 22437295       2    K5289          10
  13847788 27562694      10  T17590A          10
  19368574 29840211       7     M109          10
  12833439 22586352      12   Z87891          10
  19655369 27157385       8    Z9884          10
  15221512 29743103      11    I5021          10

Loading OHDSI Athena vocabularies...
Loaded CONCEPT.csv from /content/drive/MyDrive/knowledgegraphdata/CONCEPT.zip
Loaded CONCEPT_RELATIONSHIP.csv from /content/drive/MyDrive/knowledgeg

Unnamed: 0,concept_id,concept_name,domain_id,vocabulary_id,concept_class_id,standard_concept,concept_code,valid_start_date,valid_end_date,invalid_reason
0,45756805,Pediatric Cardiology,Provider,ABMS,Physician Specialty,S,OMOP4821938,19700101,20991231,
1,45756804,Pediatric Anesthesiology,Provider,ABMS,Physician Specialty,S,OMOP4821939,19700101,20991231,
2,45756803,Pathology-Anatomic / Pathology-Clinical,Provider,ABMS,Physician Specialty,S,OMOP4821940,19700101,20991231,
3,45756802,Pathology - Pediatric,Provider,ABMS,Physician Specialty,S,OMOP4821941,19700101,20991231,
4,45756801,Pathology - Molecular Genetic,Provider,ABMS,Physician Specialty,S,OMOP4821942,19700101,20991231,



--- Debug: rels DataFrame head ---


Unnamed: 0,concept_id_1,concept_id_2,relationship_id,valid_start_date,valid_end_date,invalid_reason
0,19082573,36935620,RxNorm dose form of,20230522,20991231,
1,40703384,36962214,Has marketed form,20230522,20991231,
2,19082573,36939166,RxNorm dose form of,20230522,20991231,
3,19001949,36946096,RxNorm dose form of,20230522,20991231,
4,19082573,36942470,RxNorm dose form of,20230522,20991231,



--- Debug: merged DataFrame head (relevant columns) ---


Unnamed: 0,concept_code_source,vocabulary_id_source,concept_id_2,concept_name_destination,vocabulary_id_destination
0,OMOP5199592,RxNorm Extension,36919314,Tozinameran 0.05 MG/ML,RxNorm Extension
1,OMOP5212221,RxNorm Extension,36919391,100 ML amoxicillin 25 MG/ML / clavulanate 6.25...,RxNorm Extension
2,OMOP5208019,RxNorm Extension,36919422,0.8 ML adalimumab 50 MG/ML Injectable Solution...,RxNorm Extension
3,OMOP5185964,RxNorm Extension,36919439,ondansetron 8 MG Disintegrating Oral Tablet by...,RxNorm Extension
4,OMOP5191160,RxNorm Extension,36919440,tadalafil 10 MG Delayed Release Oral Tablet Bo...,RxNorm Extension



Creating MIMIC-to-SNOMED mapping...
ICD10CM to SNOMED mapping entries: 129,908
ICD codes in diagnoses: ['G3183' 'F0280' 'R441' ... 'O30093' 'V835XXA' 'O359XX2']
Is 'A000' in icd_snomed_map_df? False
 'A000' not found in icd_snomed_map_df.
MIMIC diagnoses after SNOMED mapping: 218,769 rows
MIMIC diagnoses (unique SNOMED per admission): 218,768 rows

Extracting co-occurrence pairs...


Processing admissions: 100%|██████████| 148363/148363 [00:24<00:00, 5998.57it/s]


Generated 97,033 co-occurrence pairs.
MIMIC co-occurrence pairs saved to → mimic_snomed_pairs.tsv


In [None]:
# Cell 5: Merge SNOMED and MIMIC (from merge_snomed_and_mimic.py)

import pandas as pd

SNOMED_PATH = "snomed_relations_full.tsv"
MIMIC_PATH  = "mimic_snomed_pairs.tsv"
OUT_PATH    = "merged_relations.tsv"

print("Loading input files...")
snomed_df = pd.read_csv(SNOMED_PATH, sep="\t", dtype=str)
mimic_df  = pd.read_csv(MIMIC_PATH, sep="\t", dtype=str)
print(f"SNOMED relations: {len(snomed_df):,}")
print(f"MIMIC co-occurrences: {len(mimic_df):,}\n")

# === Normalize columns for consistency ===
required_cols = ["sourceId", "source_term", "relation_term", "destinationId", "destination_term"]

def ensure_cols(df, name):
    missing = set(required_cols) - set(df.columns)
    if missing:
        raise ValueError(f"{name} missing columns: {missing}")
    return df[required_cols].copy()

snomed_df = ensure_cols(snomed_df, "SNOMED")
mimic_df  = ensure_cols(mimic_df, "MIMIC")

# === Add dataset source label (optional) ===
snomed_df["source"] = "SNOMED_CT"
mimic_df["source"]  = "MIMIC_IV"

# === Merge ===
print("Merging datasets...")
before_merge_snomed = len(snomed_df)
before_merge_mimic  = len(mimic_df)
merged = pd.concat([snomed_df, mimic_df], ignore_index=True)

after_merge = len(merged)
print(f"Before merge: SNOMED={before_merge_snomed:,}, MIMIC={before_merge_mimic:,}")
print(f"After merge (combined): {after_merge:,}")

# === Remove duplicates ===
before_clean = len(merged)
merged.drop_duplicates(subset=["sourceId", "relation_term", "destinationId"], inplace=True)
after_clean = len(merged)
print(f"Duplicate removal: {before_clean:,} → {after_clean:,} (removed {before_clean - after_clean:,})")

# === Summary ===
print("\nSaving merged graph → {OUT_PATH}")
merged.to_csv(OUT_PATH, sep="\t", index=False)
print(f"Merged relational graph saved: {len(merged):,} rows")

Loading input files...
SNOMED relations: 7,161,452
MIMIC co-occurrences: 97,033

Merging datasets...
Before merge: SNOMED=7,161,452, MIMIC=97,033
After merge (combined): 7,258,485
Duplicate removal: 7,258,485 → 1,299,286 (removed 5,959,199)

Saving merged graph → {OUT_PATH}
Merged relational graph saved: 1,299,286 rows


In [None]:
# Cell 6: Graph Tensor Creation and Splitting (from Dataset_creation.py)

import pandas as pd
import numpy as np
import torch
from collections import defaultdict

IN_PATH = "merged_relations.tsv"
OUT_GRAPH = "graph_data.pt"
OUT_NODEMAP = "node_id_map.csv"
OUT_RELMAP = "rel_id_map.csv"

SPLIT_RATIO = (0.8, 0.1, 0.1)   # train/val/test
RANDOM_SEED = 42

print(f"Loading merged dataset: {IN_PATH}")
df = pd.read_csv(IN_PATH, sep="\t", dtype=str)
print(f"Loaded: {len(df):,} edges, {df['relation_term'].nunique():,} relation types")

required_cols = ["sourceId", "relation_term", "destinationId"]
for c in required_cols:
    if c not in df.columns:
        raise ValueError(f"Missing column: {c}")

# 1. Clean and validate (redundant due to merge_snomed_and_mimic.py but kept for robustness)
df = df.dropna(subset=required_cols)
df = df[(df["sourceId"].str.strip() != "") & (df["destinationId"].str.strip() != "")]
df.reset_index(drop=True, inplace=True)

# 2. Create node index mapping
nodes = pd.Index(pd.unique(pd.concat([df["sourceId"], df["destinationId"]], ignore_index=True)))
node2id = {nid: i for i, nid in enumerate(nodes)}
print(f"Unique nodes: {len(node2id):,}")

# 3. Create relation index mapping
base_relations = sorted(df["relation_term"].unique().tolist())
rel2id = {r: i for i, r in enumerate(base_relations)}
print(f"Base relation types: {len(rel2id)}")

# Encode edges
src = df["sourceId"].map(node2id).astype(np.int64).values
dst = df["destinationId"].map(node2id).astype(np.int64).values
rel = df["relation_term"].map(rel2id).astype(np.int64).values

# 4. Add inverse relations
inv_relations = [r + "_inv" for r in base_relations]
inv_offset = len(rel2id)
for i, r in enumerate(inv_relations):
    rel2id[r] = inv_offset + i

src_inv = dst.copy()
dst_inv = src.copy()
rel_inv = rel + inv_offset

# Concatenate original + inverse
src_all = np.concatenate([src, src_inv])
dst_all = np.concatenate([dst, dst_inv])
rel_all = np.concatenate([rel, rel_inv])

num_nodes = len(node2id)
num_relations = len(rel2id)
print(f"Total relations (including inverses): {num_relations}")
print(f"Total edges after adding inverses: {len(src_all):,}")

# 5. Build PyTorch tensors
edge_index = torch.tensor(np.vstack([src_all, dst_all]), dtype=torch.long)
edge_type = torch.tensor(rel_all, dtype=torch.long)

# 6. Train/Val/Test Split
rng = np.random.default_rng(RANDOM_SEED)
E = len(src)  # original (non-inverse) edges
perm = rng.permutation(E)

n_train = int(SPLIT_RATIO[0] * E)
n_val = int(SPLIT_RATIO[1] * E)

train_idx_base = perm[:n_train]
val_idx_base = perm[n_train:n_train + n_val]
test_idx_base = perm[n_train + n_val:]

# Duplicate indices for inverse edges
train_idx = np.concatenate([train_idx_base, train_idx_base + E])
val_idx = np.concatenate([val_idx_base, val_idx_base + E])
test_idx = np.concatenate([test_idx_base, test_idx_base + E])

train_idx = torch.tensor(train_idx, dtype=torch.long)
val_idx = torch.tensor(val_idx, dtype=torch.long)
test_idx = torch.tensor(test_idx, dtype=torch.long)

# 7. Save all outputs
graph_data = {
    "num_nodes": num_nodes,
    "num_relations": num_relations,
    "edge_index": edge_index,
    "edge_type": edge_type,
    "train_idx": train_idx,
    "val_idx": val_idx,
    "test_idx": test_idx,
    "node_ids": list(nodes),
    "rel2id": rel2id,
    "base_relations": base_relations,
}

torch.save(graph_data, OUT_GRAPH)
print(f"Saved graph tensors → {OUT_GRAPH}")

# Save mappings for interpretability
pd.DataFrame({"node_index": np.arange(num_nodes), "concept_id": list(nodes)}).to_csv(OUT_NODEMAP, index=False)
pd.DataFrame({"relation": list(rel2id.keys()), "rel_id": list(rel2id.values())}).to_csv(OUT_RELMAP, index=False)
print(f"Saved node map → {OUT_NODEMAP}")
print(f"Saved relation map → {OUT_RELMAP}")

print("\nGraph construction complete.")
print(f"Nodes: {num_nodes:,} | Relations: {num_relations:,} | Edges: {len(src_all):,}")
print(f"Train edges: {len(train_idx)}, Val: {len(val_idx)}, Test: {len(test_idx)}")

Loading merged dataset: merged_relations.tsv
Loaded: 1,299,286 edges, 102 relation types
Unique nodes: 382,356
Base relation types: 102
Total relations (including inverses): 204
Total edges after adding inverses: 2,598,572
Saved graph tensors → graph_data.pt
Saved node map → node_id_map.csv
Saved relation map → rel_id_map.csv

Graph construction complete.
Nodes: 382,356 | Relations: 204 | Edges: 2,598,572
Train edges: 2078856, Val: 259856, Test: 259860


In [None]:
import torch
import pandas as pd

# Load graph_data.pt
# This file contains the graph structure, node count, relation count, and node IDs.
graph_data = torch.load('graph_data.pt')
print("Loaded graph_data.pt")
print(f"Number of nodes: {graph_data['num_nodes']}")
print(f"Number of relations: {graph_data['num_relations']}")
print(f"Sample node IDs from graph_data: {graph_data['node_ids'][:5]}")

# Load node_id_map.csv
# This file maps the generated node indices back to the original concept IDs.
node_id_map_df = pd.read_csv('node_id_map.csv')
print("\nLoaded node_id_map.csv")
display(node_id_map_df.head())

Loaded graph_data.pt
Number of nodes: 382356
Number of relations: 204
Sample node IDs from graph_data: ['10000006', '134035007', '134136005', '10002003', '135161004']

Loaded node_id_map.csv


Unnamed: 0,node_index,concept_id
0,0,10000006
1,1,134035007
2,2,134136005
3,3,10002003
4,4,135161004


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.nn import TransE
from tqdm import tqdm
import os
import pandas as pd
import numpy as np

# --- 1. Custom Classifier Module ---
class AdaptedTransE(nn.Module):
    def __init__(self, num_nodes, num_relations, embedding_dim, dropout_rate=0.3):
        super().__init__()
        self.num_nodes = num_nodes
        self.num_relations = num_relations
        self.embedding_dim = embedding_dim

        # 1. Embedding Layer (TransE from PyG, for learning embeddings h and t)
        # Note: We use this layer purely for its learned entity embeddings and batching utility.
        # The TransE relation embeddings (e_r) are not used in the final classifier.
        self.transe = TransE(
            num_nodes=num_nodes,
            num_relations=num_relations,
            hidden_channels=embedding_dim,
            margin=1.0,
            p_norm=1.0,
        )

        # 2. Relation Classifier (Multi-Head Classifier as described in the paper)
        # It takes the difference vector (h - t), which is a tensor of size [batch_size, embedding_dim].
        # It maps this vector to the number of relation types.
        self.classifier = nn.Sequential(
            # First, normalize or use dropout on the difference vector (h - t)
            nn.Dropout(p=dropout_rate),
            # Linear layer to map the difference vector to logits for each relation type
            nn.Linear(embedding_dim, num_relations)
        )

        # Initialize the classifier's weights (using Xavier for stability, as suggested for embeddings)
        nn.init.xavier_uniform_(self.classifier[1].weight)

    def forward(self, head_index, tail_index):
        # 1. Get Normalized Entity Embeddings from the TransE layer
        # The TransE model internally maintains normalized entity embeddings.
        h = self.transe.node_emb(head_index)
        t = self.transe.node_emb(tail_index)

        # 2. Calculate the difference vector (h - t) which approximates the relation r
        diff_vector = h - t

        # 3. Pass the difference vector to the classifier head
        logits = self.classifier(diff_vector)

        return logits

    # We will use the standard PyG TransE loss for pre-training the embeddings
    def transe_loss(self, head_index, rel_type, tail_index):
        return self.transe.loss(head_index, rel_type, tail_index)

# --- 2. Training Function for the Classifier ---
def train_classifier(model, optimizer, data_loader, edge_index, edge_type, base_relations, num_epochs):
    model.train()

    # We need to map the full edge_type tensor to only the base relations (non-inverse)
    # to match the classifier's output size.
    # We assume base relations are the first half of the total relations (as per Dataset_creation.py)
    num_base_relations = len(base_relations)

    # We only train on the forward (non-inverse) relationships.
    forward_mask = edge_type < num_base_relations
    forward_indices = torch.where(forward_mask)[0]

    # Use only the training set indices for the forward relations
    # `graph_data['train_idx']` is now guaranteed to be on the correct device
    train_indices_forward = forward_indices[torch.isin(forward_indices, graph_data['train_idx'])]

    # Create a new loader that iterates over the forward-only triples
    classifier_loader = model.transe.loader(
        head_index=edge_index[0, train_indices_forward],
        rel_type=edge_type[train_indices_forward],
        tail_index=edge_index[1, train_indices_forward],
        batch_size=BATCH_SIZE,
        shuffle=True,
    )

    print("\nStarting Adapted TransE Classifier Training...")

    for epoch in range(1, num_epochs + 1):
        total_loss = 0
        correct_predictions = 0
        total_samples = 0

        for head_index, rel_type, tail_index in classifier_loader:
            optimizer.zero_grad()

            # Predict the relation type using the classifier
            logits = model(head_index, tail_index)

            # Use Cross-Entropy Loss for multi-class classification
            loss = F.cross_entropy(logits, rel_type)

            loss.backward()
            optimizer.step()

            total_loss += loss.item() * head_index.size(0)
            total_samples += head_index.size(0)

            # Calculate accuracy
            _, predicted = torch.max(logits, 1)
            correct_predictions += (predicted == rel_type).sum().item()

        avg_loss = total_loss / total_samples
        accuracy = correct_predictions / total_samples

        print(f"Epoch: {epoch:03d}, Loss: {avg_loss:.4f}, Train Accuracy: {accuracy:.4f}")

        if epoch % 5 == 0 or epoch == num_epochs:
            val_accuracy = evaluate_classifier(model, edge_index, edge_type, graph_data['val_idx'], num_base_relations)
            print(f"  >>> Validation Accuracy: {val_accuracy:.4f}")

    # Final Test
    test_accuracy = evaluate_classifier(model, edge_index, edge_type, graph_data['test_idx'], num_base_relations)
    print("\n--- FINAL CLASSIFIER TEST RESULTS ---")
    print(f"Test Accuracy: {test_accuracy:.4f}")
    print("-------------------------------------")


@torch.no_grad()
def evaluate_classifier(model, edge_index, edge_type, eval_idx, num_base_relations):
    model.eval()

    # Select only the forward (non-inverse) relations in the evaluation set
    eval_mask = edge_type < num_base_relations
    # `eval_idx` is now guaranteed to be on the correct device
    eval_indices_forward = eval_idx[torch.isin(eval_idx, torch.where(eval_mask)[0])]

    if eval_indices_forward.numel() == 0:
        return 0.0

    head_index = edge_index[0, eval_indices_forward]
    tail_index = edge_index[1, eval_indices_forward]
    rel_type = edge_type[eval_indices_forward]

    logits = model(head_index, tail_index)
    _, predicted = torch.max(logits, 1)

    correct_predictions = (predicted == rel_type).sum().item()
    total_samples = rel_type.size(0)

    return correct_predictions / total_samples

# --- 3. Main Execution (TransE Pre-training + Classifier Training) ---
# --- Configuration (using paper's optimal values for the classifier part) ---
EMBEDDING_DIM = 200      # Based on best run
CLASSIFIER_LR = 0.001    # Learning rate for Adam
CLASSIFIER_EPOCHS = 10   # Based on paper's finding
PRE_TRAIN_EPOCHS = 50  # Increased to improve link prediction
BATCH_SIZE = 4096         # Based on best run

GRAPH_DATA_PATH = "graph_data.pt"
if not os.path.exists(GRAPH_DATA_PATH):
    print("WARNING: graph_data.pt not found. Running placeholder data preparation now...")
    # This block requires running the data prep cells (1-6) first.
    # Since they were provided in the prompt, we assume they have run or will run.
    raise FileNotFoundError("Please ensure you run all cells prior to Cell 7, including the data preparation.")

graph_data = torch.load(GRAPH_DATA_PATH)

# --- Prepare Tensors and Device ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

num_nodes = graph_data["num_nodes"]
num_relations = graph_data["num_relations"]
base_relations = graph_data["base_relations"]

# Move all relevant tensors to the device
graph_data["edge_index"] = graph_data["edge_index"].to(device)
graph_data["edge_type"] = graph_data["edge_type"].to(device)
graph_data["train_idx"] = graph_data["train_idx"].to(device)
graph_data["val_idx"] = graph_data["val_idx"].to(device)
graph_data["test_idx"] = graph_data["test_idx"].to(device)

edge_index = graph_data["edge_index"]
edge_type = graph_data["edge_type"]

# --- Initialize Model ---
model = AdaptedTransE(num_nodes, num_relations, EMBEDDING_DIM).to(device)

# --- PHASE 1: TransE Pre-training (Optimizing embeddings for link prediction) ---
print("\n--- PHASE 1: TransE Embedding Pre-training ---")
pre_train_optimizer = optim.Adam(model.transe.parameters(), lr=0.01) # Using standard KGE learning rate
transe_loader = model.transe.loader(
    head_index=edge_index[0],
    rel_type=edge_type,
    tail_index=edge_index[1],
    batch_size=BATCH_SIZE,
    shuffle=True,
)

for epoch in tqdm(range(1, PRE_TRAIN_EPOCHS + 1), desc="TransE Pre-training"):
    total_loss = 0
    for head_index, rel_type, tail_index in transe_loader:
        pre_train_optimizer.zero_grad()
        loss = model.transe_loss(head_index, rel_type, tail_index)
        loss.backward()
        pre_train_optimizer.step()
        total_loss += float(loss) * head_index.numel()

    if epoch % 10 == 0:
        print(f"Epoch {epoch:03d} Loss: {total_loss / edge_index.size(1):.4f}")

# --- PHASE 2: Relation Classification (Optimizing classifier on fixed embeddings) ---
print("\n--- PHASE 2: Relation Classification Training ---")

# Fix TransE embeddings and optimize only the classifier weights
for param in model.transe.parameters():
    param.requires_grad = False

classifier_optimizer = optim.Adam(model.classifier.parameters(), lr=CLASSIFIER_LR)

# Train the classifier using the forward-only data subset
train_classifier(model, classifier_optimizer, None, edge_index, edge_type, base_relations, CLASSIFIER_EPOCHS)


Using device: cuda

--- PHASE 1: TransE Embedding Pre-training ---


Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:836.)
  total_loss += float(loss) * head_index.numel()
TransE Pre-training:  20%|██        | 10/50 [01:15<05:00,  7.50s/it]

Epoch 010 Loss: 0.4723


TransE Pre-training:  22%|██▏       | 11/50 [01:26<05:05,  7.83s/it]


KeyboardInterrupt: 

In [None]:
from tqdm import tqdm

def evaluate_link_prediction(model, head_index, rel_type, tail_index, k_values=[1, 3, 10, 50, 100], batch_size=4096, device='cpu'):
    model.eval() # Set model to evaluation mode

    ranks = []
    hits_at_k = {k: 0 for k in k_values}
    num_samples = head_index.size(0)

    # Create a data loader for evaluation to process in batches
    eval_loader = model.transe.loader(
        head_index=head_index,
        rel_type=rel_type,
        tail_index=tail_index,
        batch_size=batch_size,
        shuffle=False,
    )

    with torch.no_grad(): # Disable gradient calculation for inference
        for head, rel, tail in tqdm(eval_loader, desc="Evaluating Link Prediction"):
            head, rel, tail = head.to(device), rel.to(device), tail.to(device)

            # For each (head, relation, tail) triple, we'll evaluate two tasks:
            # 1. Predicting the tail entity given (head, relation)
            # 2. Predicting the head entity given (relation, tail)

            # --- Tail Prediction --- (corrupting tail)
            # Generate corrupted triples: (head, relation, all_possible_tails)
            all_nodes = torch.arange(model.num_nodes, device=device)
            # Expand head and rel for broadcasting against all_nodes
            b_head = head.unsqueeze(1).expand(-1, model.num_nodes)
            b_rel = rel.unsqueeze(1).expand(-1, model.num_nodes)
            b_tail_corrupted = all_nodes.unsqueeze(0).expand(head.size(0), -1)

            # Calculate scores for all corrupted triples
            # The score function is defined by the TransE model itself (distance)
            # Lower distance is better, so we sort in ascending order
            scores_tail_pred = model.transe.forward(b_head, b_rel, b_tail_corrupted).squeeze()

            # Get the score of the true tail for each batch item
            true_tail_scores = scores_tail_pred.gather(1, tail.unsqueeze(1)).squeeze()

            # Calculate rank: count how many corrupted tails have a better (lower) score than the true tail
            # Add 1 to convert from 0-indexed count to 1-indexed rank
            ranks_tail_pred = (scores_tail_pred < true_tail_scores.unsqueeze(1)).sum(dim=1) + 1
            ranks.extend(ranks_tail_pred.cpu().tolist())

            # Calculate Hits@k for tail prediction
            for k in k_values:
                hits_at_k[k] += (ranks_tail_pred <= k).sum().item()

            # --- Head Prediction --- (corrupting head)
            # Generate corrupted triples: (all_possible_heads, relation, tail)
            b_head_corrupted = all_nodes.unsqueeze(0).expand(head.size(0), -1)
            b_rel = rel.unsqueeze(1).expand(-1, model.num_nodes)
            b_tail = tail.unsqueeze(1).expand(-1, model.num_nodes)

            # Calculate scores for all corrupted triples
            scores_head_pred = model.transe.forward(b_head_corrupted, b_rel, b_tail).squeeze()

            # Get the score of the true head for each batch item
            true_head_scores = scores_head_pred.gather(1, head.unsqueeze(1)).squeeze()

            # Calculate rank
            ranks_head_pred = (scores_head_pred < true_head_scores.unsqueeze(1)).sum(dim=1) + 1
            ranks.extend(ranks_head_pred.cpu().tolist())

            # Calculate Hits@k for head prediction
            for k in k_values:
                hits_at_k[k] += (ranks_head_pred <= k).sum().item()

    # Aggregate results
    mean_rank = sum(ranks) / len(ranks)
    for k in k_values:
        hits_at_k[k] /= (num_samples * 2) # Multiply by 2 because we did head and tail prediction for each sample

    results = {"mean_rank": mean_rank}
    results.update({f"hits_at_{k}": hits_at_k[k] for k in k_values})

    return results

print("Defined evaluate_link_prediction function.")

Defined evaluate_link_prediction function.


In [None]:
import torch
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np # For random sampling

@torch.no_grad()
def evaluate_link_prediction(transe_model, edge_index, edge_type, eval_indices, k_values=[1, 3, 10], num_negatives=100, batch_size=4096):
    """
    Evaluates link prediction performance using Hits@k and Mean Rank by manually
    corrupting entities with negative sampling and ranking scores.

    Args:
        transe_model: The TransE model (e.g., model.transe) with trained embeddings.
        edge_index: Tensor of shape [2, num_edges] representing head and tail indices.
        edge_type: Tensor of shape [num_edges] representing relation types.
        eval_indices: Indices for the evaluation set (e.g., test_idx).
        k_values: List of k values for Hits@k.
        num_negatives: Number of negative samples to generate for each positive triple.
        batch_size: Batch size for processing evaluation triples.
    Returns:
        A dictionary containing Hits@k and Mean Rank metrics.
    """
    transe_model.eval()
    device = transe_model.node_emb.weight.device # Get the device from the model's embeddings

    # Extract head, relation, tail for the evaluation triples
    head_indices = edge_index[0, eval_indices]
    rel_types = edge_type[eval_indices]
    tail_indices = edge_index[1, eval_indices]

    ranks = []
    hits_at_k = {k: 0 for k in k_values}

    num_eval_triples = head_indices.size(0)

    # DataLoader for evaluation triples
    eval_dataset = torch.utils.data.TensorDataset(head_indices, rel_types, tail_indices)
    eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=batch_size, shuffle=False)

    print(f"\nPerforming link prediction evaluation on {num_eval_triples} triples with {num_negatives} negative samples per triple...")

    for h_batch, r_batch, t_batch in tqdm(eval_loader, desc="Evaluating Link Prediction"):
        h_batch, r_batch, t_batch = h_batch.to(device), r_batch.to(device), t_batch.to(device)

        batch_size_actual = h_batch.size(0)
        num_nodes = transe_model.num_nodes

        # --- Tail Prediction ---
        # Generate negative tails for each triple in the batch
        # Ensure negative tails are not the true tail for each corresponding batch item
        neg_tails = torch.randint(0, num_nodes, (batch_size_actual, num_negatives), device=device)
        # For simplicity, if a negative happens to be the true tail, we keep it for now.
        # A more rigorous implementation would re-sample until all are distinct from the true tail.

        # Concatenate true tails and negative tails for ranking
        # candidate_tails: [batch_size, 1 + num_negatives]
        candidate_tails = torch.cat([t_batch.unsqueeze(1), neg_tails], dim=1)

        # Expand h_batch and r_batch to match the number of candidates
        # h_expanded: [batch_size, 1 + num_negatives]
        # r_expanded: [batch_size, 1 + num_negatives]
        h_expanded = h_batch.unsqueeze(1).expand(-1, 1 + num_negatives)
        r_expanded = r_batch.unsqueeze(1).expand(-1, 1 + num_negatives)

        # Score all candidate triples (h, r, candidate_t)
        # scores: [batch_size, 1 + num_negatives]
        scores = transe_model.forward(
            h_expanded.reshape(-1),
            r_expanded.reshape(-1),
            candidate_tails.reshape(-1)
        ).reshape(batch_size_actual, 1 + num_negatives)

        # Rank the true tail (which is at index 0 in candidate_tails)
        # Scores are distances, so lower is better. Sort in ascending order.
        sorted_indices = torch.argsort(scores, dim=1) # indices of sorted scores

        # Find rank of the true tail (which was the first candidate)
        # The true rank is the position of index 0 in sorted_indices, plus 1 (1-based rank)
        true_tail_rank = (sorted_indices == 0).nonzero(as_tuple=True)[1] + 1
        ranks.extend(true_tail_rank.cpu().tolist()) # Collect ranks for Mean Rank

        # Update Hits@k
        for k in k_values:
            hits_at_k[k] += (true_tail_rank <= k).sum().item()


        # --- Head Prediction --- (symmetric logic)
        # Generate negative heads for each triple in the batch
        neg_heads = torch.randint(0, num_nodes, (batch_size_actual, num_negatives), device=device)

        # Concatenate true heads and negative heads for ranking
        candidate_heads = torch.cat([h_batch.unsqueeze(1), neg_heads], dim=1)

        # Expand r_batch and t_batch to match the number of candidates
        r_expanded = r_batch.unsqueeze(1).expand(-1, 1 + num_negatives)
        t_expanded = t_batch.unsqueeze(1).expand(-1, 1 + num_negatives)

        # Score all candidate triples (candidate_h, r, t)
        scores = transe_model.forward(
            candidate_heads.reshape(-1),
            r_expanded.reshape(-1),
            t_expanded.reshape(-1)
        ).reshape(batch_size_actual, 1 + num_negatives)

        # Rank the true head (which is at index 0 in candidate_heads)
        sorted_indices = torch.argsort(scores, dim=1)
        true_head_rank = (sorted_indices == 0).nonzero(as_tuple=True)[1] + 1
        ranks.extend(true_head_rank.cpu().tolist()) # Collect ranks for Mean Rank

        # Update Hits@k
        for k in k_values:
            hits_at_k[k] += (true_head_rank <= k).sum().item()

    # Aggregate results
    mean_rank = sum(ranks) / len(ranks) if len(ranks) > 0 else 0.0
    total_evaluated_predictions = num_eval_triples * 2 # Each original sample leads to head and tail prediction

    final_hits = {}
    for k in k_values:
        final_hits[f"Hits@{k}"] = hits_at_k[k] / total_evaluated_predictions

    results = {"Mean Rank": mean_rank}
    results.update(final_hits)

    return results

print("\n--- Evaluating Link Prediction on Test Set ---")
link_pred_metrics = evaluate_link_prediction(
    transe_model=model.transe, # Pass model.transe directly as it's the TransE instance
    edge_index=graph_data["edge_index"],
    edge_type=graph_data["edge_type"],
    eval_indices=graph_data["test_idx"], # Evaluate on the test set
    k_values=[1, 3, 10, 50, 100], # Expanded k values
    num_negatives=100, # Use num_negatives for sampling
    batch_size=BATCH_SIZE # Ensure BATCH_SIZE is accessible (it is in kernel state)
)

print("-------------------------------------------------------")
print("Final Link Prediction Metrics (Test Set):")
for metric, value in link_pred_metrics.items():
    print(f"  {metric}: {value:.4f}")
print("-------------------------------------------------------")



--- Evaluating Link Prediction on Test Set ---

Performing link prediction evaluation on 259860 triples with 100 negative samples per triple...


Evaluating Link Prediction: 100%|██████████| 64/64 [00:06<00:00, 10.63it/s]

-------------------------------------------------------
Final Link Prediction Metrics (Test Set):
  Mean Rank: 86.5505
  Hits@1: 0.0026
  Hits@3: 0.0074
  Hits@10: 0.0236
  Hits@50: 0.1129
  Hits@100: 0.6366
-------------------------------------------------------



