In [None]:
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-$(python -c 'import torch; print(torch.__version__)').html -q

# Install other necessary libraries
!pip install datasets pandas numpy networkx scikit-learn transformers sentence-transformers -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m64.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.0/5.0 MB[0m [31m28.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m17.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m19.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m15.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.4/491.4 kB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
# Import necessary libraries
import pandas as pd
import numpy as np
import networkx as nx # For initial graph understanding (optional)
import time
import sys
import warnings
from collections import defaultdict

# Machine Learning and NLP
import torch
import torch.nn.functional as F
from torch.nn import Embedding, Linear, ModuleList, ReLU
from torch_geometric.data import HeteroData
from torch_geometric.nn import SAGEConv, HeteroConv, global_mean_pool, to_hetero
from torch_geometric.loader import DataLoader # For potential batching if needed
from datasets import load_dataset
from transformers import AutoTokenizer # Potentially useful
from sentence_transformers import SentenceTransformer # For statement embeddings
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
from sklearn.preprocessing import LabelEncoder

# Ignore common warnings
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=DeprecationWarning)

# --- Configuration ---
SAMPLE_SIZE = 1000     # Number of statements to process (adjust based on resources/time)
TEST_SPLIT_RATIO = 0.2 # Ratio for train/test split
EMBEDDING_DIM = 64    # Dimension for learnable entity embeddings (speaker, subject, etc.)
SBERT_MODEL = 'all-MiniLM-L6-v2' # Sentence Transformer model for statement features
HIDDEN_CHANNELS = 128  # GNN hidden layer dimension
NUM_GNN_LAYERS = 2    # Number of GNN layers
NUM_EPOCHS = 50       # Training epochs
LEARNING_RATE = 0.01
BATCH_SIZE = 128      # Batch size for training (if using DataLoader) - Simpler approach without batching first

# Map numeric labels to text labels (consistent naming)
LABEL_MAP = {
    0: 'false', 1: 'half-true', 2: 'mostly-true',
    3: 'true', 4: 'barely-true', 5: 'pants-on-fire'
}
# Inverse map for potential use
INV_LABEL_MAP = {v: k for k, v in LABEL_MAP.items()}
CANDIDATE_LABELS = list(LABEL_MAP.values())

# Define node and edge types
NODE_TYPES = ['statement', 'speaker', 'subject', 'party']
EDGE_TYPES = [
    ('statement', 'spoken_by', 'speaker'),
    ('speaker', 'affiliated_with', 'party'),
    ('statement', 'about_subject', 'subject'),
    # Optional: Inverse edges if needed by the GNN architecture
    ('speaker', 'rev_spoken_by', 'statement'),
    ('party', 'rev_affiliated_with', 'speaker'),
    ('subject', 'rev_about_subject', 'statement'),
]

# ==============================================================================
# --- 1. Load and Prepare Data ---
# ==============================================================================
def load_and_prep_data(sample_size):
    """Loads LIAR dataset, samples, and performs initial prep."""
    print("--- 1. Loading and Preparing Data ---")
    start_time = time.time()
    try:
        dataset = load_dataset("liar")
        df_full = dataset['train'].to_pandas()
        print(f"Full dataset shape: {df_full.shape}")
    except Exception as e:
        print(f"Error loading dataset: {e}")
        sys.exit()

    # Sample data
    if sample_size < len(df_full):
        df_sample = df_full.sample(n=sample_size, random_state=42).copy()
        print(f"Using a sample of {sample_size} rows.")
    else:
        df_sample = df_full.copy()
        print(f"Using the full dataset ({len(df_sample)} rows).")

    # Map numeric labels to text and add integer labels
    if 'label' in df_sample.columns and pd.api.types.is_numeric_dtype(df_sample['label']):
        df_sample['label_text'] = df_sample['label'].map(LABEL_MAP)
        # Use original numeric label directly for training targets
        df_sample['label_idx'] = df_sample['label']
        print("Created 'label_text' and 'label_idx' columns.")
    else:
        print("Warning: Original numeric 'label' column not found or not numeric.")
        # Attempt to create numeric labels if text labels exist
        if 'label' in df_sample.columns:
             le = LabelEncoder().fit(CANDIDATE_LABELS)
             df_sample['label_text'] = df_sample['label'].astype(str)
             try:
                 df_sample['label_idx'] = le.transform(df_sample['label_text'])
                 print("Created 'label_idx' via LabelEncoder based on existing labels.")
             except ValueError:
                 print("Error: Could not encode labels. Ensure labels match CANDIDATE_LABELS.")
                 df_sample['label_idx'] = -1 # Indicate error
        else:
             print("Error: No usable label column found.")
             sys.exit()


    # Fill NaNs in relevant columns
    for col in ['speaker', 'subject', 'party_affiliation']:
        if col in df_sample.columns:
            df_sample[col] = df_sample[col].fillna('Unknown').astype(str)

    print(f"Data loading & prep completed in {time.time() - start_time:.2f} seconds.")
    return df_sample

# ==============================================================================
# --- 2. Feature Engineering ---
# ==============================================================================
def create_statement_features(statements):
    """Generates statement embeddings using SentenceTransformer."""
    print("\n--- 2a. Generating Statement Features (Embeddings) ---")
    start_time = time.time()
    try:
        model = SentenceTransformer(SBERT_MODEL)
        with torch.no_grad():
            # Ensure statements are strings
            statement_list = [str(s) if pd.notna(s) else "" for s in statements]
            embeddings = model.encode(statement_list, convert_to_tensor=True, show_progress_bar=True)
        print(f"Statement embeddings generated with shape: {embeddings.shape}")
        print(f"Statement feature generation completed in {time.time() - start_time:.2f} seconds.")
        return embeddings
    except Exception as e:
        print(f"Error generating statement embeddings: {e}")
        # Determine expected dimension based on model name
        sbert_dim = 384 if 'MiniLM' in SBERT_MODEL else (768 if 'base' in SBERT_MODEL else 768) # Default to 768 if unsure
        print(f"Returning fallback zero tensor with dimension {sbert_dim}")
        return torch.zeros((len(statements), sbert_dim))


# ==============================================================================
# --- 3. Graph Construction ---
# ==============================================================================
def build_hetero_graph(df, statement_features):
    """Builds a PyTorch Geometric HeteroData object."""
    print("\n--- 3. Building Heterogeneous Graph ---")
    start_time = time.time()
    data = HeteroData()

    # --- Node Mapping and Features ---
    # Statements (use DataFrame index as node ID)
    statement_node_ids = df.index.to_list()
    num_statements = len(statement_node_ids)
    data['statement'].x = statement_features
    # Map df index to internal graph node index (0 to num_statements-1)
    statement_map = {df_idx: graph_idx for graph_idx, df_idx in enumerate(statement_node_ids)}

    # Other entities (speaker, subject, party)
    entity_maps = {}
    entity_embeddings = {}
    for node_type in ['speaker', 'subject', 'party']:
        col_name = 'party_affiliation' if node_type == 'party' else node_type
        if col_name not in df.columns:
            print(f"Warning: Column '{col_name}' not found for node type '{node_type}'. Skipping.")
            continue

        # Handle potential multi-value subjects
        if node_type == 'subject':
            # Ensure splitting happens correctly even if some entries are not strings
            all_entities = df[col_name].astype(str).str.split(',').explode().str.strip().unique()
        else:
            all_entities = df[col_name].unique()

        # Filter out potential NaN/None values converted to strings like 'nan'
        entity_list = sorted([str(e) for e in all_entities if pd.notna(e) and str(e) and str(e).lower() != 'nan'])
        entity_map = {name: i for i, name in enumerate(entity_list)}
        entity_maps[node_type] = entity_map
        num_entities = len(entity_list)
        print(f"Found {num_entities} unique entities for type '{node_type}'")

        # Initialize learnable embeddings (will be part of the model later)
        # data[node_type].x = None # Placeholder, features added in model
        data[node_type].num_nodes = num_entities # Store number of nodes

    # --- Edge Construction ---
    edge_indices = defaultdict(list)

    for df_idx, row in df.iterrows():
        stmt_graph_idx = statement_map.get(df_idx)
        if stmt_graph_idx is None: continue # Should not happen if df.index is used

        # Statement -> Speaker
        speaker_name = str(row.get('speaker', 'Unknown'))
        speaker_map = entity_maps.get('speaker', {})
        if speaker_name in speaker_map:
            speaker_graph_idx = speaker_map[speaker_name]
            edge_indices[('statement', 'spoken_by', 'speaker')].append([stmt_graph_idx, speaker_graph_idx])

        # Speaker -> Party
        party_name = str(row.get('party_affiliation', 'Unknown'))
        party_map = entity_maps.get('party', {})
        # Ensure speaker_graph_idx was defined before using it here
        if speaker_name in speaker_map and party_name in party_map:
            speaker_graph_idx = speaker_map[speaker_name] # Get index again just in case
            party_graph_idx = party_map[party_name]
            edge_indices[('speaker', 'affiliated_with', 'party')].append([speaker_graph_idx, party_graph_idx])

        # Statement -> Subject(s)
        subjects_raw = str(row.get('subject', 'Unknown'))
        subject_map = entity_maps.get('subject', {})
        subjects = [s.strip() for s in subjects_raw.split(',') if s.strip()]
        for subject_name in subjects:
            if subject_name in subject_map:
                subject_graph_idx = subject_map[subject_name]
                edge_indices[('statement', 'about_subject', 'subject')].append([stmt_graph_idx, subject_graph_idx])

    # Convert edge lists to tensors and add reverse edges
    valid_edge_types = [] # Keep track of edge types actually added
    for edge_type_tuple in list(edge_indices.keys()):
        src_type, rel_type, dst_type = edge_type_tuple
        # Ensure list is not empty before converting to tensor
        if edge_indices[edge_type_tuple]:
            edges = torch.tensor(edge_indices[edge_type_tuple], dtype=torch.long).t().contiguous()
            data[src_type, rel_type, dst_type].edge_index = edges
            valid_edge_types.append(edge_type_tuple) # Mark this type as valid
            # Add reverse edges
            rev_rel_type = f"rev_{rel_type}"
            data[dst_type, rev_rel_type, src_type].edge_index = edges[[1, 0]] # Swap source and destination
            valid_edge_types.append((dst_type, rev_rel_type, src_type)) # Mark reverse type as valid
        else:
             print(f"Warning: No edges found for type {edge_type_tuple}. Skipping.")


    # Add statement labels and masks
    # Ensure label_idx exists and is valid
    if 'label_idx' not in df.columns or df['label_idx'].eq(-1).any():
        print("Error: Invalid 'label_idx' column. Cannot assign labels or masks.")
        # Handle error appropriately, maybe return None or raise exception
        return None, None
    data['statement'].y = torch.tensor(df['label_idx'].values, dtype=torch.long)

    # Create train/test masks for statement nodes
    num_all_statements = data['statement'].num_nodes
    indices = np.arange(num_all_statements)

    # Ensure stratification is possible (at least 2 members per class for splitting)
    unique_labels, counts = np.unique(data['statement'].y.numpy(), return_counts=True)
    min_samples_per_class = counts.min()
    n_splits_required = int(1 / TEST_SPLIT_RATIO) # Approx number of splits for stratify

    if min_samples_per_class < n_splits_required and min_samples_per_class < 2:
         print(f"Warning: The least populated class has only {min_samples_per_class} members, which is too few for stratified splitting with test_size={TEST_SPLIT_RATIO}. Performing non-stratified split.")
         stratify_labels = None
    else:
         stratify_labels = data['statement'].y.numpy()


    train_indices, test_indices = train_test_split(
        indices,
        test_size=TEST_SPLIT_RATIO,
        random_state=42,
        stratify=stratify_labels # Use labels for stratification if possible
        )

    train_mask = torch.zeros(num_all_statements, dtype=torch.bool)
    test_mask = torch.zeros(num_all_statements, dtype=torch.bool)
    train_mask[train_indices] = True
    test_mask[test_indices] = True
    data['statement'].train_mask = train_mask
    data['statement'].test_mask = test_mask

    print(f"Graph construction completed in {time.time() - start_time:.2f} seconds.")
    print(f"\nGraph Summary:\n{data}")
    # Validate graph structure
    try:
        data.validate()
        print("Graph validation successful.")
    except Exception as e:
        print(f"Graph validation failed: {e}")

    return data, entity_maps, valid_edge_types # Return maps and valid edge types

# ==============================================================================
# --- 4. GNN Model Definition ---
# ==============================================================================
class HeteroGNN(torch.nn.Module):
    # Pass valid_edge_types to the constructor
    def __init__(self, hidden_channels, out_channels, num_layers, entity_maps, embedding_dim, statement_feature_dim, valid_edge_types):
        super().__init__()
        self.entity_embeddings = torch.nn.ModuleDict()
        self.num_entities = {}
        self.valid_edge_types = valid_edge_types # Store valid edge types

        # Create learnable embeddings for entity types (speaker, subject, party)
        for node_type, entity_map in entity_maps.items():
            num_entities = len(entity_map)
            self.num_entities[node_type] = num_entities
            if num_entities > 0:
                # Ensure embedding index doesn't go out of bounds if map is empty
                self.entity_embeddings[node_type] = Embedding(num_entities, embedding_dim)
            else:
                print(f"Warning: No entities found for type {node_type}, embedding layer not created.")


        self.convs = ModuleList()
        # Input projection for statement features
        self.statement_lin = Linear(statement_feature_dim, hidden_channels)
        # Input projection for entity embeddings (only create if embeddings exist)
        if self.entity_embeddings:
            self.entity_lin = Linear(embedding_dim, hidden_channels)
        else:
            self.entity_lin = None


        in_channels_first = hidden_channels

        for i in range(num_layers):
            conv_in = in_channels_first if i == 0 else hidden_channels
            # *** FIX: Remove add_self_loops argument ***
            conv = HeteroConv({
                # Only include edge types that actually exist in the graph data
                edge_type: SAGEConv((conv_in, conv_in), hidden_channels)
                for edge_type in self.valid_edge_types # Use valid edge types here
            }, aggr='sum') # Or 'mean', 'max'
            self.convs.append(conv)

        # Final classifier only on statement nodes
        self.classifier = Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        # Apply initial projection and embeddings
        feat_dict = {}
        # Project statement features
        if 'statement' in x_dict and x_dict['statement'] is not None:
             # Ensure statement features exist before projecting
             if x_dict['statement'].shape[0] > 0:
                 feat_dict['statement'] = self.statement_lin(x_dict['statement']).relu()
             else:
                 # Handle case with no statement nodes if necessary
                 feat_dict['statement'] = torch.empty((0, self.statement_lin.out_features), device=self.statement_lin.weight.device)


        # Lookup and project entity embeddings
        if self.entity_lin: # Check if entity projection layer exists
            for node_type, emb_layer in self.entity_embeddings.items():
                 # Ensure there are nodes of this type before creating indices
                 if self.num_entities[node_type] > 0:
                     node_indices = torch.arange(self.num_entities[node_type], device=emb_layer.weight.device)
                     node_embeddings = emb_layer(node_indices)
                     feat_dict[node_type] = self.entity_lin(node_embeddings).relu()

        # GNN layers
        # Filter edge_index_dict to only include valid edge types expected by the model
        valid_edge_index_dict = {k: v for k, v in edge_index_dict.items() if k in self.valid_edge_types}

        for conv in self.convs:
            # Pass only the valid edge indices to the convolution
            feat_dict = conv(feat_dict, valid_edge_index_dict)
            # Apply activation after each layer's aggregation
            feat_dict = {key: x.relu() for key, x in feat_dict.items()}

        # Classify statement nodes
        if 'statement' in feat_dict and feat_dict['statement'].shape[0] > 0:
            out = self.classifier(feat_dict['statement'])
            return out
        else:
            # Handle case where statement features might disappear or were never present
            print("Warning: 'statement' features not found or empty after GNN layers.")
            # Determine expected output shape based on input statement nodes
            num_statement_nodes = x_dict.get('statement', torch.empty(0)).shape[0]
            return torch.zeros((num_statement_nodes, self.classifier.out_features), device=next(self.parameters()).device)


# ==============================================================================
# --- 5. Training and Evaluation ---
# ==============================================================================
def train(model, data, optimizer, criterion):
    model.train()
    optimizer.zero_grad()
    # Ensure data has features before passing to model
    if not hasattr(data, 'x_dict'):
         print("Error: data object missing 'x_dict'.")
         return 0.0
    if not hasattr(data, 'edge_index_dict'):
         print("Error: data object missing 'edge_index_dict'.")
         return 0.0

    out = model(data.x_dict, data.edge_index_dict)

    # Check if output is valid
    if out is None or out.shape[0] == 0:
         print("Warning: Model output is empty or invalid during training.")
         return 0.0

    # Ensure masks and labels exist
    if not hasattr(data['statement'], 'train_mask') or not hasattr(data['statement'], 'y'):
        print("Error: Missing train_mask or labels ('y') on statement nodes.")
        return 0.0

    # Get predictions only for training nodes
    train_mask = data['statement'].train_mask
    if train_mask.sum() == 0:
        print("Warning: No training samples found in the mask.")
        return 0.0 # Return 0 loss if no samples

    # Ensure output tensor shape matches mask length
    if out.shape[0] != train_mask.shape[0]:
        print(f"Warning: Output shape {out.shape} mismatch with mask shape {train_mask.shape}. Skipping loss calculation.")
        return 0.0

    train_preds = out[train_mask]
    train_labels = data['statement'].y[train_mask]

    # Check again if filtering resulted in empty tensors
    if train_preds.shape[0] == 0:
        print("Warning: No training samples after applying mask.")
        return 0.0

    loss = criterion(train_preds, train_labels)
    # Check for NaN loss
    if torch.isnan(loss):
        print("Warning: NaN loss detected during training.")
        return 0.0 # Or handle differently

    loss.backward()
    optimizer.step()
    return float(loss)

@torch.no_grad()
def test(model, data):
    model.eval()
    # Ensure data has features before passing to model
    if not hasattr(data, 'x_dict') or not hasattr(data, 'edge_index_dict'):
         print("Error: data object missing 'x_dict' or 'edge_index_dict' for testing.")
         return {}, {} # Return empty results

    out = model(data.x_dict, data.edge_index_dict)

    # Check if output is valid
    if out is None or out.shape[0] == 0:
         print("Warning: Model output is empty or invalid during testing.")
         return {}, {}

    # Ensure masks and labels exist
    if not hasattr(data['statement'], 'train_mask') or \
       not hasattr(data['statement'], 'test_mask') or \
       not hasattr(data['statement'], 'y'):
        print("Error: Missing masks or labels ('y') on statement nodes for testing.")
        return {}, {}

    # Ensure output tensor shape matches mask length
    if out.shape[0] != data['statement'].train_mask.shape[0]:
        print(f"Warning: Output shape {out.shape} mismatch with mask shape {data['statement'].train_mask.shape} during testing.")
        return {}, {}


    pred = out.argmax(dim=-1) # Get predicted class index

    accs = {}
    reports = {}
    # Evaluate on train and test masks
    for prefix, mask in [('Train', data['statement'].train_mask), ('Test', data['statement'].test_mask)]:
         if mask.sum() == 0:
             print(f"Warning: No samples found in {prefix} mask.")
             accs[prefix] = 0.0
             reports[prefix] = "No samples to evaluate."
             continue

         mask_preds = pred[mask].cpu().numpy()
         mask_labels = data['statement'].y[mask].cpu().numpy()

         # Ensure there are labels to evaluate
         if len(mask_labels) == 0:
              print(f"Warning: No labels found for {prefix} mask after filtering.")
              accs[prefix] = 0.0
              reports[prefix] = "No samples to evaluate after filtering."
              continue

         acc = accuracy_score(mask_labels, mask_preds)
         # Ensure labels for report generation are correct
         report_labels = np.arange(len(CANDIDATE_LABELS)) # Expected labels 0-5
         present_labels = np.unique(np.concatenate((mask_labels, mask_preds))) # Actual labels present
         # Filter target names to only those present if necessary, or use all expected
         target_names = [CANDIDATE_LABELS[i] for i in report_labels if i in present_labels]
         # Ensure labels used in report match the unique labels present
         report_labels_present = [l for l in report_labels if l in present_labels]


         report = classification_report(
             mask_labels,
             mask_preds,
             labels=report_labels_present, # Use labels actually present
             target_names=target_names,
             zero_division=0
         )
         accs[prefix] = acc
         reports[prefix] = report

    return accs, reports


# ==============================================================================
# --- Main Execution ---
# ==============================================================================
if __name__ == "__main__":
    print("Starting LIAR Dataset GNN Classification Pipeline...")
    pipeline_start_time = time.time()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # 1. Load and Prepare Data
    df = load_and_prep_data(SAMPLE_SIZE)

    # Check if labels are valid before proceeding
    if 'label_idx' not in df.columns or df['label_idx'].eq(-1).any():
         print("Exiting due to label errors during data preparation.")
         sys.exit()

    # 2. Feature Engineering (Statements)
    if 'statement' not in df.columns:
        print("Error: 'statement' column missing for feature engineering.")
        sys.exit()
    statement_features = create_statement_features(df['statement'])
    statement_feature_dim = statement_features.shape[1]

    # 3. Graph Construction
    # Pass valid_edge_types from graph building to model
    data, entity_maps, valid_edge_types = build_hetero_graph(df, statement_features)

    # Check if graph construction was successful
    if data is None:
        print("Exiting due to errors during graph construction.")
        sys.exit()

    data = data.to(device) # Move graph data to the selected device

    # 4. Initialize Model
    num_classes = len(CANDIDATE_LABELS)
    # Pass valid_edge_types to the model constructor
    model = HeteroGNN(
        hidden_channels=HIDDEN_CHANNELS,
        out_channels=num_classes,
        num_layers=NUM_GNN_LAYERS,
        entity_maps=entity_maps,
        embedding_dim=EMBEDDING_DIM,
        statement_feature_dim=statement_feature_dim,
        valid_edge_types=valid_edge_types # Pass the valid types
    ).to(device)
    print("\nModel Architecture:")
    print(model)

    # 5. Training Loop
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

    print("\n--- Starting Training ---")
    for epoch in range(1, NUM_EPOCHS + 1):
        epoch_start_time = time.time()
        loss = train(model, data, optimizer, criterion)
        # Evaluate periodically
        if epoch % 5 == 0 or epoch == 1 or epoch == NUM_EPOCHS:
            train_accs, reports = test(model, data)
            # Handle potential missing keys if evaluation failed
            train_acc_val = train_accs.get('Train', 0.0)
            test_acc_val = train_accs.get('Test', 0.0)
            print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
                  f'Train Acc: {train_acc_val:.4f}, Test Acc: {test_acc_val:.4f}, '
                  f'Time: {time.time() - epoch_start_time:.2f}s')

    print("--- Training Finished ---")

    # 6. Final Evaluation
    print("\n--- Final Evaluation ---")
    final_accs, final_reports = test(model, data)
    print("\n--- Test Set Performance ---")
    # Handle potential missing keys if evaluation failed
    test_acc_final = final_accs.get('Test', 'N/A')
    test_report_final = final_reports.get('Test', 'Evaluation failed.')
    if isinstance(test_acc_final, float):
        print(f"Accuracy: {test_acc_final:.4f}")
    else:
        print(f"Accuracy: {test_acc_final}")
    print("Classification Report:")
    print(test_report_final)


    print(f"\nPipeline finished in {time.time() - pipeline_start_time:.2f} seconds.")
    print("--- End of Script ---")



Starting LIAR Dataset GNN Classification Pipeline...
Using device: cpu
--- 1. Loading and Preparing Data ---
Full dataset shape: (10269, 14)
Using a sample of 1000 rows.
Created 'label_text' and 'label_idx' columns.
Data loading & prep completed in 0.82 seconds.

--- 2a. Generating Statement Features (Embeddings) ---


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Statement embeddings generated with shape: torch.Size([1000, 384])
Statement feature generation completed in 17.47 seconds.

--- 3. Building Heterogeneous Graph ---
Found 557 unique entities for type 'speaker'
Found 135 unique entities for type 'subject'
Found 14 unique entities for type 'party'
Graph construction completed in 0.07 seconds.

Graph Summary:
HeteroData(
  statement={
    x=[1000, 384],
    y=[1000],
    train_mask=[1000],
    test_mask=[1000],
  },
  speaker={ num_nodes=557 },
  subject={ num_nodes=135 },
  party={ num_nodes=14 },
  (statement, spoken_by, speaker)={ edge_index=[2, 1000] },
  (speaker, rev_spoken_by, statement)={ edge_index=[2, 1000] },
  (speaker, affiliated_with, party)={ edge_index=[2, 1000] },
  (party, rev_affiliated_with, speaker)={ edge_index=[2, 1000] },
  (statement, about_subject, subject)={ edge_index=[2, 2175] },
  (subject, rev_about_subject, statement)={ edge_index=[2, 2175] }
)
Graph validation successful.

Model Architecture:
HeteroGNN(
  