# **Graph Neural Network (GNN) Training for Medical Predictions**

## **Introduction**
This notebook details the end-to-end process of training a **Graph Neural Network (GNN)** for predicting medical conditions. The model learns **patient relationships and medical histories** by leveraging graph-based learning.  

We use **PyTorch Geometric (PyG)** for graph construction and **TF-IDF** for patient condition embeddings. The dataset consists of **FHIR-based structured patient records**, which we preprocess before training the model.

---

## **List of Contents**
This notebook is structured into **8 key sections**, covering data extraction, preprocessing, model training, and deployment.

### [**1. Extracting FHIR JSON Data (FHIR Extraction)**](#1-extracting-fhir-json-data-fhir-extraction)
   **Keywords**: `load_json`, `extract_patients`, `extract_family_member_history`, `extract_related_person`, `extract_conditions`
   - Extract **patient details, family medical history, relationships, and conditions** from FHIR JSON files.
   - Convert the raw JSON data into **structured Pandas DataFrames** for further processing.

### [**2. Preprocessing and Training DataFrame Creation**](#2-preprocessing-and-training-dataframe-creation)
   **Keywords**: `df_training`, `df_patient`, `df_fmh`, `df_rp`, `df_condition`
   - Merge **patient data** with medical conditions.
   - Encode **family member relationships** as features.
   - Define **multi-label classification labels** from patient conditions.
   - Prevent **data leakage** by removing disease names from condition descriptions.

### [**3. Constructing Graph Data (Graph Preparation)**](#3-constructing-graph-data-graph-preparation)
   **Keywords**: `patient_id_map`, `edge_index`, `torch.tensor`, `Data(x=node_features, edge_index=edge_index)`
   - Represent **patients as nodes** in the GNN.
   - Use **TF-IDF** to generate **node features** from medical condition descriptions.
   - Construct **edges** based on patient relationships from **RelatedPerson** data.

### [**4. Train-Test Splitting for Model Training (Iterative Train-Test Split)**](#4-train-test-splitting-for-model-training-iterative-train-test-split)
   **Keywords**: `iterative_train_test_split`, `X_train`, `Y_train`, `X_val`, `Y_val`
   - Use **iterative train-test split** to ensure balanced label distribution.
   - Split the data into **training (80%)** and **validation (20%)** sets.
   - Convert features and labels into **PyTorch tensors** for compatibility with GNN models.

### [**5. Graph Construction for Training and Validation**](#5-graph-construction-for-training-and-validation)
   **Keywords**: `train_graph`, `val_graph`, `graph.clone()`
   - Create **two separate graphs**:
     - **Training Graph (`train_graph`)** for model training.
     - **Validation Graph (`val_graph`)** for model evaluation.
   - Maintain the **same edge connections** but update **node features** according to the data split.

### [**6. GNN Model Definition**](#6-gnn-model-definition)
   **Keywords**: `class GNNModel`, `GCNConv`, `F.relu`
   - Define a **Graph Neural Network (GNN)** model using `torch_geometric.nn.GCNConv`.
   - Implement **two graph convolution layers**:
     - First layer transforms input features (`in_channels → hidden_channels`).
     - Second layer maps hidden representations to the target labels (`hidden_channels → out_channels`).
   - Apply **ReLU activation** for non-linearity.

### [**7. Model Training with Validation**](#7-model-training-with-validation)
   **Keywords**: `loss_fn`, `optimizer`, `loss_train`, `loss_val`
   - Train the model using:
     - **Adam optimizer** for weight updates.
     - **Binary Cross-Entropy Loss (`BCEWithLogitsLoss`)** for multi-label classification.
   - Perform **forward and backward propagation** on `train_graph`.
   - Evaluate model performance on `val_graph` after each epoch.
   - Print **training loss and validation loss** every 10 epochs.

### [**8. Model and Vectorizer Saving**](#8-model-and-vectorizer-saving)
   **Keywords**: `torch.save`, `pickle.dump`
   - Save the trained **GNN model weights** (`gnn_model_weights.pt`).
   - Save the **TF-IDF vectorizer** (`tfidf.pkl`) for use in inference.

---

## **Installation Requirements**
To run this notebook, install the required dependencies using:

```bash
pip install torch torch-geometric scikit-learn fastapi hypercorn numpy pandas


# 1. Extracting FHIR JSON Data (FHIR Extraction)

In [None]:
import os
import json
import pandas as pd

# ========================
# Helper Functions
# ========================

def extract_id(ref):
    """Extracts the ID from a reference string."""
    if ref.startswith("Patient/"):
        return ref.split("/")[-1]
    elif ref.startswith("RelatedPerson/"):
        return ref.split("/")[-1]
    elif ref.startswith("urn:uuid:"):
        return ref.split(":")[-1]
    else:
        return ref

def load_json(filepath):
    """Loads a JSON file and returns its content."""
    with open(filepath, "r", encoding="utf-8") as f:
        return json.load(f)

def save_as_csv(data, output_filepath):
    """Saves data as a CSV file."""
    df = pd.DataFrame(data)
    df.to_csv(output_filepath, index=False)
    print(f"CSV saved at {output_filepath}")

def save_as_json(data, output_filepath):
    """Saves data as a JSON file."""
    with open(output_filepath, "w", encoding="utf-8") as f:
        json.dump(data, f, indent=4)
    print(f"JSON saved at {output_filepath}")

# ========================
# Extraction Functions for Each Resource
# ========================

def extract_patients(data):
    """Extracts patient information including ID, birth date, gender, and name."""
    extracted = []
    for p in data:
        extracted.append({
            "patient_id": p.get("id", ""),
            "birthDate": p.get("birthDate", ""),
            "gender": p.get("gender", ""),
            "name": p.get("name", [{"family": "Unknown"}])[0].get("family", "Unknown")
        })
    return extracted

def extract_family_member_history(data):
    """Extracts family medical history, including patient ID, relationship type, and conditions."""
    extracted = []
    for rec in data:
        patient_ref = rec.get("patient", {}).get("reference", "")
        patient_id = extract_id(patient_ref)

        # Extract relationship type if available
        relationship = ""
        if "relationship" in rec and "coding" in rec["relationship"] and rec["relationship"]["coding"]:
            relationship = rec["relationship"]["coding"][0].get("display", "")

        # Extract medical conditions if available
        conditions = []
        if "condition" in rec:
            for cond in rec["condition"]:
                cond_text = cond.get("code", {}).get("text", "")
                if cond_text:
                    conditions.append(cond_text)

        extracted.append({
            "family_member_history_id": rec.get("id", ""),
            "patient_id": patient_id,
            "relationship": relationship,
            "conditions": "; ".join(conditions)  # Concatenates conditions if multiple exist
        })
    return extracted

def extract_related_person(data):
    """Extracts related person details including ID, relationship, name, gender, and birth date."""
    extracted = []
    for rec in data:
        patient_ref = rec.get("patient", {}).get("reference", "")
        patient_id = extract_id(patient_ref)

        # Extract and concatenate relationship descriptions if multiple exist
        relationship = ", ".join([r.get("text", "") for r in rec.get("relationship", [])])

        # Extract related person's name if available
        rp_name = ""
        if "name" in rec and isinstance(rec["name"], list) and len(rec["name"]) > 0:
            rp_name = rec["name"][0].get("family", "")

        extracted.append({
            "related_person_id": extract_id(rec.get("id", "")),
            "patient_id": patient_id,
            "relationship": relationship,
            "rp_name": rp_name,
            "gender": rec.get("gender", ""),
            "birthDate": rec.get("birthDate", "")
        })
    return extracted

def extract_conditions(data):
    """Extracts medical conditions associated with patients."""
    extracted = []
    for rec in data:
        subject_ref = rec.get("subject", {}).get("reference", "")
        patient_id = extract_id(subject_ref)
        disease = rec.get("code", {}).get("text", "")

        extracted.append({
            "condition_id": rec.get("id", ""),
            "patient_id": patient_id,
            "disease": disease
        })
    return extracted


In [None]:
import os
import json
import pandas as pd

# Folder containing processed FHIR data
processed_folder = "synthea/output/processed/"

# Resources considered useful for predictive modeling
resources = ["Patient", "FamilyMemberHistory", "RelatedPerson", "Condition"]

# Mapping of resources to corresponding extraction functions
extraction_functions = {
    "Patient": extract_patients,
    "FamilyMemberHistory": extract_family_member_history,
    "RelatedPerson": extract_related_person,
    "Condition": extract_conditions
}

# Dictionary to store extracted DataFrames
dataframes = {}

for res in resources:
    filepath = os.path.join(processed_folder, f"{res}.json")
    
    if os.path.exists(filepath):
        print(f"Processing {res} from {filepath}")
        data = load_json(filepath)

        # Convert dictionary to list if necessary
        if isinstance(data, dict):
            data = [data]

        extracted_data = extraction_functions[res](data)
        df = pd.DataFrame(extracted_data)
        dataframes[res] = df
        print(f"DataFrame for {res} created with shape: {df.shape}")
    
    else:
        print(f"File {res}.json not found in {processed_folder}")

# Display sample rows from each DataFrame
for res, df in dataframes.items():
    print(f"\nSample DataFrame for {res}:")
    display(df.head())


📂 Memproses Patient dari synthea/output/processed/Patient.json
✅ Dataframe Patient memiliki shape: (108, 4)
📂 Memproses FamilyMemberHistory dari synthea/output/processed/FamilyMemberHistory.json
✅ Dataframe FamilyMemberHistory memiliki shape: (193, 4)
📂 Memproses RelatedPerson dari synthea/output/processed/RelatedPerson.json
✅ Dataframe RelatedPerson memiliki shape: (261, 6)
📂 Memproses Condition dari synthea/output/processed/Condition.json
✅ Dataframe Condition memiliki shape: (4093, 3)

Contoh DataFrame untuk Patient:


Unnamed: 0,patient_id,birthDate,gender,name
0,7da148be-b73e-73e3-ed5c-67d7c712a253,2010-05-07,female,Runolfsdottir785
1,d4f1d88b-aecc-493e-2977-44a72e0de2d9,2002-11-28,female,Jerde200
2,9f7675c1-1f29-10ac-92e5-8aaf367f05c3,2007-06-07,female,Sanford861
3,839e461d-9a4d-a110-1fe9-97bd16378bfd,2008-05-28,male,Ruecker817
4,7e101445-eafd-cd17-0e6b-57f85baa3f44,1985-10-07,female,Kerluke267



Contoh DataFrame untuk FamilyMemberHistory:


Unnamed: 0,family_member_history_id,patient_id,relationship,conditions
0,family-3a644dcd-672c-9579-cdeb-65ce6783da97-,7da148be-b73e-73e3-ed5c-67d7c712a253,Father,Asthma
1,family-8463087b-be64-1139-b779-97d09881e034-,7da148be-b73e-73e3-ed5c-67d7c712a253,Sister,Hypertension; Heart Disease
2,family-00a4d481-551d-9741-dd8f-fa88fe29ab79-,d4f1d88b-aecc-493e-2977-44a72e0de2d9,Father,Hypertension
3,family-8c97920a-fc41-8150-f54e-9dcfc1f48fef-,d4f1d88b-aecc-493e-2977-44a72e0de2d9,Mother,Diabetes; Hypertension
4,family-2b27a9c6-3b32-83fe-c4eb-ff271de3536b-,9f7675c1-1f29-10ac-92e5-8aaf367f05c3,Father,Cancer



Contoh DataFrame untuk RelatedPerson:


Unnamed: 0,related_person_id,patient_id,relationship,rp_name,gender,birthDate
0,3a644dcd-672c-9579-cdeb-65ce6783da97,7da148be-b73e-73e3-ed5c-67d7c712a253,Father,Barton704,female,1975-09-20
1,67ed8fab-19a2-40c5-e56c-3dfdab2c9805,7da148be-b73e-73e3-ed5c-67d7c712a253,Mother,Schowalter414,male,1989-02-23
2,8463087b-be64-1139-b779-97d09881e034,7da148be-b73e-73e3-ed5c-67d7c712a253,Sister,Boyle917,male,1965-09-03
3,00a4d481-551d-9741-dd8f-fa88fe29ab79,d4f1d88b-aecc-493e-2977-44a72e0de2d9,Father,Bernhard322,female,1968-08-26
4,8c97920a-fc41-8150-f54e-9dcfc1f48fef,d4f1d88b-aecc-493e-2977-44a72e0de2d9,Mother,Jerde200,female,2003-10-06



Contoh DataFrame untuk Condition:


Unnamed: 0,condition_id,patient_id,disease
0,ded1426d-62e2-77ad-0c8b-5b34075c89a9,7da148be-b73e-73e3-ed5c-67d7c712a253,Medication review due (situation)
1,40f951b4-d966-312a-c6b2-2b9b89ca5f30,7da148be-b73e-73e3-ed5c-67d7c712a253,Medication review due (situation)
2,aa6822b6-d4e2-bbe9-d4fa-574aac5c27ca,7da148be-b73e-73e3-ed5c-67d7c712a253,Gingivitis (disorder)
3,35aac9a9-f1fe-8563-79d0-d279208f9098,7da148be-b73e-73e3-ed5c-67d7c712a253,Medication review due (situation)
4,63b4c19c-4488-56fe-e4eb-f5dd262aa4b2,7da148be-b73e-73e3-ed5c-67d7c712a253,Medication review due (situation)


# 2. Preprocessing and Training DataFrame Creation

In [None]:
import pandas as pd
import numpy as np
import re
import torch
import random
from sklearn.feature_extraction.text import TfidfVectorizer
from torch_geometric.data import Data

# === 1. Load Data ===
df_patient = dataframes["Patient"]
df_fmh = dataframes["FamilyMemberHistory"]
df_rp = dataframes["RelatedPerson"]
df_condition = dataframes["Condition"]

# === 2. Group Condition by patient_id and concatenate disease names ===
df_condition_grouped = (
    df_condition.groupby("patient_id")["disease"]
    .apply(lambda x: " ".join(x))
    .reset_index()
    .rename(columns={"disease": "patient_conditions_text"})
)

# Merge df_patient with patient conditions
df_training = pd.merge(df_patient, df_condition_grouped, on="patient_id", how="left").fillna("")

# === 3. Process FamilyMemberHistory & RelatedPerson as additional features ===

# Process FamilyMemberHistory
df_fmh_grouped = (
    df_fmh.groupby(["patient_id", "relationship"])["conditions"]
    .apply(lambda x: "; ".join(x.dropna().unique()))
    .reset_index()
)

df_fmh_pivot = df_fmh_grouped.pivot(index="patient_id", columns="relationship", values="conditions").reset_index()
df_fmh_pivot = df_fmh_pivot.rename(columns=lambda x: x.lower() + "_condition" if x != "patient_id" else x)
df_training = pd.merge(df_training, df_fmh_pivot, on="patient_id", how="left").fillna("")

# Process RelatedPerson
df_rp_condition = pd.merge(df_rp, df_condition_grouped, left_on="related_person_id", right_on="patient_id", how="left")
df_rp_condition.drop(columns=["patient_id_y"], inplace=True)

df_rp_grouped = (
    df_rp_condition.groupby(["patient_id_x", "relationship"])["patient_conditions_text"]
    .apply(lambda x: "; ".join(x.dropna().unique()))
    .reset_index()
)

df_rp_pivot = df_rp_grouped.pivot(index="patient_id_x", columns="relationship", values="patient_conditions_text").reset_index()
df_rp_pivot = df_rp_pivot.rename(columns=lambda x: x.lower() + "_related_condition" if x != "patient_id_x" else x)
df_training = pd.merge(df_training, df_rp_pivot, left_on="patient_id", right_on="patient_id_x", how="left").fillna("")

# === 4. Create Multi-Label Targets for Each Patient ===
target_diseases = ["Diabetes", "Hypertension", "Cancer", "Heart Disease", "Alzheimer", "Asthma"]

for disease in target_diseases:
    df_training[disease] = df_training["patient_conditions_text"].apply(lambda x: 1 if disease.lower() in str(x).lower() else 0)

# === 5. Remove Disease Names from Text Features to Prevent Data Leakage ===
def remove_target_diseases_partial(text, target_diseases, remove_ratio=0.5):
    """
    Removes disease names from text for a subset of patients (default 50%) 
    to simulate real-world scenarios where information may be incomplete.
    """
    if pd.isna(text):
        return ""

    # Randomly determine if diseases should be removed
    if random.random() > remove_ratio:  # Keep diseases for 50% of cases
        return text

    text = text.lower()
    for disease in target_diseases:
        pattern = r"\b\w*" + re.escape(disease.lower()) + r"\w*\b"
        text = re.sub(pattern, "", text).strip()

    text = re.sub(r"\s+", " ", text)
    return text

# Apply the disease removal function to patient conditions
df_training["patient_conditions_text_cleaned"] = df_training["patient_conditions_text"].apply(
    lambda x: remove_target_diseases_partial(x, target_diseases, remove_ratio=0.5)
)

# === 6. Create Node Representations with TF-IDF ===
text_columns = [
    "patient_conditions_text_cleaned",
]

# Concatenate all text features
df_training["all_text"] = df_training[text_columns].apply(lambda x: " ".join(x.dropna()), axis=1)

# 3. Constructing Graph Data (Graph Preparation)

In [None]:

# Apply TF-IDF Vectorization
tfidf = TfidfVectorizer(max_features=1000)
X_tfidf = tfidf.fit_transform(df_training["all_text"])

# Map patient_id to a unique index
patient_ids = df_training["patient_id"].tolist()
patient_id_map = {pid: i for i, pid in enumerate(patient_ids)}

# === 7. Construct Graph Edges from RelatedPerson Relationships ===
edge_index = []
for _, row in df_rp.iterrows():
    if row["patient_id"] in patient_id_map and row["related_person_id"] in patient_id_map:
        edge_index.append([patient_id_map[row["patient_id"]], patient_id_map[row["related_person_id"]]])

edge_index = torch.tensor(edge_index, dtype=torch.long).T  # Convert to PyTorch Geometric format

# === 8. Create PyTorch Geometric Graph ===
node_features = torch.tensor(X_tfidf.toarray(), dtype=torch.float32)
graph = Data(x=node_features, edge_index=edge_index)

# === 9. Convert Labels to Tensor ===
labels_tensor = torch.tensor(df_training[target_diseases].values, dtype=torch.float32)

print("Graph successfully created.")
print(graph)


✅ Graph PyG Sukses Dibuat!
Data(x=[108, 304], edge_index=[2, 261])


# 4. Train-Test Splitting for Model Training (Iterative Train-Test Split)

In [None]:
# === Iterative Train-Test Split (Train-Validation) ===
X = node_features.numpy()  # Konversi node features ke NumPy Array
Y = labels_tensor.numpy()  # Konversi label ke NumPy Array

# 80% Train, 20% Validation menggunakan Iterative Stratification
X_train, Y_train, X_val, Y_val = iterative_train_test_split(X, Y, test_size=0.2)

# Konversi kembali ke Tensor
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
Y_train_tensor = torch.tensor(Y_train, dtype=torch.float32)
X_val_tensor = torch.tensor(X_val, dtype=torch.float32)
Y_val_tensor = torch.tensor(Y_val, dtype=torch.float32)

In [None]:
# === Buat Graph untuk Train & Validation ===
train_graph = graph.clone()
train_graph.x = X_train_tensor
train_graph.edge_index = graph.edge_index

val_graph = graph.clone()
val_graph.x = X_val_tensor
val_graph.edge_index = graph.edge_index

# 6. GNN Model Definition

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from skmultilearn.model_selection import iterative_train_test_split
import numpy as np

# === Define the GNN Model ===
class GNNModel(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        """
        Initializes a Graph Convolutional Network (GCN) model.

        Args:
            input_dim (int): Number of input features per node.
            hidden_dim (int): Number of hidden layer neurons.
            output_dim (int): Number of output classes (multi-label classification).
        """
        super(GNNModel, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, data):
        """
        Forward pass through the GNN model.

        Args:
            data (torch_geometric.data.Data): Graph data containing node features and edge indices.

        Returns:
            torch.Tensor: The output logits for each node.
        """
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)  # Activation function
        x = self.conv2(x, edge_index)
        return x


In [None]:
# === Initialize the GNN Model ===
hidden_dim = 64  # Number of hidden layer neurons
output_dim = len(target_diseases)  # Number of output labels for multi-label classification

# Initialize the model with input features, hidden layer, and output labels
model = GNNModel(input_dim=node_features.shape[1], hidden_dim=hidden_dim, output_dim=output_dim)

# Define the optimizer (Adam optimizer with a learning rate of 0.01)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Define the loss function (Binary Cross-Entropy with Logits for multi-label classification)
loss_fn = torch.nn.BCEWithLogitsLoss()


# 7. Model Training with Validation

In [None]:
# === Train the GNN Model with Validation ===
num_epochs = 75  # Number of training epochs

for epoch in range(num_epochs):
    # Set the model to training mode
    model.train()
    
    # Forward pass on training data
    logits_train = model(train_graph)  
    loss_train = loss_fn(logits_train, Y_train_tensor)  # Compute training loss
    
    # Backpropagation and optimization step
    optimizer.zero_grad()
    loss_train.backward()
    optimizer.step()
    
    # Validation step
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Disable gradient computation for validation
        logits_val = model(val_graph)
        loss_val = loss_fn(logits_val, Y_val_tensor)  # Compute validation loss

    # Print training and validation loss every 10 epochs
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Train Loss: {loss_train.item():.4f}, Val Loss: {loss_val.item():.4f}")

print("GNN model training completed with validation set!")


Epoch 0, Loss: 0.7068
Epoch 10, Loss: 0.5781
Epoch 20, Loss: 0.5556
Epoch 30, Loss: 0.5335
Epoch 40, Loss: 0.5067
Epoch 50, Loss: 0.4792
Epoch 60, Loss: 0.4518
Epoch 70, Loss: 0.4245
✅ Model GNN selesai dilatih!


# 8. Model and Vectorizer Saving

In [None]:
import torch
import pickle
from sklearn.feature_extraction.text import TfidfVectorizer

# Save the trained GNN model weights
torch.save(model.state_dict(), "gnn_model_weights.pt")
print("GNN model weights saved to 'gnn_model_weights.pt'.")

# Save the TF-IDF vectorizer
with open("tfidf.pkl", "wb") as f:
    pickle.dump(tfidf, f)

print("TF-IDF vectorizer saved to 'tfidf.pkl'.")


✅ Bobot model GNN berhasil disimpan ke 'gnn_model_weights.pt'!
