In [None]:
import numpy as np
import pandas as pd
import os, sys
import h5py # Not strictly needed if not reading/writing HDF5, but can keep
import matplotlib.pyplot as plt # Not used in the provided snippet
import tensorflow as tf
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix
import nibabel as nib # For reading .nii files

import CNN_model_extractor # Assuming this module defines CNN_model()

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

# Configure GPU memory growth at the very top, before any TensorFlow operations
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs configured for memory growth.")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)


def trainCNN(directory, prm_nii_dir, labels_csv_path, nb_iters=500, prm_filename_col='SUBID', target_col='End'):
    """
    Trains a 3D CNN model by directly loading PRM images from .nii files
    and labels from a .csv file.

    Args:
        directory (str): Base directory for saving models and results.
        prm_nii_dir (str): Directory containing the .nii PRM image files.
        labels_csv_path (str): Path to the CSV file with 'SUBID' and 'End' columns.
    """
    print("Starting direct loading and CNN training...")

    # 1. Read Labels from CSV and prepare data paths
    labels_df = pd.read_csv(labels_csv_path)
    
    CNN_data_paths = []
    CNN_labels = []
    CNN_subids = [] # To store subject IDs corresponding to data paths

    for index, row in labels_df.iterrows():
        # Use 'SUBID' for the filename prefix and 'End' for the label
        subject_id = row[prm_filename_col] 
        label = row[target_col]        
        
        # Construct possible .nii filenames (prioritize .nii, then .nii.gz)
        nii_file_path = os.path.join(prm_nii_dir, f"{subject_id}.nii")
        if not os.path.exists(nii_file_path):
            nii_file_path = os.path.join(prm_nii_dir, f"{subject_id}.nii.gz") # Fallback to .nii.gz if .nii not found

        if os.path.exists(nii_file_path):
            CNN_data_paths.append(nii_file_path)
            CNN_labels.append(label)
            CNN_subids.append(subject_id) # Store the subject ID
        else:
            print(f"Warning: PRM file not found for {subject_id} at {nii_file_path}. Skipping.")

    if not CNN_data_paths:
        print("No valid PRM images found based on CSV. Exiting training.")
        return

    CNN_labels = np.array(CNN_labels) # Convert labels to numpy array
    CNN_subids_array = np.array(CNN_subids) # Convert subids to numpy array

    # Define the expected 3D shape of images (from your original script)
    expected_image_shape = (64, 64, 64) 

    # Setup output directories
    os.makedirs(directory, exist_ok=True)
    os.makedirs(os.path.join(directory, "model", "CNN"), exist_ok=True)
    
    

    kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=5)
    n_iter_CNN = nb_iters
    j = 0 # Fold counter
    
    # List to store evaluation results for each fold
    all_fold_results = []

    # Iterate through K-folds, splitting on file paths and labels
    for idx_train, idx_val in kfold.split(CNN_data_paths, CNN_labels):
        j += 1
        accuracy = 0
        EPOCHS = 0

        train_CNN_list = []
        val_CNN_list = []
        val_subids_fold = CNN_subids_array[idx_val] # Get subject IDs for current validation set

        # Load images for the current training set
        for i in idx_train:
            img_path = CNN_data_paths[i]
            try:
                img_data = nib.load(img_path).get_fdata().astype(np.float32)
                if img_data.shape == expected_image_shape:
                    train_CNN_list.append(img_data)
                else:
                    print(f"Skipping {img_path} in train set: unexpected shape {img_data.shape}. Expected {expected_image_shape}")
            except Exception as e:
                print(f"Error loading train image {img_path}: {e}. Skipping.")

        # Load images for the current validation set
        for i in idx_val:
            img_path = CNN_data_paths[i]
            try:
                img_data = nib.load(img_path).get_fdata().astype(np.float32)
                if img_data.shape == expected_image_shape:
                    val_CNN_list.append(img_data)
                else:
                    print(f"Skipping {img_path} in validation set: unexpected shape {img_data.shape}. Expected {expected_image_shape}")
            except Exception as e:
                print(f"Error loading val image {img_path}: {e}. Skipping.")
        
        # Convert lists of images to NumPy arrays
        if not train_CNN_list or not val_CNN_list:
            print(f"Skipping fold {j}: Not enough data loaded for training or validation.")
            continue

        train_CNN_np = np.array(train_CNN_list)
        val_CNN_np = np.array(val_CNN_list)

        current_train_labels = CNN_labels[idx_train]
        current_val_labels = CNN_labels[idx_val]

        # Expand dimensions for labels to match (N, 1)
        current_train_labels = np.expand_dims(current_train_labels, axis=1)
        current_val_labels = np.expand_dims(current_val_labels, axis=1)

        # Normalize data to [0, 1] range
        min_train, max_train = np.min(train_CNN_np), np.max(train_CNN_np)
        if (max_train - min_train) != 0:
            train_CNN_np = (train_CNN_np - min_train) / (max_train - min_train)
        else:
            train_CNN_np = np.zeros_like(train_CNN_np) 

        min_val, max_val = np.min(val_CNN_np), np.max(val_CNN_np)
        if (max_val - min_val) != 0:
            val_CNN_np = (val_CNN_np - min_val) / (max_val - min_val)
        else:
            val_CNN_np = np.zeros_like(val_CNN_np) 

        print('\n\n###################CNN part start!!!!###################\n\n')
        print('Train min=%.3f, max=%.3f' % (np.min(train_CNN_np), np.max(train_CNN_np)))

        # Re-instantiate model for each fold to ensure fresh weights
        model = CNN_model_extractor.CNN_model()

        # Reshape to (batch_size, Depth, Height, Width, Channels)
        train_CNN_reshaped = train_CNN_np.reshape(-1, *expected_image_shape, 1)
        val_CNN_reshaped = val_CNN_np.reshape(-1, *expected_image_shape, 1)

        best_accuracy_fold = 0
        best_model_path_fold = ""

        # Training loop for n_iter_CNN epochs
        for i in range(0, n_iter_CNN):
            EPOCHS += 1
            # Adjust batch_size here for training if needed due to OOM
            history = model.fit(train_CNN_reshaped, current_train_labels, batch_size=8, epochs=1, verbose=1) 
            # Adjust batch_size here for evaluation if needed due to OOM
            evaluate_model = model.evaluate(val_CNN_reshaped, current_val_labels, batch_size=8, verbose=0)

            if best_accuracy_fold <= evaluate_model[1]: # evaluate_model[1] is typically accuracy
                best_accuracy_fold = evaluate_model[1]
                print(f"################### Fold {j} - Epoch {EPOCHS} - New Best Accuracy: {best_accuracy_fold} ###################")
                
                # Save the model
                save_path = os.path.join(directory, "model", "CNN", f'model_CNN_fold{j}_best.h5')
                model.save(save_path)
                best_model_path_fold = save_path
                print(f"Model saved to {save_path}")
        
        # After training for the fold, load the best model and evaluate
        if os.path.exists(best_model_path_fold):
            print(f"Loading best model for fold {j} from {best_model_path_fold} for final evaluation.")
            loaded_model = tf.keras.models.load_model(best_model_path_fold)
            
            # Predict probabilities and classes
            y_pred_probs = loaded_model.predict(val_CNN_reshaped, batch_size=8) # Use consistent batch_size
            y_pred_classes = (y_pred_probs > 0.5).astype(int) # Assuming binary classification with sigmoid output

            # Flatten labels for sklearn metrics
            true_labels = current_val_labels.flatten()
            predicted_classes = y_pred_classes.flatten()
            predicted_probs = y_pred_probs.flatten()

            fold_results = {
                'fold': j,
                'accuracy': accuracy_score(true_labels, predicted_classes),
                'precision': precision_score(true_labels, predicted_classes, zero_division=0),
                'recall': recall_score(true_labels, predicted_classes, zero_division=0),
                'f1_score': f1_score(true_labels, predicted_classes, zero_division=0),
                'roc_auc': roc_auc_score(true_labels, predicted_probs),
                'confusion_matrix': confusion_matrix(true_labels, predicted_classes).tolist() # Convert to list for easier saving
            }
            all_fold_results.append(fold_results)
            print(f"Fold {j} Evaluation Results: {fold_results}")

            
        else:
            print(f"No best model saved for fold {j}. Skipping evaluation for this fold.")

    # Save all fold results to a text file
    results_file_path = os.path.join(directory, "CNN_kfold_evaluation_results.txt")
    with open(results_file_path, 'w') as f:
        f.write("CNN K-Fold Cross-Validation Results:\n\n")
        for fold_res in all_fold_results:
            f.write(f"--- Fold {fold_res['fold']} ---\n")
            f.write(f"Accuracy: {fold_res['accuracy']:.4f}\n")
            f.write(f"Precision: {fold_res['precision']:.4f}\n")
            f.write(f"Recall: {fold_res['recall']:.4f}\n")
            f.write(f"F1-Score: {fold_res['f1_score']:.4f}\n")
            f.write(f"ROC AUC: {fold_res['roc_auc']:.4f}\n")
            f.write(f"Confusion Matrix: {fold_res['confusion_matrix']}\n")
            f.write("\n")
            
        # Calculate and save average metrics
        if all_fold_results:
            avg_accuracy = np.mean([res['accuracy'] for res in all_fold_results])
            avg_precision = np.mean([res['precision'] for res in all_fold_results])
            avg_recall = np.mean([res['recall'] for res in all_fold_results])
            avg_f1_score = np.mean([res['f1_score'] for res in all_fold_results])
            avg_roc_auc = np.mean([res['roc_auc'] for res in all_fold_results])

            f.write("\n--- Average Metrics Across Folds ---\n")
            f.write(f"Average Accuracy: {avg_accuracy:.4f}\n")
            f.write(f"Average Precision: {avg_precision:.4f}\n")
            f.write(f"Average Recall: {avg_recall:.4f}\n")
            f.write(f"Average F1-Score: {avg_f1_score:.4f}\n")
            f.write(f"Average ROC AUC: {avg_roc_auc:.4f}\n")

    print(f"Evaluation results saved to {results_file_path}")
    print("CNN training and evaluation complete.")

def extract_cnn_features(model_path, prm_nii_dir, labels_csv_path, output_dir, 
                             prm_filename_col='SUBID', target_col='End', 
                             feature_layer_name='dense_1'): # Or another layer name from your CNN_model
    """
    Extracts features from PRM images using a specified trained CNN model
    and saves them along with subject IDs and labels to a CSV file.

    Args:
        model_path (str): Path to the trained .h5 CNN model file.
        prm_nii_dir (str): Directory containing the .nii PRM image files.
        labels_csv_path (str): Path to the CSV file with 'SUBID' and 'End' columns.
        output_dir (str): Directory where the extracted features CSV will be saved.
        prm_filename_col (str): Column name in CSV for subject ID.
        target_col (str): Column name in CSV for the target label.
        feature_layer_name (str): The name of the layer from which to extract features.
                                  Common choices: 'global_average_pooling3d' or the name
                                  of a Dense layer before the final output.
    """
    print(f"\n--- Starting Feature Extraction using model: {os.path.basename(model_path)} ---")

    # Load the trained model
    if not os.path.exists(model_path):
        print(f"Error: Model file not found at {model_path}.")
        return

    try:
        base_model = tf.keras.models.load_model(model_path)
        print(f"Successfully loaded model from {model_path}")
    except Exception as e:
        print(f"Error loading model from {model_path}: {e}")
        return

    # Create a feature extraction model
    # You need to know the name of the layer you want to extract features from.
    # Print base_model.summary() to see layer names.
    try:
        feature_extractor = tf.keras.Model(
            inputs=base_model.inputs,
            outputs=base_model.get_layer(feature_layer_name).output
        )
        print(f"Feature extractor created using layer: '{feature_layer_name}'")
    except ValueError as e:
        print(f"Error creating feature extractor: {e}")
        print("Please check if 'feature_layer_name' is a valid layer in your model.")
        base_model.summary() # Print summary to help user identify layer names
        return

    # Prepare data paths and labels
    labels_df = pd.read_csv(labels_csv_path)
    
    CNN_data_paths = []
    CNN_labels = []
    CNN_subids = []

    for index, row in labels_df.iterrows():
        subject_id = row[prm_filename_col] 
        label = row[target_col]        
        
        nii_file_path = os.path.join(prm_nii_dir, f"{subject_id}.nii")
        if not os.path.exists(nii_file_path):
            nii_file_path = os.path.join(prm_nii_dir, f"{subject_id}.nii.gz")

        if os.path.exists(nii_file_path):
            CNN_data_paths.append(nii_file_path)
            CNN_labels.append(label)
            CNN_subids.append(subject_id)
        else:
            print(f"Warning: PRM file not found for {subject_id}. Skipping for feature extraction.")

    if not CNN_data_paths:
        print("No valid PRM images found based on CSV. Exiting feature extraction.")
        return

    # Load and preprocess images for feature extraction
    extracted_images_list = []
    extracted_subids = []
    extracted_labels = []
    expected_image_shape = (64, 64, 64) 

    print(f"Loading {len(CNN_data_paths)} images for feature extraction...")
    for i, img_path in enumerate(CNN_data_paths):
        try:
            img_data = nib.load(img_path).get_fdata().astype(np.float32)
            if img_data.shape == expected_image_shape:
                # Normalize the image data using min/max scaling
                min_val, max_val = np.min(img_data), np.max(img_data)
                if (max_val - min_val) != 0:
                    img_data = (img_data - min_val) / (max_val - min_val)
                else:
                    img_data = np.zeros_like(img_data) # Handle cases where all pixel values are same

                extracted_images_list.append(img_data)
                extracted_subids.append(CNN_subids[i])
                extracted_labels.append(CNN_labels[i])
            else:
                print(f"Skipping {img_path}: unexpected shape {img_data.shape}. Expected {expected_image_shape}")
        except Exception as e:
            print(f"Error loading image {img_path}: {e}. Skipping.")

    if not extracted_images_list:
        print("No images successfully loaded for feature extraction. Exiting.")
        return

    images_np = np.array(extracted_images_list)
    images_reshaped = images_np.reshape(-1, *expected_image_shape, 1)

    # Extract features
    print("Extracting features...")
    # Adjust batch_size here for prediction if needed due to OOM
    features = feature_extractor.predict(images_reshaped, batch_size=8) 
    print(f"Features shape: {features.shape}")

    if features.ndim > 2:
        features_flat = features.reshape(features.shape[0], -1)
    else:
        features_flat = features # Already flat if from a Dense layer or GAP

    feature_column_names = [f'cPRM_{i}' for i in range(features_flat.shape[1])]
    
    features_df = pd.DataFrame(features_flat, columns=feature_column_names)
    features_df.insert(0, 'SUBID', extracted_subids)
    features_df.insert(1, 'Label', extracted_labels) # Include the label

    # Save to CSV
    os.makedirs(output_dir, exist_ok=True)
    output_filename = os.path.join(output_dir, f'cnn_features_{os.path.basename(model_path).replace(".h5", "")}.csv')
    features_df.to_csv(output_filename, index=False)
    print(f"Extracted features saved to {output_filename}")
    print("--- Feature Extraction Complete ---")

if __name__ == '__main__':
    # Define your actual directories and file paths
    output_directory = 'models/cPRM_extractor' 
    prm_images_folder = 'data/PRM_607' # Directory containing CEMENT001.nii, CEMENT002.nii etc.
    labels_csv_file = 'data/data.csv' # Path to your CSV with 'SUBID' and 'End' columns

    print("\n--- Running trainCNN ---")
    trainCNN(output_directory, prm_images_folder, labels_csv_file)
    print("--- Finished trainCNN ---")

    print("\n--- Running Feature Extraction ---")
    model_to_extract_from = os.path.join(output_directory, "model", "CNN", 'model_CNN_fold1_best.h5')
    features_output_dir = os.path.join(output_directory, "extracted_features")

    extract_cnn_features(
        model_to_extract_from, 
        prm_images_folder, 
        labels_csv_file, 
        features_output_dir,
        feature_layer_name='feature_layer_128' 
    )
    print("--- Finished Feature Extraction ---")
