In [None]:
!pip install torchdata



In [None]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import torch.nn.functional as F

print(torch.__version__)
print(torch.version.cuda)

2.9.0+cu126
12.6


In [None]:
!pip install -q torch_geometric

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m15.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
def simple_aggregate_visits(visit_features, output_dim):
    """
    Implements the Simple Aggregation (Baseline 1) method (Mean/Flattening).
    Averages all features across the time dimension.
    visit_features shape: (sequence_length, input_dim)
    """
    if visit_features.size(0) == 0:
        return torch.zeros(output_dim)

    # 1. Take the mean across the time dimension (axis 0)
    mean_features = torch.mean(visit_features, dim=0)

    # 2. Project to the target output_dim (Requires a trained linear layer)
    # --- DUMMY PROJECTION ---
    # In a real setup, this linear layer (L1) would be trained with the GNN.
    DUMMY_PROJ_LAYER = nn.Linear(mean_features.size(-1), output_dim)
    return F.relu(DUMMY_PROJ_LAYER(mean_features))

class Chomp1d(nn.Module):
    """
    Custom module to remove the right-side padding artifacts from Conv1d output.
    This replaces the problematic lambda function.
    """
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        # x shape: (batch_size, channels, seq_len)
        return x[:, :, :-self.chomp_size].contiguous()


class TemporalBlock(nn.Module):
    """A basic residual block for TCNs."""
    def __init__(self, in_channels, out_channels, kernel_size, stride, dilation, dropout):
        super(TemporalBlock, self).__init__()
        padding = (kernel_size - 1) * dilation

        # 1D Convolution for time series data
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size,
                               stride=stride, padding=padding, dilation=dilation)
        self.chomp1 = Chomp1d(padding) # Remove padding artifacts
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size,
                               stride=stride, padding=padding, dilation=dilation)
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
                                 self.conv2, self.chomp2, self.relu2, self.dropout2)

        # Residual connection: projects input dimension if necessary
        self.downsample = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else None
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(self.net(x) + res)


class TCNEncoder(nn.Module):
    """
    Encodes a sequence of patient visits using TCN (Method 3).
    """
    def __init__(self, input_dim, hidden_channels, output_dim, num_layers, kernel_size, dropout):
        super(TCNEncoder, self).__init__()

        # TCN layers
        layers = []
        in_channels = input_dim
        for i in range(num_layers):
            # Dilated convolutions to capture wider context
            dilation_size = 2 ** i
            out_channels = hidden_channels
            layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size, dropout=dropout)]
            in_channels = out_channels

        self.tcn = nn.Sequential(*layers)

        # Final projection layer to match the required h_patient^0 size
        self.fc = nn.Linear(hidden_channels, output_dim)
        self.relu = nn.ReLU()

    def forward(self, visit_sequences):
        # TCN expects shape: (batch_size, channels, sequence_length)
        # Input sequence is (batch_size, sequence_length, input_dim)
        x = visit_sequences.transpose(1, 2)

        # TCN output shape: (batch_size, channels, new_sequence_length)
        output = self.tcn(x)

        # For a prediction, we use the features of the last "time step"
        # The last time step is the last column in the output tensor
        h_T = output[:, :, -1]

        # Final projection
        z = self.relu(self.fc(h_T))
        return z
# ==============================================================================
# COMMON DATA TRANSFORMATION UTILITY
# ==============================================================================

def transform_raw_to_features(group: pd.DataFrame, concept_map: dict, value_dim: int):
    """
    Converts a patient's raw records into a sequence of feature vectors (visits).
    This function implements Step 1.2's concept: map TestName/Value to a fixed-size vector.

    Args:
        group: All records for a single patient, sorted by time.
        concept_map: Map of all unique TestNames to an index.
        value_dim: The feature dimension to be used for continuous values (e.g., 1 for normalized value).

    Returns:
        torch.Tensor: A sequence of feature vectors (sequence_length, total_input_dim)
    """
    # Group records by ReportDate (A single 'visit' is all tests taken at the exact same time)
    visits = group.groupby('ReportDate').agg(list)

    num_concepts = len(concept_map)
    feature_list = []

    for _, visit in visits.iterrows():
        # Initialize a feature vector for this visit (num_concepts * value_dim)
        # Using value_dim=1 (normalized value) and assuming a dense matrix where
        # all possible concepts are represented.
        visit_vector = torch.zeros(num_concepts * value_dim)

        test_names = visit['TestName']
        test_values = visit['TestValue']

        for name, value in zip(test_names, test_values):
            if name in concept_map:
                idx = concept_map[name]
                # DUMMY: Apply a simple normalization for the value
                normalized_value = float(value) / 100.0 # Placeholder normalization

                # Insert the normalized value into the vector at the concept's index
                visit_vector[idx] = normalized_value

        feature_list.append(visit_vector)

    if not feature_list:
        return torch.empty(0, num_concepts * value_dim)

    return torch.stack(feature_list)



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData
import pandas as pd
import numpy as np
from typing import Dict

print(torch.__version__)
def construct_patient_symptom_bipartite_graph(
    patient_ids: list,
    patient_features: Dict[int, torch.Tensor],
    patient_idx_map: dict,
    concept_map: dict,
    raw_data: pd.DataFrame
) -> HeteroData:
    """
    Constructs the Patient <-> Symptom Bipartite Graph (P Graph core) using
    pre-computed patient embeddings.

    Args:
        patient_ids: List of all valid patient IDs.
        patient_features: Dictionary of {PatientID: h_patient^0 tensor}.
        patient_idx_map: Map from PatientID to node index (0 to N-1).
        concept_map: Map from TestName (Symptom) to node index (0 to M-1).
        raw_data: DataFrame of patient EMR records.

    Returns:
        HeteroData: Graph containing Patient and Symptom nodes, and P->S edges.
    """
    data = HeteroData()

    # --- Node Initialization ---
    num_patients = len(patient_ids)
    num_symptoms = len(concept_map)

    # Check for valid features
    if not patient_ids:
        print("[ERROR] No valid patients found. Cannot build graph.")
        return data

    gnn_embedding_dim = patient_features[patient_ids[0]].shape[0]

    # 1. Patient Node Features (h_patient^0)
    data['patient'].x = torch.stack([patient_features[id] for id in patient_ids])
    data['patient'].node_map = patient_idx_map

    # 2. Symptom Node Features (random initialization for shared embedding space)
    data['symptom'].x = torch.rand(num_symptoms, gnn_embedding_dim)
    data['symptom'].node_map = concept_map

    # --- Edge Construction (Patient -> Symptom) ---
    p_to_s_edges = []

    filtered_df = raw_data[raw_data['PatientID'].isin(patient_features.keys())]

    # Iterate over raw records to build connections
    for _, row in raw_data.iterrows():
        p_id = row['PatientID']
        s_name = row['TestName']

        # Ensure ID/Name exists in our processed maps
        if p_id in patient_idx_map and s_name in concept_map:
            p_idx = patient_idx_map[p_id]
            s_idx = concept_map[s_name]
            p_to_s_edges.append((p_idx, s_idx))

    if p_to_s_edges:
        src, dst = zip(*p_to_s_edges)
        # Edge indices are stored as (2, num_edges) tensor
        data['patient', 'has', 'symptom'].edge_index = torch.tensor([src, dst], dtype=torch.long)
        data['symptom', 'is_related_to', 'patient'].edge_index = torch.tensor([dst, src], dtype=torch.long)
        print(f"[SUCCESS] Built {len(p_to_s_edges)} P->S edges.")
    else:
        data['patient', 'has', 'symptom'].edge_index = torch.empty((2, 0), dtype=torch.long)
        data['symptom', 'is_related_to', 'patient'].edge_index = torch.empty((2, 0), dtype=torch.long)
        print("[WARN] No P->S edges created.")

    return data


def construct_symptom_organ_graph(
    concept_map: dict,
    organ_map: dict,
    raw_data: pd.DataFrame
) -> HeteroData:
    """Creates the Symptom -> Organ bipartite graph."""
    data = HeteroData()

    # Placeholder node initialization (features will be added during final integration)
    data['symptom'].num_nodes = len(concept_map)
    data['organ'].num_nodes = len(organ_map)

    s_to_o_edges = []
    for _,row in raw_data.iterrows():
        s_name = row['TestName']
        o_name = row['Target_Organ']
        if s_name in concept_map and o_name in organ_map:
            s_idx = concept_map[s_name]
            o_idx = organ_map[o_name]
            s_to_o_edges.append((s_idx, o_idx))

    if s_to_o_edges:
        src, dst = zip(*s_to_o_edges)
        data['symptom', 'measures', 'organ'].edge_index = torch.tensor([src, dst], dtype=torch.long)
        print(f"[SUCCESS] Symptom -> Organ Graph built with {len(s_to_o_edges)} edges.")

    return data




def construct_disease_organ_graph(
    disease_map: dict,
    organ_map: dict,
    raw_data: pd.DataFrame
) -> HeteroData:
    """Creates the disease -> Organ bipartite graph."""
    data = HeteroData()

    # Placeholder node initialization (features will be added during final integration)
    data['disease'].num_nodes = len(disease_map)
    data['organ'].num_nodes = len(organ_map)

    o_to_d_edges = []
    for _,row in raw_data.iterrows():
        d_name = row['Most_Relevant_Disease']
        o_name = row['Target_Organ']
        if d_name in disease_map and o_name in organ_map:
            s_idx = disease_map[d_name]
            o_idx = organ_map[o_name]
            o_to_d_edges.append((o_idx, s_idx))

    if o_to_d_edges:
        src, dst = zip(*o_to_d_edges)
        data['organ', 'is affected', 'disease'].edge_index = torch.tensor([src, dst], dtype=torch.long)
        print(f"[SUCCESS] Disease -> Organ Graph built with {len(o_to_d_edges)} edges.")

    return data


# ==============================================================================
# 2. CORE FUNCTION: Disease & Organ Graph Builder (The requested logic)
# ==============================================================================
from typing import Dict, List, Tuple

class PatientData:
    """A container for pre-processed features and maps."""
    def __init__(self, patient_features_dict,final_patient_ids, final_patient_idx_map , concept_map, disease_map, organ_map,ext_ps, ext_so, ext_do,patient_labels):
        self.patient_features = patient_features_dict
        self.final_patient_ids = final_patient_ids
        self.final_patient_idx_map = final_patient_idx_map
        self.concept_map = concept_map
        self.disease_map = disease_map
        self.organ_map = organ_map
        self.patient_symptom_edges = ext_ps
        self.symptom_organ_edges = ext_so
        self.organ_disease_edges = ext_do
        #self.raw_data = raw_data
        self.patient_labels = patient_labels

        self.X_patients = torch.stack([self.patient_features[id] for id in self.final_patient_ids])
        self.gnn_embedding_dim = self.X_patients.shape[1]
        self.num_diseases = len(disease_map)
        self.num_symptoms = len(concept_map)
        self.num_organs = len(organ_map)
        self.num_patient_ids = len(final_patient_ids)


def construct_final_unified_graph(
    ps_graph: HeteroData,
    so_graph: HeteroData,
    od_graph: HeteroData,
    data_obj: PatientData
) -> HeteroData:
    """
    Merges all four bipartite graphs (P-S, S-O, D-O, D-S) into a single
    HeteroData object and finalizes node features and labels.
    """
    final_data = HeteroData()
    emb_dim = data_obj.gnn_embedding_dim

    # --- 2.1. Node Initialization and Feature Transfer (Merging P, S, D, O) ---

    # P Nodes (Features from TCN/Aggregate)
    final_data['patient'].x = ps_graph['patient'].x

    # S Nodes (Shared: Features initialized in P-S graph)
    final_data['symptom'].x = ps_graph['symptom'].x

    # D Nodes (Initialized randomly)
    final_data['disease'].x = torch.rand(data_obj.num_diseases, emb_dim)

    # O Nodes (Initialized randomly)
    final_data['organ'].x = torch.rand(data_obj.num_organs, emb_dim)

    # --- 2.2. Edge Transfer (Merging Graph Components) ---

    # 1. P -> S (EMR Data)
    final_data['patient', 'has', 'symptom'].edge_index = ps_graph['patient', 'has', 'symptom'].edge_index
    final_data['symptom', 'is_related_to', 'patient'].edge_index = ps_graph['symptom', 'is_related_to', 'patient'].edge_index

    # 2. S -> O (Knowledge: Symptom measures Organ)
    final_data['symptom', 'measures', 'organ'].edge_index = so_graph['symptom', 'measures', 'organ'].edge_index

    # 3. D -> O (Knowledge: Disease affects Organ)
    final_data['organ', 'is affected', 'disease'].edge_index = od_graph['organ', 'is affected', 'disease'].edge_index

    # 4. D -> S (Knowledge: Disease is linked to Symptom)
    #final_data['disease', 'linked_to', 'symptom'].edge_index = ds_graph['disease', 'linked_to', 'symptom'].edge_index

    # --- 2.3. Final Labels and Metadata ---
    final_data['patient'].y = torch.stack([data_obj.patient_labels[id] for id in data_obj.final_patient_ids])
    final_data['disease'].y_identity = torch.eye(data_obj.num_diseases)

    # --- 2.4. Debugging and Verification ---

    print("\n--- DEBUG: Graph Integration Verification ---")

    # D.1. Check Node Counts
    nodes_check = [
        ('patient', data_obj.patient_features), ('symptom', data_obj.concept_map),
        ('disease', data_obj.disease_map), ('organ', data_obj.organ_map)
    ]
    for name, source in nodes_check:
        expected = len(source)
        actual = final_data[name].x.shape[0]
        status = "[OK]" if actual == expected else f"[FAIL] Expected {expected}"
        print(f"  Node Type {name.upper()}: {actual} {status}")

    # D.2. Check Edge Counts
    edges_check = [
        (('patient', 'has', 'symptom'), len(data_obj.patient_symptom_edges)),
        (('symptom', 'is_related_to', 'patient'), len(data_obj.patient_symptom_edges)),
        (('symptom', 'measures', 'organ'), len(data_obj.symptom_organ_edges)),
        (('organ', 'is affected', 'disease'), len(data_obj.organ_disease_edges)),
        #(('disease', 'linked_to', 'symptom'), len(data_obj.external_knowledge_ds))
    ]
    total_edges = 0
    for edge_type, expected_count in edges_check:
        actual_count = final_data[edge_type].edge_index.shape[1]
        total_edges += actual_count
        status = "[OK]" if actual_count == expected_count else f"[FAIL] Expected {expected_count}"
        print(f"  Edge Type {edge_type[1].upper()}: {actual_count} {status}")

    print(f"  TOTAL Edges Unified: {total_edges}")
    print("---------------------------------------")

    return final_data

2.9.0+cu126


In [None]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m19.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.7.0


In [None]:
import matplotlib.pyplot as plt

def identify_patient_symptom_relations(df: pd.DataFrame):
    """
    Identifies and counts the frequency of the Patient <-> Symptom relationship
    in the raw EMR data. This forms the base of the Patient Record Graph (P).
    """

    # 1. Strip column names again for safety
    df.columns = df.columns.str.strip()

    # 2. Extract unique Patient and Symptom IDs
    unique_patients = df['PatientID'].unique()
    unique_symptoms = df['TestName'].unique()

    print("\n--- Relationship Identification ---")
    print(f"Total Unique Patients (Nodes): {len(unique_patients)}")
    print(f"Total Unique Symptoms (Nodes): {len(unique_symptoms)}")

    # 3. Aggregate the number of times each unique P->S edge appears (for weighting later)
    # This result is the adjacency list/matrix entries for the P Graph.
    relation_counts = df.groupby(['PatientID', 'TestName']).size().reset_index(name='Frequency')

    print("\n[Relationship 1: Patient <-> Symptom (P Graph)]")
    print(f"Total unique Patient-Symptom links (edges): {len(relation_counts)}")
    print("Example Links and Frequencies:")
    print(relation_counts.head())

    # Check for patients with high connectivity (potential hubs)
    patient_connectivity = relation_counts.groupby('PatientID')['Frequency'].sum().nlargest(3)
    #print(f"\nTop 3 Patients by total symptom records (Hubs): \n{patient_connectivity}")

    return relation_counts


def analyze_symptom_connectivity(raw_relation_counts: pd.DataFrame):
    """
    Calculates the number of unique patient relations for each symptom.
    This helps identify rare vs. common symptoms.
    """

    # The 'raw_relation_counts' table already holds UNIQUE (PatientID, TestName) pairs,
    # so we just need to count how many unique patients each TestName is linked to.

    # The 'Frequency' column holds the count of how many TIMES that specific P->S pair appeared.
    # To get unique patient connectivity (the degree of the symptom node):

    # 1. Group by TestName and count the number of unique PatientIDs connected to it.
    symptom_connectivity = raw_relation_counts.groupby('TestName')['PatientID'].nunique().reset_index(name='Unique_Patient_Count')

    print("\n--- Symptom Node Connectivity Analysis ---")

    # 2. Identify the most and least connected symptoms
    most_common = symptom_connectivity.sort_values(by='Unique_Patient_Count', ascending=False).head(5)
    least_common = symptom_connectivity.sort_values(by='Unique_Patient_Count', ascending=True).head(5)

    print("\nTop 5 Most Connected Symptoms (Common/Hub Symptoms):")
    print(most_common)

    print("\nTop 5 Least Connected Symptoms (Potential Rare/Specific Symptoms):")
    print(least_common)

    return symptom_connectivity






def identify_organ_symptom_relations(df: pd.DataFrame):
    """
    Identifies and counts the frequency of the Patient <-> Symptom relationship
    in the raw EMR data. This forms the base of the Patient Record Graph (P).
    """

    # 1. Strip column names again for safety
    df.columns = df.columns.str.strip()

    # 2. Extract unique Patient and Symptom IDs
    unique_organ = df['Target_Organ'].unique()
    unique_symptoms = df['TestName'].unique()

    print("\n--- Relationship Identification ---")
    print(f"Total Unique Organs (Nodes): {len(unique_organ)}")
    print(f"Total Unique Symptoms (Nodes): {len(unique_symptoms)}")

    # 3. Aggregate the number of times each unique P->S edge appears (for weighting later)
    # This result is the adjacency list/matrix entries for the P Graph.
    relation_counts = df.groupby(['Target_Organ', 'TestName']).size().reset_index(name='Frequency')

    print("\n[Relationship 1: Organ <-> Symptom (SO Graph)]")
    print(f"Total unique Organ-Symptom links (edges): {len(relation_counts)}")
    print("Example Links and Frequencies:")
    print(relation_counts.head())

    # Check for patients with high connectivity (potential hubs)
    patient_connectivity = relation_counts.groupby('Target_Organ')['Frequency'].sum().nlargest(3)
    #print(f"\nTop 3 Patients by total symptom records (Hubs): \n{patient_connectivity}")

    return relation_counts



def identify_organ_disease_relations(df: pd.DataFrame):
    """
    Identifies and counts the frequency of the Patient <-> Symptom relationship
    in the raw EMR data. This forms the base of the Patient Record Graph (P).
    """

    # 1. Strip column names again for safety
    df.columns = df.columns.str.strip()

    # 2. Extract unique Patient and Symptom IDs
    unique_organ = df['Target_Organ'].unique()
    unique_disease = df['Most_Relevant_Disease'].unique()

    print("\n--- Relationship Identification ---")
    print(f"Total Unique Organs (Nodes): {len(unique_organ)}")
    print(f"Total Unique diseases (Nodes): {len(unique_disease)}")

    # 3. Aggregate the number of times each unique P->S edge appears (for weighting later)
    # This result is the adjacency list/matrix entries for the P Graph.
    relation_counts = df.groupby(['Target_Organ', 'Most_Relevant_Disease']).size().reset_index(name='Frequency')

    print("\n[Relationship 1: Organ <-> Disease (OD Graph)]")
    print(f"Total unique Organ-Disease links (edges): {len(relation_counts)}")
    print("Example Links and Frequencies:")
    print(relation_counts.head())

    # Check for patients with high connectivity (potential hubs)
    patient_connectivity = relation_counts.groupby('Target_Organ')['Frequency'].sum().nlargest(3)
    print(f"\nTop 3 Patients by total symptom records (Hubs): \n{patient_connectivity}")

    return relation_counts

In [None]:
from torch_geometric.nn import GATConv, HeteroConv

from torch_geometric.nn import HANConv

def extract_metapaths(metadata):
    node_types, edge_types = metadata
    metapaths = []
    for src1, rel1, dst1 in edge_types:
        for src2, rel2, dst2 in edge_types:
            if dst1 == src2:  # chain
                metapaths.append([(src1, rel1, dst1), (src2, rel2, dst2)])
    return metapaths


class GNNEncoderHAN(nn.Module):
    def __init__(self, metadata, metapaths, hidden_channels, out_channels, num_heads=4):
        super(GNNEncoderHAN, self).__init__()
        self.metapaths = metapaths
        self.dropout = nn.Dropout(p=0.5)


        # HAN handles all hetero types internally using metapaths
        self.han_conv = HANConv(
            in_channels={nt: -1 for nt in metadata[0]},
            out_channels=hidden_channels,
            metadata=metadata,
            heads=num_heads,
        )

        self.han_conv2 = HANConv(
        in_channels={nt: hidden_channels for nt in metadata[0]},
        out_channels=hidden_channels,
        metadata=metadata,
        heads=num_heads
        )

        self.proj = nn.Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        # HANConv automatically uses metapaths, ignores raw edge_index_dict
        x_dict = self.han_conv(x_dict, edge_index_dict)
        x_dict = {k: self.dropout(F.relu(v)) for k, v in x_dict.items()}
        x_dict = self.han_conv2(x_dict, edge_index_dict)
        x_dict = {k: self.dropout(F.relu(v)) for k, v in x_dict.items()}

        # Project to final embedding dim
        z_dict = {}
        for node_type, x in x_dict.items():
            z_dict[node_type] = F.relu(self.proj(x))

        return z_dict



class GNNEncoderHGAT(nn.Module):
    # (HGAT Encoder code provided in the prompt)
    def __init__(self, metadata, hidden_channels: int, out_channels: int, num_layers: int = 2, num_heads: int = 4):
        super(GNNEncoderHGAT, self).__init__()
        self.num_layers = num_layers
        self.convs = nn.ModuleList()
        for i in range(num_layers):
            conv = HeteroConv({
                edge_type: GATConv(in_channels=(-1, -1), out_channels=hidden_channels // num_heads, heads=num_heads, add_self_loops=False)
                for edge_type in metadata[1]
            }, aggr='sum')
            self.convs.append(conv)
        self.proj = nn.Linear(hidden_channels, out_channels)

    def forward(self, x_dict: Dict[str, torch.Tensor], edge_index_dict: Dict[Tuple, torch.Tensor]):
        for i in range(self.num_layers):
            x_dict = self.convs[i](x_dict, edge_index_dict)
            x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        z_dict = {}
        for node_type, x in x_dict.items():
            z_dict[node_type] = F.relu(self.proj(x))
        return z_dict

class GraphDecoder(nn.Module):
    def __init__(self, embedding_dim: int, num_diseases: int):
        super(GraphDecoder, self).__init__()
        self.Q = nn.Parameter(torch.rand(embedding_dim, num_diseases))
        self.G = nn.Parameter(torch.rand(embedding_dim, num_diseases))
    def forward(self, z_dict: Dict[str, torch.Tensor]):
        z_p = z_dict['patient']
        z_m = z_dict['disease']
        c_p_hat = torch.sigmoid(z_p @ self.Q)
        c_m_hat = torch.sigmoid(z_m @ self.G)
        return c_p_hat, c_m_hat

class DiseasePredictionModelHGAT(nn.Module):
    def __init__(self, data_metadata, embedding_dim: int, num_diseases: int):
        super(DiseasePredictionModelHGAT, self).__init__()
        hidden_channels = embedding_dim # Simple mapping
        self.encoder = GNNEncoderHGAT(data_metadata, hidden_channels=hidden_channels, out_channels=embedding_dim, num_layers=2)
        self.decoder = GraphDecoder(embedding_dim, num_diseases)

    def forward(self, x_dict: Dict[str, torch.Tensor], edge_index_dict: Dict[Tuple, torch.Tensor]):
        z_dict = self.encoder(x_dict, edge_index_dict)
        c_p_hat, c_m_hat = self.decoder(z_dict)
        return c_p_hat, c_m_hat

class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='sum'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        inputs = torch.clamp(inputs, min=1e-9, max=1.0 - 1e-9)  # stability
        BCE = F.binary_cross_entropy(inputs, targets, reduction='none')
        pt = torch.where(targets == 1, inputs, 1 - inputs)
        loss = self.alpha * (1 - pt) ** self.gamma * BCE
        if self.reduction == 'sum':
            return loss.sum()
        elif self.reduction == 'mean':
            return loss.mean()
        return loss


class DiseasePredictionModelHAN(nn.Module):
    def __init__(self, metadata, metapaths, embedding_dim, num_diseases:int):
        super(DiseasePredictionModelHAN, self).__init__()
        hidden_channels = embedding_dim

        self.encoder = GNNEncoderHAN(
            metadata=metadata,
            metapaths=metapaths,
            hidden_channels=hidden_channels,
            out_channels=embedding_dim,
            num_heads=4
        )

        self.decoder = GraphDecoder(embedding_dim, num_diseases)

    def forward(self, x_dict, edge_index_dict):
        z_dict = self.encoder(x_dict, edge_index_dict)
        c_p_hat, c_m_hat = self.decoder(z_dict)
        return c_p_hat, c_m_hat


def negative_log_likelihood_loss(c_hat, c):
    c_hat = torch.clamp(c_hat, min=1e-9, max=1.0 - 1e-9)
    return F.binary_cross_entropy(c_hat, c, reduction='sum')

def train_hgat_model(model: DiseasePredictionModelHGAT, full_graph: HeteroData, epochs=50, lr=0.001):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
    c_p = full_graph['patient'].y
    c_m = full_graph['disease'].y_identity
    num_train = int(c_p.shape[0] * 0.8)
    train_idx = torch.arange(num_train)

    print(f"\nStarting HGAT Training with {num_train} samples...")
    focal_loss_fn = FocalLoss(alpha=1, gamma=2, reduction='sum')

    for epoch in range(1, epochs + 1):
        model.train()
        optimizer.zero_grad()
        c_p_hat, c_m_hat = model(full_graph.x_dict, full_graph.edge_index_dict)
        L_P = focal_loss_fn(c_p_hat[train_idx], c_p[train_idx])
        L_M = focal_loss_fn(c_m_hat, c_m)
        loss = L_P + L_M
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
# The value 5.0 is a common starting point.
        optimizer.step()

        #if epoch % 10 == 0 or epoch == 1:
        print(f'Epoch: {epoch:03d}, Loss: {loss.item():.4f}, L_P: {L_P.item():.4f}, L_M: {L_M.item():.4f}')
    print("\nHGAT Model Training Complete. Ready for prediction.")
    return model


def train(model, data, epochs=30, lr=1e-3, device="cpu"):
    model = model.to(device)

    # 1. Move all features to device
    feature_dict = {k: v.to(device) for k, v in data['features'].items()}
    meta_neighbors = data['meta_neighbors']  # neighbors are index lists, no device needed

    labels = data['labels'].to(device)

    opt = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        opt.zero_grad()

        logits, Z_final, beta = model(feature_dict, meta_neighbors)
        loss = F.cross_entropy(logits, labels)

        loss.backward()
        opt.step()

        print(f"Epoch {epoch+1}/{epochs} | Loss = {loss.item():.4f}")

    return model


def train_hgan_model(model: DiseasePredictionModelHAN, full_graph: HeteroData, epochs=50, lr=0.001):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
    c_p = full_graph['patient'].y
    c_m = full_graph['disease'].y_identity
    num_train = int(c_p.shape[0] * 0.8)
    train_idx = torch.arange(num_train)

    print(f"\nStarting HGAT Training with {num_train} samples...")

    for epoch in range(1, epochs + 1):
        model.train()
        optimizer.zero_grad()
        c_p_hat, c_m_hat = model(full_graph.x_dict, full_graph.edge_index_dict)
        L_P = negative_log_likelihood_loss(c_p_hat[train_idx], c_p[train_idx])
        L_M = negative_log_likelihood_loss(c_m_hat, c_m)
        loss = L_P + L_M
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
# The value 5.0 is a common starting point.
        optimizer.step()

        #if epoch % 10 == 0 or epoch == 1:
        print(f'Epoch: {epoch:03d}, Loss: {loss.item():.4f}, L_P: {L_P.item():.4f}, L_M: {L_M.item():.4f}')
    print("\nHAN Model Training Complete. Ready for prediction.")
    return model

In [None]:
from sklearn.metrics import (
    f1_score, roc_auc_score, average_precision_score,
    hamming_loss, coverage_error, label_ranking_loss
)
import torch.nn.functional as F

def evaluate_model(model, full_graph, k_list=[1,3,5]):
    print("\n========== MODEL EVALUATION ==========")

    model.eval()
    with torch.no_grad():
        c_p_hat, c_m_hat = model(full_graph.x_dict, full_graph.edge_index_dict)

    # TRUE LABELS
    y_true = full_graph['patient'].y.cpu()
    y_pred = c_p_hat.cpu()

    # Binarize using threshold 0.5
    y_bin = (y_pred >= 0.5).int()

    metrics = {}

    # 1. Micro & Macro F1
    metrics['micro_f1'] = f1_score(y_true, y_bin, average='micro', zero_division=0)
    metrics['macro_f1'] = f1_score(y_true, y_bin, average='macro', zero_division=0)

    # 2. AUPRC (micro)
    metrics['micro_auprc'] = average_precision_score(y_true, y_pred, average='micro')

    # 3. ROC-AUC
    try:
        metrics['micro_auc'] = roc_auc_score(y_true, y_pred, average='micro')
        metrics['macro_auc'] = roc_auc_score(y_true, y_pred, average='macro')
    except:
        metrics['micro_auc'] = None
        metrics['macro_auc'] = None

    # 4. Hamming Loss
    metrics['hamming_loss'] = hamming_loss(y_true, y_bin)

    # 5. Ranking Loss
    metrics['ranking_loss'] = label_ranking_loss(y_true, y_pred)

    # 6. Coverage Error
    metrics['coverage_error'] = coverage_error(y_true, y_pred)

    # 7. Top-k Accuracy
    topk_results = {}
    for k in k_list:
        topk_pred = torch.topk(y_pred, k=k, dim=1).indices
        correct = 0
        for i in range(y_true.shape[0]):
            true_labels = set(torch.where(y_true[i] == 1)[0].tolist())
            predicted_labels = set(topk_pred[i].tolist())
            if len(true_labels & predicted_labels) > 0:
                correct += 1
        topk_results[f"top_{k}_accuracy"] = correct / len(y_true)
    metrics.update(topk_results)

    # -------------------------
    # DISEASE EMBEDDING METRICS
    # -------------------------
    y_d_true = full_graph['disease'].y_identity.cpu().float()
    y_d_pred = c_m_hat.cpu().float()

    metrics["disease_mse"] = F.mse_loss(y_d_pred, y_d_true).item()

    cos_sim = F.cosine_similarity(y_d_pred, y_d_true).mean().item()
    metrics["disease_cosine_similarity"] = cos_sim

    print("\n======= Evaluation Results =======")
    for k,v in metrics.items():
        print(f"{k:25s}: {v}")

    return metrics


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


# ======================================================================
# HELPER: Node-Level Attention (Eq. 2–5)
# ======================================================================
class NodeLevelAttention(nn.Module):
    def __init__(self, in_dim, out_dim, num_heads):
        super().__init__()
        self.K = num_heads
        self.out_dim = out_dim

        # Attention vector aΦ in Eq (3)
        self.att = nn.Parameter(torch.randn(num_heads, 2 * out_dim))

        # Final activation σ(·)
        self.act = nn.LeakyReLU(0.2)

    def forward(self, h_proj, neighbor_dict, device):
        """
        h_proj: (N, out_dim)
        neighbor_dict: dict[i] = list of neighbors j for the meta-path Φ
        """
        N = h_proj.size(0)
        all_heads = []

        for k in range(self.K):
            h_head = []

            # For each node compute node-level attention for its neighbors
            for i in range(N):
                neigh = neighbor_dict[i]
                if len(neigh) == 0:
                    h_head.append(h_proj[i])
                    continue

                # Prepare tensor
                h_i = h_proj[i].repeat(len(neigh), 1)
                h_j = h_proj[neigh]

                # Concatenate [h_i || h_j]
                concat = torch.cat([h_i, h_j], dim=1)

                # e_{i,j} = σ(aᵀ [h_i || h_j])   (Eq. 3 numerator)
                e_ij = self.act(torch.matmul(concat, self.att[k].unsqueeze(1))).squeeze()

                # α_{i,j} = softmax(e_{i,j}) (Eq. 3)
                alpha = torch.softmax(e_ij, dim=0)

                # z_i^Φ = σ(sum_j α_{i,j} * h'_j) (Eq. 4)
                z_i = torch.sum(alpha.unsqueeze(1) * h_j, dim=0)
                h_head.append(z_i)

            h_head = torch.stack(h_head)
            all_heads.append(h_head)

        # Eq. 5 → concatenate multiple heads
        z_phi = torch.cat(all_heads, dim=1)  # (N, K*out_dim)
        return z_phi


# ======================================================================
# HELPER: Semantic-Level Attention (Eq. 6–8)
# ======================================================================
class SemanticAttention(nn.Module):
    def __init__(self, in_dim, att_dim=128):
        super().__init__()

        self.W = nn.Linear(in_dim, att_dim)
        self.q = nn.Parameter(torch.randn(att_dim))

    def forward(self, Z_list):
        """
        Z_list = list of embeddings from each meta-path
        """
        meta_path_weights = []
        for Z in Z_list:
            # w_Φ = average over nodes qᵀ tanh(W zᵢ)  (Eq. 7)
            h = torch.tanh(self.W(Z))
            w_phi = torch.mean(torch.matmul(h, self.q))
            meta_path_weights.append(w_phi)

        # β_Φ = softmax(w_Φ) (Eq. 8)
        beta = torch.softmax(torch.stack(meta_path_weights), dim=0)

        # Fuse final embedding Z = sum β_Φ * Z_Φ  (Algorithm 1 step 14)
        Z_final = sum(beta[i] * Z_list[i] for i in range(len(Z_list)))

        return Z_final, beta.detach().cpu()


# ======================================================================
# HAN MODEL (Algorithm 1)
# ======================================================================
class HAN(nn.Module):
    def __init__(self,
                 node_types,
                 in_dims,             # dict: node_type → input dimension
                 hidden_dim=64,
                 out_dim=64,
                 num_heads=8,
                 meta_paths=[]):
        super().__init__()

        self.node_types = node_types
        self.meta_paths = meta_paths
        self.K = num_heads
        self.hidden_dim = hidden_dim

        # M_φ : type-specific projection (Eq. 1)
        self.type_transforms = nn.ModuleDict({
            ntype: nn.Linear(in_dims[ntype], hidden_dim)
            for ntype in node_types
        })

        # Node-level attentions for each meta-path
        self.node_level = nn.ModuleList([
            NodeLevelAttention(hidden_dim, hidden_dim, num_heads)
            for _ in meta_paths
        ])

        # Semantic-level attention
        self.semantic_att = SemanticAttention(hidden_dim * num_heads)

        # Final classifier head
        self.classifier = nn.Linear(hidden_dim * num_heads, 2)

    def forward(self, features, meta_neighbors):
        """
        features: dict[node_type] = tensor
        meta_neighbors: list of dicts for each meta-path
        """
        # ---------------------------------------------------------------
        # Step 1: Type-specific projection (Eq. 1)
        # ---------------------------------------------------------------
        all_nodes = list(features.values())[0].size(0)
        h_proj = torch.zeros((all_nodes, self.hidden_dim),device=next(self.parameters()).device)

        for ntype, feat in features.items():
            h_proj += self.type_transforms[ntype](feat)

        # ---------------------------------------------------------------
        # Step 2: Node-level attention for each meta-path
        # ---------------------------------------------------------------
        Z_meta = []
        for mp_idx, neigh_dict in enumerate(meta_neighbors):
            Z_phi = self.node_level[mp_idx](h_proj, neigh_dict)
            Z_meta.append(Z_phi)

        # ---------------------------------------------------------------
        # Step 3: Semantic-level fusion (Eq. 7–8)
        # ---------------------------------------------------------------
        Z_final, beta = self.semantic_att(Z_meta)

        # ---------------------------------------------------------------
        # Step 4: Final classification
        # ---------------------------------------------------------------
        logits = self.classifier(Z_final)
        return logits, Z_final, beta


In [None]:
FILE_PATH = '/content/patient_reports.csv'
FILE_PATH_2 = '/content/enhanced_symptom_connectivity_analysis(Sheet1).csv'

# Read the data, separating by spaces/tabs
#df_2 = pd.read_csv(pd.io.common.StringIO(raw_data), sep='\t', parse_dates=['ReportDate'])

try:
    # Use pandas read_csv to load the file
    # We keep sep='\t' (tab) based on your original data format, and parse dates.
    df = pd.read_csv(
        FILE_PATH,
        sep=',',
        parse_dates=["ReportDate"],
        skipinitialspace=True
    )
# CRITICAL FIX: Explicitly strip whitespace from column names
    df.columns = df.columns.str.strip()

    # Verify column existence post-stripping (optional debug)
    #if 'ReportDate' not in df.columns:
    #    raise ValueError(f"Column 'ReportDate' not found even after stripping names. Found: {df.columns.tolist()}")

    print(f"[INFO] Successfully loaded data from {FILE_PATH}. Total rows: {len(df)}")
except FileNotFoundError:
    print(f"[ERROR] The file {FILE_PATH} was not found. Please check the file path.")
    # Exit or create an empty DataFrame to prevent crashes
    df = pd.DataFrame(columns=['PatientID', 'ReportDate', 'TestName', 'TestValue'])

# ---------------------------------------

try:
    # Use pandas read_csv to load the file
    # We keep sep='\t' (tab) based on your original data format, and parse dates.
    df2 = pd.read_csv(
        FILE_PATH_2,
        sep=',',
        #parse_dates=["Unique_Patient_Count"],
        encoding='latin-1',
        skipinitialspace=True
    )
# CRITICAL FIX: Explicitly strip whitespace from column names
    df2.columns = df2.columns.str.strip()

    # Verify column existence post-stripping (optional debug)
    #if 'ReportDate' not in df.columns:
    #    raise ValueError(f"Column 'ReportDate' not found even after stripping names. Found: {df.columns.tolist()}")
    if 'Unique_Patient_Count' in df2.columns:
        df2['Unique_Patient_Count'] = pd.to_numeric(df2['Unique_Patient_Count'], errors='coerce').astype('Int64')


    print(f"[INFO] Successfully loaded data from {FILE_PATH_2}. Total rows: {len(df2)}")
except FileNotFoundError:
    print(f"[ERROR] The file {FILE_PATH_2} was not found. Please check the file path.")
    # Exit or create an empty DataFrame to prevent crashes
    df = pd.DataFrame(columns=['TestName', 'Unique_Patient_Count', 'Most_Relevant_Disease', 'Target_Organ'])



# ----------------------------------------

# Get a map of all unique medical concepts (TestNames)
unique_patients = df['PatientID'].unique()
unique_concepts = df['TestName'].unique()
unique_organs = df2['Target_Organ'].unique()
unique_diseases = df2['Most_Relevant_Disease'].unique()
ORGAN_MAP = {name: i for i, name in enumerate(unique_organs)}
NUM_ORGANS = len(ORGAN_MAP)
CONCEPT_MAP = {name: i for i, name in enumerate(unique_concepts)}
NUM_CONCEPTS = len(CONCEPT_MAP)
DISEASE_MAP = {name: i for i, name in enumerate(unique_diseases)}
NUM_DISEASES = len(DISEASE_MAP)
PATIENT_IDS = list(unique_patients)
PATIENT_IDX_MAP = {id: i for i, id in enumerate(PATIENT_IDS)}
VALUE_DIM = 1 # Assuming only one value feature per concept (the normalized value)

# Define dimensions
INPUT_DIM = NUM_CONCEPTS * VALUE_DIM
print(INPUT_DIM)
OUTPUT_DIM = 128 # The target h_patient^0 size for the GNN
HIDDEN_DIM = 64
LSTM_LAYERS = 1
TCN_LAYERS = 2
TCN_KERNEL = 2
TCN_DROPOUT = 0.2




def generate_patient_embeddings(patient_data: pd.DataFrame):
    """Generates embeddings using both simple aggregation and LSTM methods."""

    # Initialize the LSTM encoder model
    #lstm_encoder = TimeSeriesEncoder(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM)
    tcn_encoder = TCNEncoder(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM, TCN_LAYERS, TCN_KERNEL, TCN_DROPOUT)
    patient_features_dict = {}
    # DUMMY PROJECTION layer for the simple_aggregate_visits baseline
    # In a real setup, this is part of the training process, not an arbitrary dummy
    DUMMY_PROJ_LAYER = nn.Linear(INPUT_DIM, OUTPUT_DIM)

    #results = {'simple_aggregate': {}, 'lstm_encoder': {}, 'tcn_encoder': {}}

    for patient_id, group in patient_data.groupby('PatientID'):
        #print(f"\nProcessing Patient: {patient_id}")

        # --- Step 1: Data Transformation ---
        group = group.sort_values('ReportDate')

        # Transform sparse EMR records into a dense feature sequence tensor
        visit_features_sequence = transform_raw_to_features(group, CONCEPT_MAP, VALUE_DIM)
        h_p_0 = None

        if visit_features_sequence.size(0) >= TCN_KERNEL:
            with torch.no_grad():
                ts_input = visit_features_sequence.unsqueeze(0)
                patient_features_dict[patient_id] = tcn_encoder(ts_input).squeeze(0)
        else:
            with torch.no_grad():
                patient_features_dict[patient_id] = simple_aggregate_visits(visit_features_sequence, OUTPUT_DIM)


        #print(f"  Sequence Length: {visit_features_sequence.size(0)}")
        #print(f"  Feature Dim (per visit): {visit_features_sequence.size(1)}")


        # --- Step 2: Apply Simple Aggregation (Baseline) ---
        # NOTE: Using the DUMMY_PROJ_LAYER defined above
        with torch.no_grad():
            simple_embed = simple_aggregate_visits(visit_features_sequence, OUTPUT_DIM)

        #results['simple_aggregate'][patient_id] = simple_embed
        #print(f"  Simple Aggregation h_0 shape: {simple_embed.shape}")

        # --- Step 3: Apply LSTM Encoder ---
        with torch.no_grad():
            # Add batch dimension: (1, sequence_length, input_dim)
            ts_input = visit_features_sequence.unsqueeze(0)
            #lstm_embed = lstm_encoder(ts_input).squeeze(0) # Remove batch dimension

        #results['lstm_encoder'][patient_id] = lstm_embed
        #print(f"  LSTM Encoder h_0 shape: {lstm_embed.shape}")

        with torch.no_grad():
            tcn_embed = tcn_encoder(ts_input).squeeze(0)
        #results['tcn_encoder'][patient_id] = tcn_embed
        #print(f"  TCN Encoder h_0 shape: {tcn_embed.shape}")
    # Filter patient IDs to only include those successfully encoded
    final_patient_ids = list(patient_features_dict.keys())
    final_patient_idx_map = {id: i for i, id in enumerate(final_patient_ids)}

    print(f"\n--- Stage 1: Feature Generation Complete ---")
    print(f"Total Patients Encoded: {len(final_patient_ids)}")
    print(f"Final Patient Feature Dimension (h_p^0): {OUTPUT_DIM}")

    print("\n--- First 5 items in patient_feature_dict ---")
    for i, (k, v) in enumerate(patient_features_dict.items()):
      if i == 5:
          break
      print(f"Patient ID: {k} -> Features: {v[:10]} ...")


    sample_key = next(iter(patient_features_dict))
    print("\nSample patient key:", sample_key)
    print("Feature vector:", patient_features_dict[sample_key])
    print("Length:", len(patient_features_dict[sample_key]))

    return patient_features_dict, final_patient_ids, final_patient_idx_map


def build_and_get_unified_graph():


    final_unified_graph = construct_final_unified_graph(
        ps_graph, so_graph, od_graph, data_obj
    )

    return final_unified_graph, final_patient_ids # Return the built graph object



if __name__ == '__main__':
    print(f"Total Unique Medical Concepts (Input Dim): {NUM_CONCEPTS} (x{VALUE_DIM})\n")

    patient_features_dict, final_patient_ids, final_patient_idx_map = generate_patient_embeddings(df)

    print("\n" + "="*50)
    print("COMPARISON OF GENERATED PATIENT EMBEDDINGS (h_patient^0)")
    print("="*50)

    patient_symptom_edges = identify_patient_symptom_relations(df)
    symptom_organ_edges = identify_organ_symptom_relations(df2)
    organ_disease_edges = identify_organ_disease_relations(df2)
    ps_graph = construct_patient_symptom_bipartite_graph(
        final_patient_ids,
        patient_features_dict,
        final_patient_idx_map,
        CONCEPT_MAP,
        patient_symptom_edges
    )

# 4. FINAL VERIFICATION
    print("\n--- Stage 2: Patient-Symptom Graph Constructed ---")
    #print(ps_graph)

    so_graph = construct_symptom_organ_graph(
        CONCEPT_MAP,
        ORGAN_MAP,
        symptom_organ_edges
    )
    print("\n--- Stage 2: Symptom-organ Graph Constructed ---")
    print(so_graph)


    od_graph = construct_disease_organ_graph(
        DISEASE_MAP,
        ORGAN_MAP,
        organ_disease_edges
    )
    print("\n--- Stage 2: Disease-organ Graph Constructed ---")
    print(od_graph)


    PATIENT_LABELS = {id: torch.randint(0, 2, (NUM_DISEASES,)).float() for id in final_patient_ids}

    data_obj = PatientData(patient_features_dict,final_patient_ids, final_patient_idx_map, CONCEPT_MAP, DISEASE_MAP, ORGAN_MAP, patient_symptom_edges, symptom_organ_edges, organ_disease_edges,PATIENT_LABELS)


    # --- C. MERGE ALL GRAPHS (Integration Layer) ---
    final_unified_graph,_ = build_and_get_unified_graph()

    print("\n" + "="*70)
    print("GNN INPUT READY")
    print(final_unified_graph)
    print(f"Total Edges for GNN: {final_unified_graph.num_edges}")
    print("="*70)



    metadata = final_unified_graph.metadata()
    print("\n--- DEBUG: Graph Metadata Schema Check ---")
    print(f"Node Types Passed to Model: {metadata[0]}")
    print(f"Edge Types Passed to Model: {metadata[1]}")


    metapaths = extract_metapaths(metadata)

    # 1. Extract node types
    node_types = final_unified_graph.metadata()[0]

# 2. Extract input feature dims automatically
    in_dims = {ntype: final_unified_graph[ntype].x.shape[1] for ntype in node_types}

    print("Detected metapaths:")
    for m in metapaths:
      print(m)

    num_diseases_final = final_unified_graph['disease'].x.shape[0] # This should give 64
    print(f"Number of Disease Nodes in Final Graph: {num_diseases_final} (Expected: {len(DISEASE_MAP)})")

    import torch

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    model = HAN(
        node_types=node_types,
        in_dims=in_dims,
        hidden_dim=64,
        out_dim=OUTPUT_DIM,
        num_heads=8,
        meta_paths=metapaths
    ).to(device)

    hgat_model = DiseasePredictionModelHGAT(
        data_metadata=metadata,
        embedding_dim=OUTPUT_DIM,
        num_diseases=num_diseases_final
    )

    han_model = DiseasePredictionModelHAN(
        metadata=metadata,
        metapaths=metapaths,
        embedding_dim=OUTPUT_DIM,
        num_diseases=num_diseases_final
    )


    # Move graph data to device
    for k in final_unified_graph.x_dict:
        final_unified_graph.x_dict[k] = final_unified_graph.x_dict[k].to(device)
    for k in final_unified_graph.edge_index_dict:
        final_unified_graph.edge_index_dict[k] = final_unified_graph.edge_index_dict[k].to(device)

    final_unified_graph['patient'].y = final_unified_graph['patient'].y.to(device)
    final_unified_graph['disease'].y_identity = final_unified_graph['disease'].y_identity.to(device)


    trained_model_HAN = train(model, final_unified_graph)
    #print("\n HAN model  training..........")
    #HAN_trained_model = train_hgan_model(han_model, final_unified_graph, epochs=50, lr=0.0005)

    # 2. Start Training Simulation
    #print("\n HGT model  training..........")
    #trained_model = train_hgat_model(hgat_model, final_unified_graph, epochs=50, lr=0.0005)

    #print("\n[RESEARCH PHASE] The HGAT and HAN models are now trained and ready for comparison.")

    #print("\nEvaluating HGAT Model...")
    #hgat_results = evaluate_model(trained_model, final_unified_graph)

    print("\nEvaluating HAN Model...")
    han_results = evaluate_model(trained_model_HAN, final_unified_graph)

[INFO] Successfully loaded data from /content/patient_reports.csv. Total rows: 160942
[INFO] Successfully loaded data from /content/enhanced_symptom_connectivity_analysis(Sheet1).csv. Total rows: 176
177
Total Unique Medical Concepts (Input Dim): 177 (x1)


--- Stage 1: Feature Generation Complete ---
Total Patients Encoded: 24352
Final Patient Feature Dimension (h_p^0): 128

--- First 5 items in patient_feature_dict ---
Patient ID: 139760 -> Features: tensor([0.0367, 0.0000, 0.0000, 0.0000, 0.0606, 0.0000, 0.0711, 0.0000, 0.0239,
        0.0000]) ...
Patient ID: 200041 -> Features: tensor([0.0187, 0.0000, 0.0000, 0.0000, 0.0275, 0.0000, 0.0765, 0.0000, 0.0285,
        0.0294]) ...
Patient ID: 201519 -> Features: tensor([0.0000, 0.0000, 0.0131, 0.0000, 0.0000, 0.0602, 0.0000, 0.0000, 0.0000,
        0.0000]) ...
Patient ID: 201605 -> Features: tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0474, 0.0000, 0.0778, 0.0000, 0.0000,
        0.0588]) ...
Patient ID: 201839 -> Features: tensor([0.0

IndexError: list index out of range