In [3]:
import torch
import torch.nn as nn
from torch.nn import Linear, Sequential, BatchNorm1d, ReLU, LeakyReLU, Tanh
import torch.nn.functional as F
from torch_geometric.data import Data, Batch, DataLoader
from torch_geometric.nn import EdgeConv, global_mean_pool
import pandas as pd
import os
import logging
import numpy as np
from sklearn.model_selection import KFold
from sklearn.metrics import r2_score, mean_squared_error # Import for evaluation metrics
import random

# --- Set random seeds ---
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.deterministic = True
    torch.backends.cudnn.benchmark = False

# --- Configure Logging ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

TARGET_NAME = 'FEV1_FVC'  
TARGET_COLS = ['FEV1_FVC']  

# --- Control variable for saving/splitting SUBJID ---
USE_SUBJID = True  # Set to True to save and split SUBJID, False otherwise

# --- Manually Input Best Parameters Here ---
best_lr = 0.001
best_hidden_channels = 128
best_embedding_dim = 32
best_num_conv_layers = 3
best_mlp1_layers_dims = [32, 32, 32]
best_mlp2_layers_dims = [64, 64, 64]
best_edgeconv_aggr = 'max'
best_use_batchnorm = True
best_activation_name = 'relu'
best_epochs = 500
BATCH_SIZE = 32

activation_map = {'relu': ReLU(), 'leaky_relu': LeakyReLU(), 'tanh': Tanh()}
best_mlp_activation = activation_map[best_activation_name]


# --- Define the Adjusted EdgeGNN Model ---
class AirwayEdgeGNN(torch.nn.Module):
    def __init__(self, num_node_features, output_channels,
                 num_conv_layers=3, mlp1_layers=[64, 64], mlp2_layers=[64, 64],
                 mlp_activation=ReLU(), edgeconv_aggr='max', use_batchnorm=False):
        super(AirwayEdgeGNN, self).__init__()
        self.convs = torch.nn.ModuleList()

        mlp1 = []
        in_channels = 2 * num_node_features
        for h_dim in mlp1_layers:
            mlp1.append(Linear(in_channels, h_dim))
            if use_batchnorm:
                mlp1.append(BatchNorm1d(h_dim))
            mlp1.append(mlp_activation)
            in_channels = h_dim
        self.convs.append(EdgeConv(nn=Sequential(*mlp1), aggr=edgeconv_aggr))
        last_out_channels = mlp1_layers[-1] if mlp1_layers else 2 * num_node_features

        for _ in range(num_conv_layers - 1):
            mlp_intermediate = []
            in_channels = 2 * last_out_channels
            for h_dim in mlp2_layers:
                mlp_intermediate.append(Linear(in_channels, h_dim))
                if use_batchnorm:
                    mlp_intermediate.append(BatchNorm1d(h_dim))
                mlp_intermediate.append(mlp_activation)
                in_channels = h_dim
            self.convs.append(EdgeConv(nn=Sequential(*mlp_intermediate), aggr=edgeconv_aggr))
            last_out_channels = mlp2_layers[-1] if mlp2_layers else 2 * last_out_channels

        self.out = Linear(last_out_channels, output_channels)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        for conv in self.convs:
            x = conv(x, edge_index)
        x = global_mean_pool(x, batch)
        x = self.out(x)
        return x


# --- Data Loading Functions (same as before) ---
def load_graph_from_excel(filepath):
    nodes_df = pd.read_excel(filepath, sheet_name='Nodes')
    edges_df = pd.read_excel(filepath, sheet_name='Edges')
    nodes = nodes_df['node_id'].tolist()
    edges = list(zip(edges_df['bp0'], edges_df['bp1']))
    node_features = torch.tensor(nodes_df[['x', 'y', 'z']].values, dtype=torch.float)
    edge_features = torch.tensor(edges_df[['generation', 'length', 'diameter', 'InArea', 'OutArea', 'InPeri', 'OutPeri', 'WT', 'WA', 'Din', 'Dout', 'Cr']].values, dtype=torch.float)
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_features)
    return data

def load_all_graphs_from_folder(folder_path):
    graph_files = [f for f in os.listdir(folder_path) if f.endswith('.xlsx')]
    airway_trees = [load_graph_from_excel(os.path.join(folder_path, f)) for f in graph_files]
    return airway_trees

def load_feature_names_from_excel(filepath, column_name='FeatureName', sheet_name='Sheet1'):
    """
    Loads feature names from a specified column in an Excel file.
    """
    try:
        df = pd.read_excel(filepath, sheet_name=sheet_name)
        if column_name in df.columns:
            return df[column_name].tolist()
        else:
            logger.error(f"Column '{column_name}' not found in {filepath}")
            return None
    except FileNotFoundError:
        logger.error(f"Feature names file not found at {filepath}")
        return None
    except Exception as e:
        logger.error(f"Error loading feature names from {filepath}: {e}")
        return None


# --- Function to train the GNN ---
def train(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(loader.dataset)

# --- Function to evaluate the GNN and collect predictions ---
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    predictions = []
    ground_truths = []
    filenames = []
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data)
            loss = criterion(out, data.y)
            total_loss += loss.item() * data.num_graphs
            
            predictions.extend(out.cpu().numpy().flatten())
            ground_truths.extend(data.y.cpu().numpy().flatten())
            # Assuming 'filename' attribute is added to Data objects
            filenames.extend(data.filename) 
            
    avg_loss = total_loss / len(loader.dataset)
    return avg_loss, predictions, ground_truths, filenames


# --- Main Execution for Training, Feature Extraction and K-Fold Split ---
if __name__ == "__main__":
    # --- NEW DATA LOADING SECTION ---
    # This replaces the old loading of two separate Excel files
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print("Loading airway trees and target...")
    GRAPH_FOLDER = 'data/airways_607'
    airway_trees_raw = load_all_graphs_from_folder(GRAPH_FOLDER)
    graph_filenames = [f for f in os.listdir(GRAPH_FOLDER) if f.endswith('.xlsx')]
    graph_filenames.sort() # Ensure filenames are sorted to match the order of airway_trees_raw if loaded sequentially

    # Load the combined data file
    COMBINED_DATA_FILE = 'data/data.xlsx'
    try:
        combined_df = pd.read_excel(COMBINED_DATA_FILE)
        if combined_df.isnull().values.any():
            print("\n**Warning: NaN values are present in the combined DataFrame.**")
            rows_with_nan = combined_df[combined_df.isnull().any(axis=1)]
            print("Rows with NaN values:\n", rows_with_nan)

        # Filter out rows with 'NONE' in SUBJID
        combined_df = combined_df[combined_df['SUBJID'] != 'NONE']

    except FileNotFoundError:
        logger.error(f"Combined data file not found at {COMBINED_DATA_FILE}. Exiting.")
        exit()


    # Create mappings from filename to all data
    filename_to_all_data = {row['Filename']: row for _, row in combined_df.iterrows()}

    final_airway_trees = []
    final_targets_data = []
    final_subj_id = []
    aligned_graph_filenames = [] # To store filenames that successfully align

    # Iterate through the graph filenames to ensure correct alignment
    # and only include data for which both graph and target exist
    for i, file in enumerate(graph_filenames):
        if file in filename_to_all_data:
            row = filename_to_all_data[file]
            final_airway_trees.append(airway_trees_raw[i])
            final_targets_data.append([row[target_col] for target_col in TARGET_COLS])
            final_subj_id.append(row['SUBJID'])
            aligned_graph_filenames.append(file) # Add filename to the aligned list
        else:
            logger.warning(f"Data for file {file} not found in {COMBINED_DATA_FILE}. Skipping.")

    targets = torch.tensor(final_targets_data).float().to(device)
    airway_trees = final_airway_trees

    OUTPUT_DIR = 'models/EdgeGNN_regressor'
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    MODEL_SAVE_PATH = f'{OUTPUT_DIR}/EdgeGNN_regressor.pth'
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    num_node_features = airway_trees[0].x.shape[1]
    output_dim = 1

    target_idx = TARGET_COLS.index(TARGET_NAME)
    processed_airway_trees = [
        Data(x=tree.x, edge_index=tree.edge_index, edge_attr=tree.edge_attr, 
             y=torch.tensor([targets[i, target_idx]]), filename=aligned_graph_filenames[i]) # Pass filename
        for i, tree in enumerate(airway_trees)
    ]
    loader_train_all = DataLoader(processed_airway_trees, batch_size=BATCH_SIZE, shuffle=True)

    # --- Initialize and Train the GNN ---
    model_train = AirwayEdgeGNN(
        num_node_features=num_node_features,
        output_channels=output_dim,
        num_conv_layers=best_num_conv_layers,
        mlp1_layers=best_mlp1_layers_dims,
        mlp2_layers=best_mlp2_layers_dims,
        mlp_activation=best_mlp_activation,
        edgeconv_aggr=best_edgeconv_aggr,
        use_batchnorm=best_use_batchnorm
    ).to(device)
    optimizer = torch.optim.Adam(model_train.parameters(), lr=best_lr)
    criterion = nn.MSELoss()

    if not os.path.exists(MODEL_SAVE_PATH):
        logger.info("Starting GNN training...")
        for epoch in range(best_epochs):
            loss = train(model_train, loader_train_all, optimizer, criterion, device)
            logger.info(f"Epoch {epoch+1}/{best_epochs}, Loss: {loss:.4f}")
        torch.save(model_train.state_dict(), MODEL_SAVE_PATH)
        logger.info(f"Finished GNN training. Model saved to {MODEL_SAVE_PATH}")
    else:
        try:
            model_train.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=device))
            logger.info(f"Loaded pre-trained GNN model from {MODEL_SAVE_PATH}")
        except RuntimeError as e:
            logger.error(f"Error loading pre-trained model: {e}")
            logger.info("Using a newly initialized model and retraining it.")
            for epoch in range(best_epochs):
                loss = train(model_train, loader_train_all, optimizer, criterion, device)
                logger.info(f"Epoch {epoch+1}/{best_epochs}, Loss: {loss:.4f}")
            torch.save(model_train.state_dict(), MODEL_SAVE_PATH)
            logger.info(f"Finished GNN retraining. Model saved to {MODEL_SAVE_PATH}")




Loading airway trees and target...


2025-08-25 14:06:28,916 - INFO - Starting GNN training...
  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)
2025-08-25 14:06:29,274 - INFO - Epoch 1/500, Loss: 0.3590
2025-08-25 14:06:29,364 - INFO - Epoch 2/500, Loss: 0.0189
2025-08-25 14:06:29,450 - INFO - Epoch 3/500, Loss: 0.0154
2025-08-25 14:06:29,535 - INFO - Epoch 4/500, Loss: 0.0152
2025-08-25 14:06:29,623 - INFO - Epoch 5/500, Loss: 0.0153
2025-08-25 14:06:29,709 - INFO - Epoch 6/500, Loss: 0.0150
2025-08-25 14:06:29,794 - INFO - Epoch 7/500, Loss: 0.0150
2025-08-25 14:06:29,883 - INFO - Epoch 8/500, Loss: 0.0151
2025-08-25 14:06:29,969 - INFO - Epoch 9/500, Loss: 0.0150
2025-08-25 14:06:30,055 - INFO - Epoch 10/500, Loss: 0.0150
2025-08-25 14:06:30,141 - INFO - Epoch 11/500, Loss: 0.0151
2025-08-25 14:06:30,225 - INFO - Epoch 12/500, Loss: 0.0149
2025-08-25 14:06:30,310 - INFO - Epoch 13/500, Loss: 0.0150
2025-08-25 14:06:30,396 - INFO - Epoch 14/500, Lo