# Graph Neural Network for Length of Stay Prediction

This notebook implements a Graph Neural Network (GNN) to predict patient length of stay using PyTorch Geometric.

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATv2Conv, global_mean_pool
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import plotly.graph_objects as go
import networkx as nx
import umap
import shap
from tqdm import tqdm
import optuna

## 1. Data Loading and Preprocessing

In [None]:
# Load the dataset
df = pd.read_csv('model_df_12_02_24.csv')

# Define feature groups
demographic_features = ['admit_age', 'admit_bmi']
clinical_features = ['brain_injury_mild/_moderate', 'brain_injury_severe', 
                    'vent_at_admission', 'trach_patient', 'motor_score']
comorbidity_features = ['diabetes', 'hypertension', 'heart_disease', 
                       'neurological_disorder', 'psychiatric_disorder']

# Feature selection using SHAP values will be implemented here
def select_features_shap(df, target_col):
    # Implementation of SHAP-based feature selection
    pass

# Create graph structure
def create_patient_graph(features, edge_threshold=0.5):
    # Implementation of graph creation using patient similarity
    pass

## 2. Graph Neural Network Model

In [None]:
class GNNModel(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim=64, num_heads=4):
        super(GNNModel, self).__init__()
        self.conv1 = GATv2Conv(input_dim, hidden_dim, heads=num_heads)
        self.conv2 = GATv2Conv(hidden_dim * num_heads, hidden_dim, heads=1)
        self.lin1 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.lin2 = torch.nn.Linear(hidden_dim, 1)
        
    def forward(self, x, edge_index, batch):
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        x = global_mean_pool(x, batch)
        x = F.elu(self.lin1(x))
        x = self.lin2(x)
        return x

## 3. Training and Evaluation

In [None]:
def train_model(model, train_loader, optimizer, device):
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = F.mse_loss(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

## 4. 3D Visualization

In [None]:
def visualize_graph_3d(graph, embeddings, labels):
    reducer = umap.UMAP(n_components=3, random_state=42)
    embedding_3d = reducer.fit_transform(embeddings)
    
    # Create 3D visualization using plotly
    fig = go.Figure(data=[go.Scatter3d(
        x=embedding_3d[:, 0],
        y=embedding_3d[:, 1],
        z=embedding_3d[:, 2],
        mode='markers',
        marker=dict(
            size=5,
            color=labels,
            colorscale='Viridis',
            opacity=0.8
        )
    )])
    
    fig.update_layout(title='3D Patient Graph Visualization')
    fig.show()

## 5. Model Explainability

In [None]:
def explain_predictions(model, data):
    # Implementation of SHAP values for GNN
    explainer = shap.DeepExplainer(model, data)
    shap_values = explainer.shap_values(data)
    
    # Visualization of feature importance
    shap.summary_plot(shap_values, data)