# Part 4: Model Training

This notebook handles model training and evaluation for the GNN.

In [None]:
import sys
sys.path.append('..')

import torch
import numpy as np
from torch_geometric.loader import DataLoader
import plotly.graph_objects as go
from tqdm.notebook import tqdm

from src.models.gnn_model import GNNModel
from src.models.loss import WeightedMSELoss
from src.data.dataset import GraphDataset

In [None]:
def load_graph_data():
    """Load prepared graph data"""
    data = torch.load('data/graph_data.pt')
    print("\nLoaded data shapes:")
    print(f"Features: {data['features'].shape}")
    print(f"Edge index: {data['edge_index'].shape}")
    print(f"LOS values: {data['los_values'].shape}")
    return data

In [None]:
def create_data_loaders(data, batch_size=32):
    """Create train/val/test data loaders"""
    n_samples = len(data['features'])
    indices = torch.randperm(n_samples)
    
    train_size = int(0.7 * n_samples)
    val_size = int(0.15 * n_samples)
    
    train_indices = indices[:train_size]
    val_indices = indices[train_size:train_size+val_size]
    test_indices = indices[train_size+val_size:]
    
    # Create datasets
    train_dataset = GraphDataset(
        data['features'],
        data['edge_index'],
        data['los_values'],
        train_indices
    )
    
    val_dataset = GraphDataset(
        data['features'],
        data['edge_index'],
        data['los_values'],
        val_indices
    )
    
    test_dataset = GraphDataset(
        data['features'],
        data['edge_index'],
        data['los_values'],
        test_indices
    )
    
    return (
        DataLoader(train_dataset, batch_size=batch_size, shuffle=True),
        DataLoader(val_dataset, batch_size=batch_size),
        DataLoader(test_dataset, batch_size=batch_size)
    )

In [None]:
if __name__ == "__main__":
    # Load data
    data = load_graph_data()
    
    # Create data loaders
    train_loader, val_loader, test_loader = create_data_loaders(data)
    
    # Create model
    input_dim = data['features'].shape[1]
    model = GNNModel(input_dim=input_dim)
    
    # Rest of training code...