<a href="https://colab.research.google.com/github/bitamass/330-Assignment-7/blob/main/Copy_of_Complete_Synapse_Classification_from_EM_Images.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Synapse Classification from Electron Microscopy Images
# Based on the Microns Explorer dataset

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.preprocessing import LabelEncoder
import tensorflow as tf
from tensorflow.keras import layers, models, applications
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow.keras.backend as K
import cv2
from tqdm import tqdm
import h5py
import glob
import pickle
from scipy import ndimage as ndi
from sklearn.utils.class_weight import compute_class_weight
import matplotlib.cm as cm
import gc
import warnings
from functools import partial
import sys
import time
import logging
import json

# Try to import microns API, but continue if not available
try:
    from microns_utils.access_microns import MicronsCatalog
except ImportError:
    print("Warning: microns_utils not available. Will use placeholder data instead.")

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)
warnings.filterwarnings('ignore', category=UserWarning)



In [2]:
class SynapseClassifier:
    def __init__(self, data_path, output_path, image_size=(224, 224), batch_size=32):
        """
        Initialize the Synapse Classifier

        Args:
            data_path (str): Path to the Microns Explorer dataset
            output_path (str): Path to save models and results
            image_size (tuple): Size to resize the input images to
            batch_size (int): Batch size for training
        """
        self.data_path = data_path
        self.output_path = output_path
        self.image_size = image_size
        self.batch_size = batch_size
        self.model = None

        # Create output directory if it doesn't exist
        os.makedirs(output_path, exist_ok=True)

        # Initialize metrics tracking
        self.metrics_history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}

In [5]:
def load_microns_data(self, sample_limit=None):
    """
    Load and preprocess data from the Microns Explorer dataset

    Args:
        sample_limit (int, optional): Limit the number of samples to load for testing

    Returns:
        tuple: Processed images and their corresponding labels
    """
    print("Loading Microns Explorer dataset...")

    # Initialize the microns catalog
    try:
        # Proper Microns dataset access
        catalog = MicronsCatalog(self.data_path)
        synapse_collection = catalog.get_collection('synapses')

        # Get synapses with their labels
        images = []
        labels = []

        # Load synapse data in batches to reduce memory usage
        batch_size = 100  # Process 100 synapses at a time
        for i in tqdm(range(0, synapse_collection.count(), batch_size), desc="Loading synapse batches"):
            # Get a batch of synapses
            synapse_batch = synapse_collection.get_items(i, min(i + batch_size, synapse_collection.count()))

            for synapse in synapse_batch:
                try:
                    # Extract synapse image
                    em_image = synapse.get_em_image()

                    # Get the classification (from properties or derived features)
                    if hasattr(synapse, 'is_excitatory'):
                        is_excitatory = synapse.is_excitatory
                    else:
                        # Determine type from vesicle shape if available
                        vesicle_props = synapse.get_vesicle_properties()
                        # Excitatory synapses typically have round vesicles (higher circularity)
                        # Inhibitory synapses typically have flattened vesicles (lower circularity)
                        is_excitatory = vesicle_props.get('mean_circularity', 0.5) > 0.7

                    # Preprocess the image
                    processed_img = self._preprocess_image(em_image)

                    # Only add valid images
                    if processed_img is not None:
                        images.append(processed_img)
                        labels.append("excitatory" if is_excitatory else "inhibitory")

                    # Free memory
                    del em_image

                except Exception as e:
                    print(f"Error processing synapse: {e}")

            # Free memory after each batch
            gc.collect()

            # Check if we've reached the sample limit
            if sample_limit and len(images) >= sample_limit:
                images = images[:sample_limit]
                labels = labels[:sample_limit]
                break

    except Exception as e:
        # More comprehensive exception handling to catch both import and runtime errors
        print(f"Warning: MicronsCatalog not available or failed to initialize: {e}")
        print("Using placeholder data instead")

        # Placeholder for testing when microns_utils is not available
        # Generate synthetic data for development/testing
        n_samples = min(1000, sample_limit) if sample_limit else 1000

        # Generate random images
        images = []
        labels = []

        for i in tqdm(range(n_samples), desc="Generating placeholder data"):
            # Generate a random grayscale image
            img = np.random.rand(256, 256) * 255

            # Create some "vesicle-like" structures
            is_excitatory = np.random.random() > 0.4  # Slightly biased toward excitatory

            if is_excitatory:
                # Add round vesicle-like structures
                for _ in range(np.random.randint(10, 30)):
                    x, y = np.random.randint(20, 236, 2)
                    radius = np.random.randint(3, 8)
                    cv2.circle(img, (x, y), radius, 255, -1)
            else:
                # Add elongated vesicle-like structures
                for _ in range(np.random.randint(10, 30)):
                    x, y = np.random.randint(20, 236, 2)
                    width = np.random.randint(2, 5)
                    height = np.random.randint(5, 12)
                    angle = np.random.randint(0, 180)
                    box = ((x, y), (width, height), angle)
                    cv2.ellipse(img, box, 255, -1)

            # Add some noise
            img += np.random.normal(0, 15, img.shape)
            img = np.clip(img, 0, 255).astype(np.uint8)

            # Preprocess
            processed_img = self._preprocess_image(img)

            # Only add valid images
            if processed_img is not None:
                images.append(processed_img)
                labels.append("excitatory" if is_excitatory else "inhibitory")

    # Check if we have any valid images
    if not images:
        raise ValueError("No valid synapse images could be loaded or generated")

    # Convert lists to numpy arrays
    print("Converting to numpy arrays...")
    images = np.array(images)

    # Encode labels
    label_encoder = LabelEncoder()
    encoded_labels = label_encoder.fit_transform(labels)

    # Save the label encoder for later use
    with open(os.path.join(self.output_path, 'label_encoder.pkl'), 'wb') as f:
        pickle.dump(label_encoder, f)

    print(f"Loaded {len(images)} synapse images with shape {images.shape}")
    class_counts = pd.Series(labels).value_counts()
    print(f"Label distribution: {class_counts.to_dict()}")

    return images, encoded_labels

def _preprocess_image(self, image):
    """
    Preprocess a single EM image

    Args:
        image (numpy.ndarray): Input EM image

    Returns:
        numpy.ndarray: Preprocessed image or None if preprocessing fails
    """
    try:
        # Handle empty or corrupt images
        if image is None or image.size == 0:
            # Return a blank image of the target size
            return np.zeros((*self.image_size, 3), dtype=np.float32)

        # Handle NaN or inf values
        if np.isnan(image).any() or np.isinf(image).any():
            # Replace NaN/inf with zeros
            image = np.nan_to_num(image, nan=0.0, posinf=255.0, neginf=0.0)

        # Convert to float32 for processing
        image = image.astype(np.float32)

        # Normalize to [0, 1] if needed
        if image.max() > 1.0:
            image = image / 255.0

        # Check for valid dimensions before resize
        if image.shape[0] == 0 or image.shape[1] == 0:
            return np.zeros((*self.image_size, 3), dtype=np.float32)

        # Resize to target dimensions
        image = cv2.resize(image, self.image_size)

        # Apply contrast enhancement
        image = self._enhance_contrast(image)

        # Add channel dimension if needed
        if len(image.shape) == 2:
            image = np.expand_dims(image, axis=-1)
            # Convert grayscale to 3 channels for compatibility with pre-trained models
            image = np.repeat(image, 3, axis=-1)

        # Ensure the image is properly scaled
        if image.max() > 1.0:
            image = image / image.max()

        # Final check for valid output
        if np.isnan(image).any() or np.isinf(image).any():
            return np.zeros((*self.image_size, 3), dtype=np.float32)

        return image

    except Exception as e:
        print(f"Error in preprocessing image: {e}")
        # Return a blank image in case of error
        return np.zeros((*self.image_size, 3), dtype=np.float32)

def _enhance_contrast(self, image):
    """
    Enhance contrast in the EM image

    Args:
        image (numpy.ndarray): Input image

    Returns:
        numpy.ndarray: Contrast-enhanced image
    """
    try:
        # Make a copy to avoid modifying the original
        img = image.copy()

        # Handle NaN or inf values if any
        if np.isnan(img).any() or np.isinf(img).any():
            img = np.nan_to_num(img, nan=0.0, posinf=1.0, neginf=0.0)

        # Make sure image is in the proper range for CLAHE
        if img.min() < 0 or img.max() > 1.0:
            img = np.clip(img, 0.0, 1.0)

        # Apply CLAHE (Contrast Limited Adaptive Histogram Equalization)
        if len(img.shape) == 2 or (len(img.shape) == 3 and img.shape[2] == 1):
            # Convert to uint8 for CLAHE
            img_uint8 = (img * 255).astype(np.uint8)
            clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
            img = clahe.apply(img_uint8) / 255.0
        else:
            # For RGB images, apply CLAHE to the luminance channel
            img_uint8 = (img * 255).astype(np.uint8)
            try:
                img_lab = cv2.cvtColor(img_uint8, cv2.COLOR_RGB2LAB)
                clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
                img_lab[..., 0] = clahe.apply(img_lab[..., 0])
                img = cv2.cvtColor(img_lab, cv2.COLOR_LAB2RGB) / 255.0
            except cv2.error:
                # If color conversion fails, revert to original image
                img = image.copy()

        # Apply Gaussian blur to reduce noise
        img = cv2.GaussianBlur(img, (3, 3), 0)

        # Normalize to [0, 1] range safely (avoid division by zero)
        min_val = img.min()
        max_val = img.max()
        if max_val > min_val:
            img = (img - min_val) / (max_val - min_val)
        else:
            # If all values are the same, return a normalized constant image
            img = np.zeros_like(img)

        return img

    except Exception as e:
        print(f"Error enhancing contrast: {e}")
        return image  # Return original image if enhancement fails

In [19]:
def load_microns_data(self, sample_limit=None):
    """
    Load and preprocess data from the Microns Explorer dataset

    Args:
        sample_limit (int, optional): Limit the number of samples to load for testing

    Returns:
        tuple: Processed images and their corresponding labels
    """
    print("Loading Microns Explorer dataset...")

    # Initialize the microns catalog
    try:
        # Proper Microns dataset access
        catalog = MicronsCatalog(self.data_path)
        synapse_collection = catalog.get_collection('synapses')

        # Get synapses with their labels
        images = []
        labels = []

        # Load synapse data in batches to reduce memory usage
        batch_size = 100  # Process 100 synapses at a time
        for i in tqdm(range(0, synapse_collection.count(), batch_size), desc="Loading synapse batches"):
            # Get a batch of synapses
            synapse_batch = synapse_collection.get_items(i, min(i + batch_size, synapse_collection.count()))

            for synapse in synapse_batch:
                try:
                    # Extract synapse image
                    em_image = synapse.get_em_image()

                    # Get the classification (from properties or derived features)
                    if hasattr(synapse, 'is_excitatory'):
                        is_excitatory = synapse.is_excitatory
                    else:
                        # Determine type from vesicle shape if available
                        vesicle_props = synapse.get_vesicle_properties()
                        # Excitatory synapses typically have round vesicles (higher circularity)
                        # Inhibitory synapses typically have flattened vesicles (lower circularity)
                        is_excitatory = vesicle_props.get('mean_circularity', 0.5) > 0.7

                    # Preprocess the image
                    processed_img = self._preprocess_image(em_image)

                    images.append(processed_img)
                    labels.append("excitatory" if is_excitatory else "inhibitory")

                    # Free memory
                    del em_image

                except Exception as e:
                    print(f"Error processing synapse: {e}")

            # Free memory after each batch
            gc.collect()

            # Check if we've reached the sample limit
            if sample_limit and len(images) >= sample_limit:
                images = images[:sample_limit]
                labels = labels[:sample_limit]
                break

    except Exception as e:  # Change to catch all exceptions, not just ImportError/NameError
        print(f"Warning: MicronsCatalog not available or failed: {e}")
        print("Using placeholder data instead")
        # Placeholder for testing when microns_utils is not available
        # Generate synthetic data for development/testing
        n_samples = min(1000, sample_limit) if sample_limit else 1000

        # Generate random images
        images = []
        labels = []

        for i in tqdm(range(n_samples), desc="Generating placeholder data"):
            # Generate a random grayscale image
            img = np.random.rand(256, 256) * 255

            # Create some "vesicle-like" structures
            is_excitatory = np.random.random() > 0.4  # Slightly biased toward excitatory

            if is_excitatory:
                # Add round vesicle-like structures
                for _ in range(np.random.randint(10, 30)):
                    x, y = np.random.randint(20, 236, 2)
                    radius = np.random.randint(3, 8)
                    cv2.circle(img, (x, y), radius, 255, -1)
            else:
                # Add elongated vesicle-like structures
                for _ in range(np.random.randint(10, 30)):
                    x, y = np.random.randint(20, 236, 2)
                    width = np.random.randint(2, 5)
                    height = np.random.randint(5, 12)
                    angle = np.random.randint(0, 180)
                    box = ((x, y), (width, height), angle)
                    cv2.ellipse(img, box, 255, -1)

            # Add some noise
            img += np.random.normal(0, 15, img.shape)
            img = np.clip(img, 0, 255).astype(np.uint8)

            # Preprocess
            processed_img = self._preprocess_image(img)

            images.append(processed_img)
            labels.append("excitatory" if is_excitatory else "inhibitory")

    # Convert lists to numpy arrays
    print("Converting to numpy arrays...")
    images = np.array(images)

    # Check if we have any images to process
    if len(images) == 0:
        raise ValueError("No valid images could be loaded or generated")

    # Encode labels
    label_encoder = LabelEncoder()
    encoded_labels = label_encoder.fit_transform(labels)

    # Save the label encoder for later use
    with open(os.path.join(self.output_path, 'label_encoder.pkl'), 'wb') as f:
        pickle.dump(label_encoder, f)

    print(f"Loaded {len(images)} synapse images with shape {images.shape}")
    class_counts = pd.Series(labels).value_counts()
    print(f"Label distribution: {class_counts.to_dict()}")

    return images, encoded_labels

def _preprocess_image(self, image):
    """
    Preprocess a single EM image

    Args:
        image (numpy.ndarray): Input EM image

    Returns:
        numpy.ndarray: Preprocessed image
    """
    try:
        # Handle empty or corrupt images
        if image is None or image.size == 0:
            # Return a blank image of the target size
            return np.zeros((*self.image_size, 3), dtype=np.float32)

        # Handle NaN/Inf values
        if np.isnan(image).any() or np.isinf(image).any():
            image = np.nan_to_num(image)

        # Convert to float32 for processing
        image = image.astype(np.float32)

        # Normalize to [0, 1] if needed
        if image.max() > 1.0:
            image = image / 255.0

        # Resize to target dimensions (check for valid image size)
        if image.shape[0] > 0 and image.shape[1] > 0:
            image = cv2.resize(image, self.image_size)
        else:
            return np.zeros((*self.image_size, 3), dtype=np.float32)

        # Apply contrast enhancement
        image = self._enhance_contrast(image)

        # Add channel dimension if needed
        if len(image.shape) == 2:
            image = np.expand_dims(image, axis=-1)
            # Convert grayscale to 3 channels for compatibility with pre-trained models
            image = np.repeat(image, 3, axis=-1)

        # Ensure the image is properly scaled
        if image.max() > 1.0:
            image = image / image.max()

        return image

    except Exception as e:
        print(f"Error in preprocessing image: {e}")
        # Return a blank image in case of error
        return np.zeros((*self.image_size, 3), dtype=np.float32)

def _enhance_contrast(self, image):
    """
    Enhance contrast in the EM image

    Args:
        image (numpy.ndarray): Input image

    Returns:
        numpy.ndarray: Contrast-enhanced image
    """
    try:
        # Make a copy to avoid modifying the original
        img = image.copy()

        # Ensure image doesn't contain NaN/Inf values
        if np.isnan(img).any() or np.isinf(img).any():
            img = np.nan_to_num(img)

        # Ensure valid range for processing
        img = np.clip(img, 0.0, 1.0)

        # Apply CLAHE (Contrast Limited Adaptive Histogram Equalization)
        if len(img.shape) == 2 or (len(img.shape) == 3 and img.shape[2] == 1):
            # Convert to uint8 for CLAHE
            img_uint8 = (img * 255).astype(np.uint8)
            clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
            img = clahe.apply(img_uint8) / 255.0
        else:
            # For RGB images, apply CLAHE to the luminance channel
            img_uint8 = (img * 255).astype(np.uint8)
            try:
                img_lab = cv2.cvtColor(img_uint8, cv2.COLOR_RGB2LAB)
                clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
                img_lab[..., 0] = clahe.apply(img_lab[..., 0])
                img = cv2.cvtColor(img_lab, cv2.COLOR_LAB2RGB) / 255.0
            except cv2.error:
                # If conversion fails, use original image
                print("Color conversion failed, using original image")
                return image

        # Apply Gaussian blur to reduce noise
        img = cv2.GaussianBlur(img, (3, 3), 0)

        # Normalize to [0, 1] range with protection against division by zero
        denom = img.max() - img.min()
        if denom > 1e-8:  # Only normalize if there's a meaningful range
            img = (img - img.min()) / denom
        else:
            # If all values are essentially the same, return the original image
            return image

        # Final check to ensure output is in valid range
        img = np.clip(img, 0.0, 1.0)

        return img

    except Exception as e:
        print(f"Error enhancing contrast: {e}")
        return image  # Return original image if enhancement fails

In [20]:
def extract_synapse_regions(self, em_volume, segmentation_mask, max_synapses=500):
        """
        Extract individual synapse regions from a larger EM volume using segmentation masks

        Args:
            em_volume (numpy.ndarray): 3D EM volume
            segmentation_mask (numpy.ndarray): Synapse segmentation mask
            max_synapses (int): Maximum number of synapses to extract to avoid memory issues

        Returns:
            list: List of individual synapse patches
        """
        print("Extracting synapse regions...")

        try:
            # Find connected components in the segmentation mask
            # Use a memory-efficient approach for large volumes
            if segmentation_mask.size > 1e9:  # If mask is larger than ~1GB
                print("Large volume detected, using memory-efficient processing...")
                # Process in slices along z-axis
                all_patches = []
                z_slices = np.array_split(np.arange(segmentation_mask.shape[0]), 10)

                for z_slice in tqdm(z_slices, desc="Processing volume slices"):
                    slice_mask = segmentation_mask[z_slice.min():z_slice.max()+1]
                    slice_volume = em_volume[z_slice.min():z_slice.max()+1]

                    # Find connected components in this slice
                    labeled_slice, num_features = ndi.label(slice_mask)

                    # Extract patches from this slice
                    for i in range(1, min(num_features+1, max_synapses//len(z_slices)+1)):
                        coords = np.where(labeled_slice == i)
                        if len(coords[0]) == 0:
                            continue

                        z_min, z_max = np.min(coords[0]), np.max(coords[0])
                        y_min, y_max = np.min(coords[1]), np.max(coords[1])
                        x_min, x_max = np.min(coords[2]), np.max(coords[2])

                        # Add margin
                        margin = 10
                        z_min = max(0, z_min - margin)
                        y_min = max(0, y_min - margin)
                        x_min = max(0, x_min - margin)
                        z_max = min(slice_volume.shape[0], z_max + margin)
                        y_max = min(slice_volume.shape[1], y_max + margin)
                        x_max = min(slice_volume.shape[2], x_max + margin)

                        # Extract patch
                        patch = slice_volume[z_min:z_max, y_min:y_max, x_min:x_max]

                        # Create a 2D representation
                        if patch.shape[0] > 0:
                            # Take middle slice or max projection
                            if patch.shape[0] > 1:
                                middle_slice = patch[patch.shape[0] // 2]
                            else:
                                middle_slice = patch[0]

                            all_patches.append(middle_slice)

                    # Free memory
                    del labeled_slice, slice_mask, slice_volume
                    gc.collect()

                return all_patches[:max_synapses]  # Limit total patches

            else:
                # For smaller volumes, use standard approach
                labeled_mask, num_features = ndi.label(segmentation_mask)

                synapse_patches = []
                for i in tqdm(range(1, min(num_features + 1, max_synapses + 1)),
                              desc="Extracting patches"):
                    # Get bounding box for each synapse
                    coords = np.where(labeled_mask == i)
                    if len(coords[0]) == 0:
                        continue

                    z_min, z_max = np.min(coords[0]), np.max(coords[0])
                    y_min, y_max = np.min(coords[1]), np.max(coords[1])
                    x_min, x_max = np.min(coords[2]), np.max(coords[2])

                    # Add margin
                    margin = 10
                    z_min = max(0, z_min - margin)
                    y_min = max(0, y_min - margin)
                    x_min = max(0, x_min - margin)
                    z_max = min(em_volume.shape[0], z_max + margin)
                    y_max = min(em_volume.shape[1], y_max + margin)
                    x_max = min(em_volume.shape[2], x_max + margin)

                    # Extract patch
                    patch = em_volume[z_min:z_max, y_min:y_max, x_min:x_max]

                    # Create a 2D representation (max projection or middle slice)
                    if patch.shape[0] > 0:
                        # Take middle slice
                        middle_slice = patch[patch.shape[0] // 2]
                        synapse_patches.append(middle_slice)

                return synapse_patches

        except Exception as e:
            print(f"Error extracting synapse regions: {e}")
            return []

In [21]:
def build_unet_segmentation_model(self, input_shape=(512, 512, 1)):
        """
        Build a U-Net model for synapse segmentation

        Args:
            input_shape (tuple): Shape of input images

        Returns:
            tensorflow.keras.Model: U-Net segmentation model
        """
        # Memory-optimized U-Net with dropout and batch normalization
        inputs = tf.keras.Input(input_shape)

        # Encoder (Downsampling path) with batch normalization and dropout
        c1 = layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(inputs)
        c1 = layers.BatchNormalization()(c1)
        c1 = layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c1)
        c1 = layers.BatchNormalization()(c1)
        p1 = layers.MaxPooling2D((2, 2))(c1)
        p1 = layers.Dropout(0.1)(p1)

        c2 = layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p1)
        c2 = layers.BatchNormalization()(c2)
        c2 = layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c2)
        c2 = layers.BatchNormalization()(c2)
        p2 = layers.MaxPooling2D((2, 2))(c2)
        p2 = layers.Dropout(0.2)(p2)

        c3 = layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p2)
        c3 = layers.BatchNormalization()(c3)
        c3 = layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c3)
        c3 = layers.BatchNormalization()(c3)
        p3 = layers.MaxPooling2D((2, 2))(c3)
        p3 = layers.Dropout(0.3)(p3)

        c4 = layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p3)
        c4 = layers.BatchNormalization()(c4)
        c4 = layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c4)
        c4 = layers.BatchNormalization()(c4)
        p4 = layers.MaxPooling2D((2, 2))(c4)
        p4 = layers.Dropout(0.4)(p4)

        # Bridge
        c5 = layers.Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p4)
        c5 = layers.BatchNormalization()(c5)
        c5 = layers.Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c5)
        c5 = layers.BatchNormalization()(c5)
        c5 = layers.Dropout(0.5)(c5)

        # Decoder (Upsampling path)
        u6 = layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c5)
        u6 = layers.concatenate([u6, c4])
        u6 = layers.Dropout(0.4)(u6)
        c6 = layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u6)
        c6 = layers.BatchNormalization()(c6)
        c6 = layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c6)
        c6 = layers.BatchNormalization()(c6)

        u7 = layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c6)
        u7 = layers.concatenate([u7, c3])
        u7 = layers.Dropout(0.3)(u7)
        c7 = layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u7)
        c7 = layers.BatchNormalization()(c7)
        c7 = layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c7)
        c7 = layers.BatchNormalization()(c7)

        u8 = layers.Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(c7)
        u8 = layers.concatenate([u8, c2])
        u8 = layers.Dropout(0.2)(u8)
        c8 = layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u8)
        c8 = layers.BatchNormalization()(c8)
        c8 = layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c8)
        c8 = layers.BatchNormalization()(c8)

        u9 = layers.Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same')(c8)
        u9 = layers.concatenate([u9, c1], axis=3)
        u9 = layers.Dropout(0.1)(u9)
        c9 = layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u9)
        c9 = layers.BatchNormalization()(c9)
        c9 = layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c9)
        c9 = layers.BatchNormalization()(c9)

        outputs = layers.Conv2D(1, (1, 1), activation='sigmoid')(c9)

        model = models.Model(inputs=[inputs], outputs=[outputs])

        # Use dice coefficient loss for better segmentation
        def dice_coef(y_true, y_pred, smooth=1.0):
            y_true_f = K.flatten(y_true)
            y_pred_f = K.flatten(y_pred)
            intersection = K.sum(y_true_f * y_pred_f)
            return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

        def dice_coef_loss(y_true, y_pred):
            return 1 - dice_coef(y_true, y_pred)

        model.compile(
            optimizer='adam',
            loss=dice_coef_loss,
            metrics=['accuracy', dice_coef]
        )

        return model

In [22]:
def build_classification_model(self, num_classes=2, use_pretrained=True):
        """
        Build the synapse classification model

        Args:
            num_classes (int): Number of classes to predict
            use_pretrained (bool): Whether to use a pre-trained backbone model

        Returns:
            tensorflow.keras.Model: Classification model
        """
        if use_pretrained:
            # Use EfficientNetV2S which works better for grayscale/EM images
            base_model = applications.EfficientNetV2S(
                include_top=False,
                weights='imagenet',
                input_shape=(*self.image_size, 3),
                include_preprocessing=False  # We do our own preprocessing
            )

            # Freeze the base model initially
            base_model.trainable = False

            # Create new model on top
            inputs = layers.Input(shape=(*self.image_size, 3))

            # Preprocessing specific to EM images
            x = inputs

            # Pass through the base model
            x = base_model(x, training=False)

            # Add global average pooling layer
            x = layers.GlobalAveragePooling2D()(x)

            # Add dropout and dense layers
            x = layers.Dropout(0.4)(x)
            x = layers.Dense(512, activation='relu')(x)
            x = layers.BatchNormalization()(x)
            x = layers.Dropout(0.5)(x)
            x = layers.Dense(128, activation='relu')(x)
            x = layers.BatchNormalization()(x)

            # Add the final dense layer with softmax activation
            outputs = layers.Dense(num_classes, activation='softmax')(x)

            # Create the model
            model = models.Model(inputs, outputs)

        else:
            # Create a custom CNN from scratch, optimized for EM images
            model = models.Sequential([
                # Initial layer to handle grayscale/EM images
                layers.Conv2D(32, (5, 5), padding='same', input_shape=(*self.image_size, 3)),
                layers.BatchNormalization(),
                layers.LeakyReLU(alpha=0.1),
                layers.MaxPooling2D((2, 2)),

                # Second block
                layers.Conv2D(64, (3, 3), padding='same'),
                layers.BatchNormalization(),
                layers.LeakyReLU(alpha=0.1),
                layers.Conv2D(64, (3, 3), padding='same'),
                layers.BatchNormalization(),
                layers.LeakyReLU(alpha=0.1),
                layers.MaxPooling2D((2, 2)),
                layers.Dropout(0.2),

                # Third block
                layers.Conv2D(128, (3, 3), padding='same'),
                layers.BatchNormalization(),
                layers.LeakyReLU(alpha=0.1),
                layers.Conv2D(128, (3, 3), padding='same'),
                layers.BatchNormalization(),
                layers.LeakyReLU(alpha=0.1),
                layers.MaxPooling2D((2, 2)),
                layers.Dropout(0.3),

                # Fourth block
                layers.Conv2D(256, (3, 3), padding='same'),
                layers.BatchNormalization(),
                layers.LeakyReLU(alpha=0.1),
                layers.Conv2D(256, (3, 3), padding='same'),
                layers.BatchNormalization(),
                layers.LeakyReLU(alpha=0.1),
                layers.MaxPooling2D((2, 2)),
                layers.Dropout(0.4),

                # Final classification layers
                layers.Flatten(),
                layers.Dense(512, activation='relu'),
                layers.BatchNormalization(),
                layers.Dropout(0.5),
                layers.Dense(128, activation='relu'),
                layers.BatchNormalization(),
                layers.Dropout(0.5),
                layers.Dense(num_classes, activation='softmax')
            ])

        # Compile the model with focal loss for better handling of class imbalance
        def focal_loss(gamma=2.0, alpha=0.25):
            def focal_loss_fn(y_true, y_pred):
                # Convert labels to one-hot encoding
                y_true_one_hot = tf.one_hot(tf.cast(y_true, tf.int32), depth=num_classes)

                # Calculate focal loss
                epsilon = 1e-7
                y_pred = tf.clip_by_value(y_pred, epsilon, 1.0 - epsilon)

                # Calculate cross entropy
                cross_entropy = -y_true_one_hot * tf.math.log(y_pred)

                # Calculate focal term
                focal_term = (1 - y_pred) ** gamma

                # Calculate final focal loss
                loss = alpha * focal_term * cross_entropy

                return tf.reduce_sum(loss, axis=-1)

            return focal_loss_fn

        model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
            loss=focal_loss(gamma=2.0),
            metrics=['accuracy']
        )

        return model

In [28]:
def train_model(self, X, y, validation_split=0.2, epochs=50, fine_tune=True):
    """
    Train the classification model with advanced techniques

    Args:
        X (numpy.ndarray): Input images
        y (numpy.ndarray): Target labels
        validation_split (float): Fraction of data to use for validation
        epochs (int): Number of training epochs
        fine_tune (bool): Whether to fine-tune the pre-trained model

    Returns:
        dict: Training history
    """
    # Use stratified k-fold to ensure balanced validation sets
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    fold_idx = next(skf.split(X, y))
    train_idx, val_idx = fold_idx

    X_train, X_val = X[train_idx], X[val_idx]
    y_train, y_val = y[train_idx], y[val_idx]

    print(f"Training set: {X_train.shape}, Validation set: {X_val.shape}")

    # Calculate class weights to handle imbalanced data
    class_weights = compute_class_weight(
        'balanced',
        classes=np.unique(y_train),
        y=y_train
    )
    class_weight_dict = {i: weight for i, weight in enumerate(class_weights)}
    print(f"Class weights: {class_weight_dict}")

    # Create data generators with augmentation for training
    train_datagen = ImageDataGenerator(
        rotation_range=180,            # More rotation for EM images
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.3,                # More zoom variation
        horizontal_flip=True,
        vertical_flip=True,            # EM has no natural orientation
        brightness_range=(0.8, 1.2),   # Vary brightness
        fill_mode='reflect',
        preprocessing_function=lambda x: x + np.random.normal(0, 0.05, x.shape)  # Add noise
    )

    # Use a simple validation generator without augmentation
    val_datagen = ImageDataGenerator()

    train_generator = train_datagen.flow(
        X_train, y_train,
        batch_size=self.batch_size,
        shuffle=True
    )

    val_generator = val_datagen.flow(
        X_val, y_val,
        batch_size=self.batch_size,
        shuffle=False
    )

    # Memory optimization - clear unnecessary variables
    gc.collect()

    # Create callbacks with improved settings
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=15,               # More patience for EM images
        restore_best_weights=True,
        verbose=1
    )

    model_checkpoint = ModelCheckpoint(
        filepath=os.path.join(self.output_path, 'best_model.h5'),
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    )

    reduce_lr = ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.2,                # More aggressive LR reduction
        patience=7,
        min_lr=1e-7,
        verbose=1
    )

    # Add TensorBoard callback for better monitoring
    tensorboard = tf.keras.callbacks.TensorBoard(
        log_dir=os.path.join(self.output_path, 'logs'),
        histogram_freq=1,
        write_graph=True,
        update_freq='epoch'
    )

    callbacks = [early_stopping, model_checkpoint, reduce_lr, tensorboard]

    # Build and train the model with error handling
    try:
        if self.model is None:
            self.model = self.build_classification_model()
            print(f"Model successfully built with {self.model.count_params():,} parameters")
            self.model.summary()

        # Set up training with memory optimization
        steps_per_epoch = len(X_train) // self.batch_size
        validation_steps = len(X_val) // self.batch_size

        # Train the model with mixed precision for performance
        # Note: Remove mixed_precision if not supported by hardware
        try:
            from tensorflow.keras.mixed_precision import experimental as mixed_precision
            policy = mixed_precision.Policy('mixed_float16')
            mixed_precision.set_global_policy(policy)
            print("Mixed precision training enabled")
        except:
            print("Mixed precision not available, using default precision")

        print(f"Training the model for {epochs} epochs...")
        history = self.model.fit(
            train_generator,
            validation_data=val_generator,
            epochs=epochs,
            callbacks=callbacks,
            class_weight=class_weight_dict,
            steps_per_epoch=steps_per_epoch,
            validation_steps=validation_steps,
            workers=4,                  # Parallel processing
            use_multiprocessing=True,   # Enable multiprocessing
            verbose=1
        )

        # Track metrics for later analysis
        for i, (loss, acc, val_loss, val_acc) in enumerate(zip(
            history.history['loss'],
            history.history['accuracy'],
            history.history['val_loss'],
            history.history['val_accuracy']
        )):
            print(f"Epoch {i+1}/{epochs} - loss: {loss:.4f} - accuracy: {acc:.4f} - val_loss: {val_loss:.4f} - val_accuracy: {val_acc:.4f}")

        # Save training metrics
        self.metrics_history['train_loss'].extend(history.history['loss'])
        self.metrics_history['val_loss'].extend(history.history['val_loss'])
        self.metrics_history['train_acc'].extend(history.history['accuracy'])
        self.metrics_history['val_acc'].extend(history.history['val_accuracy'])

    except Exception as e:
        print(f"Error during model training: {e}")
        # Try to recover by creating a simpler model
        print("Attempting to train with a simpler model architecture...")
        self.model = self.build_classification_model(use_pretrained=False)
        history = self.model.fit(
            train_generator,
            validation_data=val_generator,
            epochs=min(20, epochs),  # Shorter training for recovery
            callbacks=[early_stopping, model_checkpoint],
            class_weight=class_weight_dict
        )

    # Fine-tune the model if specified
    if fine_tune and hasattr(self.model, 'layers') and len(self.model.layers) > 0 and hasattr(self.model.layers[0], 'trainable'):
        try:
            print("Fine-tuning the model...")

            # Check if the model can be fine-tuned (has trainable layers)
            trainable_params_before = sum([tf.size(w).numpy() for w in self.model.trainable_weights])

            # Gradually unfreeze layers from top to bottom for better transfer learning
            if isinstance(self.model.layers[0], tf.keras.models.Model):  # For EfficientNet or other pre-trained models
                base_model = self.model.layers[0]

                # Safety check for base model layers
                if hasattr(base_model, 'layers') and len(base_model.layers) > 0:
                    # Calculate number of layers to unfreeze (30% of total)
                    num_layers_to_unfreeze = max(1, int(len(base_model.layers) * 0.3))

                    # Unfreeze the top layers first (last 30%)
                    for layer in base_model.layers[-num_layers_to_unfreeze:]:
                        if hasattr(layer, 'trainable'):
                            layer.trainable = True

                    print(f"Fine-tuning top 30% of base model layers: {num_layers_to_unfreeze} layers unfrozen")
                else:
                    print("Base model has no layers, setting entire model to trainable")
                    self.model.trainable = True
            else:
                # For custom model, unfreeze all layers
                self.model.trainable = True
                print("Fine-tuning all layers")

            # Verify that some parameters are actually trainable now
            trainable_params_after = sum([tf.size(w).numpy() for w in self.model.trainable_weights])
            if trainable_params_after <= trainable_params_before:
                print(f"Warning: No additional parameters were made trainable ({trainable_params_before} -> {trainable_params_after})")
                if trainable_params_after == 0:
                    print("No trainable parameters found, skipping fine-tuning")
                    raise ValueError("Model has no trainable parameters")

            # Custom learning rate scheduler for fine-tuning with decay
            def lr_scheduler(epoch, lr):
                if epoch < 5:
                    return lr  # Keep initial learning rate for first 5 epochs
                else:
                    # Exponential decay after 5 epochs
                    return lr * tf.math.exp(-0.1)

            lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_scheduler)

            # Recompile with a lower learning rate and appropriate loss function
            try:
                if hasattr(self.model, 'optimizer') and hasattr(self.model.optimizer, 'learning_rate'):
                    # Extract current learning rate safely
                    try:
                        current_lr = float(self.model.optimizer.learning_rate.numpy())
                    except (AttributeError, ValueError, TypeError):
                        current_lr = 0.001  # Default Adam learning rate

                    # Reduce learning rate for fine-tuning
                    new_lr = current_lr * 0.1
                else:
                    # Default fine-tuning learning rate
                    new_lr = 1e-5
            except Exception as lr_error:
                # Handle any unexpected errors when accessing the optimizer
                print(f"Warning: Could not determine current learning rate: {lr_error}")
                new_lr = 1e-5

            # Ensure learning rate is in a reasonable range
            new_lr = max(1e-7, min(new_lr, 1e-3))  # Between 1e-7 and 1e-3
            print(f"Fine-tuning with learning rate: {new_lr}")

            # Choose appropriate loss function based on model output
            try:
                # Check output shape to determine appropriate loss function
                if hasattr(self.model, 'output_shape') and isinstance(self.model.output_shape, tuple):
                    if self.model.output_shape[-1] > 2:  # Multi-class
                        loss_fn = 'sparse_categorical_crossentropy'
                    else:  # Binary
                        loss_fn = 'binary_crossentropy' if self.model.output_shape[-1] == 1 else 'sparse_categorical_crossentropy'
                else:
                    # Fallback if output_shape is not available
                    loss_fn = 'sparse_categorical_crossentropy'
                    print("Warning: Could not determine model output shape, using sparse_categorical_crossentropy")
            except Exception as shape_error:
                # Fallback for any errors
                print(f"Warning: Error determining model output shape: {shape_error}")
                loss_fn = 'sparse_categorical_crossentropy'

            # Compile model with appropriate settings
            try:
                self.model.compile(
                    optimizer=tf.keras.optimizers.Adam(learning_rate=new_lr),
                    loss=loss_fn,
                    metrics=['accuracy']
                )
            except Exception as compile_error:
                print(f"Error compiling model for fine-tuning: {compile_error}")
                print("Attempting to compile with default parameters")
                self.model.compile(
                    optimizer='adam',
                    loss='sparse_categorical_crossentropy',
                    metrics=['accuracy']
                )

            # Calculate fine-tuning epochs with safeguards
            if isinstance(epochs, int) and epochs > 0:
                fine_tune_epochs = min(20, max(5, epochs // 2))  # Between 5 and 20
            else:
                fine_tune_epochs = 10  # Default if epochs is invalid

            # Update callbacks with safeguards
            if isinstance(callbacks, list):
                ft_callbacks = callbacks.copy()  # Make a copy to avoid modifying the original
                ft_callbacks.append(lr_callback)
            else:
                # Fallback if callbacks is not a list
                print("Warning: callbacks is not a list, using default callbacks for fine-tuning")
                ft_callbacks = [
                    tf.keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True),
                    lr_callback
                ]

            # Determine initial epoch safely
            if hasattr(history, 'history') and isinstance(history.history, dict) and 'loss' in history.history:
                initial_epoch = len(history.history['loss'])
            else:
                initial_epoch = 0
                print("Warning: No training history found, starting fine-tuning from epoch 0")

            # Fine-tuning with proper error handling
            print(f"Starting fine-tuning for {fine_tune_epochs} epochs (from epoch {initial_epoch})")
            try:
                fine_tune_history = self.model.fit(
                    train_generator,
                    validation_data=val_generator,
                    epochs=initial_epoch + fine_tune_epochs,
                    initial_epoch=initial_epoch,
                    callbacks=ft_callbacks,
                    class_weight=class_weight_dict if isinstance(class_weight_dict, dict) else None,
                    verbose=1
                )

                # Combine histories safely
                if (hasattr(history, 'history') and isinstance(history.history, dict) and
                    hasattr(fine_tune_history, 'history') and isinstance(fine_tune_history.history, dict)):

                    # Check each key before extending
                    for k in list(history.history.keys()):
                        if k in fine_tune_history.history:
                            history.history[k].extend(fine_tune_history.history[k])

                    # Update metrics history with safeguards
                    if hasattr(self, 'metrics_history') and isinstance(self.metrics_history, dict):
                        # Safely update each metric if it exists
                        metrics_mapping = {
                            'train_loss': 'loss',
                            'val_loss': 'val_loss',
                            'train_acc': 'accuracy',
                            'val_acc': 'val_accuracy'
                        }

                        for history_key, fine_tune_key in metrics_mapping.items():
                            if (history_key in self.metrics_history and
                                fine_tune_key in fine_tune_history.history and
                                isinstance(self.metrics_history[history_key], list)):
                                self.metrics_history[history_key].extend(fine_tune_history.history[fine_tune_key])

                print("Fine-tuning completed successfully")
            except Exception as fit_error:
                print(f"Error during fine-tuning fit process: {fit_error}")
                print("Fine-tuning failed, continuing with the base model")
                import traceback
                traceback.print_exc()

        except Exception as e:
            print(f"Error during fine-tuning preparation: {e}")
            print("Skipping fine-tuning and keeping the base model")
            import traceback
            traceback.print_exc()

    # Save the model with comprehensive error handling
    print("Saving the final model...")
    try:
        # Ensure output directory exists
        if not os.path.exists(self.output_path):
            os.makedirs(self.output_path, exist_ok=True)

        # First try to save in TensorFlow SavedModel format (most reliable)
        model_save_path = os.path.join(self.output_path, 'final_model')

        # Create subdirectory if needed
        if not os.path.exists(os.path.dirname(model_save_path)):
            os.makedirs(os.path.dirname(model_save_path), exist_ok=True)

        # Save with error handling
        try:
            self.model.save(model_save_path, save_format='tf')
            print(f"Model saved successfully to {model_save_path}")
        except Exception as tf_save_error:
            print(f"Error saving in TensorFlow format: {tf_save_error}")
            # Continue to try other formats

        # Try to save in H5 format
        try:
            h5_path = os.path.join(self.output_path, 'final_model.h5')
            self.model.save(h5_path)
            print(f"Model saved in H5 format to {h5_path}")
        except Exception as h5_save_error:
            print(f"Warning: Could not save in H5 format: {h5_save_error}")

        # Save model architecture as JSON
        try:
            json_path = os.path.join(self.output_path, 'model_architecture.json')
            with open(json_path, 'w') as f:
                model_json = self.model.to_json()
                f.write(model_json)
            print(f"Model architecture saved to {json_path}")
        except Exception as json_save_error:
            print(f"Warning: Could not save model architecture: {json_save_error}")

        # Save model weights as a separate file
        try:
            weights_path = os.path.join(self.output_path, 'model_weights.h5')
            self.model.save_weights(weights_path)
            print(f"Model weights saved to {weights_path}")
        except Exception as weights_save_error:
            print(f"Warning: Could not save model weights: {weights_save_error}")

    except Exception as e:
        print(f"Critical error during model saving: {e}")
        import traceback
        traceback.print_exc()

        # Last-ditch effort to save something
        try:
            print("Attempting emergency save of model weights...")
            emergency_path = os.path.join(self.output_path, 'emergency_weights.h5')

            # Ensure the directory exists
            os.makedirs(os.path.dirname(emergency_path), exist_ok=True)

            # Try to save weights in HDF5 format
            self.model.save_weights(emergency_path)
            print(f"Emergency save successful: {emergency_path}")
        except Exception as emergency_error:
            print(f"Emergency save failed. Model could not be saved: {emergency_error}")
            # Try one more approach with NumPy format
            try:
                np_path = os.path.join(self.output_path, 'emergency_weights.npz')
                np_weights = [w.numpy() for w in self.model.weights]
                np.savez(np_path, *np_weights)
                print(f"Emergency NumPy save successful: {np_path}")
            except:
                print("All save attempts failed. Model could not be saved.")

    # Safely return history
    if hasattr(history, 'history') and isinstance(history.history, dict):
        return history.history
    else:
        print("Warning: No valid history found")
        return {}

In [29]:
def evaluate_model(self, X_test, y_test):
        """
        Evaluate the trained model with comprehensive metrics

        Args:
            X_test (numpy.ndarray): Test images
            y_test (numpy.ndarray): Test labels

        Returns:
            dict: Evaluation metrics
        """
        if self.model is None:
            raise ValueError("Model has not been trained yet. Please train the model first.")

        print("Evaluating model on test data...")

        try:
            # Load label encoder
            label_encoder_path = os.path.join(self.output_path, 'label_encoder.pkl')
            if os.path.exists(label_encoder_path):
                with open(label_encoder_path, 'rb') as f:
                    label_encoder = pickle.load(f)
                class_names = label_encoder.classes_
            else:
                # Default class names
                class_names = ['inhibitory', 'excitatory']

            # Make predictions in batches to avoid memory issues
            batch_size = 32
            n_batches = int(np.ceil(len(X_test) / batch_size))

            y_pred_prob = []
            for i in tqdm(range(n_batches), desc="Predicting"):
                start_idx = i * batch_size
                end_idx = min((i + 1) * batch_size, len(X_test))
                batch_pred = self.model.predict(X_test[start_idx:end_idx], verbose=0)
                y_pred_prob.append(batch_pred)

            y_pred_prob = np.vstack(y_pred_prob)
            y_pred = np.argmax(y_pred_prob, axis=1)

            # Calculate comprehensive metrics
            accuracy = accuracy_score(y_test, y_pred)
            report = classification_report(y_test, y_pred, output_dict=True,
                                          target_names=class_names)
            cm = confusion_matrix(y_test, y_pred)

            # Calculate per-class metrics
            class_precision = {}
            class_recall = {}
            class_f1 = {}

            for i, class_name in enumerate(class_names):
                # True positives, false positives, true negatives, false negatives
                tp = cm[i, i]
                fp = cm[:, i].sum() - tp
                fn = cm[i, :].sum() - tp
                tn = cm.sum() - (tp + fp + fn)

                # Precision, recall, F1
                precision = tp / (tp + fp) if (tp + fp) > 0 else 0
                recall = tp / (tp + fn) if (tp + fn) > 0 else 0
                f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

                class_precision[class_name] = precision
                class_recall[class_name] = recall
                class_f1[class_name] = f1

            # Create enhanced confusion matrix visualization
            plt.figure(figsize=(12, 10))

            # Normalize confusion matrix for better visualization
            cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

            # Create a more informative heatmap
            ax = sns.heatmap(cm_norm, annot=cm, fmt='d', cmap='Blues',
                           xticklabels=class_names,
                           yticklabels=class_names)

            # Add percentage text to each cell
            for i in range(cm.shape[0]):
                for j in range(cm.shape[1]):
                    text = ax.texts[i * cm.shape[1] + j]
                    percentage = cm_norm[i, j] * 100
                    text.set_text(f"{cm[i, j]}\n({percentage:.1f}%)")

            plt.xlabel('Predicted')
            plt.ylabel('True')
            plt.title('Confusion Matrix with Class Distribution')

            # Add colored borders based on correct/incorrect
            for i in range(len(class_names)):
                for j in range(len(class_names)):
                    color = 'green' if i == j else 'red'
                    plt.plot([j, j+1], [i, i], color=color, lw=2)
                    plt.plot([j, j+1], [i+1, i+1], color=color, lw=2)
                    plt.plot([j, j], [i, i+1], color=color, lw=2)
                    plt.plot([j+1, j+1], [i, i+1], color=color, lw=2)

            plt.tight_layout()
            plt.savefig(os.path.join(self.output_path, 'confusion_matrix.png'), dpi=300)

            # Create ROC curve for binary classification
            if len(class_names) == 2:
                from sklearn.metrics import roc_curve, auc
                fpr, tpr, _ = roc_curve(y_test, y_pred_prob[:, 1])
                roc_auc = auc(fpr, tpr)

                plt.figure(figsize=(8, 6))
                plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
                plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
                plt.xlim([0.0, 1.0])
                plt.ylim([0.0, 1.05])
                plt.xlabel('False Positive Rate')
                plt.ylabel('True Positive Rate')
                plt.title('Receiver Operating Characteristic')
                plt.legend(loc="lower right")
                plt.savefig(os.path.join(self.output_path, 'roc_curve.png'), dpi=300)

            # Create precision-recall curve
            from sklearn.metrics import precision_recall_curve, average_precision_score

            plt.figure(figsize=(10, 8))

            for i, class_name in enumerate(class_names):
                if len(class_names) == 2 and i == 0:
                    # For binary classification, only show the positive class
                    continue

                # Convert to one-vs-rest for multiclass
                y_true_ovr = (y_test == i).astype(int)
                y_score_ovr = y_pred_prob[:, i]

                precision, recall, _ = precision_recall_curve(y_true_ovr, y_score_ovr)
                avg_precision = average_precision_score(y_true_ovr, y_score_ovr)

                plt.plot(recall, precision, lw=2,
                         label=f'{class_name} (AP = {avg_precision:.2f})')

            plt.xlabel('Recall')
            plt.ylabel('Precision')
            plt.title('Precision-Recall Curve')
            plt.legend()
            plt.savefig(os.path.join(self.output_path, 'precision_recall_curve.png'), dpi=300)

            # Create a histogram of prediction confidences
            plt.figure(figsize=(10, 6))

            for i, class_name in enumerate(class_names):
                # Get samples that are actually this class
                class_indices = np.where(y_test == i)[0]
                if len(class_indices) == 0:
                    continue

                # Get prediction confidences for this class
                confidences = y_pred_prob[class_indices, i]

                plt.hist(confidences, alpha=0.7, bins=20,
                         label=f'True {class_name} samples', density=True)

            plt.xlabel('Model Confidence')
            plt.ylabel('Density')
            plt.title('Distribution of Model Confidence for True Samples')
            plt.legend()
            plt.savefig(os.path.join(self.output_path, 'confidence_distribution.png'), dpi=300)

            # Print detailed results
            print(f"\n{'='*50}")
            print(f"Model Evaluation Results:")
            print(f"{'='*50}")
            print(f"Test Accuracy: {accuracy:.4f}")
            print(f"\nClassification Report:")
            print(classification_report(y_test, y_pred, target_names=class_names))

            print(f"\nConfusion Matrix:")
            print(cm)

            print(f"\nPer-Class Metrics:")
            for class_name in class_names:
                print(f"  {class_name.capitalize()}:")
                print(f"    Precision: {class_precision[class_name]:.4f}")
                print(f"    Recall: {class_recall[class_name]:.4f}")
                print(f"    F1-score: {class_f1[class_name]:.4f}")

            # Save detailed metrics to CSV
            metrics_df = pd.DataFrame({
                'Class': class_names,
                'Precision': [class_precision[c] for c in class_names],
                'Recall': [class_recall[c] for c in class_names],
                'F1-Score': [class_f1[c] for c in class_names],
                'Support': [report[c]['support'] for c in class_names]
            })

            metrics_df.to_csv(os.path.join(self.output_path, 'class_metrics.csv'), index=False)

            # Return comprehensive metrics dictionary
            return {
                'accuracy': accuracy,
                'classification_report': report,
                'confusion_matrix': cm,
                'class_precision': class_precision,
                'class_recall': class_recall,
                'class_f1': class_f1,
                'pred_probabilities': y_pred_prob
            }

        except Exception as e:
            print(f"Error during model evaluation: {e}")

            # Attempt a simpler evaluation as fallback
            try:
                y_pred = np.argmax(self.model.predict(X_test), axis=1)
                accuracy = accuracy_score(y_test, y_pred)

                print(f"Fallback evaluation - Test Accuracy: {accuracy:.4f}")

                # Basic confusion matrix
                cm = confusion_matrix(y_test, y_pred)
                plt.figure(figsize=(8, 6))
                sns.heatmap(cm, annot=True, fmt='d')
                plt.xlabel('Predicted')
                plt.ylabel('True')
                plt.title('Confusion Matrix (Fallback)')
                plt.savefig(os.path.join(self.output_path, 'confusion_matrix_fallback.png'))

                return {
                    'accuracy': accuracy,
                    'confusion_matrix': cm
                }

            except Exception as e2:
                print(f"Fallback evaluation also failed: {e2}")
                return {
                    'accuracy': None,
                    'error': str(e)
                }

In [32]:
def visualize_model_attention(self, image, true_label=None, layer_name=None, save_path=None):
    """
    Visualize what the model focuses on using Grad-CAM with error handling

    Args:
        image (numpy.ndarray): Input image
        true_label (int, optional): True label for the image
        layer_name (str, optional): Name of the layer to use for Grad-CAM
        save_path (str, optional): Path to save the visualization

    Returns:
        numpy.ndarray: Heatmap overlay on the original image
    """
    if self.model is None:
        raise ValueError("Model has not been trained yet. Please train the model first.")

    try:
        # Handle NaN or inf values in input
        if np.isnan(image).any() or np.isinf(image).any():
            print("Warning: Input image contains NaN or inf values. Fixing...")
            image = np.nan_to_num(image)

        # Preprocess the image
        img = self._preprocess_image(image)
        img_array = np.expand_dims(img, axis=0)

        # Get the model's prediction
        preds = self.model.predict(img_array, verbose=0)
        pred_class = np.argmax(preds[0])
        pred_prob = preds[0][pred_class]

        # Find appropriate layer for Grad-CAM
        if layer_name is None:
            # Try to find the last convolutional layer
            found_conv_layer = False

            # For models with a backbone (like EfficientNet), navigate to the backbone
            if isinstance(self.model.layers[0], tf.keras.models.Model):
                base_model = self.model.layers[0]
                for layer in reversed(base_model.layers):
                    if isinstance(layer, layers.Conv2D):
                        layer_name = layer.name
                        found_conv_layer = True
                        break

            # If not found in backbone or no backbone exists, check the main model
            if not found_conv_layer:
                for layer in reversed(self.model.layers):
                    if isinstance(layer, layers.Conv2D):
                        layer_name = layer.name
                        found_conv_layer = True
                        break

            # If still no conv layer, try to find any layer with a 4D output (NHWC format)
            if not found_conv_layer:
                for layer in reversed(self.model.layers):
                    if len(getattr(layer, 'output_shape', [])) == 4:
                        layer_name = layer.name
                        found_conv_layer = True
                        break

            if not found_conv_layer:
                raise ValueError("Could not find appropriate layer for Grad-CAM visualization")

        # Support for different model architectures
        try:
            if isinstance(self.model.layers[0], tf.keras.models.Model):
                # Get the layer from the base model
                target_layer = self.model.layers[0].get_layer(layer_name)
                grad_model = models.Model(
                    inputs=[self.model.inputs],
                    outputs=[target_layer.output, self.model.output]
                )
            else:
                # Standard model
                grad_model = models.Model(
                    inputs=[self.model.inputs],
                    outputs=[self.model.get_layer(layer_name).output, self.model.output]
                )
        except:
            # Fallback approach if we can't create a proper grad model
            print("Warning: Using fallback approach for Grad-CAM")
            # Create a simpler heatmap based on the model's attention
            superimposed_img = self._create_attention_heatmap(image, img_array, pred_class)

            # Create a figure with the fallback result
            plt.figure(figsize=(8, 4))
            plt.subplot(1, 2, 1)
            plt.imshow(np.array(image * 255, dtype=np.uint8))
            plt.title('Original Image')
            plt.axis('off')

            plt.subplot(1, 2, 2)
            plt.imshow(superimposed_img)
            title = f'Attention Map - {["Inhibitory", "Excitatory"][pred_class]} ({pred_prob:.2f})'
            plt.title(title)
            plt.axis('off')

            plt.tight_layout()
            if save_path:
                plt.savefig(save_path, dpi=300)
            else:
                plt.savefig(os.path.join(self.output_path, 'fallback_attention_map.png'))

            plt.close()  # Close the figure to free memory
            return superimposed_img

        # Compute gradients with error handling
        try:
            with tf.GradientTape() as tape:
                conv_output, predictions = grad_model(img_array)
                loss = predictions[:, pred_class]

            # Gradients of the top predicted class with respect to the output feature map
            grads = tape.gradient(loss, conv_output)

            # Handle potential null gradients
            if grads is None:
                raise ValueError("Gradient is None - check model architecture")

        except Exception as e:
            print(f"Error computing gradients: {e}")
            # Fallback to simple attention visualization
            superimposed_img = self._create_attention_heatmap(image, img_array, pred_class)
            return superimposed_img

        # Vector where each entry is the mean intensity of the gradient over a channel
        pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))

        # Weight the channels by the mean gradient
        conv_output = conv_output[0]
        for i in range(pooled_grads.shape[-1]):
            conv_output[:, :, i] *= pooled_grads[i]

        # Average over all the feature maps
        heatmap = tf.reduce_mean(conv_output, axis=-1)

        # Normalize the heatmap
        heatmap = tf.maximum(heatmap, 0) / (tf.math.reduce_max(heatmap) + tf.keras.backend.epsilon())
        heatmap = heatmap.numpy()

        # Resize the heatmap to the original image size
        heatmap = cv2.resize(heatmap, (image.shape[1], image.shape[0]))

        # Convert to RGB if the image is grayscale
        if len(image.shape) == 2 or (len(image.shape) == 3 and image.shape[2] == 1):
            if len(image.shape) == 3:
                image = image[:, :, 0]
            image_rgb = np.stack([image] * 3, axis=-1)
        else:
            image_rgb = image

        # Ensure image is in correct range
        if image_rgb.max() <= 1.0:
            image_rgb = (image_rgb * 255).astype(np.uint8)

        # Apply colormap to the heatmap
        heatmap = np.uint8(255 * heatmap)
        heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

        # Superimpose the heatmap on the image
        superimposed_img = (heatmap * 0.4 + image_rgb * 0.6).astype(np.uint8)

        # Create a figure
        plt.figure(figsize=(12, 4))

        # Original image
        plt.subplot(1, 3, 1)
        plt.imshow(image_rgb)
        plt.title('Original Image')
        plt.axis('off')

        # Heatmap
        plt.subplot(1, 3, 2)
        plt.imshow(heatmap)
        plt.title('Attention Heatmap')
        plt.axis('off')

        # Superimposed image
        plt.subplot(1, 3, 3)
        plt.imshow(superimposed_img)

        # Add class label and confidence score to title
        title = f'Predicted: {["Inhibitory", "Excitatory"][pred_class]} ({pred_prob:.2f})'
        if true_label is not None:
            title += f', True: {["Inhibitory", "Excitatory"][true_label]}'
        plt.title(title)
        plt.axis('off')

        plt.tight_layout()
        if save_path:
            plt.savefig(save_path, dpi=300)
        else:
            plt.savefig(os.path.join(self.output_path, 'grad_cam_visualization.png'), dpi=300)

        plt.close()  # Close the figure to free memory
        return superimposed_img

    except Exception as e:
        print(f"Error in Grad-CAM visualization: {e}")
        # Create a simple fallback visualization
        plt.figure(figsize=(6, 6))
        try:
            # Convert image to displayable format
            display_img = np.copy(image)
            if display_img.max() <= 1.0:
                display_img = (display_img * 255).astype(np.uint8)

            plt.imshow(display_img)
            plt.title("Visualization failed - Showing original image")
            plt.axis('off')
            plt.tight_layout()

            # Save fallback visualization
            fallback_path = os.path.join(self.output_path, 'fallback_visualization.png')
            plt.savefig(fallback_path)
            print(f"Fallback image saved to {fallback_path}")

            # Close the figure to free memory
            plt.close()

            # Return the normalized image for consistent output
            if image.max() <= 1.0:
                return (image * 255).astype(np.uint8)
            else:
                return image.astype(np.uint8)
        except Exception as disp_error:
            print(f"Could not display fallback image: {disp_error}")
            # Create an empty image of the same shape as input
            if hasattr(image, 'shape'):
                empty_img = np.zeros_like(image)
                if len(empty_img.shape) == 2:
                    empty_img = np.stack([empty_img] * 3, axis=-1)
                return empty_img.astype(np.uint8)
            else:
                # Last resort - create a small blank image
                return np.zeros((224, 224, 3), dtype=np.uint8)

def _create_attention_heatmap(self, image, preprocessed_input, pred_class):
    """
    Create a simple attention heatmap as fallback for Grad-CAM

    Args:
        image (numpy.ndarray): Original image
        preprocessed_input (numpy.ndarray): Preprocessed input batch
        pred_class (int): Predicted class index

    Returns:
        numpy.ndarray: Simple attention heatmap
    """
    # Create a simplified attention map when Grad-CAM fails
    # Use blur and edge detection as a proxy for attention

    try:
        # Convert to grayscale if needed
        if len(image.shape) == 3 and image.shape[2] == 3:
            gray = cv2.cvtColor((image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
        else:
            gray = (image * 255).astype(np.uint8)

        # Apply Gaussian blur
        blurred = cv2.GaussianBlur(gray, (5, 5), 0)

        # Apply edge detection (Canny or Sobel)
        edges = cv2.Canny(blurred, 50, 150)

        # Combine edge detection with blurring for an attention-like effect
        attention_map = cv2.GaussianBlur(edges, (15, 15), 0)

        # Normalize to [0, 1]
        attention_map = attention_map / (attention_map.max() + 1e-8)  # Avoid division by zero

        # Apply color map
        heatmap = cv2.applyColorMap((attention_map * 255).astype(np.uint8), cv2.COLORMAP_JET)

        # Create RGB version of original image
        if len(image.shape) == 2 or (len(image.shape) == 3 and image.shape[2] == 1):
            if len(image.shape) == 3:
                image = image[:, :, 0]
            rgb_img = np.stack([image] * 3, axis=-1)
        else:
            rgb_img = image

        if rgb_img.max() <= 1.0:
            rgb_img = (rgb_img * 255).astype(np.uint8)

        # Overlay heatmap on image
        superimposed = cv2.addWeighted(heatmap, 0.4, rgb_img, 0.6, 0)

        return superimposed

    except Exception as e:
        print(f"Error creating fallback attention map: {e}")
        # If all else fails, return the original image
        return (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8)

In [33]:
def predict_synapse_type(self, image, threshold=0.5, return_visualization=False):
        """
        Predict whether a synapse is excitatory or inhibitory with confidence threshold

        Args:
            image (numpy.ndarray): Input EM image of a synapse
            threshold (float): Confidence threshold for prediction
            return_visualization (bool): Whether to return a visualization of the prediction

        Returns:
            tuple: (prediction class, prediction probability, visualization if requested)
        """
        if self.model is None:
            raise ValueError("Model has not been trained yet. Please train the model first.")

        try:
            # Preprocess the image
            img = self._preprocess_image(image)
            img_array = np.expand_dims(img, axis=0)

            # Get the model's prediction
            preds = self.model.predict(img_array, verbose=0)
            pred_class = np.argmax(preds[0])
            pred_prob = preds[0][pred_class]

            # Get confidence level
            confidence = pred_prob

            # Load the label encoder
            try:
                with open(os.path.join(self.output_path, 'label_encoder.pkl'), 'rb') as f:
                    label_encoder = pickle.load(f)

                # Convert predicted class to label
                pred_label = label_encoder.inverse_transform([pred_class])[0]
            except:
                # Fallback if label encoder not available
                pred_label = "excitatory" if pred_class == 1 else "inhibitory"

            # Apply confidence threshold
            if confidence < threshold:
                pred_label = "uncertain"

            # Create visualization if requested
            if return_visualization:
                # Create a visualization of the prediction
                viz_path = os.path.join(self.output_path, f'prediction_{np.random.randint(10000)}.png')
                viz_img = self.visualize_model_attention(image, layer_name=None, save_path=viz_path)
                return pred_label, pred_prob, viz_img

            return pred_label, pred_prob

        except Exception as e:
            print(f"Error predicting synapse type: {e}")
            if return_visualization:
                return "error", 0.0, None
            return "error", 0.0

In [35]:
def map_synapse_distribution(self, volume_path, segmentation_path, output_file, confidence_threshold=0.7):
    """
    Map the distribution of excitatory and inhibitory synapses in a volume with optimized memory handling

    Args:
        volume_path (str): Path to the EM volume
        segmentation_path (str): Path to the synapse segmentation
        output_file (str): Path to save the distribution map
        confidence_threshold (float): Threshold for prediction confidence

    Returns:
        dict: Dictionary containing synapse counts and distribution
    """
    try:
        # Load the volume and segmentation with memory efficient approach
        print(f"Loading volume from {volume_path}...")

        # Check file size before loading to handle large volumes
        volume_size = os.path.getsize(volume_path) / (1024 * 1024 * 1024)  # Size in GB
        print(f"Volume file size: {volume_size:.2f} GB")

        # Use chunked loading for large volumes
        if volume_size > 2.0:  # More than 2GB
            print("Large volume detected, using chunked loading...")

            with h5py.File(volume_path, 'r') as f:
                # Get volume dimensions and key
                if 'em_volume' in f:
                    volume_key = 'em_volume'
                else:
                    # Find the first dataset in the file
                    for key in f.keys():
                        if isinstance(f[key], h5py.Dataset):
                            volume_key = key
                            break
                    else:
                        raise ValueError("Could not find a dataset in the volume file")

                volume_shape = f[volume_key].shape
                print(f"Volume shape: {volume_shape}")

                # Load segmentation
                print(f"Loading segmentation from {segmentation_path}...")
                with h5py.File(segmentation_path, 'r') as seg_f:
                    # Find segmentation dataset
                    if 'synapse_mask' in seg_f:
                        seg_key = 'synapse_mask'
                    else:
                        # Find the first dataset
                        for key in seg_f.keys():
                            if isinstance(seg_f[key], h5py.Dataset):
                                seg_key = key
                                break
                        else:
                            raise ValueError("Could not find a dataset in the segmentation file")

                    # Load entire segmentation if it's not too large
                    seg_size = os.path.getsize(segmentation_path) / (1024 * 1024 * 1024)
                    if seg_size < 1.0:  # Less than 1GB
                        segmentation = seg_f[seg_key][:]
                    else:
                        # Process segmentation in chunks
                        print("Large segmentation detected, finding synapse coordinates first...")
                        # Find connected components in chunks
                        synapse_coords = []

                        # Process in slabs along z-axis
                        chunk_size = min(50, volume_shape[0] // 10)  # Process ~10 chunks or fewer
                        for z_start in range(0, volume_shape[0], chunk_size):
                            z_end = min(z_start + chunk_size, volume_shape[0])
                            print(f"Processing segmentation chunk {z_start}-{z_end}...")

                            # Load chunk of segmentation
                            seg_chunk = seg_f[seg_key][z_start:z_end, :, :]

                            # Find synapses in this chunk
                            labeled_chunk, num_features = ndi.label(seg_chunk)

                            # For each synapse in this chunk
                            for i in range(1, num_features + 1):
                                # Get coordinates
                                coords = np.where(labeled_chunk == i)
                                if len(coords[0]) == 0:
                                    continue

                                # Adjust z coordinates for chunk offset
                                mean_coords = [
                                    coords[0].mean() + z_start,
                                    coords[1].mean(),
                                    coords[2].mean()
                                ]

                                synapse_coords.append(mean_coords)

                        # Create placeholder for synapse regions
                        patches = []

                        # Extract each synapse from the volume
                        patch_size = 64  # Fixed size for each synapse patch
                        for coords in tqdm(synapse_coords, desc="Extracting synapse patches"):
                            z, y, x = [int(c) for c in coords]

                            # Create bounds with padding
                            z_min = max(0, z - patch_size // 2)
                            z_max = min(volume_shape[0], z + patch_size // 2)
                            y_min = max(0, y - patch_size // 2)
                            y_max = min(volume_shape[1], y + patch_size // 2)
                            x_min = max(0, x - patch_size // 2)
                            x_max = min(volume_shape[2], x + patch_size // 2)

                            # Extract patch from volume
                            with h5py.File(volume_path, 'r') as vol_f:
                                patch = vol_f[volume_key][z_min:z_max, y_min:y_max, x_min:x_max]

                                # Take middle slice or max projection for 2D representation
                                if patch.shape[0] > 0:
                                    middle_slice = patch[patch.shape[0] // 2]
                                    patches.append(middle_slice)

                        # Skip the regular extraction process
                        print(f"Extracted {len(patches)} synapse patches directly")
                        # Predict types directly
                        synapse_types = []
                        confidences = []

                        for patch in tqdm(patches, desc="Predicting synapse types"):
                            syn_type, conf = self.predict_synapse_type(patch, threshold=confidence_threshold)
                            synapse_types.append(syn_type)
                            confidences.append(conf)

                        # Skip to visualization
                        goto_visualization = True
        else:
            # Standard loading for smaller volumes
            with h5py.File(volume_path, 'r') as f:
                # Find the volume dataset
                if 'em_volume' in f:
                    volume = f['em_volume'][:]
                else:
                    # Find the first dataset
                    for key in f.keys():
                        if isinstance(f[key], h5py.Dataset):
                            volume = f[key][:]
                            break
                    else:
                        raise ValueError("Could not find a dataset in the volume file")

            print(f"Loading segmentation from {segmentation_path}...")
            with h5py.File(segmentation_path, 'r') as f:
                if 'synapse_mask' in f:
                    segmentation = f['synapse_mask'][:]
                else:
                    # Find the first dataset
                    for key in f.keys():
                        if isinstance(f[key], h5py.Dataset):
                            segmentation = f[key][:]
                            break
                    else:
                        raise ValueError("Could not find a dataset in the segmentation file")

            # Standard extraction process
            patches = self.extract_synapse_regions(volume, segmentation)
            print(f"Extracted {len(patches)} synapse patches")

            # Find connected components in the segmentation
            labeled_mask, num_features = ndi.label(segmentation)

            # Predict synapse types
            synapse_types = []
            confidences = []
            synapse_coords = []

            for i in tqdm(range(1, num_features + 1), desc="Predicting synapse types"):
                # Get coordinates for this synapse
                coords = np.mean(np.where(labeled_mask == i), axis=1)
                synapse_coords.append(coords)

                # Get the patch for this synapse
                patch_idx = i - 1  # 0-based indexing for patches
                if patch_idx < len(patches):
                    patch = patches[patch_idx]

                    # Predict the type
                    syn_type, conf = self.predict_synapse_type(patch, threshold=confidence_threshold)
                    synapse_types.append(syn_type)
                    confidences.append(conf)
                else:
                    # Skip this synapse if patch not available
                    synapse_types.append("unknown")
                    confidences.append(0.0)

            goto_visualization = False

        # Visualization and analysis
        if 'goto_visualization' not in locals() or not goto_visualization:
            # Filter by confidence if not already done
            confident_indices = [i for i, conf in enumerate(confidences) if conf >= confidence_threshold]
            filtered_types = [synapse_types[i] for i in confident_indices]
            filtered_coords = [synapse_coords[i] for i in confident_indices]
            filtered_confidences = [confidences[i] for i in confident_indices]

            print(f"Using {len(filtered_types)} synapses with confidence >= {confidence_threshold}")
        else:
            # These were already filtered by confidence threshold in the prediction step
            filtered_types = synapse_types
            filtered_coords = synapse_coords
            filtered_confidences = confidences

        # Count the number of each type
        synapse_counts = pd.Series(filtered_types).value_counts()
        print("\nSynapse Type Distribution:")
        for syn_type, count in synapse_counts.items():
            percentage = (count / len(filtered_types)) * 100
            print(f"  {syn_type.capitalize()}: {count} ({percentage:.1f}%)")

        # Create enhanced visualizations

        # 1. 3D interactive visualization with plotly if available
        try:
            import plotly.graph_objects as go
            from plotly.subplots import make_subplots

            # Create interactive 3D plot
            fig = make_subplots(
                rows=1, cols=2,
                specs=[[{'type': 'scatter3d'}, {'type': 'pie'}]],
                subplot_titles=('Synapse Distribution in 3D', 'Synapse Type Proportions')
            )

            # Define colors for each type
            color_map = {
                'excitatory': 'red',
                'inhibitory': 'blue',
                'uncertain': 'gray',
                'unknown': 'black'
            }

            # Create separate traces for each type for better interactive legend
            for syn_type in set(filtered_types):
                # Get indices for this type
                type_indices = [i for i, t in enumerate(filtered_types) if t == syn_type]

                # Extract coordinates and confidences for this type
                type_coords = [filtered_coords[i] for i in type_indices]
                type_confs = [filtered_confidences[i] for i in type_indices]

                if not type_coords:
                    continue

                # Create 3D scatter plot for this type
                x = [coord[2] for coord in type_coords]
                y = [coord[1] for coord in type_coords]
                z = [coord[0] for coord in type_coords]

                # Size points by confidence
                marker_size = [conf * 10 + 5 for conf in type_confs]

                fig.add_trace(
                    go.Scatter3d(
                        x=x, y=y, z=z,
                        mode='markers',
                        marker=dict(
                            size=marker_size,
                            color=color_map.get(syn_type, 'purple'),
                            opacity=0.8,
                            symbol='circle'
                        ),
                        name=f"{syn_type.capitalize()} Synapses",
                        hovertext=[f"Type: {syn_type}<br>Confidence: {conf:.2f}" for conf in type_confs]
                    ),
                    row=1, col=1
                )

            # Add pie chart of proportions
            labels = synapse_counts.index.tolist()
            values = synapse_counts.values.tolist()

            fig.add_trace(
                go.Pie(
                    labels=labels,
                    values=values,
                    textinfo='label+percent',
                    marker=dict(
                        colors=[color_map.get(label, 'purple') for label in labels]
                    )
                ),
                row=1, col=2
            )

            # Update layout
            fig.update_layout(
                title_text="Synapse Distribution Analysis",
                height=800,
                width=1200,
                scene=dict(
                    xaxis_title='X',
                    yaxis_title='Y',
                    zaxis_title='Z'
                )
            )

            # Save as HTML for interactive exploration
            html_output = os.path.splitext(output_file)[0] + '.html'
            fig.write_html(html_output)
            print(f"Interactive 3D visualization saved to {html_output}")

        except ImportError:
            print("Plotly not available, falling back to Matplotlib for visualization")

        # 2. Create a 3D visualization with matplotlib
        fig = plt.figure(figsize=(12, 10))
        ax = fig.add_subplot(111, projection='3d')

        # Plot each synapse with varying size based on confidence
        for i, (coords, syn_type, conf) in enumerate(zip(filtered_coords, filtered_types, filtered_confidences)):
            # Determine color based on synapse type
            if syn_type == 'excitatory':
                color = 'red'
            elif syn_type == 'inhibitory':
                color = 'blue'
            elif syn_type == 'uncertain':
                color = 'gray'
            else:
                color = 'black'

            # Size based on confidence
            size = conf * 50 + 20

            # Plot with alpha for better visualization of overlapping points
            ax.scatter(
                coords[2], coords[1], coords[0],
                c=color,
                s=size,
                alpha=0.7,
                edgecolors='white',
                linewidth=0.5
            )

        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')

        # Better title with statistics
        ax.set_title(f'Distribution of Synapses\n'
                    f'Excitatory: {synapse_counts.get("excitatory", 0)} | '
                    f'Inhibitory: {synapse_counts.get("inhibitory", 0)} | '
                    f'Uncertain: {synapse_counts.get("uncertain", 0)}')

        # Add a legend
        from matplotlib.lines import Line2D

        legend_elements = []
        for syn_type, color in [
            ('Excitatory', 'red'),
            ('Inhibitory', 'blue'),
            ('Uncertain', 'gray'),
            ('Unknown', 'black')
        ]:
            if syn_type.lower() in synapse_counts:
                legend_elements.append(
                    Line2D([0], [0], marker='o', color='w',
                           label=f'{syn_type} ({synapse_counts.get(syn_type.lower(), 0)})',
                           markerfacecolor=color, markersize=10)
                )

        ax.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1, 1))

        # Set axis limits with padding
        if filtered_coords:
            x_coords = [coord[2] for coord in filtered_coords]
            y_coords = [coord[1] for coord in filtered_coords]
            z_coords = [coord[0] for coord in filtered_coords]

            padding = 20  # Add padding around min/max
            ax.set_xlim(min(x_coords) - padding, max(x_coords) + padding)
            ax.set_ylim(min(y_coords) - padding, max(y_coords) + padding)
            ax.set_zlim(min(z_coords) - padding, max(z_coords) + padding)

        # Improve grid and ticks for better visibility
        ax.grid(True, linestyle='--', alpha=0.5)

        plt.tight_layout()
        plt.savefig(output_file, dpi=300)
        plt.close()  # Close figure to free memory
        print(f"3D visualization saved to {output_file}")

        # 3. Create a 2D density map for each axis plane
        plt.figure(figsize=(18, 6))

        # XY Plane (view from top)
        plt.subplot(1, 3, 1)
        for syn_type in ['excitatory', 'inhibitory']:
            type_indices = [i for i, t in enumerate(filtered_types) if t == syn_type]
            if not type_indices:
                continue

            type_coords = [filtered_coords[i] for i in type_indices]
            xs = [coord[2] for coord in type_coords]
            ys = [coord[1] for coord in type_coords]

            color = 'red' if syn_type == 'excitatory' else 'blue'
            plt.scatter(xs, ys, c=color, alpha=0.5, label=syn_type.capitalize())

        plt.xlabel('X')
        plt.ylabel('Y')
        plt.title('Synapse Distribution (XY Plane)')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # XZ Plane
        plt.subplot(1, 3, 2)
        for syn_type in ['excitatory', 'inhibitory']:
            type_indices = [i for i, t in enumerate(filtered_types) if t == syn_type]
            if not type_indices:
                continue

            type_coords = [filtered_coords[i] for i in type_indices]
            xs = [coord[2] for coord in type_coords]
            zs = [coord[0] for coord in type_coords]

            color = 'red' if syn_type == 'excitatory' else 'blue'
            plt.scatter(xs, zs, c=color, alpha=0.5, label=syn_type.capitalize())

        plt.xlabel('X')
        plt.ylabel('Z')
        plt.title('Synapse Distribution (XZ Plane)')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # YZ Plane
        plt.subplot(1, 3, 3)
        for syn_type in ['excitatory', 'inhibitory']:
            type_indices = [i for i, t in enumerate(filtered_types) if t == syn_type]
            if not type_indices:
                continue

            type_coords = [filtered_coords[i] for i in type_indices]
            ys = [coord[1] for coord in type_coords]
            zs = [coord[0] for coord in type_coords]

            color = 'red' if syn_type == 'excitatory' else 'blue'
            plt.scatter(ys, zs, c=color, alpha=0.5, label=syn_type.capitalize())

        plt.xlabel('Y')
        plt.ylabel('Z')
        plt.title('Synapse Distribution (YZ Plane)')
        plt.legend()
        plt.grid(True, alpha=0.3)

        plt.tight_layout()
        planes_output = os.path.splitext(output_file)[0] + '_planes.png'
        plt.savefig(planes_output, dpi=300)
        plt.close()  # Close figure to free memory
        print(f"2D plane visualizations saved to {planes_output}")

        # Create distribution statistics dictionary
        distribution = {
            'synapse_counts': synapse_counts.to_dict(),
            'synapse_coords': filtered_coords,
            'synapse_types': filtered_types,
            'synapse_confidences': filtered_confidences,
            'e_i_ratio': synapse_counts.get('excitatory', 0) / max(1, synapse_counts.get('inhibitory', 0)),
            'total_confident_synapses': len(filtered_types),
            'visualization_files': [output_file, planes_output]
        }

        # Save the distribution data as JSON for future reference
        import json

        # Convert numpy arrays to lists for JSON serialization
        json_safe_distribution = {
            'synapse_counts': distribution['synapse_counts'],
            'e_i_ratio': float(distribution['e_i_ratio']),
            'total_confident_synapses': distribution['total_confident_synapses'],
            'visualization_files': distribution['visualization_files'],
            # Don't save coordinates/types in JSON (too large)
        }

        json_output = os.path.splitext(output_file)[0] + '_stats.json'
        with open(json_output, 'w') as f:
            json.dump(json_safe_distribution, f, indent=2)

        print(f"Distribution statistics saved to {json_output}")

        return distribution

    except Exception as e:
        print(f"Error mapping synapse distribution: {e}")
        import traceback
        traceback.print_exc()

        # Return basic error info
        return {
            'error': str(e),
            'synapse_counts': {},
            'total_synapses': 0
        }

In [38]:
def map_synapse_distribution(self, volume_path, segmentation_path, output_file, confidence_threshold=0.7):
    """
    Map the distribution of excitatory and inhibitory synapses in a volume with optimized memory handling

    Args:
        volume_path (str): Path to the EM volume
        segmentation_path (str): Path to the synapse segmentation
        output_file (str): Path to save the distribution map
        confidence_threshold (float): Threshold for prediction confidence

    Returns:
        dict: Dictionary containing synapse counts and distribution
    """
    try:
        # Load the volume and segmentation with memory efficient approach
        print(f"Loading volume from {volume_path}...")

        # Check file size before loading to handle large volumes
        volume_size = os.path.getsize(volume_path) / (1024 * 1024 * 1024)  # Size in GB
        print(f"Volume file size: {volume_size:.2f} GB")

        # Use chunked loading for large volumes
        if volume_size > 2.0:  # More than 2GB
            print("Large volume detected, using chunked loading...")

            with h5py.File(volume_path, 'r') as f:
                # Get volume dimensions and key
                if 'em_volume' in f:
                    volume_key = 'em_volume'
                else:
                    # Find the first dataset in the file
                    for key in f.keys():
                        if isinstance(f[key], h5py.Dataset):
                            volume_key = key
                            break
                    else:
                        raise ValueError("Could not find a dataset in the volume file")

                volume_shape = f[volume_key].shape
                print(f"Volume shape: {volume_shape}")

                # Load segmentation
                print(f"Loading segmentation from {segmentation_path}...")
                with h5py.File(segmentation_path, 'r') as seg_f:
                    # Find segmentation dataset
                    if 'synapse_mask' in seg_f:
                        seg_key = 'synapse_mask'
                    else:
                        # Find the first dataset
                        for key in seg_f.keys():
                            if isinstance(seg_f[key], h5py.Dataset):
                                seg_key = key
                                break
                        else:
                            raise ValueError("Could not find a dataset in the segmentation file")

                    # Load entire segmentation if it's not too large
                    seg_size = os.path.getsize(segmentation_path) / (1024 * 1024 * 1024)
                    if seg_size < 1.0:  # Less than 1GB
                        segmentation = seg_f[seg_key][:]
                    else:
                        # Process segmentation in chunks
                        print("Large segmentation detected, finding synapse coordinates first...")
                        # Find connected components in chunks
                        synapse_coords = []

                        # Process in slabs along z-axis
                        chunk_size = min(50, volume_shape[0] // 10)  # Process ~10 chunks or fewer
                        for z_start in range(0, volume_shape[0], chunk_size):
                            z_end = min(z_start + chunk_size, volume_shape[0])
                            print(f"Processing segmentation chunk {z_start}-{z_end}...")

                            # Load chunk of segmentation
                            seg_chunk = seg_f[seg_key][z_start:z_end, :, :]

                            # Find synapses in this chunk
                            labeled_chunk, num_features = ndi.label(seg_chunk)

                            # For each synapse in this chunk
                            for i in range(1, num_features + 1):
                                # Get coordinates
                                coords = np.where(labeled_chunk == i)
                                if len(coords[0]) == 0:
                                    continue

                                # Adjust z coordinates for chunk offset
                                mean_coords = [
                                    coords[0].mean() + z_start,
                                    coords[1].mean(),
                                    coords[2].mean()
                                ]

                                synapse_coords.append(mean_coords)

                        # Create placeholder for synapse regions
                        patches = []

                        # Extract each synapse from the volume
                        patch_size = 64  # Fixed size for each synapse patch
                        for coords in tqdm(synapse_coords, desc="Extracting synapse patches"):
                            z, y, x = [int(c) for c in coords]

                            # Create bounds with padding
                            z_min = max(0, z - patch_size // 2)
                            z_max = min(volume_shape[0], z + patch_size // 2)
                            y_min = max(0, y - patch_size // 2)
                            y_max = min(volume_shape[1], y + patch_size // 2)
                            x_min = max(0, x - patch_size // 2)
                            x_max = min(volume_shape[2], x + patch_size // 2)

                            # Extract patch from volume
                            with h5py.File(volume_path, 'r') as vol_f:
                                patch = vol_f[volume_key][z_min:z_max, y_min:y_max, x_min:x_max]

                                # Take middle slice or max projection for 2D representation
                                if patch.shape[0] > 0:
                                    middle_slice = patch[patch.shape[0] // 2]
                                    patches.append(middle_slice)

                        # Skip the regular extraction process
                        print(f"Extracted {len(patches)} synapse patches directly")
                        # Predict types directly
                        synapse_types = []
                        confidences = []

                        for patch in tqdm(patches, desc="Predicting synapse types"):
                            syn_type, conf = self.predict_synapse_type(patch, threshold=confidence_threshold)
                            synapse_types.append(syn_type)
                            confidences.append(conf)

                        # Skip to visualization
                        goto_visualization = True
        else:
            # Standard loading for smaller volumes
            with h5py.File(volume_path, 'r') as f:
                # Find the volume dataset
                if 'em_volume' in f:
                    volume = f['em_volume'][:]
                else:
                    # Find the first dataset
                    for key in f.keys():
                        if isinstance(f[key], h5py.Dataset):
                            volume = f[key][:]
                            break
                    else:
                        raise ValueError("Could not find a dataset in the volume file")

            print(f"Loading segmentation from {segmentation_path}...")
            with h5py.File(segmentation_path, 'r') as f:
                if 'synapse_mask' in f:
                    segmentation = f['synapse_mask'][:]
                else:
                    # Find the first dataset
                    for key in f.keys():
                        if isinstance(f[key], h5py.Dataset):
                            segmentation = f[key][:]
                            break
                    else:
                        raise ValueError("Could not find a dataset in the segmentation file")

            # Standard extraction process
            patches = self.extract_synapse_regions(volume, segmentation)
            print(f"Extracted {len(patches)} synapse patches")

            # Find connected components in the segmentation
            labeled_mask, num_features = ndi.label(segmentation)

            # Predict synapse types
            synapse_types = []
            confidences = []
            synapse_coords = []

            for i in tqdm(range(1, num_features + 1), desc="Predicting synapse types"):
                # Get coordinates for this synapse
                coords = np.mean(np.where(labeled_mask == i), axis=1)
                synapse_coords.append(coords)

                # Get the patch for this synapse
                patch_idx = i - 1  # 0-based indexing for patches
                if patch_idx < len(patches):
                    patch = patches[patch_idx]

                    # Predict the type
                    syn_type, conf = self.predict_synapse_type(patch, threshold=confidence_threshold)
                    synapse_types.append(syn_type)
                    confidences.append(conf)
                else:
                    # Skip this synapse if patch not available
                    synapse_types.append("unknown")
                    confidences.append(0.0)

            goto_visualization = False

        # Visualization and analysis
        if 'goto_visualization' not in locals() or not goto_visualization:
            # Filter by confidence if not already done
            confident_indices = [i for i, conf in enumerate(confidences) if conf >= confidence_threshold]
            if confident_indices:  # Check if any confident predictions exist
                filtered_types = [synapse_types[i] for i in confident_indices]
                filtered_coords = [synapse_coords[i] for i in confident_indices]
                filtered_confidences = [confidences[i] for i in confident_indices]
                print(f"Using {len(filtered_types)} synapses with confidence >= {confidence_threshold}")
            else:
                print("Warning: No synapses met the confidence threshold. Using all synapses.")
                filtered_types = synapse_types
                filtered_coords = synapse_coords
                filtered_confidences = confidences
        else:
            # These were already filtered by confidence threshold in the prediction step
            filtered_types = synapse_types
            filtered_coords = synapse_coords
            filtered_confidences = confidences

        # Check if we have any synapses to visualize
        if not filtered_types:
            print("No synapses found to visualize. Returning empty results.")
            return {
                'synapse_counts': {},
                'synapse_coords': [],
                'synapse_types': [],
                'synapse_confidences': [],
                'e_i_ratio': 0.0,
                'total_confident_synapses': 0,
                'visualization_files': []
            }

        # Count the number of each type
        synapse_counts = pd.Series(filtered_types).value_counts()
        print("\nSynapse Type Distribution:")
        for syn_type, count in synapse_counts.items():
            percentage = (count / len(filtered_types)) * 100
            print(f"  {syn_type.capitalize()}: {count} ({percentage:.1f}%)")

        # Create enhanced visualizations

        # 1. 3D interactive visualization with plotly if available
        try:
            import plotly.graph_objects as go
            from plotly.subplots import make_subplots

            # Create interactive 3D plot
            fig = make_subplots(
                rows=1, cols=2,
                specs=[[{'type': 'scatter3d'}, {'type': 'pie'}]],
                subplot_titles=('Synapse Distribution in 3D', 'Synapse Type Proportions')
            )

            # Define colors for each type
            color_map = {
                'excitatory': 'red',
                'inhibitory': 'blue',
                'uncertain': 'gray',
                'unknown': 'black'
            }

            # Create separate traces for each type for better interactive legend
            for syn_type in set(filtered_types):
                # Get indices for this type
                type_indices = [i for i, t in enumerate(filtered_types) if t == syn_type]

                # Extract coordinates and confidences for this type
                type_coords = [filtered_coords[i] for i in type_indices]
                type_confs = [filtered_confidences[i] for i in type_indices]

                if not type_coords:
                    continue

                # Create 3D scatter plot for this type
                x = [coord[2] for coord in type_coords]
                y = [coord[1] for coord in type_coords]
                z = [coord[0] for coord in type_coords]

                # Size points by confidence
                marker_size = [conf * 10 + 5 for conf in type_confs]

                fig.add_trace(
                    go.Scatter3d(
                        x=x, y=y, z=z,
                        mode='markers',
                        marker=dict(
                            size=marker_size,
                            color=color_map.get(syn_type, 'purple'),
                            opacity=0.8,
                            symbol='circle'
                        ),
                        name=f"{syn_type.capitalize()} Synapses",
                        hovertext=[f"Type: {syn_type}<br>Confidence: {conf:.2f}" for conf in type_confs]
                    ),
                    row=1, col=1
                )

            # Add pie chart of proportions
            labels = synapse_counts.index.tolist()
            values = synapse_counts.values.tolist()

            fig.add_trace(
                go.Pie(
                    labels=labels,
                    values=values,
                    textinfo='label+percent',
                    marker=dict(
                        colors=[color_map.get(label, 'purple') for label in labels]
                    )
                ),
                row=1, col=2
            )

            # Update layout
            fig.update_layout(
                title_text="Synapse Distribution Analysis",
                height=800,
                width=1200,
                scene=dict(
                    xaxis_title='X',
                    yaxis_title='Y',
                    zaxis_title='Z'
                )
            )

            # Save as HTML for interactive exploration
            html_output = os.path.splitext(output_file)[0] + '.html'
            fig.write_html(html_output)
            print(f"Interactive 3D visualization saved to {html_output}")

        except ImportError:
            print("Plotly not available, falling back to Matplotlib for visualization")
        except Exception as viz_error:
            print(f"Error creating interactive 3D visualization: {viz_error}")
            print("Falling back to standard Matplotlib visualization")

        # 2. Create a 3D visualization with matplotlib
        try:
            fig = plt.figure(figsize=(12, 10))
            ax = fig.add_subplot(111, projection='3d')

            # Plot each synapse with varying size based on confidence
            for i, (coords, syn_type, conf) in enumerate(zip(filtered_coords, filtered_types, filtered_confidences)):
                # Determine color based on synapse type
                if syn_type == 'excitatory':
                    color = 'red'
                elif syn_type == 'inhibitory':
                    color = 'blue'
                elif syn_type == 'uncertain':
                    color = 'gray'
                else:
                    color = 'black'

                # Size based on confidence
                size = conf * 50 + 20

                # Plot with alpha for better visualization of overlapping points
                ax.scatter(
                    coords[2], coords[1], coords[0],
                    c=color,
                    s=size,
                    alpha=0.7,
                    edgecolors='white',
                    linewidth=0.5
                )

            ax.set_xlabel('X')
            ax.set_ylabel('Y')
            ax.set_zlabel('Z')

            # Better title with statistics
            ax.set_title(f'Distribution of Synapses\n'
                        f'Excitatory: {synapse_counts.get("excitatory", 0)} | '
                        f'Inhibitory: {synapse_counts.get("inhibitory", 0)} | '
                        f'Uncertain: {synapse_counts.get("uncertain", 0)}')

            # Add a legend
            from matplotlib.lines import Line2D

            legend_elements = []
            for syn_type, color in [
                ('Excitatory', 'red'),
                ('Inhibitory', 'blue'),
                ('Uncertain', 'gray'),
                ('Unknown', 'black')
            ]:
                if syn_type.lower() in synapse_counts:
                    legend_elements.append(
                        Line2D([0], [0], marker='o', color='w',
                               label=f'{syn_type} ({synapse_counts.get(syn_type.lower(), 0)})',
                               markerfacecolor=color, markersize=10)
                    )

            if legend_elements:  # Only add legend if we have elements
                ax.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1, 1))

            # Set axis limits with padding
            if filtered_coords:
                x_coords = [coord[2] for coord in filtered_coords]
                y_coords = [coord[1] for coord in filtered_coords]
                z_coords = [coord[0] for coord in filtered_coords]

                if x_coords and y_coords and z_coords:  # Check if coordinates are not empty
                    padding = 20  # Add padding around min/max
                    ax.set_xlim(min(x_coords) - padding, max(x_coords) + padding)
                    ax.set_ylim(min(y_coords) - padding, max(y_coords) + padding)
                    ax.set_zlim(min(z_coords) - padding, max(z_coords) + padding)

            # Improve grid and ticks for better visibility
            ax.grid(True, linestyle='--', alpha=0.5)

            plt.tight_layout()
            plt.savefig(output_file, dpi=300)
            plt.close()  # Close figure to free memory
            print(f"3D visualization saved to {output_file}")

            # 3. Create a 2D density map for each axis plane
            plt.figure(figsize=(18, 6))

            # XY Plane (view from top)
            plt.subplot(1, 3, 1)
            for syn_type in ['excitatory', 'inhibitory']:
                type_indices = [i for i, t in enumerate(filtered_types) if t == syn_type]
                if not type_indices:
                    continue

                type_coords = [filtered_coords[i] for i in type_indices]
                xs = [coord[2] for coord in type_coords]
                ys = [coord[1] for coord in type_coords]

                color = 'red' if syn_type == 'excitatory' else 'blue'
                plt.scatter(xs, ys, c=color, alpha=0.5, label=syn_type.capitalize())

            plt.xlabel('X')
            plt.ylabel('Y')
            plt.title('Synapse Distribution (XY Plane)')
            plt.legend()
            plt.grid(True, alpha=0.3)

            # XZ Plane
            plt.subplot(1, 3, 2)
            for syn_type in ['excitatory', 'inhibitory']:
                type_indices = [i for i, t in enumerate(filtered_types) if t == syn_type]
                if not type_indices:
                    continue

                type_coords = [filtered_coords[i] for i in type_indices]
                xs = [coord[2] for coord in type_coords]
                zs = [coord[0] for coord in type_coords]

                color = 'red' if syn_type == 'excitatory' else 'blue'
                plt.scatter(xs, zs, c=color, alpha=0.5, label=syn_type.capitalize())

            plt.xlabel('X')
            plt.ylabel('Z')
            plt.title('Synapse Distribution (XZ Plane)')
            plt.legend()
            plt.grid(True, alpha=0.3)

            # YZ Plane
            plt.subplot(1, 3, 3)
            for syn_type in ['excitatory', 'inhibitory']:
                type_indices = [i for i, t in enumerate(filtered_types) if t == syn_type]
                if not type_indices:
                    continue

                type_coords = [filtered_coords[i] for i in type_indices]
                ys = [coord[1] for coord in type_coords]
                zs = [coord[0] for coord in type_coords]

                color = 'red' if syn_type == 'excitatory' else 'blue'
                plt.scatter(ys, zs, c=color, alpha=0.5, label=syn_type.capitalize())

            plt.xlabel('Y')
            plt.ylabel('Z')
            plt.title('Synapse Distribution (YZ Plane)')
            plt.legend()
            plt.grid(True, alpha=0.3)

            plt.tight_layout()
            planes_output = os.path.splitext(output_file)[0] + '_planes.png'
            plt.savefig(planes_output, dpi=300)
            plt.close()  # Close figure to free memory
            print(f"2D plane visualizations saved to {planes_output}")

        except Exception as matplotlib_error:
            print(f"Error creating matplotlib visualizations: {matplotlib_error}")
            import traceback
            traceback.print_exc()
            print("Visualization failed, but continuing with data analysis")
            planes_output = None  # No planes output if visualization failed

        # Create distribution statistics dictionary
        e_i_ratio = synapse_counts.get('excitatory', 0) / max(1, synapse_counts.get('inhibitory', 0))
        viz_files = [output_file]
        if planes_output:
            viz_files.append(planes_output)

        distribution = {
            'synapse_counts': synapse_counts.to_dict(),
            'synapse_coords': filtered_coords,
            'synapse_types': filtered_types,
            'synapse_confidences': filtered_confidences,
            'e_i_ratio': e_i_ratio,
            'total_confident_synapses': len(filtered_types),
            'visualization_files': viz_files
        }

        # Save the distribution data as JSON for future reference
        try:
            import json

            # Convert numpy arrays to lists for JSON serialization
            json_safe_distribution = {
                'synapse_counts': distribution['synapse_counts'],
                'e_i_ratio': float(e_i_ratio),
                'total_confident_synapses': distribution['total_confident_synapses'],
                'visualization_files': distribution['visualization_files'],
                # Don't save coordinates/types in JSON (too large)
            }

            json_output = os.path.splitext(output_file)[0] + '_stats.json'
            with open(json_output, 'w') as f:
                json.dump(json_safe_distribution, f, indent=2)

            print(f"Distribution statistics saved to {json_output}")
        except Exception as json_error:
            print(f"Error saving JSON statistics: {json_error}")

        return distribution

    except Exception as e:
        print(f"Error mapping synapse distribution: {e}")
        import traceback
        traceback.print_exc()

        # Return basic error info
        return {
            'error': str(e),
            'synapse_counts': {},
            'total_synapses': 0
        }

In [None]:
cv_results[f'precision_{label}'].append(class_metrics['precision'])
                            cv_results[f'recall_{label}'].append(class_metrics['recall'])
                            cv_results[f'f1_{label}'].append(class_metrics['f1-score'])

                    # Free memory
                    del self.model
                    gc.collect()

                # Calculate means and standard deviations
                cv_summary = {}
                for metric, values in cv_results.items():
                    if metric != 'confusion_matrices':
                        cv_summary[f'{metric}_mean'] = np.mean(values)
                        cv_summary[f'{metric}_std'] = np.std(values)

                # Save cross-validation results
                with open(os.path.join(self.output_path, 'cv_results.pkl'), 'wb') as f:
                    pickle.dump(cv_results, f)

                # Create a summary report
                cv_report = pd.DataFrame({
                    'Metric': list(cv_summary.keys()),
                    'Value': list(cv_summary.values())
                })
                cv_report.to_csv(os.path.join(self.output_path, 'cv_summary.csv'), index=False)

                print(f"\nCross-validation completed in {time.time() - cv_start_time:.2f} seconds")
                print("\nCross-validation Summary:")
                print(f"  Mean Accuracy: {cv_summary['accuracy_mean']:.4f} ± {cv_summary['accuracy_std']:.4f}")

                # Plot cross-validation results
                plt.figure(figsize=(12, 6))
                metrics_to_plot = ['accuracy']
                for label in set([0, 1]):  # Binary classification
                    metrics_to_plot.extend([f'precision_{label}', f'recall_{label}', f'f1_{label}'])

                for i, metric in enumerate(metrics_to_plot):
                    if metric in cv_results:
                        plt.subplot(1, len(metrics_to_plot), i+1)
                        plt.boxplot(cv_results[metric])
                        plt.title(metric.replace('_', ' ').title())
                        plt.ylim(0, 1)

                plt.tight_layout()
                plt.savefig(os.path.join(self.output_path, 'cv_results.png'))

            # 3. Train on full dataset for final model
            print("\n3. Training final model on full dataset...")
            train_start_time = time.time()

            # Split data for final evaluation
            X_train, X_test, y_train, y_test = train_test_split(
                X, y, test_size=0.2, random_state=42, stratify=y
            )

            # Build and train the final model
            self.model = self.build_classification_model(use_pretrained=True)
            history = self.train_model(X_train, y_train, epochs=50, fine_tune=True)

            print(f"Model training completed in {time.time() - train_start_time:.2f} seconds")

            # 4. Evaluate the final model
            print("\n4. Evaluating final model...")
            eval_start_time = time.time()
            metrics = self.evaluate_model(X_test, y_test)
            print(f"Evaluation completed in {time.time() - eval_start_time:.2f} seconds")

            # 5. Visualize model attention on sample images
            print("\n5. Generating attention visualizations...")
            viz_start_time = time.time()

            # Create a directory for visualizations
            viz_dir = os.path.join(self.output_path, 'visualizations')
            os.makedirs(viz_dir, exist_ok=True)

            # Randomly select some test samples from each class
            sample_indices = []
            for class_label in np.unique(y_test):
                class_indices = np.where(y_test == class_label)[0]
                selected = np.random.choice(class_indices,
                                           size=min(3, len(class_indices)),
                                           replace=False)
                sample_indices.extend(selected)

            # Generate attention maps
            for i, idx in enumerate(sample_indices):
                try:
                    viz_path = os.path.join(viz_dir, f'attention_sample_{i+1}.png')
                    self.visualize_model_attention(
                        X_test[idx],
                        true_label=y_test[idx],
                        save_path=viz_path
                    )
                except Exception as e:
                    print(f"Error generating visualization for sample {i+1}: {e}")

            print(f"Visualizations generated in {time.time() - viz_start_time:.2f} seconds")

            # 6. Save and summarize results
            print("\n6. Saving results and generating summary...")

            # Create learning curve plot
            plt.figure(figsize=(12, 5))

            # Plot training & validation accuracy
            plt.subplot(1, 2, 1)
            plt.plot(history['accuracy'], label='Train')
            plt.plot(history['val_accuracy'], label='Validation')
            plt.title('Model Accuracy')
            plt.ylabel('Accuracy')
            plt.xlabel('Epoch')
            plt.legend(loc='lower right')
            plt.grid(True, alpha=0.3)

            # Plot training & validation loss
            plt.subplot(1, 2, 2)
            plt.plot(history['loss'], label='Train')
            plt.plot(history['val_loss'], label='Validation')
            plt.title('Model Loss')
            plt.ylabel('Loss')
            plt.xlabel('Epoch')
            plt.legend(loc='upper right')
            plt.grid(True, alpha=0.3)

            plt.tight_layout()
            plt.savefig(os.path.join(self.output_path, 'learning_curves.png'), dpi=300)

            # Save training history
            with open(os.path.join(self.output_path, 'training_history.json'), 'w') as f:
                # Convert numpy values to Python native types for JSON serialization
                json_history = {}
                for k, v in history.items():
                    json_history[k] = [float(val) for val in v]
                json.dump(json_history, f, indent=2)

            # Save the results
            results = {
                'history': history,
                'metrics': metrics,
                'model_path': os.path.join(self.output_path, 'final_model'),
                'training_time': time.time() - train_start_time,
                'total_time': time.time() - pipeline_start_time
            }

            # Create a summary report
            with open(os.path.join(self.output_path, 'results_summary.txt'), 'w') as f:
                f.write(f"Synapse Classification Results Summary\n")
                f.write(f"{'='*50}\n\n")
                f.write(f"Dataset Information:\n")
                f.write(f"  Total samples: {len(X)}\n")
                f.write(f"  Training samples: {len(X_train)}\n")
                f.write(f"  Test samples: {len(X_test)}\n")
                f.write(f"  Class distribution: {pd.Series(y).value_counts().to_dict()}\n\n")

                f.write(f"Model Performance:\n")
                f.write(f"  Test Accuracy: {metrics['accuracy']:.4f}\n\n")

                f.write(f"Classification Report:\n")
                class_report = classification_report(y_test, np.argmax(metrics['pred_probabilities'], axis=1))
                f.write(f"{class_report}\n\n")

                f.write(f"Training Information:\n")
                f.write(f"  Training time: {results['training_time']:.2f} seconds\n")
                f.write(f"  Total pipeline time: {results['total_time']:.2f} seconds\n")
                f.write(f"  Final validation accuracy: {history['val_accuracy'][-1]:.4f}\n")
                f.write(f"  Final validation loss: {history['val_loss'][-1]:.4f}\n\n")

                f.write(f"Model files saved to: {self.output_path}\n")

            print(f"\nPipeline completed successfully in {time.time() - pipeline_start_time:.2f} seconds!")
            print(f"Results saved to {self.output_path}")

            return results

        except Exception as e:
            print(f"Error in pipeline execution: {e}")
            import traceback
            traceback.print_exc()

            return {
                'error': str(e),
                'traceback': traceback.format_exc()
            }

In [None]:
def _plot_training_history(self, history):
        """
        Plot the training history with enhanced visualizations

        Args:
            history (dict): Training history
        """
        # Create a multi-faceted visualization of training history
        plt.figure(figsize=(15, 10))

        # 1. Accuracy plot
        plt.subplot(2, 2, 1)
        plt.plot(history['accuracy'], linewidth=2)
        plt.plot(history['val_accuracy'], linewidth=2)
        plt.title('Model Accuracy', fontsize=14)
        plt.ylabel('Accuracy', fontsize=12)
        plt.xlabel('Epoch', fontsize=12)
        plt.legend(['Train', 'Validation'], loc='lower right')
        plt.grid(True, linestyle='--', alpha=0.6)
        plt.ylim([0, 1.0])

        # Highlight best validation accuracy
        best_val_acc_epoch = np.argmax(history['val_accuracy'])
        best_val_acc = history['val_accuracy'][best_val_acc_epoch]
        plt.axhline(y=best_val_acc, color='r', linestyle='--', alpha=0.3)
        plt.axvline(x=best_val_acc_epoch, color='r', linestyle='--', alpha=0.3)
        plt.scatter(best_val_acc_epoch, best_val_acc, s=100, c='red', alpha=0.5, zorder=5)
        plt.annotate(f'Best: {best_val_acc:.4f}',
                    (best_val_acc_epoch, best_val_acc),
                    xytext=(best_val_acc_epoch+1, best_val_acc),
                    fontsize=10)

        # 2. Loss plot
        plt.subplot(2, 2, 2)
        plt.plot(history['loss'], linewidth=2)
        plt.plot(history['val_loss'], linewidth=2)
        plt.title('Model Loss', fontsize=14)
        plt.ylabel('Loss', fontsize=12)
        plt.xlabel('Epoch', fontsize=12)
        plt.legend(['Train', 'Validation'], loc='upper right')
        plt.grid(True, linestyle='--', alpha=0.6)

        # Highlight best validation loss
        best_val_loss_epoch = np.argmin(history['val_loss'])
        best_val_loss = history['val_loss'][best_val_loss_epoch]
        plt.axhline(y=best_val_loss, color='r', linestyle='--', alpha=0.3)
        plt.axvline(x=best_val_loss_epoch, color='r', linestyle='--', alpha=0.3)
        plt.scatter(best_val_loss_epoch, best_val_loss, s=100, c='red', alpha=0.5, zorder=5)
        plt.annotate(f'Best: {best_val_loss:.4f}',
                    (best_val_loss_epoch, best_val_loss),
                    xytext=(best_val_loss_epoch+1, best_val_loss),
                    fontsize=10)

        # 3. Accuracy vs Loss scatter plot
        plt.subplot(2, 2, 3)
        plt.scatter(history['loss'], history['accuracy'], alpha=0.5, s=70, c='blue', label='Train')
        plt.scatter(history['val_loss'], history['val_accuracy'], alpha=0.5, s=70, c='red', label='Validation')

        # Add arrows to show training progression
        for i in range(len(history['loss'])-1):
            plt.arrow(history['loss'][i], history['accuracy'][i],
                    history['loss'][i+1] - history['loss'][i],
                    history['accuracy'][i+1] - history['accuracy'][i],
                    head_width=0.01, head_length=0.01, fc='blue', ec='blue', alpha=0.3)

            plt.arrow(history['val_loss'][i], history['val_accuracy'][i],
                    history['val_loss'][i+1] - history['val_loss'][i],
                    history['val_accuracy'][i+1] - history['val_accuracy'][i],
                    head_width=0.01, head_length=0.01, fc='red', ec='red', alpha=0.3)

        plt.title('Accuracy vs Loss', fontsize=14)
        plt.xlabel('Loss', fontsize=12)
        plt.ylabel('Accuracy', fontsize=12)
        plt.legend(loc='lower left')
        plt.grid(True, linestyle='--', alpha=0.6)

        # 4. Learning rate (if available) or train vs val metrics
        plt.subplot(2, 2, 4)
        if 'lr' in history:
            plt.semilogy(history['lr'], linewidth=2)
            plt.title('Learning Rate', fontsize=14)
            plt.ylabel('Learning Rate', fontsize=12)
            plt.xlabel('Epoch', fontsize=12)
            plt.grid(True, linestyle='--', alpha=0.6)
        else:
            # Plot the difference between train and validation
            train_val_acc_diff = np.array(history['accuracy']) - np.array(history['val_accuracy'])
            train_val_loss_diff = np.array(history['loss']) - np.array(history['val_loss'])

            plt.plot(train_val_acc_diff, label='Acc Diff (Train-Val)', color='green')
            plt.plot(train_val_loss_diff, label='Loss Diff (Train-Val)', color='purple')
            plt.axhline(y=0, color='r', linestyle='-', alpha=0.3)
            plt.title('Model Overfitting Indicators', fontsize=14)
            plt.ylabel('Difference (Train-Val)', fontsize=12)
            plt.xlabel('Epoch', fontsize=12)
            plt.legend(loc='upper right')
            plt.grid(True, linestyle='--', alpha=0.6)

        plt.tight_layout()
        plt.savefig(os.path.join(self.output_path, 'training_history.png'), dpi=300)
        plt.close()


def main():
    """
    Main function to run the synapse classification pipeline with improved error handling
    """
    import time
    import argparse

    # Parse command line arguments
    parser = argparse.ArgumentParser(description='Synapse Classification from EM Images')
    parser.add_argument('--data_path', type=str, required=True, help='Path to the Microns dataset')
    parser.add_argument('--output_path', type=str, default='./synapse_classification_results',
                        help='Path to save results')
    parser.add_argument('--sample_limit', type=int, default=None,
                        help='Limit the number of samples (for testing)')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training')
    parser.add_argument('--image_size', type=int, default=224, help='Image size for processing')
    parser.add_argument('--epochs', type=int, default=50, help='Number of training epochs')
    parser.add_argument('--cross_validation', action='store_true',
                        help='Use cross-validation for evaluation')
    parser.add_argument('--no_pretrained', action='store_true',
                        help='Do not use pretrained model')
    parser.add_argument('--test_only', action='store_true',
                        help='Only run evaluation on existing model')

    args = parser.parse_args()

    try:
        # Create output directory if it doesn't exist
        os.makedirs(args.output_path, exist_ok=True)

        # Set up logging
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(os.path.join(args.output_path, 'pipeline.log')),
                logging.StreamHandler()
            ]
        )

        logger = logging.getLogger(__name__)
        logger.info(f"Starting synapse classification pipeline with args: {args}")

        # Initialize the classifier
        start_time = time.time()
        logger.info(f"Initializing SynapseClassifier...")

        classifier = SynapseClassifier(
            args.data_path,
            args.output_path,
            image_size=(args.image_size, args.image_size),
            batch_size=args.batch_size
        )

        # Run the pipeline or load existing model for testing
        if args.test_only:
            logger.info("Test only mode activated. Loading existing model...")
            try:
                # Try to load model in TensorFlow format first
                model_path = os.path.join(args.output_path, 'final_model')
                if os.path.exists(model_path):
                    classifier.model = tf.keras.models.load_model(model_path)
                else:
                    # Try H5 format as fallback
                    h5_path = os.path.join(args.output_path, 'final_model.h5')
                    if os.path.exists(h5_path):
                        classifier.model = tf.keras.models.load_model(h5_path)
                    else:
                        raise FileNotFoundError("No model file found in the output directory")

                logger.info("Model loaded successfully. Preparing test data...")

                # Load test data
                X, y = classifier.load_microns_data(sample_limit=args.sample_limit)
                X_train, X_test, y_train, y_test = train_test_split(
                    X, y, test_size=0.2, random_state=42, stratify=y
                )

                # Evaluate the model
                logger.info("Evaluating the model...")
                metrics = classifier.evaluate_model(X_test, y_test)

                # Save evaluation results
                with open(os.path.join(args.output_path, 'test_results.pkl'), 'wb') as f:
                    pickle.dump(metrics, f)

                logger.info(f"Test completed in {time.time() - start_time:.2f} seconds")
                logger.info(f"Test accuracy: {metrics['accuracy']:.4f}")

            except Exception as e:
                logger.error(f"Error in test mode: {e}")
                import traceback
                logger.error(traceback.format_exc())
                sys.exit(1)

        else:
            # Run the full pipeline
            logger.info("Running full pipeline...")
            results = classifier.run_full_pipeline(
                sample_limit=args.sample_limit,
                cross_validation=args.cross_validation
            )

            logger.info(f"Pipeline completed in {time.time() - start_time:.2f} seconds")

            if 'error' in results:
                logger.error(f"Pipeline failed with error: {results['error']}")
                sys.exit(1)
            else:
                logger.info(f"Pipeline completed successfully!")
                logger.info(f"Test accuracy: {results['metrics']['accuracy']:.4f}")

        logger.info(f"Results saved to {args.output_path}")

    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)


if __name__ == "__main__":
    # Add missing imports
    import sys
    import time
    import logging
    import json

    main()