In [None]:
import sys
import os 
from PIL import Image
import numpy as np
import pandas as pd
from PIL import Image
import cv2
sys.path.append('..')  #adds the Root Directory to the system path
from BL_CalciumAnalysis.image_analysis_methods import ImageAnalysis


In [None]:
#test 

In [None]:
print(sys.executable) #print the path of the Python executable being used, which should point to the Python interpreter in your Conda environment.

In [None]:

from skimage.measure import label, regionprops
import matplotlib.pyplot as plt
from skimage.color import rgb2gray
from skimage import io
import glob
import ast 


class ImageAnalysis:
    def __init__(self, project_folder):
        self.project_folder = project_folder
        self.directory_df = self.initialize_directory_df() 
        
    def initialize_directory_df(self):
        directories = [d for d in os.listdir(self.project_folder) if os.path.isdir(os.path.join(self.project_folder, d))]
        directory_data = [{'directory_name': d, 'directory_path': os.path.join(self.project_folder, d)} for d in directories]
        return pd.DataFrame(directory_data, columns=['directory_name', 'directory_path'])
    
    def list_directories(self):
        return [d for d in os.listdir(self.project_folder) if os.path.isdir(os.path.join(self.project_folder, d))]
    
    def list_files(self, folder_name):
        folder_path = os.path.join(self.project_folder, folder_name)
        all_files = []
        for root, dirs, files in os.walk(folder_path):
            for file in files:
                all_files.append(os.path.join(root, file))
        return all_files
    
    def generate_dark_image(self, tiff_path, num_frames=200):
        """
        Generates a median 'dark' image from the first specified number of frames in a multi-frame TIFF file.

        This method is used for compensating the dark pixel offset in bioluminescence imaging data.

        Parameters:
        tiff_path (str): Path to the multi-frame TIFF file.
        num_frames (int, optional): Number of frames to consider for generating the dark image. Defaults to 200.

        Returns:
        numpy.ndarray: A median image representing the 'dark' image.
        """
        with Image.open(tiff_path) as img:
            frames = [np.array(img.getdata(), dtype=np.float32).reshape(img.size[::-1]) for i in range(num_frames)]
            median_frame = np.median(frames, axis=0)
            return median_frame

    def subtract_dark_image(self, raw_tiff_path, dark_image):
        """
        Subtracts a 'dark' image from each frame of a multi-frame TIFF file.

        This method is used to compensate for the average dark pixel offset in bioluminescence imaging data.

        Parameters:
        raw_tiff_path (str): Path to the raw multi-frame TIFF file.
        dark_image (numpy.ndarray): The 'dark' image to be subtracted from each frame of the raw image.

        Returns:
        list of numpy.ndarray: A list of images, each representing a frame from the raw image with the dark image subtracted.
        """
        with Image.open(raw_tiff_path) as img:
            compensated_images = []
            for i in range(img.n_frames):
                img.seek(i)
                frame = np.array(img.getdata(), dtype=np.float32).reshape(img.size[::-1])
                compensated_image = cv2.subtract(frame, dark_image)
                compensated_images.append(compensated_image)
            return compensated_images
        
    def expand_directory_df(self):
        # Add new columns with default empty lists
        self.directory_df['sensor_type'] = ''
        self.directory_df['session_id'] = ''
        self.directory_df['stimulation_ids'] = [[] for _ in range(len(self.directory_df))]
        self.directory_df['stimulation_frame_number'] = [[] for _ in range(len(self.directory_df))]

        for index, row in self.directory_df.iterrows():
            folder_name = row['directory_name']
            folder_path = row['directory_path']
            
            # Parse folder name for sensor type and session id
            parts = folder_name.split('_')
            sensor_type = 'gcamp8' if parts[0].startswith('g') else 'cablam'
            session_id = parts[0][1:] + parts[1]  # Assuming the first part is always the experiment ID

            # Update DataFrame with sensor_type and session_id
            self.directory_df.at[index, 'sensor_type'] = sensor_type
            self.directory_df.at[index, 'session_id'] = session_id

            # Check for CSV file ending in 'biolumi' or 'fluor'
            csv_filename = [f for f in os.listdir(folder_path) if (f.endswith('biolumi.csv') or f.endswith('fluor.csv'))]
            if csv_filename:
                csv_file_path = os.path.join(folder_path, csv_filename[0])
                df_csv = pd.read_csv(csv_file_path, header=None)
                stimulation_ids = df_csv.iloc[1].dropna().tolist()
                stimulation_frame_number = df_csv.iloc[0].dropna().tolist()

                # Update DataFrame with stimulation information
                self.directory_df.at[index, 'stimulation_ids'] = stimulation_ids
                self.directory_df.at[index, 'stimulation_frame_number'] = stimulation_frame_number

        return self.directory_df
    
    def get_session_raw_data(self, session_id):
        # Check if the session_id is in the 'session_id' column of the directory_df
        if session_id in self.directory_df['session_id'].tolist():
            # Find the directory path for the given session_id
            directory_path = self.directory_df[self.directory_df['session_id'] == session_id]['directory_path'].values[0]
            
            # Search for the .tif file within that directory
            for file_name in os.listdir(directory_path):
                if file_name.endswith('.tif'):
                    return os.path.join(directory_path, file_name)

            # If no .tif file is found in the directory
            return f"No .tif file found in the directory for session {session_id}."
        else:
            # If the session_id is not present in the DataFrame
            return f"Session ID {session_id} is not present in the directory DataFrame."
        
    def max_projection_mean_values(self, tif_path):
        """
        Generates a maximum intensity projection based on the mean values of a multi-frame TIF file
        and saves it to a new subdirectory 'processed_data/processed_image_analysis_output'
        with a '_max_projection' suffix in the file name.

        Parameters:
        tif_path (str): Path to the multi-frame TIF file.

        Returns:
        str: Path to the saved maximum intensity projection image.
        """

        with Image.open(tif_path) as img:
            # Initialize a summing array with the shape of the first frame and float type for mean calculation
            sum_image = np.zeros((img.height, img.width), dtype=np.float32)

            # Sum up all frames
            for i in range(img.n_frames):
                img.seek(i)
                sum_image += np.array(img, dtype=np.float32)

            # Compute the mean image by dividing the sum by the number of frames
            mean_image = sum_image / img.n_frames
        
        # Define the new directory path
        processed_dir = os.path.join(os.path.dirname(tif_path), 'processed_data', 'processed_image_analysis_output')
        
        # Create the directory if it does not exist
        os.makedirs(processed_dir, exist_ok=True)
        
        # Create a new file path for the max projection image with the '_max_projection' suffix
        # The filename is extracted from tif_path and appended with '_max_projection.tif'
        file_name = os.path.basename(tif_path)
        max_proj_image_path = os.path.join(processed_dir, file_name.replace('.tif', '_max_projection.tif'))
       
        # Save the max projection image to the new file path
        Image.fromarray(mean_image).save(max_proj_image_path)

        # Return the path to the saved image
        return max_proj_image_path
    
    def analyze_all_sessions(self, function_to_apply):
        """
        Iterates over all session IDs in the directory DataFrame and applies the given function to each.

        Parameters:
        function_to_apply (callable): Function to be applied to each session. It should accept a session ID.

        Returns:
        dict: A dictionary with session_ids as keys and function return values as values.
        """
        results = {}
        for session_id in self.directory_df['session_id']:
            try:
                result = function_to_apply(session_id)
                results[session_id] = result
            except Exception as e:
                print(f"An error occurred while processing session {session_id}: {e}")
        return results
    
    def add_tiff_dimensions(self):
        """
        Analyzes the dimensions of TIF files in the directory DataFrame and adds this data as new columns.
        """
        # Ensure the DataFrame has the columns for dimensions; initialize them with None or appropriate defaults
        if 'x_dim' not in self.directory_df.columns:
            self.directory_df['x_dim'] = None
            self.directory_df['y_dim'] = None
            self.directory_df['z_dim_frames'] = None

        # Iterate over each session_id and update the dimensions
        for index, row in self.directory_df.iterrows():
            tif_path = self.get_session_raw_data(row['session_id'])
            if isinstance(tif_path, str) and tif_path.endswith('.tif'):
                try:
                    with Image.open(tif_path) as img:
                        self.directory_df.at[index, 'x_dim'] = img.width
                        self.directory_df.at[index, 'y_dim'] = img.height
                        # For z-dimension, count the frames
                        img.seek(0)  # Ensure the pointer is at the beginning
                        frames = 0
                        while True:
                            try:
                                img.seek(img.tell() + 1)
                                frames += 1
                            except EOFError:
                                break
                        self.directory_df.at[index, 'z_dim_frames'] = frames
                except Exception as e:
                    print(f"Could not process TIF dimensions for session {row['session_id']}: {e}")
    
    def analyze_roi(self, session_id):
        """
        Analyzes ROI of the 'labels_postexport.tif' file for a given session and saves two results:
        one with labels and another without labels.I t also saves the labeled image data as numpy array for future use.
        """
        
        # SETP 1: DEFINE PATHS
        # define the paths, including the directory where processed images will be saved (processed_dir) 
        # and the name of the TIF file that contains the ROI labels (consistent_file_name)
        processed_dir = 'processed_data/processed_image_analysis_output'
        consistent_file_name = 'labels_postexport.tif'
        output_suffix_with_labels = '_roi_analysis_with_labels.png'
        output_suffix_without_labels = '_roi_analysis_without_labels.png'

        # STEP 2: RETRIEVE SESSION DATA 
        # Retrieve the directory path from the DataFrame
        # looks up the session's directory path from a DataFrame (directory_df) using the provided session_id. 
        # If the session ID isn't found, it returns a message indicating no directory entry was found for that session.
        
        directory_entry = self.directory_df[self.directory_df['session_id'] == session_id]
        if directory_entry.empty:
            return f"No directory entry found for session {session_id}"

        # STEP 3: VERIFY AND LOAD THE ROI TIF FILE 
        # constructs the full path to the labels_postexport.tif file and checks if it exists. If it does, the file is opened and loaded. 
        # If the file is in RGB format, it's converted to grayscale using rgb2gray from skimage.color. 
        # This conversion is crucial for analyzing the image as a binary mask where non-white pixels are considered ROIs.
        directory_path = directory_entry['directory_path'].values[0]
        
        # Build the path to the postexport TIFF file
        tiff_file_path = os.path.join(directory_path, processed_dir, consistent_file_name)

        # Verify that the file exists
        if not os.path.exists(tiff_file_path):
            return f"File not found for session {session_id}"

        
        
        # STEP 4: CREATE AND SAVE THE BINARY MASK
        # k: The method then converts the grayscale image to a binary mask, identifying all non-white pixels as ROIs 
        # (pixels with value less than 1 after normalization are set to 1, and others to 0). 
        # This binary mask is labeled using label from skimage.measure, assigning a unique label to each connected component (ROI).
        
        # Load the image
        mask_image = Image.open(tiff_file_path)

        # Convert RGB image to grayscale if necessary
        if mask_image.mode == 'RGB':
            # Convert to grayscale using skimage's rgb2gray
            image_array = rgb2gray(np.array(mask_image))

        # Assuming that all non-white pixels are ROIs
        binary_mask = np.where(image_array < 1, 1, 0)  # Here, 1 corresponds to white in the normalized grayscale image

        # Label the regions
        labeled_image = label(binary_mask, connectivity=1)
        num_rois = np.max(labeled_image)
        
        # Save the labeled image data as a NumPy array file for future processing
        labeled_image_path = os.path.join(directory_path, processed_dir, f"{session_id}_labeled_image.npy")
        np.save(labeled_image_path, labeled_image)
        
        
        # STEP 5: SAVE THE UNLABELED ROI IMAGE 
        # Save Unlabeled ROI Image: The method saves a version of the labeled image without any annotations to a specified path (output_path_without_labels). 
        # This image is saved in the processed_image_analysis_output directory with a specific suffix to indicate it's the unlabeled version.
        
        # Save the image without labels
        output_path_without_labels = os.path.join(directory_path, processed_dir, session_id + output_suffix_without_labels)
        plt.imsave(output_path_without_labels, labeled_image, cmap='nipy_spectral')

        
        # STEP 6: ANALYZE AND SAVE LABELED ROI IMAGE 
        # Iterates through each detected region using regionprops, extracts the centroid, 
        # and annotates the image with the region's label. 
        # This annotated image is saved separately, indicating it includes ROI labels.
        
        # Analyze regions and save properties
        regions = regionprops(labeled_image)

        # Prepare to save the ROI analysis image with labels
        output_path_with_labels = os.path.join(directory_path, processed_dir, session_id + output_suffix_with_labels)
        
        fig, ax = plt.subplots()
        ax.imshow(labeled_image, cmap='nipy_spectral')
        ax.axis('off')

        # Annotate each ROI with its corresponding label (ID)
        for region in regions:
            # Get the coordinates of the centroid of the region
            y, x = region.centroid
            # Annotate the ROI ID at the centroid position
            ax.text(x, y, str(region.label), color='white', ha='center', va='center')

        plt.savefig(output_path_with_labels)
        plt.close()

        # Return the paths of the saved figures LABELED AND UNLABELED and number of ROIs
        return (output_path_with_labels, output_path_without_labels), num_rois
    
    def analyze_all_rois(self):
        """
        Applies ROI analysis to all sessions and saves the results.
        """
        results = {}
        for session_id in self.directory_df['session_id']:
            result = self.analyze_roi(session_id)
            results[session_id] = result
        return results
    
    def extract_calcium_signals(self, session_id):
        """
        Extracts calcium signals from time-series data using the saved labeled ROI mask
        and saves the results as a CSV file in the 'processed_image_analysis_output' directory.

        Parameters:
        session_id (str): Session ID for which to perform the analysis.

        Returns:
        str: Path to the saved CSV file containing calcium signal data.
        """
        processed_dir = 'processed_data/processed_image_analysis_output'
        calcium_csv_suffix = '_calcium_signals.csv'

        # Retrieve the directory path from the DataFrame
        directory_entry = self.directory_df[self.directory_df['session_id'] == session_id]
        if directory_entry.empty:
            return f"No directory entry found for session {session_id}"

        directory_path = directory_entry['directory_path'].values[0]

        # Path to the saved labeled image numpy file
        labeled_image_path = os.path.join(directory_path, processed_dir, session_id + '_labeled_image.npy')

        # Verify and load the labeled image numpy file
        if not os.path.exists(labeled_image_path):
            return f"Labeled image file not found for session {session_id}"
        labeled_image = np.load(labeled_image_path)

        # Locate and load the time-series TIFF file
        tif_files = glob.glob(os.path.join(directory_path, '*.tif'))
        tif_files = [f for f in tif_files if 'postexport' not in f and 'labels' not in f]  # Ensure it's the correct TIFF
        if not tif_files:
            return f"No time-series .tif file found in the directory for session {session_id}"
        time_series_path = tif_files[0]  # Assuming there's only one relevant TIFF file
        time_series = io.imread(time_series_path)

        # Initialize an array to store calcium signal data
        num_rois = np.max(labeled_image)
        num_frames = time_series.shape[0]
        calcium_signals = np.zeros((num_rois, num_frames))

        # Extract the signal from each ROI in each frame
        for t in range(num_frames):
            frame = time_series[t]
            for roi in range(1, num_rois + 1):
                roi_mask = labeled_image == roi
                roi_data = frame[roi_mask]
                calcium_signals[roi - 1, t] = np.mean(roi_data)

        # Create and save the DataFrame with calcium signals
        calcium_df = pd.DataFrame(calcium_signals.T, columns=[f"ROI_{i}" for i in range(1, num_rois + 1)])
        calcium_df['Frame'] = np.arange(1, num_frames + 1)
        csv_path = os.path.join(directory_path, processed_dir, session_id + calcium_csv_suffix)
        calcium_df.to_csv(csv_path, index=False)

        return csv_path
    
    def analyze_all_calcium_signals(self):
        """
        Applies calcium signal extraction to all session_ids in the directory DataFrame and stores the results.
        """
        results = {}
        for session_id in self.directory_df['session_id']:
            # Ensure the ROI analysis has been done to get the labeled images
            roi_results = self.analyze_roi(session_id)
            # Check if analyze_roi returned a path to labeled images
            if isinstance(roi_results, tuple):
                # Extract calcium signals using the labeled ROI mask
                calcium_csv_path = self.extract_calcium_signals(session_id)
                results[session_id] = calcium_csv_path
            else:
                # If roi_results is an error message, pass it through
                results[session_id] = roi_results
                
            #results is a dictionary where each key is a session_id and the corresponding value is the path to the saved CSV file containing calcium signal data.
        return results
            
    def plot_session_calcium_signals(self, session_id, use_corrected_data=False):
        """
        Attempt to plot calcium signals for a given session and return a status message.

        Parameters
        ----------
        session_id : str
            The identifier for the session for which calcium signals are to be plotted.
        use_corrected_data : bool, optional
            Flag indicating whether to use corrected calcium signals or not (default is False).

        Returns
        -------
        str
            A message indicating whether the plotting was successful or failed. If it failed,
            the message includes the reason for the failure.
        """
        try:
            self.plot_calcium_signals(session_id, use_corrected_data=use_corrected_data)
            data_type = 'corrected' if use_corrected_data else 'uncorrected'
            return f"Plotted {data_type} calcium signals for session {session_id}"
        except Exception as e:
            return f"Failed to plot calcium signals for session {session_id}: {e}"
    
    def plot_all_sessions_calcium_signals(self, use_corrected_data=False):
        """
        Apply the plot_session_calcium_signals method to all sessions in the dataset.

        Iterates over all session IDs and plots calcium signals for each session using
        the plot_session_calcium_signals method. Collects and returns the outcomes of
        the plotting process for each session.

        Parameters
        ----------
        use_corrected_data : bool, optional
            Flag indicating whether to use corrected calcium signals or not (default is False).

        Returns
        -------
        dict
            A dictionary where each key is a session ID and the corresponding value is
            a message indicating the success or failure of the plotting operation for
            that session.
        """
        results = {}
        for session_id in self.directory_df['session_id'].tolist():
            result = self.plot_session_calcium_signals(session_id, use_corrected_data=use_corrected_data)
            results[session_id] = result
        return results
        
    def plot_calcium_signals(self, session_id, use_corrected_data=False):
        processed_dir = 'processed_data/processed_image_analysis_output'
        
        # Choose the file suffix based on whether corrected data should be used
        calcium_csv_suffix = '_corrected_calcium_signals.csv' if use_corrected_data else '_calcium_signals.csv'

        directory_entry = self.directory_df[self.directory_df['session_id'] == session_id]
        if directory_entry.empty:
            print(f"No directory entry found for session {session_id}")
            return

        directory_path = directory_entry['directory_path'].values[0]
        csv_path = os.path.join(directory_path, processed_dir, str(session_id) + calcium_csv_suffix)

        if not os.path.exists(csv_path):
            print(f"Calcium signals file {'corrected ' if use_corrected_data else ''}not found for session {session_id}")
            return

        calcium_signals_df = pd.read_csv(csv_path)
        frame_numbers = calcium_signals_df['Frame']
        calcium_signals = calcium_signals_df.drop('Frame', axis=1)

        # Normalize the signals and calculate offsets
        normalized_signals = (calcium_signals - calcium_signals.min()) / (calcium_signals.max() - calcium_signals.min())
        offsets = np.arange(len(normalized_signals.columns)) * 1.2

        plt.figure(figsize=(15, 10))
        
        # Assuming 'stimulation_frame_number' contains a list or similar; if it's a single number, adjust accordingly
        stim_frame_numbers = self.directory_df.loc[
            self.directory_df['session_id'] == session_id, 'stimulation_frame_number'
        ].values[0]

        # Plot red dotted lines for stimulation timestamps
        for frame_number in stim_frame_numbers:
            plt.axvline(x=frame_number, color='r', linestyle='--', linewidth=0.5)

        # Plot each normalized calcium signal with an offset
        for i, (roi_label, signal) in enumerate(normalized_signals.items()):
            plt.plot(frame_numbers, signal + offsets[i], label=roi_label)

        plt.xlabel('Frame Number')
        plt.ylabel('Normalized Calcium Signal (A.U.)')
        plt.title(f"Time Series of ROIs for Session {session_id} ({'Corrected' if use_corrected_data else 'Uncorrected'})")

        plt.yticks(ticks=offsets + 0.5, labels=normalized_signals.columns)
        plt.grid(False)
        plt.tight_layout()

        save_suffix = 'corrected_' if use_corrected_data else ''
        save_dir = os.path.join(directory_path, 'processed_data', 'processed_image_analysis_output')
        os.makedirs(save_dir, exist_ok=True)
        save_path = os.path.join(save_dir, f"{session_id}_{save_suffix}calcium_signals_plot.png")

        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Plot saved to {save_path}")
        plt.show()

    def save_individual_roi_plots(self, session_id, use_corrected_data=False):
        processed_dir = 'processed_data/processed_image_analysis_output'
        cell_roi_dir = 'cell_roi_processed_data'
        calcium_csv_suffix = '_corrected_calcium_signals.csv' if use_corrected_data else '_calcium_signals.csv'

        directory_entry = self.directory_df[self.directory_df['session_id'] == session_id]
        if directory_entry.empty:
            print(f"No directory entry found for session {session_id}")
            return

        directory_path = directory_entry['directory_path'].values[0]
        stimulation_times = directory_entry['stimulation_frame_number'].values[0]

        processed_data_path = os.path.join(directory_path, processed_dir)
        csv_path = os.path.join(processed_data_path, str(session_id) + calcium_csv_suffix)
        roi_output_dir = os.path.join(processed_data_path, cell_roi_dir)

        os.makedirs(roi_output_dir, exist_ok=True)

        if not os.path.exists(csv_path):
            print(f"{'Corrected' if use_corrected_data else 'Uncorrected'} calcium signals file not found for session {session_id}")
            return

        calcium_signals_df = pd.read_csv(csv_path)
        frame_numbers = calcium_signals_df['Frame']
        calcium_signals = calcium_signals_df.drop('Frame', axis=1)

        for roi_label in calcium_signals:
            plt.figure(figsize=(10, 5))
            plt.plot(frame_numbers, calcium_signals[roi_label], label=roi_label)

            for stim_time in stimulation_times:
                plt.axvline(x=stim_time, color='red', linestyle='--', linewidth=1)

            plt.title(f"{roi_label} - Session {session_id} ({'Corrected' if use_corrected_data else 'Uncorrected'})")
            plt.xlabel('Frame Number')
            plt.ylabel('Calcium Signal Intensity')
            plt.legend()
            plt.grid(True)

            save_path = os.path.join(roi_output_dir, f"{roi_label}_{'corrected_' if use_corrected_data else ''}signal_plot.png")
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            plt.close()

            print(f"Plot for {roi_label} saved to {save_path}")
              
    def save_individual_roi_plots_all_sessions(self, use_corrected_data=False):
        results = {}
        for session_id in self.directory_df['session_id'].tolist():
            result = self.save_individual_roi_plots(session_id, use_corrected_data=use_corrected_data)
            results[session_id] = result
        return results
        
    def plot_roi_with_zoomed_stimulations(self, session_id):
        processed_dir = 'processed_data/processed_image_analysis_output'
        calcium_csv_suffix = '_calcium_signals.csv'

        # Retrieve directory path and stimulation frames from DataFrame
        directory_entry = self.directory_df[self.directory_df['session_id'] == session_id]
        if directory_entry.empty:
            print(f"No directory entry found for session {session_id}")
            return

        directory_path = directory_entry['directory_path'].values[0]
        csv_path = os.path.join(directory_path, processed_dir, session_id + calcium_csv_suffix)
        stimulation_frames = directory_entry['stimulation_frame_number'].iloc[0]  # Used directly as a list

        # Read the calcium signals
        calcium_signals_df = pd.read_csv(csv_path)
        frame_numbers = calcium_signals_df['Frame']
        calcium_signals = calcium_signals_df.drop('Frame', axis=1)
        
        # Determine the spacing factor based on the max calcium signal value
        spacing_factor = calcium_signals.max().max() * 0.1  # For example, 10% of the max signal
        
        # Number of subplots based on the number of stimulations
        num_stimulations = len(stimulation_frames)
        num_rois = calcium_signals.shape[1]

        # Create the figure with multiple subplots
        fig, axs = plt.subplots(num_stimulations + 1, 1, figsize=(15, 5 * (num_stimulations + 1)), gridspec_kw={'height_ratios': [3] + [1]*num_stimulations})
        
        # Plot the full session signal in the first subplot
        for i, col in enumerate(calcium_signals.columns):
            axs[0].plot(frame_numbers, calcium_signals[col] + (i * spacing_factor), label=col)

        # Add stimulation markers to the full session plot
        for stim_frame in stimulation_frames:
            axs[0].axvline(x=stim_frame, color='red', linestyle='--', linewidth=1)

        # Zoom into each stimulation event in the subsequent subplots
        for i, stim_frame in enumerate(stimulation_frames, start=1):
            start_frame = max(stim_frame - 100, 0)
            end_frame = min(stim_frame + 200, len(frame_numbers))
            for j, col in enumerate(calcium_signals.columns):
                signal_segment = calcium_signals[col][start_frame:end_frame]
                frame_segment = frame_numbers[start_frame:end_frame]
                axs[i].plot(frame_segment, signal_segment + (j * spacing_factor), label=col)

            # Add a vertical line for the stimulation moment
            axs[i].axvline(x=stim_frame, color='red', linestyle='--', linewidth=1)
            axs[i].set_xlim([start_frame, end_frame])

        # Adjust the layout and save the figure
        plt.tight_layout()
        plt.show()
        save_path = os.path.join(directory_path, processed_dir, f"{session_id}_detailed_ROI_analysis.png")
        fig.savefig(save_path, dpi=300)
        plt.close(fig)

        print(f"Detailed ROI analysis figure saved to {save_path}")
        
    def plot_and_save_roi_stimulations(self, session_id):
        processed_dir = 'processed_data/processed_image_analysis_output'
        calcium_csv_suffix = '_calcium_signals.csv'
        

        # Retrieve directory path from DataFrame
        directory_entry = self.directory_df[self.directory_df['session_id'] == session_id]
        if directory_entry.empty:
            print(f"No directory entry found for session {session_id}")
            return

        directory_path = directory_entry['directory_path'].values[0]
        csv_path = os.path.join(directory_path, processed_dir, session_id + calcium_csv_suffix)

        if not os.path.exists(csv_path):
            print(f"Calcium signals file not found for session {session_id}")
            return

        # Read calcium signals into DataFrame
        calcium_signals_df = pd.read_csv(csv_path)
        frame_numbers = calcium_signals_df['Frame']
        calcium_signals = calcium_signals_df.drop('Frame', axis=1)

        # Hardcoded stimulation frames
        #stimulation_frames = [3587, 3788, 3990, 4191, 4393, 4595]
        stimulation_frames = directory_entry['stimulation_frame_number'].iloc[0]  # Used directly as a list

        # Ensure the output directory exists
        roi_output_dir = os.path.join(directory_path, processed_dir, 'cell_roi_processed_data')
        os.makedirs(roi_output_dir, exist_ok=True)

        for roi_label in calcium_signals:
            # Set up figure
            num_plots = len(stimulation_frames) + 1
            fig, axs = plt.subplots(num_plots, 1, figsize=(10, num_plots * 5))

            # Plot the full signal
            axs[0].plot(frame_numbers, calcium_signals[roi_label])
            axs[0].set_title(f"Full Session Calcium Signal for ROI {roi_label}")
            for stim_frame in stimulation_frames:
                axs[0].axvline(x=stim_frame, color='red', linestyle='--', linewidth=0.5)
            
            # Plot zoomed-in stimulations
            for idx, stim_frame in enumerate(stimulation_frames):
                zoom_start = max(stim_frame - 100, 0)
                zoom_end = min(stim_frame + 200, max(frame_numbers))
                zoomed_signal = calcium_signals[roi_label][zoom_start:zoom_end]
                zoomed_frame_numbers = frame_numbers[zoom_start:zoom_end]

                axs[idx + 1].plot(zoomed_frame_numbers, zoomed_signal)
                axs[idx + 1].axvline(x=stim_frame, color='red', linestyle='--', linewidth=0.5)
                axs[idx + 1].set_xlim(zoom_start, zoom_end)
                axs[idx + 1].set_title(f"Stimulation at Frame {stim_frame}")

            # Finalize and save figure
            plt.tight_layout()
            fig.savefig(os.path.join(roi_output_dir, f"{roi_label}_stimulation_plot.png"), dpi=300)
            plt.close()

            print(f"ROI {roi_label} plots saved in {roi_output_dir}")
            
    def plot_and_save_roi_stimulations_all_sessions(self): 
        
        results = {}
        
        for session_id in self.directory_df['session_id']:
            result = self.plot_and_save_roi_stimulations(session_id)
            results[session_id] = result
        return results
                          
    def find_responsive_rois_first_stim_mean(self, session_id, pre_stim_duration=5, post_stim_duration=5, threshold=2):
        processed_dir = 'processed_data/processed_image_analysis_output'
        calcium_csv_suffix = '_calcium_signals.csv'

        # Retrieve directory path from DataFrame
        directory_entry = self.directory_df[self.directory_df['session_id'] == session_id]
        if directory_entry.empty:
            print(f"No directory entry found for session {session_id}")
            return None

        directory_path = directory_entry['directory_path'].values[0]
        csv_path = os.path.join(directory_path, processed_dir, session_id + calcium_csv_suffix)

        if not os.path.exists(csv_path):
            print(f"Calcium signals file not found for session {session_id}")
            return None

        # Read calcium signals into DataFrame
        calcium_signals_df = pd.read_csv(csv_path)
        frame_numbers = calcium_signals_df['Frame']
        calcium_signals = calcium_signals_df.drop('Frame', axis=1)  # Explicitly drop the 'Frame' column

        # Use the hardcoded first stimulation frame
        #first_stim_frame = [3587][0]  # Hardcoded first stimulation frame
        first_stim_frame = directory_entry['stimulation_frame_number'].iloc[0][0]

        responsive_rois = []
        for roi_label in calcium_signals.columns:  # Iterate over ROI columns only
            pre_stim_signal = calcium_signals.loc[(frame_numbers >= first_stim_frame-pre_stim_duration) & (frame_numbers < first_stim_frame), roi_label]
            post_stim_signal = calcium_signals.loc[(frame_numbers >= first_stim_frame) & (frame_numbers < first_stim_frame+post_stim_duration), roi_label]

            # Calculate the z-score for the difference in means
            pre_mean = pre_stim_signal.mean()
            post_mean = post_stim_signal.mean()
            signal_change = post_mean - pre_mean

            # Standard deviation of the pre-stimulus signal
            pre_std = pre_stim_signal.std(ddof=1)  # Use ddof=1 for sample standard deviation
            # Calculate the z-score
            if pre_std > 0:
                z_score = signal_change / pre_std
            else:
                z_score = 0
            
            #if abs(z_score) > threshold: # for absolute z-score
            if z_score > threshold: # for positive z-score
                responsive_rois.append(roi_label)
                print(f"ROI {roi_label} is responsive. Change: {signal_change:.2f}, Z-score: {z_score:.2f}")


        return responsive_rois
    
    def plot_responsive_rois_around_stim(self, session_id, pre_stim_duration=100, post_stim_duration=200, threshold=2):
        responsive_rois = self.find_responsive_rois_first_stim_mean(session_id, pre_stim_duration=3, post_stim_duration=3, threshold=threshold)

        if not responsive_rois:
            print("No responsive ROIs found.")
            return

        processed_dir = 'processed_data/processed_image_analysis_output'
        calcium_csv_suffix = '_calcium_signals.csv'
        directory_entry = self.directory_df[self.directory_df['session_id'] == session_id]
        directory_path = directory_entry['directory_path'].values[0]
        csv_path = os.path.join(directory_path, processed_dir, session_id + calcium_csv_suffix)

        calcium_signals_df = pd.read_csv(csv_path)
        #first_stim_frame = 3587  # Update this based on your actual first stim frame
        first_stim_frame = directory_entry['stimulation_frame_number'].iloc[0][0]

        plt.figure(figsize=(10, 6))

        for roi_label in responsive_rois:
            if roi_label in calcium_signals_df.columns:
                roi_data = calcium_signals_df[[roi_label, 'Frame']]
                
                # Select data around the stimulation point
                stim_start = max(first_stim_frame - pre_stim_duration, 0)
                stim_end = min(first_stim_frame + post_stim_duration, max(calcium_signals_df['Frame']))
                
                roi_segment = roi_data[(roi_data['Frame'] >= stim_start) & (roi_data['Frame'] <= stim_end)]
                
                plt.plot(roi_segment['Frame'], roi_segment[roi_label], label=f'ROI {roi_label}')
                plt.axvline(x=first_stim_frame, color='red', linestyle='--')

        plt.xlabel('Frame Number')
        plt.ylabel('Signal Intensity')
        plt.title(f'Calcium Signals Around First Stimulation for Responsive ROIs in Session {session_id}')
        plt.legend()
        plt.tight_layout()

        save_path = os.path.join(directory_path, processed_dir, f"{session_id}_responsive_ROIs_around_stim.png")
        plt.savefig(save_path, dpi=300)
        plt.close()

        print(f"Plot saved at {save_path}")
        
    def plot_mean_and_sem_of_responsive_rois(self, session_id, pre_stim_duration=100, post_stim_duration=200, threshold=2):
        responsive_rois = self.find_responsive_rois_first_stim_mean(session_id, pre_stim_duration, post_stim_duration, threshold)

        if not responsive_rois:
            print("No responsive ROIs found with a positive z-score.")
            return

        processed_dir = 'processed_data/processed_image_analysis_output'
        calcium_csv_suffix = '_calcium_signals.csv'
        directory_entry = self.directory_df[self.directory_df['session_id'] == session_id]
        directory_path = directory_entry['directory_path'].values[0]
        csv_path = os.path.join(directory_path, processed_dir, session_id + calcium_csv_suffix)

        calcium_signals_df = pd.read_csv(csv_path)
        frame_numbers = calcium_signals_df['Frame']
        
        first_stim_frame = directory_entry['stimulation_frame_number'].iloc[0][0]  # Assuming the first element is the first stim frame

        # Plot setup
        plt.figure(figsize=(10, 6))
        
        #Compute mean and SEM only for the responsive ROIs
        mean_response = calcium_signals_df[responsive_rois].mean(axis=1)
        sem_response = calcium_signals_df[responsive_rois].sem(axis=1)

        # Focus on the time around the first stimulation
        stim_start = max(frame_numbers.searchsorted(first_stim_frame - pre_stim_duration), 0)
        stim_end = min(frame_numbers.searchsorted(first_stim_frame + post_stim_duration), len(frame_numbers) - 1)
        
        # Extract the segment for plotting
        frame_segment = frame_numbers.iloc[stim_start:stim_end + 1]
        mean_segment = mean_response.iloc[stim_start:stim_end + 1]
        sem_segment = sem_response.iloc[stim_start:stim_end + 1]

        plt.figure(figsize=(10, 6))
        plt.errorbar(frame_segment, mean_segment, yerr=sem_segment, fmt='-', color='blue', ecolor='lightblue', label='Mean +/- SEM')

        # Stimulation line
        plt.axvline(x=first_stim_frame, color='red', linestyle='--', label='First Stimulus')

        # Labels and title
        plt.xlabel('Frame Number')
        plt.ylabel('Calcium Signal Intensity')
        plt.title(f'Calcium Signals Around First Stimulation for Responsive ROIs in Session {session_id}')
        plt.legend()
        plt.tight_layout()

        # Save the plot
        save_path = os.path.join(directory_path, processed_dir, f"{session_id}_responsive_ROIs_mean_sem.png")
        plt.savefig(save_path, dpi=300)
        plt.close()

        print(f"Plot saved at {save_path}")
        
    def plot_normalized_mean_and_sem_of_responsive_rois(self, session_id, pre_stim_duration=100, post_stim_duration=200, threshold=2):
        responsive_rois = self.find_responsive_rois_first_stim_mean(session_id, pre_stim_duration, post_stim_duration, threshold)

        if not responsive_rois:
            print("No responsive ROIs found with a positive z-score.")
            return

        processed_dir = 'processed_data/processed_image_analysis_output'
        calcium_csv_suffix = '_calcium_signals.csv'
        directory_entry = self.directory_df[self.directory_df['session_id'] == session_id]
        directory_path = directory_entry['directory_path'].values[0]
        csv_path = os.path.join(directory_path, processed_dir, session_id + calcium_csv_suffix)

        calcium_signals_df = pd.read_csv(csv_path)
        first_stim_frame = directory_entry['stimulation_frame_number'].iloc[0][0]  # Assuming the first element is the first stim frame

        # Normalize the signals for each responsive ROI
        normalized_signals = calcium_signals_df[responsive_rois].apply(lambda x: (x - x.min()) / (x.max() - x.min()))

        # Compute mean and SEM for the normalized signals
        mean_normalized_response = normalized_signals.mean(axis=1)
        sem_normalized_response = normalized_signals.sem(axis=1)

        # Focus on the time around the first stimulation
        stim_start = max(first_stim_frame - pre_stim_duration, 0)
        stim_end = min(first_stim_frame + post_stim_duration, len(calcium_signals_df))

        # Corrected plotting section:
        frame_mask = (calcium_signals_df['Frame'] >= stim_start-9) & (calcium_signals_df['Frame'] <= stim_end+100)

        plt.figure(figsize=(10, 6))
        plt.errorbar(calcium_signals_df.loc[frame_mask, 'Frame'], 
                    mean_normalized_response.loc[frame_mask], 
                    yerr=sem_normalized_response.loc[frame_mask], 
                    label='Normalized Mean +/- SEM', 
                    color='blue', 
                    ecolor='lightblue')



        # Stimulation line
        plt.axvline(x=first_stim_frame-1, color='red', linestyle='--', label='First Stimulus')

        # Labels and title
        plt.xlabel('Frame Number')
        plt.ylabel('Normalized Calcium Signal Intensity')
        plt.title(f'Normalized Calcium Signals Around First Stimulation for Responsive ROIs in Session {session_id}')
        plt.legend()
        plt.tight_layout()

        # Save the plot
        save_path = os.path.join(directory_path, processed_dir, f"{session_id}_responsive_ROIs_normalized_mean_sem.png")
        plt.savefig(save_path, dpi=300)
        plt.close()

        print(f"Normalized plot saved at {save_path}")

    def plot_normalized_mean_and_sem_of_all_stims(self, session_id, pre_stim_duration=100, post_stim_duration=200, threshold=2):
        responsive_rois = self.find_responsive_rois_first_stim_mean(session_id, pre_stim_duration, post_stim_duration, threshold)

        if not responsive_rois:
            print("No responsive ROIs found with a positive z-score.")
            return

        processed_dir = 'processed_data/processed_image_analysis_output'
        calcium_csv_suffix = '_calcium_signals.csv'
        directory_entry = self.directory_df[self.directory_df['session_id'] == session_id]
        directory_path = directory_entry['directory_path'].values[0]
        csv_path = os.path.join(directory_path, processed_dir, session_id + calcium_csv_suffix)

        calcium_signals_df = pd.read_csv(csv_path)
        stimulation_frames = directory_entry['stimulation_frame_number'].iloc[0]

        # Normalize the signals for each responsive ROI
        normalized_signals = calcium_signals_df[responsive_rois].apply(lambda x: (x - x.min()) / (x.max() - x.min()))

        # Create a composite figure with subplots for each stimulation
        num_stimulations = len(stimulation_frames)
        
        # Initialize numpy arrays for storing mean and SEM data
        # Determine the response window size (modify this logic based on how you decide to handle edge cases)
        response_window_size = pre_stim_duration + post_stim_duration
        mean_responses = np.zeros((num_stimulations, response_window_size))
        sem_responses = np.zeros((num_stimulations, response_window_size))
        
        
        
        fig, axs = plt.subplots(num_stimulations, 1, figsize=(10, 6 * num_stimulations))

        for i, stim_frame in enumerate(stimulation_frames):
            # Compute mean and SEM for the normalized signals focusing on the time around the stimulation
            stim_start = max(stim_frame - pre_stim_duration, 0)
            stim_end = min(stim_frame + post_stim_duration, len(calcium_signals_df))

            frame_mask = (calcium_signals_df['Frame'] >= stim_start-9) & (calcium_signals_df['Frame'] <= stim_end+100)
            mean_normalized_response = normalized_signals.loc[frame_mask].mean(axis=1)
            sem_normalized_response = normalized_signals.loc[frame_mask].sem(axis=1)
            
            # Truncate or pad the response if necessary (this example assumes padding with zeros)
            response_length = len(mean_normalized_response)
            #print the length of the response
            if response_length == response_window_size:
                mean_responses[i, :] = mean_normalized_response.values
                sem_responses[i, :] = sem_normalized_response.values
            else:
                # If the response length is less, pad the rest; if more, truncate (shouldn't occur with correct mask)
                padded_mean = np.pad(mean_normalized_response.values, (0, response_window_size - response_length), 'constant', constant_values=(0, 0))
                padded_sem = np.pad(sem_normalized_response.values, (0, response_window_size - response_length), 'constant', constant_values=(0, 0))
                mean_responses[i, :] = padded_mean
                sem_responses[i, :] = padded_sem

            # Frame numbers for plotting
            frame_numbers_for_plot = calcium_signals_df.loc[frame_mask, 'Frame']

            # Filling between the SEM range around the mean
            axs[i].fill_between(frame_numbers_for_plot,
                                mean_normalized_response - sem_normalized_response,
                                mean_normalized_response + sem_normalized_response,
                                color='lightblue', alpha=0.5, label='SEM')

            axs[i].plot(frame_numbers_for_plot,
                        mean_normalized_response,
                        color='blue', label=f'Normalized Mean (Stim {i+1})')

            axs[i].axvline(x=stim_frame-1, color='red', linestyle='--', label='Stimulus')
            axs[i].set_xlabel('Frame Number')
            axs[i].set_ylabel('Normalized Calcium Signal Intensity')
            axs[i].set_title(f'Normalized Signals Around Stim {i+1} for Responsive ROIs in Session {session_id}')
            axs[i].legend()

            plt.tight_layout()

        # Save the composite figure
        save_path = os.path.join(directory_path, processed_dir, f"{session_id}_all_stims_responsive_ROIs_normalized_mean_sem.png")
        fig.savefig(save_path, dpi=300)
        plt.close(fig)
        
        # Save the numpy arrays for later use
        np.save(os.path.join(directory_path, processed_dir, f"{session_id}_mean_responses.npy"), mean_responses)
        np.save(os.path.join(directory_path, processed_dir, f"{session_id}_sem_responses.npy"), sem_responses)

        print(f"Composite normalized plot saved and Mean and SEM response data saved at {save_path}")
        
    def plot_overlaid_normalized_responses(self, session_id, pre_stim_duration=100, post_stim_duration=200, threshold=2):
        responsive_rois = self.find_responsive_rois_first_stim_mean(session_id, pre_stim_duration, post_stim_duration, threshold)

        if not responsive_rois:
            print("No responsive ROIs found with a positive z-score.")
            return

        processed_dir = 'processed_data/processed_image_analysis_output'
        calcium_csv_suffix = '_calcium_signals.csv'
        directory_entry = self.directory_df[self.directory_df['session_id'] == session_id]
        directory_path = directory_entry['directory_path'].values[0]
        csv_path = os.path.join(directory_path, processed_dir, session_id + calcium_csv_suffix)
        stimulation_frames = directory_entry['stimulation_frame_number'].iloc[0]

        calcium_signals_df = pd.read_csv(csv_path)
        normalized_signals = calcium_signals_df[responsive_rois].apply(lambda x: (x - x.min()) / (x.max() - x.min()))

        plt.figure(figsize=(10, 6))
        colors = plt.cm.viridis(np.linspace(0, 1, len(stimulation_frames)))  # Color map for different stim intensities

        for i, (stim_frame, color) in enumerate(zip(stimulation_frames, colors)):
            stim_start = max(stim_frame - pre_stim_duration, 0)
            stim_end = min(stim_frame + post_stim_duration, len(calcium_signals_df))

            frame_mask = (calcium_signals_df['Frame'] >= stim_start-9) & (calcium_signals_df['Frame'] <= stim_end+100)
            mean_normalized_response = normalized_signals.loc[frame_mask].mean(axis=1)
            sem_normalized_response = normalized_signals.loc[frame_mask].sem(axis=1)

            # Plotting the mean response with shaded SEM
            frame_numbers_for_plot = calcium_signals_df.loc[frame_mask, 'Frame']
            plt.fill_between(frame_numbers_for_plot,
                            mean_normalized_response - sem_normalized_response,
                            mean_normalized_response + sem_normalized_response,
                            color=color, alpha=0.5)

            plt.plot(frame_numbers_for_plot,
                    mean_normalized_response,
                    color=color, label=f'Stim {i+1}')

        plt.axvline(x=stimulation_frames[0]-1, color='red', linestyle='--', label='First Stimulus')
        plt.xlabel('Frame Number')
        plt.ylabel('Normalized Calcium Signal Intensity')
        plt.title(f'Overlaid Normalized Responses for Responsive ROIs in Session {session_id}')
        plt.legend()
        plt.tight_layout()

        save_path = os.path.join(directory_path, processed_dir, f"{session_id}_overlaid_normalized_responses.png")
        plt.savefig(save_path, dpi=300)
        plt.close()

        print(f"Overlaid normalized responses plot saved at {save_path}")
        
    def plot_mean_responses_from_file(self, session_id):
        processed_dir = 'processed_data/processed_image_analysis_output'
        directory_entry = self.directory_df[self.directory_df['session_id'] == session_id]
        directory_path = directory_entry['directory_path'].values[0]
        
        # Construct the file path and load the data
        mean_responses_path = os.path.join(directory_path, processed_dir, f"{session_id}_mean_responses.npy")
        mean_responses = np.load(mean_responses_path)
        

        # Check if mean_responses is empty or not
        if mean_responses.size == 0:
            print("No data found in the mean responses file.")
            return
        
        #print the shape of the mean_responses
        print(mean_responses.shape)
        
        num_stimulations = mean_responses.shape[0]
        response_window_size = mean_responses.shape[1]
        
        # Plotting
        plt.figure(figsize=(10, 6))
        for i in range(num_stimulations):
            plt.plot(mean_responses[i, :], label=f'Stim {i+1}')
        
        plt.xlabel('Time Point')
        plt.ylabel('Normalized Mean Response')
        plt.title(f'Normalized Mean Responses for All Stimulations in Session {session_id}')
        plt.legend()
        plt.tight_layout()
        plt.show()
    
    def create_trial_locked_calcium_signals(self, session_id, use_corrected_data=False):
        """
        Generate trial-locked calcium signal data for a given session ID, allowing
        the choice between corrected and uncorrected data.
        
        Parameters
        ----------
        session_id : str
            The session ID for which to generate trial-locked signals.
        use_corrected_data : bool, optional
            Whether to use corrected calcium signal data. The default is False, which uses uncorrected data.
        
        Returns
        -------
        tuple
            A tuple containing the stimulation frame numbers, ROI data, and stimulation IDs.
        """

        processed_dir = 'processed_data/processed_image_analysis_output'
        calcium_csv_suffix = '_corrected_calcium_signals.csv' if use_corrected_data else '_calcium_signals.csv'

        directory_entry = self.directory_df[self.directory_df['session_id'] == session_id]

        if directory_entry.empty:
            print(f"No directory entry found for session {session_id}")
            return None, None, None

        directory_path = directory_entry['directory_path'].values[0]
        csv_path = os.path.join(directory_path, processed_dir, f"{session_id}{calcium_csv_suffix}")

        if not os.path.exists(csv_path):
            print(f"Calcium signals file not found for session {session_id} using {'corrected' if use_corrected_data else 'uncorrected'} data")
            return None, None, None

        calcium_signals_df = pd.read_csv(csv_path)
        stim_frame_numbers = directory_entry['stimulation_frame_number'].values[0]
        stimulation_ids = directory_entry['stimulation_ids'].values[0]

        pre_stim_frames = 10
        post_stim_frames = 100

        roi_data = {roi: {} for roi in calcium_signals_df.columns if 'ROI' in roi}

        for stim_id, stim_frame in zip(stimulation_ids, stim_frame_numbers):
            start_idx = max(stim_frame - pre_stim_frames, 0) 
            end_idx = min(stim_frame + post_stim_frames, len(calcium_signals_df))

            for roi in roi_data:
                trial = calcium_signals_df.loc[start_idx:end_idx, roi]
                roi_data[roi][(stim_id, stim_frame)] = trial.to_numpy()

        return stim_frame_numbers, roi_data, stimulation_ids
    
    def process_all_sessions(self, use_corrected_data=False):
        """
        Process all sessions using either corrected or uncorrected calcium signal data.

        Parameters
        ----------
        use_corrected_data : bool, optional
            Flag indicating whether to use corrected calcium signals. Defaults to False, 
            indicating uncorrected data should be used.

        Returns
        -------
        dict
            A dictionary with processed data for all sessions, keyed by session ID.
        """
        all_data = {}
        for session_id in self.directory_df['session_id'].unique():
            stim_frame_numbers, roi_data, stimulation_ids = self.create_trial_locked_calcium_signals(session_id, use_corrected_data=use_corrected_data)
            if stim_frame_numbers and roi_data and stimulation_ids:  # Ensure data was returned
                all_data[session_id] = {
                    'stim_frame_numbers': stim_frame_numbers,
                    'roi_data': roi_data,
                    'stimulation_ids': stimulation_ids
                }
        return all_data
    
    def process_all_sessions_entire_recording(self, use_corrected_data=False):
        """
        Processes all sessions and stores calcium signal dataframes in a dictionary.

        Parameters
        ----------
        use_corrected_data : bool, optional
            Whether to use corrected calcium signal data. The default is False, which uses uncorrected data.

        Returns
        -------
        dict
            A dictionary where each key is a session ID and the value is the corresponding calcium_signals dataframe.
        """
        processed_dir = 'processed_data/processed_image_analysis_output'
        calcium_csv_suffix = '_corrected_calcium_signals.csv' if use_corrected_data else '_calcium_signals.csv'
        session_data = {}

        for session_id in self.directory_df['session_id'].unique():
            directory_entry = self.directory_df[self.directory_df['session_id'] == session_id]

            if directory_entry.empty:
                print(f"No directory entry found for session {session_id}")
                continue  # Skip this session and proceed with the next

            directory_path = directory_entry['directory_path'].values[0]
            csv_path = os.path.join(directory_path, processed_dir, f"{session_id}{calcium_csv_suffix}")

            if not os.path.exists(csv_path):
                print(f"Calcium signals file not found for session {session_id} using {'corrected' if use_corrected_data else 'uncorrected'} data")
                continue  # Skip this session and proceed with the next

            calcium_signals_df = pd.read_csv(csv_path)
            # Store the dataframe in the dictionary with session_id as the key
            session_data[session_id] = calcium_signals_df

        return session_data
            
            
    def preprocess_and_extract_signals(self, session_id):
        """
        Renamed and extended functionality to include pre-processing of bioluminescence video data.
        Corrects the "Dark signal" for each ROI by calculating the mean of the first 100 frames
        of the signal and subtracts this value for each ROI from the entire series. Negative 
        values resulting from this subtraction are set as NaN.

        Parameters
        ----------
        session_id : str
            Unique identifier for the experimental session.

        Returns
        -------
        dict
            A dictionary containing processed and extracted calcium signal data for the session.
        """
        
        
            
        processed_dir = 'processed_data/processed_image_analysis_output'
        calcium_csv_suffix = '_calcium_signals.csv'

        directory_entry = self.directory_df[self.directory_df['session_id'] == session_id] #pull out the entry for the given session_id from the directory dataframe
            
        #pull out the list of stimulation frame numbers for the given session_id under the stimulation_frame_number column
        stim_frame_numbers = directory_entry['stimulation_frame_number'].values[0]
            
            #pull out the stimulation label for the given session_id under the stimulation_label column
        stimulation_ids = directory_entry['stimulation_ids'].values[0]
            
        if directory_entry.empty:
            print(f"No directory entry found for session {session_id}")
            return

        directory_path = directory_entry['directory_path'].values[0]
        csv_path = os.path.join(directory_path, processed_dir, session_id + calcium_csv_suffix)

        if not os.path.exists(csv_path):
            print(f"Calcium signals file not found for session {session_id}")
            return
        
        # ... (existing code remains unchanged up to calcium_signals_df loading)

        calcium_signals_df = pd.read_csv(csv_path) #import the calcium signals csv file
        
        # Correct the "Dark signal" for each ROI
        for roi in calcium_signals_df.columns:
            if 'ROI' in roi:  # Assuming ROI columns are prefixed with 'ROI'
                # Calculate the mean of the first 100 frames for the dark signal
                dark_signal_mean = calcium_signals_df[roi][:100].mean()
                # Subtract the dark signal mean from the entire series for this ROI
                calcium_signals_df[roi] = calcium_signals_df[roi] - dark_signal_mean
                # Set negative values to NaN
                calcium_signals_df.loc[calcium_signals_df[roi] < 0, roi] = np.nan

            # Convert the values in the calcium_signals dataframe to integers with no decimal points
            # Note: This may not be applicable anymore since you will have NaNs after correction
            # calcium_signals_df = calcium_signals_df.astype(int)
            
            # ... (the rest of your existing code for extracting trial-locked signals)
            #convert the values in the calcium_signals dataframe to integers with no decimal points
            calcium_signals_df = calcium_signals_df.astype(int)
            
            # Parameters for alignment
            pre_stim_frames = 10  # Number of frames before stimulation to include
            post_stim_frames = 100  # Number of frames after stimulation to include
            
            # Create a nested dictionary where each key-value pair corresponds to a different ROI. 
            # For each ROI, you have another dictionary where the key is a tuple of (stimulation_id, stim_frame_number), 
            # and the value is a NumPy array containing the calcium signal values for a window around the stimulation frame.

            # Initialize a nested dictionary to hold ROI, stimulation ID and frame number, and data
            roi_data = {roi: {} for roi in calcium_signals_df.columns if 'ROI' in roi}

            # Loop through each stimulation frame number and their corresponding stimulation IDs
            for stim_id, stim_frame in zip(stimulation_ids, stim_frame_numbers):
                # Calculate the index range for frames to extract
                start_idx = max(stim_frame - pre_stim_frames, 0)  # Ensure index is not negative
                end_idx = min(stim_frame + post_stim_frames, len(calcium_signals_df))  # Ensure index is within range

                # Loop through each ROI column
                for roi in roi_data:
                    # Extract the relevant section of the calcium signals for the ROI
                    trial = calcium_signals_df.loc[start_idx:end_idx, roi]

                    # Store the trial data as a NumPy array in the nested dictionary
                    # Using a tuple of (stimulation_id, stim_frame_number) as the key
                    roi_data[roi][(stim_id, stim_frame)] = trial.to_numpy().astype(int)
            
        return  stim_frame_numbers, roi_data, stimulation_ids


In [None]:
#instantiates an ImageAnalysis object
project_folder = '/Volumes/MannySSD/cablam_imaging/raw_data_for_analysis' #path to the folder containing the raw data to be analyzed (i.e. the folder containing the folders for each experiment)
analysis = ImageAnalysis(project_folder)
print(analysis.directory_df)
analysis.directory_df
#expand the directory dataframe with the new columns
analysis.expand_directory_df()

In [None]:
analysis_cablam = ImageAnalysis(project_folder)
analysis_cablam.expand_directory_df()
analysis_cablam.directory_df = analysis_cablam.directory_df[(analysis_cablam.directory_df['sensor_type'] == 'cablam') & (analysis_cablam.directory_df['directory_name'].str.contains('05xfz'))]
analysis_cablam.directory_df

analysis_gcamp8 = ImageAnalysis(project_folder)
analysis_gcamp8.expand_directory_df()
analysis_gcamp8.directory_df = analysis_gcamp8.directory_df[(analysis_gcamp8.directory_df['sensor_type'] == 'gcamp8')]
analysis_gcamp8.directory_df

analysis_cablam1x = ImageAnalysis(project_folder)
analysis_cablam1x.expand_directory_df()
analysis_cablam1x.directory_df = analysis_cablam1x.directory_df[(analysis_cablam1x.directory_df['sensor_type'] == 'cablam') & (analysis_cablam1x.directory_df['directory_name'].str.contains('1xfz'))]
analysis_cablam1x.directory_df



In [None]:
session_ids = analysis.directory_df['session_id']
print(session_ids)


In [None]:
# Usage example with a given session_id, assuming the session_id is present in the directory DataFrame and has a .tif file
session_id = '1212232023'  # Replace with your actual session_id
raw_data_path = analysis.get_session_raw_data(session_id) #get the raw data path for the given session_id

# below is the code to generate the max projection image for the raw data at the raw_data_path 
if raw_data_path and not raw_data_path.startswith("No .tif file"): #if the raw data path was found and does not start with "No .tif file"
    max_proj_image = analysis.max_projection_mean_values(raw_data_path) #generate the max projection image 


In [None]:
# Assuming you've already initialized your ImageAnalysis instance and populated directory_df:
analysis.add_tiff_dimensions()

# Now the directory_df has the dimensions for each TIFF file included
print(analysis.directory_df.head())  # Display the updated DataFrame to verify


In [None]:
analysis.directory_df

In [None]:
# run the analysis on all sessions in the directory_df

#define a wrapper function to apply the max_projection_mean_values method to all sessions
def analyze_session_max_projection(session_id):
    """
    Wrapper function to apply max_projection_mean_values to a session's TIF file.

    Parameters:
    session_id (str): The session ID for which the TIF file will be processed.

    Returns:
    str: Path to the processed max projection TIFF file.
    """
    # analysis is an instance of ImageAnalysis
    tif_path = analysis.get_session_raw_data(session_id)
    if isinstance(tif_path, str) and tif_path.endswith('.tif'):
        return analysis.max_projection_mean_values(tif_path)
    else:
        return f"No valid TIF file found for session {session_id}"# Apply max_projection_mean_values to all sessions

results = analysis.analyze_all_sessions(analyze_session_max_projection)

# Output the results
for session_id, result_path in results.items():
    if isinstance(result_path, str):
        print(f"Session ID {session_id}: Max projection image saved at {result_path}")
    else:
        print(f"Session ID {session_id}: {result_path}")

In [None]:
# Analyze ROIs for all sessions
analysis.analyze_all_rois()

In [None]:
# extract the calcium signal for a given session_id

session_id = '1212232023'  # Replace with your actual session_id
calcium_signals_path = analysis.extract_calcium_signals(session_id)
print(f"Calcium signals saved at: {calcium_signals_path}")


In [None]:
#extract calcium signals for all sessions in the directory_df, confirms the location of the saved calcium signals and the location of the saved calcium signals

all_results = analysis.analyze_all_calcium_signals()

for session_id, csv_path in all_results.items():
    if isinstance(csv_path, str):
        print(f"Session {session_id} - Calcium signals saved at: {csv_path}")
    else:
        print(f"Session {session_id} - Error: {csv_path}")

In [None]:
all_data = analysis.process_all_sessions(use_corrected_data=False)
all_data_gcamp8 = analysis_gcamp8.process_all_sessions(use_corrected_data=False)
all_data_cablam = analysis_cablam.process_all_sessions(use_corrected_data=True)
all_data_cablam1x = analysis_cablam1x.process_all_sessions(use_corrected_data=True)

#lets now extract the entire ROIs for all session in the directory_df and save as an attribuye 
all_data_cablam_session_data = analysis_cablam.process_all_sessions_entire_recording(use_corrected_data=True)
all_data_cablam1x_session_data = analysis_cablam1x.process_all_sessions_entire_recording(use_corrected_data=True)
all_data_gcamp8_session_data = analysis_gcamp8.process_all_sessions_entire_recording(use_corrected_data=False)



In [None]:
import importlib
import BL_CalciumAnalysis.image_analysis_methods as iam

importlib.reload(iam)
from BL_CalciumAnalysis.image_analysis_methods import print_cells_per_recording

print_cells_per_recording({
    "GCaMP": all_data_gcamp8,
    "Cablam 1x": all_data_cablam1x,
    "Cablam 0.5x": all_data_cablam,
})

In [None]:
all_data
all_data_gcamp8
all_data_cablam1x
all_data_cablam


In [None]:
#print every llevel of all_data_cablam with descriptive text
for session_id, session_data in all_data_gcamp8.items():
    print(f"Session ID: {session_id}")
    print(f"Stimulation Frame Numbers: {session_data['stim_frame_numbers']}")
    print(f"Stimulation IDs: {session_data['stimulation_ids']}")
    for roi, roi_data in session_data['roi_data'].items():
        print(f"ROI: {roi}")
        for key, value in roi_data.items():
            print(f"Stimulation ID, Frame Number: {key}")
            print(f"Data: {value}")
    print("\n")

In [None]:
#this function plots the mean respnonse for each session giveb a specific stim_id

def plot_mean_response_for_session_with_nans(roi_data, stim_id):
    """
    Plots the mean response for a given stim_id across all ROIs for a single session's data,
    accounting for potential NaN values in the data.

    :param roi_data: The roi_data for a single session (a dictionary with ROIs as keys and data as values).
    :param stim_id: The stimulation ID to plot data for.
    """
    selected_data = {}
    for roi, stim_data in roi_data.items():
        for (current_stim_id, _), data in stim_data.items():
            if current_stim_id == stim_id:
                # Select the data even if it contains NaNs; they will be ignored in the mean and std calculations
                selected_data[roi] = data

    # Ensure there is data to plot
    if not selected_data:
        print(f"No data found for stim_id {stim_id}")
        return

    # Use np.nanmean and np.nanstd to compute the statistics while ignoring NaNs
    stacked_data = np.stack(list(selected_data.values()))
    mean_response = np.nanmedian(stacked_data, axis=0)
    sem_response = np.nanstd(stacked_data, axis=0, ddof=1) / np.sqrt(np.sum(~np.isnan(stacked_data), axis=0))
    
    # Count the number of contributing ROIs at each time point
    valid_counts = np.sum(~np.isnan(stacked_data), axis=0)
    print(f"Number of contributing ROIs at each time point: {valid_counts}")
    
    time_points = np.arange(-10, 101)  # Assuming the same time range as before

    plt.figure(figsize=(10, 5))
    plt.plot(time_points, mean_response, label='Mean Response')
    plt.fill_between(time_points, mean_response - sem_response, mean_response + sem_response, alpha=0.3, label='SEM')

    # Add red dotted lines at time point 0
    plt.axvline(x=0, color='red', linestyle='--', linewidth=1)

    plt.xlabel('Time (relative to stimulus)')
    plt.ylabel('Calcium Signal')
    plt.title(f'Median Calcium Response for Stim ID {stim_id}')
    plt.legend()
    plt.show()
    
    return valid_counts

        
# Now to iterate over all sessions and plot for a specific stim_id
def plot_all_sessions_with_nans(all_data, stim_id=12):
    for session_id, session_data in all_data.items():
        print(f"Plotting for session {session_id}")
        plot_mean_response_for_session_with_nans(session_data['roi_data'], stim_id)
        
#run the plot_all_sessions function


plot_all_sessions_with_nans(all_data_cablam, stim_id=12)
plot_all_sessions_with_nans(all_data_gcamp8, stim_id=12)






In [None]:
def compile_pooled_responses_with_nans(all_data):
    """
    Compiles the mean response and SEM for each stim_id pooled across all ROIs and all sessions,
    accounting for potential NaN values in the data.

    :param all_data: The dictionary containing all session data.
    :return: A dictionary with the pooled mean responses and SEMs, considering NaNs.
    """
    pooled_responses = {}

    # Get all unique stim_ids from the data, considering all sessions and ROIs
    unique_stim_ids = set()
    for session_data in all_data.values():
        for roi_data in session_data['roi_data'].values():
            unique_stim_ids.update(stim_id for stim_id, _ in roi_data.keys())

    # Aggregate data for each stim_id and calculate mean and SEM, ignoring NaNs
    for stim_id in unique_stim_ids:
        pooled_data = []
        for session_data in all_data.items():
            roi_data = session_data[1]['roi_data']  # Access session data
            for roi, stim_data in roi_data.items():
                for (current_stim_id, _), data in stim_data.items():
                    if current_stim_id == stim_id:
                        pooled_data.append(data)

        # Check if there is any data collected for the stim_id
        if pooled_data:
            # Stack the pooled data and compute statistics ignoring NaN values
            stacked_data = np.stack(pooled_data)
            mean_response = np.nanmean(stacked_data, axis=0)
            sem_response = np.nanstd(stacked_data, axis=0, ddof=1) / np.sqrt(np.sum(~np.isnan(stacked_data), axis=0))

            # Store the computed mean response and SEM for each stim_id
            if stim_id not in pooled_responses:
                pooled_responses[stim_id] = {}
            pooled_responses[stim_id]['mean_response'] = mean_response
            pooled_responses[stim_id]['sem_response'] = sem_response

    return pooled_responses

pooled_nans_responses_gcamp8 = compile_pooled_responses_with_nans(all_data_gcamp8)
pooled_nans_responses_cablam = compile_pooled_responses_with_nans(all_data_cablam)
pooled_nans_responses_cablam1x = compile_pooled_responses_with_nans(all_data_cablam1x)



In [None]:
def plot_delta_f_over_f_with_nans(pooled_responses, specific_stim_ids=None, base_color='green', subtitle=''):
        """
        Plots the normalized calcium responses (ΔF/F) for all or specific stimulation IDs,
        handling NaN values in the responses. The line opacity reflects the intensity of the stimulation ID.
        
        Parameters
        ----------
        pooled_responses : dict
            A dictionary with stim IDs as keys and 'mean_response' and 'sem_response' as subkeys,
            potentially containing NaN values.
        specific_stim_ids : list of int, optional
            A list of stimulation IDs to be plotted. If None, all stim IDs are plotted. If an invalid stim ID is
            provided, a ValueError is raised.
        base_color : str, optional
            The color of the plot lines. The opacity of this color is adjusted based on the stimulation ID intensity.
        
        Raises
        ------
        ValueError
            If a specified stim ID is not present in the data.
        """
        plt.figure(figsize=(10, 5))
        
        if specific_stim_ids is not None:
            invalid_stim_ids = set(specific_stim_ids) - set(pooled_responses.keys())
            if invalid_stim_ids:
                raise ValueError(f"Invalid stim IDs provided: {invalid_stim_ids}")
            stim_ids = specific_stim_ids
        else:
            stim_ids = sorted(pooled_responses.keys())
        
        min_stim_id, max_stim_id = min(stim_ids), max(stim_ids)

        for stim_id in stim_ids:
            response = pooled_responses[stim_id]['mean_response']
            baseline = np.nanmedian(response[:10])  # Use np.nanmean to ignore NaNs in baseline calculation
            delta_f_over_f = (response - baseline) / baseline if baseline != 0 else np.zeros_like(response)
            
            alpha = 0.1 + 0.9 * (stim_id - min_stim_id) / (max_stim_id - min_stim_id)
            time_points = np.arange(-10, len(response) - 10)

            plt.plot(time_points, delta_f_over_f, label=f'Stim ID {stim_id}', color=base_color, alpha=alpha)

        plt.xlabel('Time (relative to stimulus)')
        plt.ylabel(r'$\Delta F/F$')
        # Set the main title and the custom subtitle
        plt.suptitle('Normalized Calcium Responses (ΔF/F) for Selected Stimulation IDs', fontsize=14)
        plt.title(subtitle, fontsize=10)  # Set the subtitle based on the input
        plt.legend()
        plt.show()



plot_delta_f_over_f_with_nans(pooled_nans_responses_gcamp8, specific_stim_ids=[12, 24, 36, 60, 120, 480], base_color='green', subtitle='GCaMP8s')
plot_delta_f_over_f_with_nans(pooled_nans_responses_cablam, specific_stim_ids=[12, 24, 36, 60, 120, 480], base_color='red', subtitle='CaBLAM with 1/2x Fz Dilution')
plot_delta_f_over_f_with_nans(pooled_nans_responses_cablam1x, specific_stim_ids=[12, 24, 36, 60, 120, 480], base_color='blue', subtitle='CaBLAM with 1x Fz Dilution')


In [None]:

def plot_delta_f_over_f_with_nans(pooled_responses, specific_stim_ids=None, base_color='green', subtitle='', plot_style='mean'):
    """
    Plots the normalized calcium responses (ΔF/F) for all or specific stimulation IDs,
    with options for plotting style including mean only or mean with SEM as dotted lines.
    
    Parameters
    ----------
    pooled_responses : dict
        A dictionary with stim IDs as keys and 'mean_response' and 'sem_response' as subkeys,
        potentially containing NaN values.
    specific_stim_ids : list of int, optional
        A list of stimulation IDs to be plotted.
    base_color : str, optional
        The color of the plot lines.
    subtitle : str, optional
        Custom subtitle text for the plot.
    plot_style : str, optional
        The plotting style: 'mean' (default) or 'mean_sem'.
    """
    plt.figure(figsize=(10, 5))
    
    if specific_stim_ids is not None:
        invalid_stim_ids = set(specific_stim_ids) - set(pooled_responses.keys())
        if invalid_stim_ids:
            raise ValueError(f"Invalid stim IDs provided: {invalid_stim_ids}")
        stim_ids = specific_stim_ids
    else:
        stim_ids = sorted(pooled_responses.keys())
    
    min_stim_id, max_stim_id = min(stim_ids), max(stim_ids)

    for stim_id in stim_ids:
        response = pooled_responses[stim_id]['mean_response']
        baseline = np.nanmedian(response[:10])
        delta_f_over_f = (response - baseline) / baseline if baseline != 0 else np.zeros_like(response)
        sem = pooled_responses[stim_id]['sem_response'] / baseline if baseline != 0 else np.zeros_like(response)

        alpha = 0.1 + 0.9 * (stim_id - min_stim_id) / (max_stim_id - min_stim_id)
        time_points = np.arange(-10, len(response) - 10)
        
        # Plot mean
        plt.plot(time_points, delta_f_over_f, label=f'Stim ID {stim_id}', color=base_color, alpha=alpha)

        # If 'mean_sem' style is chosen, also plot SEM as dotted lines
        if plot_style == 'mean_sem':
            plt.plot(time_points, delta_f_over_f + sem, linestyle='--', color=base_color, alpha=alpha)
            plt.plot(time_points, delta_f_over_f - sem, linestyle='--', color=base_color, alpha=alpha)

    plt.suptitle('Normalized Calcium Responses (ΔF/F)', fontsize=14)
    plt.title(subtitle, fontsize=10)
    plt.xlabel('Time (relative to stimulus)')
    plt.ylabel(r'$\Delta F/F$')
    plt.legend()
    plt.show()
    
plot_delta_f_over_f_with_nans(pooled_nans_responses_gcamp8, specific_stim_ids=[12, 24, 36, 60, 120, 480], base_color='green', subtitle='GCaMP8s', plot_style='mean_sem')
plot_delta_f_over_f_with_nans(pooled_nans_responses_cablam, specific_stim_ids=[12, 24, 36, 60, 120, 480], base_color='red', subtitle='CaBLAM with 1/2x Fz Dilution', plot_style='mean_sem')
plot_delta_f_over_f_with_nans(pooled_nans_responses_cablam1x, specific_stim_ids=[12, 24, 36, 60, 120, 480], base_color='blue', subtitle='CaBLAM with 1x Fz Dilution', plot_style='mean_sem')


In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_delta_f_over_f_with_nans(pooled_responses, specific_stim_ids=None, base_color='green', 
                                  subtitle='', plot_style='mean', plot_mode='overlay'):
    """
    Extended documentation...
    """
    plt.figure(figsize=(10, 5))
    
    if plot_mode not in ['overlay', 'subplot']:
        raise ValueError("plot_mode must be 'overlay' or 'subplot'")
    
    if specific_stim_ids is not None:
        stim_ids = specific_stim_ids
    else:
        stim_ids = sorted(pooled_responses.keys())
    
    # Calculate consistent color intensity for all plots
    min_stim_id, max_stim_id = min(stim_ids), max(stim_ids)
    alpha_values = {stim_id: 0.1 + 0.9 * (stim_id - min_stim_id) / (max_stim_id - min_stim_id) for stim_id in stim_ids}
    
    if plot_mode == 'overlay':
        for stim_id in stim_ids:
            response = pooled_responses[stim_id]['mean_response']
            baseline = np.nanmedian(response[:10])
            delta_f_over_f = (response - baseline) / baseline if baseline != 0 else np.zeros_like(response)
            sem = pooled_responses[stim_id]['sem_response'] / baseline if baseline != 0 else np.zeros_like(response)

            alpha = 0.1 + 0.9 * (stim_id - min_stim_id) / (max_stim_id - min_stim_id)
            time_points = np.arange(-10, len(response) - 10)
            
            # Plot mean
            plt.plot(time_points, delta_f_over_f, label=f'Stim ID {stim_id}', color=base_color, alpha=alpha)

            # If 'mean_sem' style is chosen, also plot SEM as dotted lines
            if plot_style == 'mean_sem':
                plt.plot(time_points, delta_f_over_f + sem, linestyle='--', color=base_color, alpha=alpha)
                plt.plot(time_points, delta_f_over_f - sem, linestyle='--', color=base_color, alpha=alpha)

        plt.suptitle('Normalized Calcium Responses (ΔF/F)', fontsize=14)
        plt.title(subtitle, fontsize=10)
        plt.xlabel('Time (relative to stimulus)')
        plt.ylabel(r'$\Delta F/F$')
        plt.legend()
        plt.show()

    elif plot_mode == 'subplot':
        # Calculate subplot size
        num_plots = len(stim_ids)
        fig, axs = plt.subplots(1, num_plots, figsize=(5 * num_plots, 5), sharey=True)
        
        for idx, stim_id in enumerate(stim_ids):
            response = pooled_responses[stim_id]['mean_response']
            baseline = np.nanmedian(response[:10])
            delta_f_over_f = (response - baseline) / baseline if baseline != 0 else np.zeros_like(response)
            sem = pooled_responses[stim_id]['sem_response'] / baseline if baseline != 0 else np.zeros_like(response)
            
            # Calculate alpha value for consistent color intensity
            alpha = alpha_values[stim_id]
            time_points = np.arange(-10, len(response) - 10)
            
            # Select the right subplot
            ax = axs[idx] if num_plots > 1 else axs
            # Plot mean
            ax.plot(time_points, delta_f_over_f, label=f'Stim ID {stim_id}', color=base_color, alpha=alpha)

            # If 'mean_sem' style is chosen, also plot SEM
            if plot_style == 'mean_sem':
                ax.fill_between(time_points, delta_f_over_f + sem, delta_f_over_f - sem, color=base_color, alpha=alpha*0.3)
            
            ax.set_title(f'Stim ID {stim_id}', fontsize=10)
            ax.set_xlabel('Time (relative to stimulus)')
            if idx == 0:  # Only add y-label to the first subplot
                ax.set_ylabel(r'$\Delta F/F$')
            ax.legend()

        plt.suptitle(f'Normalized Calcium Responses (ΔF/F) - {subtitle}', fontsize=14)
        plt.tight_layout()
        plt.show()
        

    
    
plot_delta_f_over_f_with_nans(pooled_responses=pooled_nans_responses_gcamp8, 
                              specific_stim_ids=[12, 24, 36, 60, 120, 480], 
                              base_color='green', 
                              subtitle='GCaMP8s', 
                              plot_style='mean_sem', 
                              plot_mode='subplot')

plot_delta_f_over_f_with_nans(pooled_responses=pooled_nans_responses_cablam,
                                specific_stim_ids=[12, 24, 36, 60, 120, 480],
                                base_color='red',
                                subtitle='CaBLAM with 1/2x Fz Dilution',
                                plot_style='mean_sem',
                                plot_mode='subplot')

plot_delta_f_over_f_with_nans(pooled_responses=pooled_nans_responses_cablam1x,
                                specific_stim_ids=[12, 24, 36, 60, 120, 480],
                                base_color='blue',
                                subtitle='CaBLAM with 1x Fz Dilution',
                                plot_style='mean_sem',
                                plot_mode='subplot')



In [None]:
def plot_overlay_delta_f_over_f_with_nans(pooled_responses_1, pooled_responses_2, specific_stim_ids=None,
                                          color_1='green', color_2='red', subtitle_1='', subtitle_2='',
                                          plot_style='mean_sem'):

    plt.figure(figsize=(5 * len(specific_stim_ids), 5))

    if specific_stim_ids is None:
        raise ValueError("specific_stim_ids must be provided")

    for idx, stim_id in enumerate(specific_stim_ids):
        ax = plt.subplot(1, len(specific_stim_ids), idx + 1)
        
        for pooled_responses, base_color, subtitle in zip([pooled_responses_1, pooled_responses_2],
                                                          [color_1, color_2], [subtitle_1, subtitle_2]):
            response = pooled_responses[stim_id]['mean_response']
            baseline = np.nanmedian(response[:10])
            delta_f_over_f = (response - baseline) / baseline if baseline != 0 else np.zeros_like(response)
            sem = pooled_responses[stim_id]['sem_response'] / baseline if baseline != 0 else np.zeros_like(response)
            
            time_points = np.arange(-10, len(response) - 10)*10
            
            ax.plot(time_points, delta_f_over_f, label=f'{subtitle}, Stim ID {stim_id}', color=base_color)
            if plot_style == 'mean_sem':
                ax.fill_between(time_points, delta_f_over_f + sem, delta_f_over_f - sem, color=base_color, alpha=0.3)
        
        ax.set_title(f'Stim ID {stim_id}', fontsize=10)
        ax.set_xlabel('Time (relative to stimulus, ms)')
        if idx == 0:
            ax.set_ylabel(r'$\Delta F/F$')
        ax.legend()
        
    plt.suptitle(f'Normalized Calcium Responses (ΔF/F)', fontsize=14)
    plt.tight_layout()
    plt.show()
    
plot_overlay_delta_f_over_f_with_nans(pooled_responses_1=pooled_nans_responses_gcamp8, pooled_responses_2=pooled_nans_responses_cablam1x, specific_stim_ids=[12, 24, 36, 60, 120, 480],
                                          color_1='green', color_2='blue', subtitle_1='', subtitle_2='',
                                          plot_style='mean_sem')

plot_overlay_delta_f_over_f_with_nans(pooled_responses_1=pooled_nans_responses_gcamp8, pooled_responses_2=pooled_nans_responses_cablam, specific_stim_ids=[12, 24, 36, 60, 120, 480],
                                          color_1='green', color_2='red', subtitle_1='', subtitle_2='',
                                          plot_style='mean_sem')

In [None]:
def plot_delta_f_over_f_with_nans(pooled_responses, specific_stim_ids=None, base_color='green', 
                                  subtitle='', plot_style='mean', normalization_method='peak'):
    """
    Plots the normalized calcium responses (ΔF/F) for all or specific stimulation IDs,
    with options for plotting style including mean only or mean with SEM as shaded areas,
    and different normalization methods.

    Parameters
    ----------
    pooled_responses : dict
        A dictionary with stim IDs as keys and 'mean_response' and 'sem_response' as subkeys,
        potentially containing NaN values.
    specific_stim_ids : list of int, optional
        A list of stimulation IDs to be plotted.
    base_color : str, optional
        The color of the plot lines.
    subtitle : str, optional
        Custom subtitle text for the plot.
    plot_style : str, optional
        The plotting style: 'mean' (default) or 'mean_sem'.
    normalization_method : str, optional
        The normalization method: 'baseline' (default) or 'peak'.
    """
    plt.figure(figsize=(10, 5))
    
    if specific_stim_ids is not None:
        invalid_stim_ids = set(specific_stim_ids) - set(pooled_responses.keys())
        if invalid_stim_ids:
            raise ValueError(f"Invalid stim IDs provided: {invalid_stim_ids}")
        stim_ids = specific_stim_ids
    else:
        stim_ids = sorted(pooled_responses.keys())
    
    for stim_id in stim_ids:
        response = pooled_responses[stim_id]['mean_response']
        baseline = np.nanmedian(response[:10])
        delta_f_over_f = (response - baseline) / baseline if baseline != 0 else np.zeros_like(response)
        
        if normalization_method == 'peak':
            peak_delta = np.nanmax(delta_f_over_f)
            delta_f_over_f /= peak_delta if peak_delta != 0 else 1
        
        sem = pooled_responses[stim_id]['sem_response'] / baseline if baseline != 0 else np.zeros_like(response)
        if normalization_method == 'peak' and peak_delta != 0:
            sem /= peak_delta

        alpha = 0.1 + 0.9 * (stim_id - min(stim_ids)) / (max(stim_ids) - min(stim_ids))
        time_points = np.arange(-10, len(response) - 10)
        
        # Plot mean
        plt.plot(time_points, delta_f_over_f, label=f'Stim ID {stim_id}', color=base_color, alpha=alpha)

        # If 'mean_sem' style is chosen, also plot SEM as a shaded area
        if plot_style == 'mean_sem':
            plt.fill_between(time_points, delta_f_over_f - sem, delta_f_over_f + sem, color=base_color, alpha=0.3 * alpha)

    plt.suptitle('Normalized Calcium Responses (ΔF/F)', fontsize=14)
    plt.title(subtitle, fontsize=10)
    plt.xlabel('Time (relative to stimulus)')
    plt.ylabel(r'$\Delta F/F$')
    plt.legend()
    plt.show()

plot_delta_f_over_f_with_nans(
    pooled_responses=pooled_nans_responses_gcamp8, 
    specific_stim_ids=[12, 24, 36, 60, 120, 480],
    base_color='green', 
    subtitle='GCaMP8s', 
    plot_style='mean_sem', 
    normalization_method='peak')


plot_delta_f_over_f_with_nans(
    pooled_responses=pooled_nans_responses_cablam, 
    specific_stim_ids=[12, 24, 36, 60, 120, 480],
    base_color='red', 
    subtitle='cablam', 
    plot_style='mean_sem', 
    normalization_method='peak')

plot_delta_f_over_f_with_nans(
    pooled_responses=pooled_nans_responses_cablam1x, 
    specific_stim_ids=[12, 24, 36, 60, 120, 480],
    base_color='blue', 
    subtitle='cablam1x', 
    plot_style='mean_sem', 
    normalization_method='peak')

In [None]:
def plot_delta_f_over_f_with_nans_maxnorm(pooled_responses, specific_stim_ids=None, base_color='green', 
                                  subtitle='', plot_style='mean', normalization_method='peak'):
    """
    Plots the normalized calcium responses (ΔF/F) for all or specific stimulation IDs,
    with options for plotting style including mean only or mean with SEM as dotted lines,
    and different normalization methods.

    Parameters
    ----------
    pooled_responses : dict
        A dictionary with stim IDs as keys and 'mean_response' and 'sem_response' as subkeys,
        potentially containing NaN values.
    specific_stim_ids : list of int, optional
        A list of stimulation IDs to be plotted.
    base_color : str, optional
        The color of the plot lines.
    subtitle : str, optional
        Custom subtitle text for the plot.
    plot_style : str, optional
        The plotting style: 'mean' (default) or 'mean_sem'.
    normalization_method : str, optional
        The normalization method: 'baseline' (default) or 'peak'.
    """
    plt.figure(figsize=(10, 5))
    
    if specific_stim_ids is not None:
        invalid_stim_ids = set(specific_stim_ids) - set(pooled_responses.keys())
        if invalid_stim_ids:
            raise ValueError(f"Invalid stim IDs provided: {invalid_stim_ids}")
        stim_ids = specific_stim_ids
    else:
        stim_ids = sorted(pooled_responses.keys())
    
    for stim_id in stim_ids:
        response = pooled_responses[stim_id]['mean_response']
        baseline = np.nanmedian(response[:10])
        delta_f_over_f = (response - baseline) / baseline if baseline != 0 else np.zeros_like(response)
        
        if normalization_method == 'peak':
            peak_delta = np.nanmax(delta_f_over_f)
            delta_f_over_f /= peak_delta if peak_delta != 0 else 1
        
        sem = pooled_responses[stim_id]['sem_response'] / baseline if baseline != 0 else np.zeros_like(response)
        if normalization_method == 'peak' and peak_delta != 0:
            sem /= peak_delta

        alpha = 0.1 + 0.9 * (stim_id - min(stim_ids)) / (max(stim_ids) - min(stim_ids))
        time_points = np.arange(-10, len(response) - 10)
        
        # Plot mean
        plt.plot(time_points, delta_f_over_f, label=f'Stim ID {stim_id}', color=base_color, alpha=alpha)

        # If 'mean_sem' style is chosen, also plot SEM as dotted lines
        if plot_style == 'mean_sem':
            plt.plot(time_points, delta_f_over_f + sem, linestyle='--', color=base_color, alpha=alpha)
            plt.plot(time_points, delta_f_over_f - sem, linestyle='--', color=base_color, alpha=alpha)

    plt.suptitle('Normalized Calcium Responses (ΔF/F)', fontsize=14)
    plt.title(subtitle, fontsize=10)
    plt.xlabel('Time (relative to stimulus)')
    plt.ylabel(r'$\Delta F/F$')
    plt.legend()
    plt.show()
    
plot_delta_f_over_f_with_nans_maxnorm(
    pooled_responses=pooled_nans_responses_gcamp8, 
    specific_stim_ids=[12, 24, 36, 60, 120, 480],
    base_color='green', 
    subtitle='GCaMP8s', 
    plot_style='mean_sem', 
    normalization_method='peak')




In [None]:
def plot_overlay_delta_f_over_f_with_nans(pooled_responses_1, pooled_responses_2, specific_stim_ids=None,
                                          color_1='green', color_2='red', subtitle_1='', subtitle_2='',
                                          plot_style='mean_sem'):
    """
    Plots an overlay of normalized calcium responses (ΔF/F) for two sets of pooled responses
    on the same axes for direct comparison, with options for plotting style including mean only
    or mean with SEM as dotted lines.

    Parameters
    ----------
    pooled_responses_1 : dict
        The first set of pooled responses with stim IDs as keys and 'mean_response' and 'sem_response' as subkeys.
    pooled_responses_2 : dict
        The second set of pooled responses with stim IDs as keys and 'mean_response' and 'sem_response' as subkeys.
    specific_stim_ids : list of int, optional
        A list of stimulation IDs to be plotted.
    color_1 : str, optional
        The color of the plot lines for the first set of responses.
    color_2 : str, optional
        The color of the plot lines for the second set of responses.
    subtitle_1 : str, optional
        Custom subtitle text for the first set of responses.
    subtitle_2 : str, optional
        Custom subtitle text for the second set of responses.
    plot_style : str, optional
        The plotting style: 'mean' or 'mean_sem'.
    """
    plt.figure(figsize=(5 * len(specific_stim_ids), 5))

    if specific_stim_ids is None:
        raise ValueError("specific_stim_ids must be provided")

    for idx, stim_id in enumerate(specific_stim_ids):
        ax = plt.subplot(1, len(specific_stim_ids), idx + 1)
        
        for pooled_responses, base_color, subtitle in zip([pooled_responses_1, pooled_responses_2],
                                                          [color_1, color_2], [subtitle_1, subtitle_2]):
            response = pooled_responses[stim_id]['mean_response']
            baseline = np.nanmedian(response[:10])
            delta_f_over_f = (response - baseline) / baseline if baseline != 0 else np.zeros_like(response)
            
            
            peak_delta = np.nanmax(delta_f_over_f)
            delta_f_over_f /= peak_delta if peak_delta != 0 else 1
            
            sem = pooled_responses[stim_id]['sem_response'] / baseline if baseline != 0 else np.zeros_like(response)
            sem /= peak_delta if peak_delta != 0 else 1

            time_points = np.arange(-10, len(response) - 10)
            
            ax.plot(time_points, delta_f_over_f, label=f'{subtitle}, Stim ID {stim_id}', color=base_color)
            if plot_style == 'mean_sem':
                #ax.plot(time_points, delta_f_over_f + sem, linestyle='--', color=base_color)
                #ax.plot(time_points, delta_f_over_f - sem, linestyle='--', color=base_color)
                
                #plot shaded area
                ax.fill_between(time_points, delta_f_over_f + sem, delta_f_over_f - sem, color=base_color, alpha=0.3)
                
        
        ax.set_title(f'Stim ID {stim_id}', fontsize=10)
        ax.set_xlabel('Time (relative to stimulus)')
        if idx == 0:
            ax.set_ylabel(r'$\Delta F/F$')
        ax.legend()
        
    plt.suptitle('Normalized Calcium Responses (ΔF/F)', fontsize=14)
    plt.tight_layout()
    plt.show()

plot_overlay_delta_f_over_f_with_nans(pooled_responses_1=pooled_nans_responses_gcamp8, pooled_responses_2=pooled_nans_responses_cablam1x, specific_stim_ids=[12, 24, 36, 60, 120, 480],
                                          color_1='green', color_2='blue', subtitle_1='', subtitle_2='',
                                          plot_style='mean_sem')

plot_overlay_delta_f_over_f_with_nans(pooled_responses_1=pooled_nans_responses_cablam1x, pooled_responses_2=pooled_nans_responses_cablam, specific_stim_ids=[12, 24, 36, 60, 120, 480],
                                          color_1='red', color_2='blue', subtitle_1='', subtitle_2='',
                                          plot_style='mean_sem')

plot_overlay_delta_f_over_f_with_nans(pooled_responses_1=pooled_nans_responses_cablam1x, pooled_responses_2=pooled_nans_responses_cablam, specific_stim_ids=[12, 24, 36, 60, 120, 480, 960, 1920],
                                          color_1='red', color_2='blue', subtitle_1='', subtitle_2='',
                                          plot_style='mean')

In [None]:
def plot_overlay_delta_f_over_f_with_nans(pooled_responses_1, pooled_responses_2, specific_stim_ids=None,
                                          color_1='green', color_2='red', subtitle_1='', subtitle_2='',
                                          plot_style='mean_sem'):
    """
    Plots an overlay of normalized calcium responses (ΔF/F) for two sets of pooled responses
    on the same axes for direct comparison, with options for plotting style including mean only
    or mean with SEM as dotted lines, and calculations for time to peak and half-decay.

    Parameters
    ----------
    ... [other parameters as before] ...
    """
    bin_size_ms = 100  # Bin size in milliseconds

    plt.figure(figsize=(5 * len(specific_stim_ids), 5))

    if specific_stim_ids is None:
        raise ValueError("specific_stim_ids must be provided")

    for idx, stim_id in enumerate(specific_stim_ids):
        ax = plt.subplot(1, len(specific_stim_ids), idx + 1)

        for pooled_responses, base_color, subtitle in zip([pooled_responses_1, pooled_responses_2],
                                                          [color_1, color_2], [subtitle_1, subtitle_2]):
            response = pooled_responses[stim_id]['mean_response']
            baseline = np.nanmedian(response[:10])
            delta_f_over_f = (response - baseline) / baseline if baseline != 0 else np.zeros_like(response)

            peak_delta = np.nanmax(delta_f_over_f)
            delta_f_over_f /= peak_delta if peak_delta != 0 else 1

            sem = pooled_responses[stim_id]['sem_response'] / baseline if baseline != 0 else np.zeros_like(response)
            sem /= peak_delta if peak_delta != 0 else 1

            time_points = np.arange(-10, len(response) - 10) * bin_size_ms  # Convert to milliseconds

            # Calculate time to peak
            peak_index = np.nanargmax(delta_f_over_f)
            
            # This is assuming the stimulus occurs at index 10 of your response array, and each bin represents bin_size_ms milliseconds.
            pre_stimulus_bins = 10
            time_to_peak = (peak_index - pre_stimulus_bins) * bin_size_ms
            
            # Calculate half-decay time
            half_decay_value = peak_delta / 2
            # Find the first index after the peak where the value falls below half the peak
            half_decay_index = np.where(delta_f_over_f[peak_index:] < half_decay_value)[0]
            half_decay_time = (half_decay_index[0] + peak_index) * bin_size_ms if half_decay_index.size > 0 else np.nan
            
            ax.plot(time_points, delta_f_over_f, label=f'{subtitle}, Stim ID {stim_id}', color=base_color)
            if plot_style == 'mean_sem':
                ax.plot(time_points, delta_f_over_f + sem, linestyle='--', color=base_color)
                ax.plot(time_points, delta_f_over_f - sem, linestyle='--', color=base_color)

            print(f'{subtitle}, Stim ID {stim_id}: Time to peak = {time_to_peak} ms; Half-decay time = {half_decay_time} ms')
            
            #draw a line at the time to peak
            ax.axvline(time_to_peak, color=base_color, linestyle='--', alpha=0.5)
            

        ax.set_title(f'Stim ID {stim_id}', fontsize=10)
        ax.set_xlabel('Time (relative to stimulus) [ms]')
        if idx == 0:
            ax.set_ylabel(r'$\Delta F/F$')
        ax.legend()

    plt.suptitle('Normalized Calcium Responses (ΔF/F)', fontsize=14)
    plt.tight_layout()
    plt.show()

plot_overlay_delta_f_over_f_with_nans(pooled_responses_1=pooled_nans_responses_gcamp8, pooled_responses_2=pooled_nans_responses_cablam, specific_stim_ids=[12, 24, 36, 60, 120, 480],
                                          color_1='green', color_2='blue', subtitle_1='', subtitle_2='',
                                          plot_style='mean')

plot_overlay_delta_f_over_f_with_nans(pooled_responses_1=pooled_nans_responses_cablam, pooled_responses_2=pooled_nans_responses_cablam1x, specific_stim_ids=[12, 24, 36, 60, 120, 480, 960, 1920],
                                          color_1='red', color_2='blue', subtitle_1='', subtitle_2='',
                                          plot_style='mean')

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_delta_f_over_f_subplots(pooled_responses, specific_stim_ids=None, base_color='green'):
    """
    Plots normalized calcium responses (ΔF/F) for all or specific stimulation IDs as subplots,
    with the line opacity reflecting the intensity of the stimulation ID.
    
    Parameters
    ----------
    pooled_responses : dict
        A dictionary with stim IDs as keys and 'mean_response' and 'sem_response' as subkeys.
    specific_stim_ids : list of int, optional
        A list of stimulation IDs to be plotted. If None (default), all stim IDs will be plotted.
        If an invalid stim ID is provided, a ValueError will be raised.
    base_color : str, optional
        The color of the plot lines. Default is 'green'. This function will adjust the opacity of this color
        based on the stimulation ID intensity.
    
    Raises
    ------
    ValueError
        If a specified stim ID is not present in the data.
    
    Examples
    --------
    >>> plot_delta_f_over_f_subplots(pooled_responses) # Plots all responses as subplots with green lines
    >>> plot_delta_f_over_f_subplots(pooled_responses, specific_stim_ids=[12, 24, 36], base_color='blue')
        # Plots responses for stim IDs 12, 24, and 36 as subplots with blue lines
    """
    # Check if specific_stim_ids have been provided
    if specific_stim_ids is not None:
        # Verify that the provided stim IDs are valid
        invalid_stim_ids = set(specific_stim_ids) - set(pooled_responses.keys())
        if invalid_stim_ids:
            raise ValueError(f"Invalid stim IDs provided: {invalid_stim_ids}")
        stim_ids = specific_stim_ids
    else:
        stim_ids = sorted(pooled_responses.keys())
    
    # Determine the number of subplots based on the number of stim IDs
    num_subplots = len(stim_ids)
    fig, axes = plt.subplots(num_subplots, 1, figsize=(10, 5 * num_subplots), sharex=True)
    
    if num_subplots == 1:
        axes = [axes]  # Make sure axes is iterable for a single subplot case

    # Get the min and max stim IDs for opacity scaling
    min_stim_id, max_stim_id = min(stim_ids), max(stim_ids)
    
    for ax, stim_id in zip(axes, stim_ids):
        response = pooled_responses[stim_id]['mean_response']
        baseline = np.nanmean(response[:10])  # Adjust this index to match the pre-stimulus period
        delta_f_over_f = (response - baseline) / baseline
        
        # Normalize the stim_id to get an alpha value between 0.1 and 1.0
        alpha = 0.1 + 0.9 * (stim_id - min_stim_id) / (max_stim_id - min_stim_id)

        time_points = np.arange(-10, len(response) - 10)  # Adjust the range as necessary

        ax.plot(time_points, delta_f_over_f, label=f'Stim ID {stim_id}', color=base_color, alpha=alpha)
        ax.set_ylabel(r'$\Delta F/F$')
        ax.legend()

    # Set the xlabel for the last subplot
    axes[-1].set_xlabel('Time (relative to stimulus)')
    plt.suptitle('Normalized Calcium Responses (ΔF/F) for Selected Stimulation IDs')
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])  # Adjust layout to accommodate the suptitle
    plt.show()

# Example usage:

plot_delta_f_over_f_subplots(pooled_nans_responses_gcamp8, specific_stim_ids=[12, 24, 36, 60, 120, 480], base_color='green')
plot_delta_f_over_f_subplots(pooled_nans_responses_cablam, specific_stim_ids=[12, 24, 36, 60, 120, 480], base_color='red')
plot_delta_f_over_f_subplots(pooled_nans_responses_cablam1x, specific_stim_ids=[12, 24, 36, 60, 120, 480], base_color='blue')


In [None]:

import numpy as np
import pandas as pd
from scipy.stats import ttest_ind

def calculate_responsiveness(all_data, pre_stim_frames=10, post_stim_frames=10, alpha=0.01, return_dataframe=False):
    """
    This function calculates and identifies responsive cells within calcium imaging data, applying statistical 
    tests to determine whether the change in signal post-stimulation is significant compared to the pre-stimulation 
    baseline. It stores detailed metrics including means, standard deviations, and p-values for each ROI across all sessions.

    Parameters:
    - all_data (dict): Nested dictionary containing the processed calcium signal data for multiple sessions, 
      structured with session IDs as top-level keys.
    - pre_stim_frames (int): The number of frames before the stimulus used to calculate the baseline signal.
    - post_stim_frames (int): The number of frames after the stimulus used for post-stimulus signal analysis.
    - alpha (float): The significance level used to determine if a response is statistically significant.
    - return_dataframe (bool): If set to True, the function also returns a pandas DataFrame containing the computed metrics.

    Returns:
    - dict: A nested dictionary containing calculated metrics for each session ID, ROI, and stimulus event. If 
      `return_dataframe` is True, it also returns a DataFrame alongside this dictionary.

    The output dictionary follows a multi-level structure:
    - Level 1 (Session Level): Keys are session IDs, and values are dictionaries containing data for each session.
    - Level 2 (ROI Level): Within each session dictionary, keys are ROIs, and values are dictionaries with metrics for each ROI.
    - Level 3 (Stimulus Event Level): For each ROI, keys are tuples of (stimulation_id, stim_frame_number), and values 
      are dictionaries containing the metrics calculated for each stimulus event.

    Metrics included for each stimulus event:
    - 'pre_stim_mean': Mean of the signal in the pre-stimulus period.
    - 'pre_stim_sd': Standard deviation of the signal in the pre-stimulus period.
    - 'post_stim_peak': Maximum signal value in the post-stimulus period (not normalized).
    - 'post_stim_sd': Standard deviation of the signal in the post-stimulus period, excluding the peak value.
    - 'p_value': P-value from the t-test comparing pre-stimulus and post-stimulus signals.
    - 'is_responsive': Boolean indicating whether the ROI is considered responsive based on the p-value being below alpha.

    
    Returns:
    dict or (dict, pd.DataFrame): A dictionary and optionally a DataFrame containing all metrics and SDs for each session ID, ROI, and stimulus.
    """
    responsiveness_data = {}
    dataframe_rows = []

    for session_id, session_data in all_data.items():
        session_responsiveness = {}
        for roi, roi_data in session_data['roi_data'].items():
            roi_responsiveness = {}
            for (stim_id, stim_frame), signal_data in roi_data.items():
                # Validate signal_data length
                if signal_data.size >= (pre_stim_frames + post_stim_frames + 1):
                    pre_stim_signal = signal_data[:pre_stim_frames]
                    post_stim_signal = signal_data[pre_stim_frames + 1 : pre_stim_frames + 1 + post_stim_frames]
                    
  
                    
                    # New calculation for the entire array
                    delta_f_f_full_array = (signal_data - np.mean(signal_data[:pre_stim_frames])) / np.mean(signal_data[:pre_stim_frames])
                    
                    

                    # Calculate means and SDs
                    pre_stim_mean = np.mean(pre_stim_signal)
                    pre_stim_sd = np.std(pre_stim_signal)
                    post_stim_peak = np.nanmax(post_stim_signal) if not np.isnan(np.nanmax(post_stim_signal)) else np.nan
                    post_stim_sd = np.std(post_stim_signal[1:])  # Excluding the peak (stimulation point)
                    post_stim_peak_index = np.nanargmax(post_stim_signal) if not np.isnan(post_stim_peak) else np.nan
                    
                    #calculate the median of the post_stim_signal and the median of the pre_stim_signal
                    post_stim_median = np.median(post_stim_signal)
                    pre_stim_median = np.median(pre_stim_signal)
                    
                    #calculate the mean of the post_stim_signal and the mean of the pre_stim_signal
                    post_stim_mean = np.mean(post_stim_signal)
                    
                    # calculate the delta_f/f for the post_stim_signal and the pre_stim_signal and save entire array
                    delta_f_f_post_stim = (post_stim_signal - pre_stim_mean) / pre_stim_mean
                    
                    # calculate the peak delta_f/f for the post_stim_signal and save the value 
                    peak_delta_f_f_post_stim = (post_stim_peak - pre_stim_mean) / pre_stim_mean
            
                    # Perform t-test between normalized pre-stimulus and post-stimulus signals
                    t_stat, p_value = ttest_ind(pre_stim_signal, post_stim_signal, equal_var=False)

                    # Determine responsiveness based on the p-value without explicit prior length check
                    is_responsive = p_value < alpha if not np.isnan(p_value) else False
                    
                    # Time metrics calculations with safety checks
                    half_peak_value = post_stim_peak / 2 if not np.isnan(post_stim_peak) else np.nan
                    half_rise_index = np.where(post_stim_signal >= half_peak_value)[0][0] if np.any(post_stim_signal >= half_peak_value) else np.nan
                    half_decay_index = np.where(post_stim_signal[post_stim_peak_index:] <= half_peak_value)[0][0] + post_stim_peak_index if post_stim_peak_index and np.any(post_stim_signal[post_stim_peak_index:] <= half_peak_value) else np.nan

                    # Convert indices to milliseconds
                    # Adjusted line with conditional to ensure a minimum of 100 ms:
                    time_to_peak = max(100, post_stim_peak_index * 100) if not np.isnan(post_stim_peak_index) else np.nan
                    half_rise_time = half_rise_index * 100 if not np.isnan(half_rise_index) else np.nan
                    half_decay_time = half_decay_index * 100 if not np.isnan(half_decay_index) else np.nan

                # Save all calculated metrics
                roi_responsiveness[(stim_id, stim_frame)] = {
                    'pre_stim_mean': pre_stim_mean,
                    'pre_stim_sd': pre_stim_sd,
                    'post_stim_peak': post_stim_peak,
                    'post_stim_sd': post_stim_sd,
                    'p_value': p_value,
                    'post_stim_mean': post_stim_mean,
                    'delta_f_f_post_stim': delta_f_f_post_stim,
                    'pre_stim_median': pre_stim_median,
                    'post_stim_median': post_stim_median,
                    'peak_delta_f_f_post_stim': peak_delta_f_f_post_stim,
                    'is_responsive': is_responsive
            
                }

                # Append data for DataFrame
                dataframe_rows.append({
                    'session_id': session_id,
                    'roi': roi,
                    'stimulation_id': stim_id,
                    'stim_frame_number': stim_frame,
                    'pre_stim_mean': pre_stim_mean,
                    'pre_stim_sd': pre_stim_sd,
                    'post_stim_peak': post_stim_peak,
                    'post_stim_sd': post_stim_sd,
                    'post_stim_mean': post_stim_mean,
                    'delta_f_f_post_stim': delta_f_f_post_stim*100,
                    'pre_stim_median': pre_stim_median,
                    'post_stim_median': post_stim_median,
                    'peak_delta_f_f_post_stim': peak_delta_f_f_post_stim*100,
                    'delta_f_f_full_array': delta_f_f_full_array*100,
                    'raw_signal': signal_data,
                    'p_value': p_value,
                    'time_to_peak': time_to_peak,
                    'half_rise_time': half_rise_time,
                    'half_decay_time': half_decay_time,
                    'is_responsive': is_responsive
                })

            session_responsiveness[roi] = roi_responsiveness
        responsiveness_data[session_id] = session_responsiveness

    # Create and return DataFrame if requested
    if return_dataframe:
        responsiveness_df = pd.DataFrame(dataframe_rows)
        return responsiveness_data, responsiveness_df
    else:
        return responsiveness_data
    
    

def filter_responsive_rois(all_data, responsiveness_data):
    """
    Creates a new data structure similar to all_data but excludes the data for non-responsive ROIs 
    for specific stimulation IDs, maintaining only responsive ROI data.

    Parameters:
    all_data (dict): Original dictionary with the complete dataset.
    responsiveness_data (dict): Dictionary containing responsiveness information for each ROI.

    Returns:
    dict: A new dictionary mirroring all_data's structure but excluding data for non-responsive ROIs per stimulus.
    """
    filtered_data = {}

    for session_id, session_content in all_data.items():
        filtered_data[session_id] = {
            'stim_frame_numbers': session_content['stim_frame_numbers'],
            'roi_data': {},
            'stimulation_ids': session_content['stimulation_ids']
        }

        for roi, roi_data in session_content['roi_data'].items():
            filtered_roi_data = {}

            for stim_key, signal_data in roi_data.items():
                # Include the data only if the ROI is responsive for this stimulus
                if responsiveness_data[session_id][roi].get(stim_key, {}).get('is_responsive', False):
                    filtered_roi_data[stim_key] = signal_data
            
            # Update only if there's at least one responsive stim event for the ROI
            if filtered_roi_data:
                filtered_data[session_id]['roi_data'][roi] = filtered_roi_data

    return filtered_data


responsiveness_data, responsiveness_df = calculate_responsiveness(all_data, return_dataframe=True)
responsiveness_data_gcamp8, responsiveness_df_gcamp8 = calculate_responsiveness(all_data_gcamp8, return_dataframe=True)
responsiveness_data_cablam, responsiveness_df_cablam = calculate_responsiveness(all_data_cablam, return_dataframe=True)
responsiveness_data_cablam1x, responsiveness_df_cablam1x = calculate_responsiveness(all_data_cablam1x, return_dataframe=True)

filtered_data_gcamp8 = filter_responsive_rois(all_data_gcamp8, responsiveness_data_gcamp8)
pooled_responses_filtered_gcamp8 = compile_pooled_responses(filtered_data_gcamp8)
plot_delta_f_over_f(pooled_responses_filtered_gcamp8, specific_stim_ids=[12, 36, 60, 120, 480], base_color='green')

filtered_data_cablam = filter_responsive_rois(all_data_cablam, responsiveness_data_cablam)
pooled_responses_filtered_cablam = compile_pooled_responses(filtered_data_cablam)
plot_delta_f_over_f(pooled_responses_filtered_cablam,specific_stim_ids=[12, 36, 60, 120, 480], base_color='blue')

filtered_data_cablam1x = filter_responsive_rois(all_data_cablam1x, responsiveness_data_cablam1x)
pooled_responses_filtered_cablam1x = compile_pooled_responses(filtered_data_cablam1x)
plot_delta_f_over_f(pooled_responses_filtered_cablam1x,specific_stim_ids=[12, 36, 60, 120, 480], base_color='red')






In [None]:
responsiveness_df_gcamp8

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit


class SensorDataPlotter:
    def __init__(self, data_frames, sensor_names, sensor_box_colors, sensor_strip_colors):
        """
        Initialize the object with a list of data frames, corresponding sensor names, and specific colors for each sensor.
        :param data_frames: List of pandas DataFrames containing the sensor data.
        :param sensor_names: List of strings representing the names of the sensors.
        :param sensor_box_colors: Dictionary mapping sensor names to boxplot colors.
        :param sensor_strip_colors: Dictionary mapping sensor names to stripplot colors.
        """
        self.data_frames = data_frames
        self.sensor_names = sensor_names
        self.sensor_box_colors = sensor_box_colors
        self.sensor_strip_colors = sensor_strip_colors
        self.combined_df = None
        
    def prepare_for_plotting(self, df_column_name):
        """
        Prepares a single dataframe suitable for plotting from multiple sensor dataframes.
        :param df_column_name: The name of the column to use for the value in the plot.
        """
        # Add a 'sensor_name' column to each DataFrame and concatenate them into a single DataFrame
        frames = []
        for df, name in zip(self.data_frames, self.sensor_names):
            df = df.copy()  # Make a copy to avoid modifying the original DataFrame
            df['sensor_name'] = name
            df['value'] = df[df_column_name]
            frames.append(df)
        # Concatenate all Dataframes into a single DataFrame
        self.combined_df = pd.concat(frames, ignore_index=True)
        
        #filter the self.combined_df to only include responsive ROIs when is_responsive is True based on if the entry is True or False
        self.combined_df = self.combined_df[self.combined_df['is_responsive'] == True]
        
        # Ensure the value column is present and has the correct data type for plotting
        if df_column_name not in self.combined_df.columns:
            raise ValueError(f"The column '{df_column_name}' does not exist in the DataFrame.") # Raise an error if the column does not exist
        self.combined_df[df_column_name] = pd.to_numeric(self.combined_df[df_column_name], errors='coerce') # Convert to numeric where eoors are coerced which means invalid parsing will be set as NaN

    
    def plot_data(self, df_column_name, selected_stim_ids, box_width=.8, strip_size=3, fig_size=(12, 8), dpi=300):
        """
        Plots the data using boxplot and stripplot for selected stimulation IDs.
        :param df_column_name: The name of the column to use for the value in the plot.
        :param selected_stim_ids: List of stimulation IDs to plot. If None, plot all.
        :param box_width: The width of the boxplots.
        :param strip_size: The size of the points in the stripplots.
        :param fig_size: Tuple representing the figure size (width, height) in inches.
        :param dpi: The resolution in dots per inch.
        """
        
        ### check varuable types and raise errors if necessary ###
        ##########################################################
       
        df_column_name = str(df_column_name)
        
        if selected_stim_ids is not None:
            if not isinstance(selected_stim_ids, list):
                raise ValueError("The selected_stim_ids parameter must be a list of stimulation IDs.")
            if not all(isinstance(stim_id, int) for stim_id in selected_stim_ids):
                raise ValueError("All elements in the selected_stim_ids list must be integers.")
            
        if not isinstance(box_width, (int, float)):
            raise ValueError("The box_width parameter must be an integer or float.")
        if not isinstance(strip_size, (int, float)):
            raise ValueError("The strip_size parameter must be an integer or float.")
        if not isinstance(fig_size, tuple) or len(fig_size) != 2:
            raise ValueError("The fig_size parameter must be a tuple of two integers.")
        if not all(isinstance(val, (int, float)) for val in fig_size):
            raise ValueError("The fig_size parameter must contain only integers or floats.")
        if not isinstance(dpi, int):
            raise ValueError("The dpi parameter must be an integer.")
        
    
        
        ### import, and filter the combined_df for the selected stimulation IDs if provided###
        #####################################################################################
       
        #check if the combined_df is None and if so, call the prepare_for_plotting method 
        if self.combined_df is None:
            self.prepare_for_plotting(df_column_name)
       
            
        # Filter the combined DataFrame for the selected stimulation IDs if provided
        if selected_stim_ids is not None:
            self.combined_df = self.combined_df[self.combined_df['stimulation_id'].isin(selected_stim_ids)]
            
        # Raise an error if no data is available after filtering
        if self.combined_df.empty:
            raise ValueError("No data available for the selected stimulation IDs.")
        

        
        ### set up the boxplot and stripplot properties ###
        ####################################################
        
        # Boxplot properties will have black edges and lines
        boxprops = {'edgecolor': 'k', 'linewidth': 1.5}
        lineprops = {'color': 'k', 'linewidth': 1.5}
        
        boxplot_kwargs = {
            'boxprops': boxprops, 'medianprops': lineprops,
            'whiskerprops': lineprops, 'capprops': lineprops,
            'width': box_width, 'palette': self.sensor_box_colors,
            'hue_order': self.sensor_names
        }

        # Stripplot properties
        stripplot_kwargs = {
            'linewidth': 0.1, 'size': strip_size, 'alpha': 0.3,
            'palette': self.sensor_strip_colors, 'hue_order': self.sensor_names
        }
        
        # Plotting with specified figure size and resolution
        plt.figure(figsize=fig_size, dpi=dpi)
        ax = plt.subplot()

        sns.stripplot(
            x='stimulation_id', y=df_column_name, hue='sensor_name',
            data=self.combined_df, ax=ax, jitter=0.3, dodge=True,
            **stripplot_kwargs
        )
        
        ## error bars on the boxplot are the 95% confidence interval
        sns.boxplot(
            x='stimulation_id', y=df_column_name, hue='sensor_name',
            data=self.combined_df, ax=ax, fliersize=0,
            **boxplot_kwargs
        )
        
        ### set the axis labels and ticks properties ### 
        ################################################
            
        # Set the font size of the x-axis and y-axis labels 
        ax.set_xlabel('Stimulation ID', fontsize=18)
        ax.set_ylabel(df_column_name, fontsize=18)
        
        # Set the font size of the x-axis and y-axis ticks
        ax.tick_params(axis='x', labelsize=24)
        ax.tick_params(axis='y', labelsize=24)
        ax.legend_.remove()
        
        #set the font to arial and the font size to 24
        plt.rcParams['font.sans-serif'] = 'Arial'
        plt.tight_layout()  # Adjust layout to fit legend#
        plt.show()

    def plot_mean_with_error(self, df_column_name, error_type='SEM', selected_stim_ids=None, xlim=None, ylim= None, fig_size=(8, 6), dpi=300):
        """
        Plots the mean values with error bars for selected stimulation IDs across sensors, using consistent colors.
        :param df_column_name: The name of the column to use for the value in the plot.
        :param error_type: The type of error to display ('SD' for Standard Deviation or 'SEM' for Standard Error of the Mean).
        :param selected_stim_ids: List of stimulation IDs to plot. If None, plot all.
        :param fig_size: Tuple representing the figure size (width, height) in inches.
        :param dpi: The resolution in dots per inch.
        """
        if self.combined_df is None:
            self.prepare_for_plotting(df_column_name)
        
        # Filter for selected stimulation IDs if provided
        if selected_stim_ids is not None:
            plot_df = self.combined_df[self.combined_df['stimulation_id'].isin(selected_stim_ids)]
        else:
            plot_df = self.combined_df

        plt.figure(figsize=fig_size, dpi=dpi)
        ax = plt.subplot()
        
        

        # Plot the mean with error bars for each sensor
        for sensor_name in self.sensor_names:
            sensor_data = plot_df[plot_df['sensor_name'] == sensor_name]
            means = sensor_data.groupby('stimulation_id')[df_column_name].mean()
            errors = sensor_data.groupby('stimulation_id')[df_column_name].std() if error_type == 'SD' else sensor_data.groupby('stimulation_id')[df_column_name].sem()
            
            ax.errorbar(means.index, means, yerr=errors, label=sensor_name,
                        fmt='-o', capsize=5, color=self.sensor_box_colors[sensor_name])
        
        # Set the font size of the x-axis and y-axis labels 
        ax.set_xlabel('Stimulation ID', fontsize=18)
        ax.set_ylabel(df_column_name, fontsize=18)
        
        # Set the font size of the x-axis and y-axis ticks
        ax.tick_params(axis='x', labelsize=24)
        ax.tick_params(axis='y', labelsize=24)
        
        ax.set_title('Mean ' + df_column_name + ' by Stimulation ID across Sensors', fontsize=14)
        ax.legend(title='Sensor', loc='upper left')
                # Set custom axis limits if provided
        if xlim:
            plt.xlim(xlim)
        if ylim:
            plt.ylim(ylim)
            
        plt.tight_layout()
        plt.show()
                
    def plot_time_series(self, full_array_column, selected_stim_ids=None, fig_size=(10, 8), dpi=300):
        """
        Plots time series data for selected stimulation IDs for each sensor separately.
        :param full_array_column: The name of the column with time series data.
        :param selected_stim_ids: List of stimulation IDs to plot. If None, plot all available.
        :param fig_size: Tuple representing the figure size of each subplot.
        :param dpi: The resolution in dots per inch.
        """
        if self.combined_df is None:
            raise ValueError("Data has not been prepared for plotting. Call prepare_for_plotting first.")
        
        # Get unique stimulation IDs to plot
        stim_ids = selected_stim_ids if selected_stim_ids is not None else self.combined_df['stimulation_id'].unique()
        num_stim_ids = len(stim_ids)
        num_sensors = len(self.sensor_names)
        
        # Create a figure with subplots for each sensor and stimulation ID
        fig, axes = plt.subplots(num_sensors, num_stim_ids, figsize=(fig_size[0] * num_stim_ids, fig_size[1] * num_sensors), dpi=dpi, sharey=True)

        if num_sensors == 1 or num_stim_ids == 1:  # If there's only one sensor or one stim ID, axes will not be a 2D array
            axes = np.array(axes).reshape(num_sensors, -1)

        for row_idx, sensor_name in enumerate(self.sensor_names):
            for col_idx, stim_id in enumerate(stim_ids):
                ax = axes[row_idx, col_idx]
                sensor_stim_data = self.combined_df[(self.combined_df['sensor_name'] == sensor_name) & (self.combined_df['stimulation_id'] == stim_id)]
                
                if not sensor_stim_data.empty:
                    sample_size =  len(sensor_stim_data.iloc[0][full_array_column])
                    time_vector = (np.arange(sample_size) - 10) * 100  # Adjust time_vector for 100 ms intervals

                    # Plot all individual responses in grey using the time_vector
                    for _, row in sensor_stim_data.iterrows():
                        ax.plot(time_vector, row[full_array_column], color='gainsboro', alpha=0.3)

                    # Calculate and plot the median response in the sensor's color
                    median_response = np.nanmedian([row[full_array_column] for _, row in sensor_stim_data.iterrows()], axis=0)
                    ax.plot(time_vector, median_response, color=self.sensor_box_colors[sensor_name])
                    

                ax.set_title(f'Stim ID {stim_id} - {sensor_name}', fontsize=14)
                ax.set_xlabel('Time (ms)', fontsize=18)
                if col_idx == 0:
                    ax.set_ylabel('ΔF/F', fontsize=18)
                ax.tick_params(axis='x', labelsize=24)
                ax.tick_params(axis='y', labelsize=24)
  

        plt.tight_layout()
        plt.rcParams['font.sans-serif'] = 'Arial'
        plt.show()
        
    def exp_decay(self, t, A, tau, C):
        """
        Exponential decay function used for fitting the photobleaching.
        :param t: Time variable.
        :param A: Amplitude.
        :param tau: Decay constant.
        :param C: Offset.
        """
        return A * np.exp(-t / tau) + C

    def correct_photobleaching(self, full_array_column, fit_start):
        """
        Applies exponential decay fitting to correct for photobleaching in the time series data.
        :param full_array_column: The column name in the DataFrame containing the time series data.
        :param fit_start: The index where to start fitting the exponential decay, typically after the response decay.
        """
        corrected_column = f"{full_array_column}_corrected"
        self.combined_df[corrected_column] = np.nan

        for i, row in self.combined_df.iterrows():
            time_series = np.array(row[full_array_column])
            t = np.arange(len(time_series))

            # Fit the exponential decay function to the photobleaching portion of the data
            try:
                # Initial parameters: A close to the average of the fit range, tau as a reasonable guess, C close to min of the fit range
                p0 = [np.mean(time_series[fit_start:]), 100, np.min(time_series[fit_start:])]
                params, _ = curve_fit(self.exp_decay, t[fit_start:], time_series[fit_start:], p0=p0)

                # Generate the bleach correction curve
                bleach_correction = self.exp_decay(t, *params)

                # Correct the time series data
                corrected_time_series = time_series / bleach_correction

                # Store the corrected time series in the DataFrame
                self.combined_df.at[i, corrected_column] = list(corrected_time_series)
            except RuntimeError as e:
                print(f"Error fitting photobleaching: {e}")
                # If fitting fails, keep the original data
                self.combined_df.at[i, corrected_column] = row[full_array_column]

        return self.combined_df

###################################################################
#compare all three sensors
###################################################################
# Sensor names
sensor_names = ['CaBLAM', 'CaBLAM1x', 'GCaMP8s']

# Dictionaries for sensor colors (boxplot and stripplot)
sensor_box_colors = {
    'CaBLAM': '#ffcccc',   # Light red
    'CaBLAM1x': '#ccccff', # Light blue
    'GCaMP8s': '#d3d3d3'   # Light grey
}

sensor_strip_colors = {
    'CaBLAM': '#ff0000',   # Dark red
    'CaBLAM1x': '#0000ff', # Dark blue
    'GCaMP8s': '#808080'   # Dark grey
}

# Initialize the SensorDataPlotter object
allthree_plotter = SensorDataPlotter(
    data_frames=[responsiveness_df_cablam, responsiveness_df_cablam1x, responsiveness_df_gcamp8],
    sensor_names=sensor_names,
    sensor_box_colors=sensor_box_colors,
    sensor_strip_colors=sensor_strip_colors
)
# Prepare the data and plot
allthree_plotter.plot_data('peak_delta_f_f_post_stim', selected_stim_ids=[12, 36, 60, 120, 480])
allthree_plotter.plot_data('peak_delta_f_f_post_stim', selected_stim_ids=[12])


# Prepare the data and plot
allthree_plotter.plot_data('post_stim_mean', selected_stim_ids=[12, 36, 60, 120, 480])
allthree_plotter.plot_data('time_to_peak', selected_stim_ids=[12, 36, 60, 120, 480])
#allthree_plotter.plot_data('half_decay_time', selected_stim_ids=[12, 36, 60, 120, 480])

allthree_plotter.plot_mean_with_error('peak_delta_f_f_post_stim', error_type='SEM')

###################################################################
#compare CaBLAM and CaBLAM1x
##################################################################
cablamvscablam1x_plotter = SensorDataPlotter(
    data_frames=[responsiveness_df_cablam, responsiveness_df_cablam1x],
    sensor_names=['CaBLAM', 'CaBLAM1x'],
    sensor_box_colors={
        'CaBLAM': '#ffcccc',   # Light red
        'CaBLAM1x': '#ccccff'  # Light blue
    },
    sensor_strip_colors={
        'CaBLAM': '#ff0000',   # Dark red
        'CaBLAM1x': '#0000ff'  # Dark blue
    }
)

cablamvscablam1x_plotter.plot_data('peak_delta_f_f_post_stim', selected_stim_ids=[12, 36, 60, 120, 480])
cablamvscablam1x_plotter.plot_data('time_to_peak', selected_stim_ids=[12, 36, 60, 120, 480])
cablamvscablam1x_plotter.plot_mean_with_error('peak_delta_f_f_post_stim', error_type='SEM', selected_stim_ids=[12, 36, 60, 120, 480], ylim=None)
cablamvscablam1x_plotter.plot_time_series('delta_f_f_full_array', selected_stim_ids=[12, 36, 60, 120, 480])



###################################################################
#compare CaBLAM and GCaMP8s
##################################################################

# Sensor names
sensor_names_cablamvsgcamp = ['CaBLAM', 'GCaMP8s']

# Dictionaries for sensor colors (boxplot and stripplot)
sensor_box_colors2 = {
    'CaBLAM': '#ccccff',   # Light blue
    'GCaMP8s': '#d3d3d3'   # Light grey
}

sensor_strip_colors2 = {
    'CaBLAM': '#0000ff',   # Dark blue
    'GCaMP8s': '#808080'   # Dark grey
}
# Initialize the SensorDataPlotter object
cablamvsgcamp_plotter = SensorDataPlotter(
    data_frames=[responsiveness_df_cablam, responsiveness_df_gcamp8],
    sensor_names=sensor_names_cablamvsgcamp,
    sensor_box_colors=sensor_box_colors2,
    sensor_strip_colors=sensor_strip_colors2
) 

cablamvsgcamp_plotter.plot_data('peak_delta_f_f_post_stim',selected_stim_ids=[12, 36, 60, 120, 480])
cablamvsgcamp_plotter.plot_data('time_to_peak',selected_stim_ids=[12, 36, 60, 120, 480])

cablamvsgcamp_plotter.plot_mean_with_error('peak_delta_f_f_post_stim', error_type='SEM' ,selected_stim_ids=[12, 36, 60, 120, 480], ylim=(0, 1.5))
cablamvsgcamp_plotter.plot_time_series('delta_f_f_full_array', selected_stim_ids=[12, 36, 60, 120, 480])



###################################################################
# analyze for only GCaMP8s
##################################################################

# Sensor names
sensor_names_gcamp = ['GCaMP8s']

# Dictionaries for sensor colors (boxplot and stripplot)
sensor_box_colors_gcamp = {
    'GCaMP8s': '#d3d3d3'   # Light grey
}

sensor_strip_colors_gcamp = {
    'GCaMP8s': '#808080'   # Dark grey
}

# Initialize the SensorDataPlotter object
gcamp_plotter = SensorDataPlotter(
    data_frames=[responsiveness_df_gcamp8],
    sensor_names=sensor_names_gcamp,
    sensor_box_colors=sensor_box_colors_gcamp,
    sensor_strip_colors=sensor_strip_colors_gcamp
)

gcamp_plotter.plot_data('peak_delta_f_f_post_stim', selected_stim_ids=[12, 36, 60, 120, 480])
gcamp_plotter.plot_data('time_to_peak', selected_stim_ids=[12, 36, 60, 120, 480])
gcamp_plotter.plot_mean_with_error('peak_delta_f_f_post_stim', error_type='SEM', selected_stim_ids=[12, 36, 60, 120, 480], ylim=None)
gcamp_plotter.plot_time_series('delta_f_f_full_array', selected_stim_ids=[12, 36, 60, 120, 480])
gcamp_plotter.plot




In [None]:
def get_n_per_sensor_stim_from_plotter(plotter, metric, stim_ids):
    """
    Returns the number of neurons used in the plot_data boxplots per sensor and stimulus,
    based on plotter.combined_df (which includes only responsive neurons).
    """
    # Ensure combined_df is reset
    plotter.prepare_for_plotting(metric)

    df = plotter.combined_df
    df = df[df['stimulation_id'].isin(stim_ids)]

    summary = (
        df.groupby(['sensor_name', 'stimulation_id'])
          .size()
          .reset_index(name='n_responsive_neurons')
          .sort_values(['sensor_name', 'stimulation_id'])
    )
    return summary

stim_ids = [12, 120]
n_summary = get_n_per_sensor_stim_from_plotter(allthree_plotter, 'peak_delta_f_f_post_stim', stim_ids)
print(n_summary)

In [None]:
def report_sample_sizes_from_plotter(plotter, df_column_name, selected_stim_ids=None, auto_reset=True):
    if auto_reset:
        plotter.prepare_for_plotting(df_column_name)

    df = plotter.combined_df.copy()

    if selected_stim_ids is not None:
        df = df[df['stimulation_id'].isin(selected_stim_ids)]

    summary_df = (
        df.groupby(['sensor_name', 'stimulation_id'])
          .size()
          .reset_index(name='n_responsive_rois')
          .sort_values(by=['sensor_name', 'stimulation_id'])
    )

    print("\nSample size report (responsive ROIs used per group):")
    for _, row in summary_df.iterrows():
        print(f"  {row['sensor_name']} — Stim {row['stimulation_id']}: n = {row['n_responsive_rois']}")

    return summary_df

report_sample_sizes_from_plotter(cablamvsgcamp_plotter, 'peak_delta_f_f_post_stim', selected_stim_ids=[12, 60, 480])

In [None]:
report_sample_sizes_from_plotter(cablamvscablam1x_plotter, 'peak_delta_f_f_post_stim', selected_stim_ids=[12, 36, 60, 120, 480])

In [None]:


def plot_stim_responsiveness(df, stim_ids=None, include='responsive', y_lim=None, x_lim=None, mean_color='black', figsize=(15, 5)):
    """
    Plots the delta F/F response for given stimulation IDs, filtering based on responsiveness if specified.
    Individual replicates are plotted in light grey, while the mean response is plotted in a user-defined color.
    Adds a red dotted line at the stimulation onset, considering the user-defined x-axis limits.
    User can define the y-axis limits, x-axis limits, and the figure size.

    Parameters:
    - df (pd.DataFrame): DataFrame containing the responsiveness data.
    - stim_ids (list): List of stimulation IDs to plot. If None, all unique IDs in the DataFrame will be used.
    - include (str): Can be 'responsive', 'non-responsive', or 'both' to filter units based on responsiveness.
    - y_lim (tuple): A tuple of (min, max) for y-axis limits. If None, limits are automatically determined.
    - x_lim (tuple): A tuple of (min, max) for x-axis limits. If None, defaults to the entire range of the data.
    - mean_color (str): Color for the mean response line.
    - figsize (tuple): Figure dimension as (width, height).

    Returns:
    - fig (plt.Figure): The created figure.
    """
    # Filter based on responsiveness if required
    if include == 'responsive':
        df = df[df['is_responsive'] == True]
    elif include == 'non-responsive':
        df = df[df['is_responsive'] == False]

    # If stim_ids is not provided, get the unique IDs from the DataFrame and sort them
    if stim_ids is None:
        stim_ids = sorted(df['stimulation_id'].unique())
    else:
        stim_ids = sorted(stim_ids)

    # Create a figure and axes with subplots
    n_stims = len(stim_ids)
    fig, axes = plt.subplots(1, n_stims, figsize=figsize, sharey=True)

    # Adjust if we only have one subplot to make sure 'axes' is iterable
    if n_stims == 1:
        axes = [axes]

    # Set the y-axis limit if specified
    if y_lim:
        plt.setp(axes, ylim=y_lim)

    # Adjust the x-axis to align with the pre-stimulus, stimulus onset, and post-stimulus periods
    stim_index = 10  # Index at which stimulation occurs
    total_frames = 111  # Total number of frames, including pre-stim, stim, and post-stim

    # Iterate through each sorted stimulation ID and plot
    for ax, stim_id in zip(axes, stim_ids):
        stim_df = df[df['stimulation_id'] == stim_id]

        # Assuming 'delta_f_f_full_array' contains lists, we will need to extract them
        delta_f_f_values = np.vstack(stim_df['delta_f_f_full_array'].values)

        # Adjust the time vector to account for the stimulation index
        time_vector = np.arange(-stim_index, total_frames - stim_index)

        # Plot individual replicates in light grey
        for trace in delta_f_f_values:
            ax.plot(time_vector, trace, color='lightgrey', linewidth=0.5)

        # Calculate mean response
        mean_response = np.nanmean(delta_f_f_values, axis=0)

        # Plot the mean response in user-defined color
        ax.plot(time_vector, mean_response, color=mean_color, label=f'Stim ID {stim_id}')

        # Set the x-axis limit if specified
        if x_lim:
            ax.set_xlim(x_lim)

        # Add a vertical line at stimulation onset if within the x_lim range
        if not x_lim or (x_lim and x_lim[0] <= 0 <= x_lim[1]):
            ax.axvline(x=0, color='red', linestyle='--', label='Stimulation Onset')

        ax.set_title(f'Stim ID {stim_id}')
        ax.set_xlabel('Time (relative to stimulus)')
        ax.set_ylabel('ΔF/F')
        ax.legend()
        #hide legend
        ax.get_legend().remove()
        
        

    # To prevent x-axis labels from overlapping
    plt.tight_layout()

    return fig




# Usage example:
# Assuming 'responsiveness_df_cablam' is your DataFrame
# Usage example with user-defined x-axis limits:
plot_fig = plot_stim_responsiveness(
    df=responsiveness_df_cablam,
    stim_ids = [12, 24, 36, 60, 120, 480],
    include='responsive',
   
    x_lim=(-15,105),  # Set x-axis limits as needed
    mean_color='blue', #what other color optuions are there?
    figsize=(20, 6)
)

# Usage example:
# Assuming 'responsiveness_df_cablam' is your DataFrame
# Usage example with user-defined x-axis limits:
plot_fig = plot_stim_responsiveness(
    df=responsiveness_df_cablam,
    stim_ids = [12, 24, 36, 60, 120, 480],
    include='responsive',
  
    x_lim=(-15,10),  # Set x-axis limits as needed
    mean_color='blue',
    figsize=(20, 6)#what other color optuions are there?
)    
plot_fig = plot_stim_responsiveness(
    df=responsiveness_df_gcamp8,
    stim_ids = [12, 24, 36, 60, 120, 480],
    include='responsive',
    
    x_lim=(-15,10),  # Set x-axis limits as needed
    mean_color='blue', #what other color optuions are there?
    figsize=(20, 6)
)

plt.show()


In [None]:
def plot_stim_responsiveness(df, stim_ids=None, include='both', y_lim=None, x_lim=None, mean_color='black', figsize=(15, 5)):
    """
    Plots the delta F/F response for given stimulation IDs, filtering based on responsiveness if specified.
    Individual replicates are plotted in light grey, while the mean response is plotted in a user-defined color.
    Adds a red dotted line at the stimulation onset, considering the user-defined x-axis limits.
    Prints the number of responsive and unresponsive units for each stimulus ID on the plot.
    User can define the y-axis limits, x-axis limits, and the figure size.

    Parameters:
    - df (pd.DataFrame): DataFrame containing the responsiveness data.
    - stim_ids (list): List of stimulation IDs to plot. If None, all unique IDs in the DataFrame will be used.
    - include (str): Can be 'responsive', 'non-responsive', or 'both' to filter units based on responsiveness.
    - y_lim (tuple): A tuple of (min, max) for y-axis limits. If None, limits are automatically determined.
    - x_lim (tuple): A tuple of (min, max) for x-axis limits. If None, defaults to the entire range of the data.
    - mean_color (str): Color for the mean response line.
    - figsize (tuple): Figure dimension as (width, height).

    Returns:
    - fig (plt.Figure): The created figure.
    """
    # Prepare responsive and non-responsive DataFrames if needed
    df_responsive = df[df['is_responsive'] == True] if include in ['responsive', 'both'] else pd.DataFrame()
    df_unresponsive = df[df['is_responsive'] == False] if include in ['non-responsive', 'both'] else pd.DataFrame()

    # If stim_ids is not provided, get the unique IDs from the DataFrame and sort them
    if stim_ids is None:
        stim_ids = sorted(df['stimulation_id'].unique())
    else:
        stim_ids = sorted(stim_ids)

    # Create a figure and axes with subplots
    n_stims = len(stim_ids)
    fig, axes = plt.subplots(1, n_stims, figsize=figsize, sharey=True)

    # Adjust if we only have one subplot to make sure 'axes' is iterable
    if n_stims == 1:
        axes = [axes]

    # Set the y-axis limit if specified
    if y_lim:
        plt.setp(axes, ylim=y_lim)

    # Set the x-axis limit if specified
    if x_lim:
        plt.setp(axes, xlim=x_lim)

    # Adjust the x-axis to align with the pre-stimulus, stimulus onset, and post-stimulus periods
    stim_index = 10  # Index at which stimulation occurs
    total_frames = 111  # Total number of frames, including pre-stim, stim, and post-stim

    # Iterate through each sorted stimulation ID and plot
    for ax, stim_id in zip(axes, stim_ids):
        if include in ['responsive', 'both']:
            stim_df = df_responsive[df_responsive['stimulation_id'] == stim_id]
            delta_f_f_values = np.vstack(stim_df['delta_f_f_full_array'].values)
            
        if include in ['non-responsive', 'both']:
            stim_df_nonres = df_unresponsive[df_unresponsive['stimulation_id'] == stim_id]
            delta_f_f_values_nonres = np.vstack(stim_df_nonres['delta_f_f_full_array'].values) if not stim_df_nonres.empty else np.array([])
            
        # Adjust the time vector to account for the stimulation index
        time_vector = np.arange(-stim_index, total_frames - stim_index)
        
        # Plot individual replicates in light grey
        for trace in delta_f_f_values_nonres:
            ax.plot(time_vector, trace, color='lightgrey', linewidth=0.5, alpha=0.5)
        for trace in delta_f_f_values:
            ax.plot(time_vector, trace, color='lightgrey', linewidth=0.5)

        # Calculate mean response
        mean_response = np.nanmean(delta_f_f_values, axis=0)

        # Plot the mean response in user-defined color
        ax.plot(time_vector, mean_response, color=mean_color, label=f'Stim ID {stim_id}')

        # Set the x-axis limit if specified
        if x_lim:
            ax.set_xlim(x_lim)

        # Add a vertical line at stimulation onset if within the x_lim range
        if not x_lim or (x_lim and x_lim[0] <= 0 <= x_lim[1]):
            ax.axvline(x=0, color='red', linestyle='--', label='Stimulation Onset')

        # Print the counts on the plot
        ax.text(0.95, 0.95, f'Responsive: {len(delta_f_f_values)}\nUnresponsive: {len(delta_f_f_values_nonres)}',
                verticalalignment='top', horizontalalignment='right',
                transform=ax.transAxes, color='black', fontsize=8, bbox=dict(boxstyle="round,pad=0.3", edgecolor='black', facecolor='white', alpha=0.5))

        ax.set_title(f'Stim ID {stim_id}')
        ax.set_xlabel('Time (relative to stimulus)')
        ax.set_ylabel('ΔF/F')
        ax.legend().remove()

    # To prevent x-axis labels from overlapping
    plt.tight_layout()

    return fig
# Usage example:
# Assuming 'responsiveness_df_cablam' is your DataFrame
plot_fig = plot_stim_responsiveness(
    df=responsiveness_df_cablam,
    stim_ids = [12, 24, 36, 60, 120, 480],
    include='both',
    y_lim=None,
    x_lim=(-10, 100),
    mean_color='red',
    figsize=(20, 6)
)

# Assuming 'responsiveness_df_cablam' is your DataFrame
plot_fig = plot_stim_responsiveness(
    df=responsiveness_df_gcamp8,
    include='both',
    y_lim=None,
    x_lim=(-10, 100),
    mean_color='black',
    figsize=(20, 6)
)
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_stim_responsiveness(df, stim_ids=None, include='both', y_lim=None, x_lim=None, mean_color='black', figsize=(15, 5)):
    """
    Plots the delta F/F response for given stimulation IDs, filtering based on responsiveness if specified.
    Individual replicates are plotted in light grey, while the mean response is plotted in a user-defined color.
    Adds a red dotted line at the stimulation onset, considering the user-defined x-axis limits.
    Prints the number of responsive and unresponsive units for each stimulus ID on the plot.
    User can define the y-axis limits, x-axis limits, and the figure size.

    Parameters:
    - df (pd.DataFrame): DataFrame containing the responsiveness data.
    - stim_ids (list): List of stimulation IDs to plot. If None, all unique IDs in the DataFrame will be used.
    - include (str): Can be 'responsive', 'non-responsive', or 'both' to filter units based on responsiveness.
    - y_lim (tuple): A tuple of (min, max) for y-axis limits. If None, limits are automatically determined.
    - x_lim (tuple): A tuple of (min, max) for x-axis limits. If None, defaults to the entire range of the data.
    - mean_color (str): Color for the mean response line.
    - figsize (tuple): Figure dimension as (width, height).

    Returns:
    - fig (plt.Figure): The created figure.
    """

    # If stim_ids is not provided, get the unique IDs from the DataFrame and sort them
    if stim_ids is None:
        stim_ids = sorted(df['stimulation_id'].unique())
    else:
        stim_ids = sorted(stim_ids)
    
    # Adjust the x-axis to align with the pre-stimulus, stimulus onset, and post-stimulus periods
    stim_index = 9  # Index at which stimulation occurs
    total_frames = 111  # Total number of frames, including pre-stim, stim, and post-stim
    sampling_interval = 100  # Time per index in ms at 10Hz sampling rate

    # Create a figure and axes with subplots
    n_stims = len(stim_ids)
    fig, axes = plt.subplots(1, n_stims, figsize=figsize, sharey=True)

    # Adjust if we only have one subplot to make sure 'axes' is iterable
    if n_stims == 1:
        axes = [axes]

    # Set the y-axis limit if specified
    if y_lim:
        plt.setp(axes, ylim=y_lim)

    # Set the x-axis limit if specified
    if x_lim is not None:
        new_x_lim = (x_lim[0] * sampling_interval, x_lim[1] * sampling_interval)
        plt.setp(axes, xlim=new_x_lim)



    for ax, stim_id in zip(axes, stim_ids):
        # Filter the DataFrame based on the current stim_id
        stim_df = df[df['stimulation_id'] == stim_id]
        if include != 'both':
            stim_df = stim_df[stim_df['is_responsive'] == (include == 'responsive')]

        # Get the delta_f_f_full_array values for plotting
        delta_f_f_values = np.vstack(stim_df['delta_f_f_full_array'].values)

        # Calculate the time vector considering the stimulation index
        time_vector = np.arange(-stim_index, total_frames - stim_index) * sampling_interval

        # Plot individual replicates in light grey
        for trace in delta_f_f_values:
            ax.plot(time_vector, trace, color='lightgrey', linewidth=0.5)

        # Calculate mean response and plot in the specified mean_color
        mean_response = np.nanmedian(delta_f_f_values, axis=0)
        ax.plot(time_vector, mean_response, color=mean_color, label=f'Stim ID {stim_id}')

        # Add vertical line at stimulation onset if it's within the x-axis limits
        if x_lim is None or (0 >= x_lim[0] and 0 <= x_lim[1]):
            ax.axvline(x=0, color='red', linestyle='--', label='Stimulation Onset')

        # Count and display the number of responsive and unresponsive units for this stim_id
        num_responsive = len(stim_df[stim_df['is_responsive'] == True])
        num_unresponsive = len(stim_df[stim_df['is_responsive'] == False])
        info_text = f'Responsive: {num_responsive}'
        ax.text(0.95, 0.95, info_text, transform=ax.transAxes, fontsize=9,
                verticalalignment='top', horizontalalignment='right',
                bbox=dict(facecolor='white', alpha=0.5, edgecolor='black', boxstyle='round'))

        # Set titles and labels
        ax.set_title(f'Stim ID {stim_id}')
        ax.set_xlabel('ms', fontsize=24)
        ax.set_ylabel('ΔF/F$_o$ (%)', fontsize=24)
        #make y-axis labels larger
        ax.tick_params(axis='y', labelsize=18)
        ax.tick_params(axis='x', labelsize=18)
        ax.legend().remove()

    # To prevent x-axis labels from overlapping
    plt.tight_layout()

    return fig

plot_stim_responsiveness(
    df=responsiveness_df_cablam,
    include='responsive',
    y_lim=None,
    x_lim=(-10, 100),
    mean_color='red',
    figsize=(20, 6)
)

plot_stim_responsiveness(
    df=responsiveness_df_cablam1x,
    include='responsive',
    y_lim=None,
    x_lim=(-10, 100),
    mean_color='blue',
    figsize=(20, 6)
)

plot_stim_responsiveness(
    df=responsiveness_df_cablam,
    include='responsive',
    stim_ids = [12, 24, 36, 60, 120, 480],
    y_lim=None,
    x_lim=(-10, 100),
    mean_color='red',
    figsize=(20, 6)
)

plot_stim_responsiveness(
    df=responsiveness_df_cablam1x,
    stim_ids = [12, 24, 36, 60, 120, 480],
    include='responsive',
    y_lim=None,
    x_lim=(-10, 100),
    mean_color='blue',
    figsize=(20, 6)
)

plot_stim_responsiveness(
    df=responsiveness_df_gcamp8,
    include='responsive',
    y_lim=None,
    x_lim=(-10, 100),
    mean_color='black',
    figsize=(20, 6)
)

plot_fig = plot_stim_responsiveness(
    df=responsiveness_df_cablam,
    stim_ids = [12, 24, 36, 60, 120, 480],
    include='responsive',
    y_lim=None,
    x_lim=(-10, 10),
    mean_color='red',
    figsize=(20, 6)
)

# Assuming 'responsiveness_df_cablam' is your DataFrame
plot_fig = plot_stim_responsiveness(
    df=responsiveness_df_gcamp8,
    include='responsive',
    y_lim=None,
    x_lim=(-10, 10),
    mean_color='black',
    figsize=(20, 6)
)

# Assuming 'responsiveness_df_cablam' is your DataFrame
plot_fig = plot_stim_responsiveness(
    df=responsiveness_df_gcamp8,
    include='responsive',
    y_lim=None,
    x_lim=(-5, 20),
    mean_color='black',
    figsize=(20, 6)
)
plt.show()

In [None]:
import numpy as np
from scipy.signal import fftconvolve

def calculate_template(dataframe, stimulation_id_col, is_responsive_col, delta_f_f_full_array_col, stimulation_id_val=12):
    # Filter the dataframe to get the relevant data
    template_data = dataframe[(dataframe[stimulation_id_col] == stimulation_id_val) & (dataframe[is_responsive_col] == True)]
    delta_f_f_full_arrays = template_data[delta_f_f_full_array_col].values
    #for to be sure that the data is a numpy and not a pandas array
    delta_f_f_full_arrays = np.array(delta_f_f_full_arrays)
    
    #stack the arrays of delta_f_f_full_arrays
    delta_f_f_full_arrays = np.vstack(delta_f_f_full_arrays) #stack the arrays of delta_f_f_full_arrays by row
    # Calculate the median across each ROI (each row)
    template = np.nanmedian(delta_f_f_full_arrays, axis=0)
    
    #normalize the template by the euclidean norm
    template = template / np.linalg.norm(template)
    
    #print the number of NANs in the template
    print(f'Number of NaNs in the template: {np.sum(np.isnan(template))}')
    
    #only use from the 10th frame to the 50th fram
    template = template[10:14]
    return template


def interpolate_signal(time_series_np):
    """Interpolate missing values in a time series."""
    time_series_pd = pd.Series(time_series_np)
    time_series_interpolated = time_series_pd.interpolate()
    # Forward-fill or back-fill to handle NaNs at the beginning or end
    time_series_interpolated = time_series_interpolated.fillna(method='bfill').fillna(method='ffill')
    return time_series_interpolated.values


def template_matching(session_id, roi_data, template_data):
    session_id = str(session_id)
    session_data = roi_data[session_id]
    results = {}
    results[session_id] = {}
    
    for roi_id, time_series_data in session_data.items():
        time_series_data_np = np.array(time_series_data.values)
        nan_count = np.sum(np.isnan(time_series_data_np))
        
        print(f'Number of NaNs in the time series data for ROI {roi_id}: {nan_count}')
        print(f'Total number of elements in the time series data for ROI {roi_id}: {len(time_series_data_np)}')
        
        # Interpolate NaNs if they are present
        if nan_count > 0:
            time_series_data_np = interpolate_signal(time_series_data_np)
            # Check if the interpolation was successful
            nan_count = np.sum(np.isnan(time_series_data_np))
            if nan_count > 0:
                print(f'Interpolation failed for ROI {roi_id}: Still contains NaNs after interpolation.')
                continue
        
        # Apply matched filtering
        filtered_signal = fftconvolve(time_series_data_np, template_data[::-1], mode='same')
        
        # Threshold calculation may need to handle NaNs if they are still present
        threshold = np.nanstd(filtered_signal) * 3
        
        # Find peaks
        peaks = np.where(filtered_signal > threshold)[0]
        results[session_id][roi_id] = {
            'peaks': peaks, 
            'filtered_signal': filtered_signal, 
            'threshold': threshold, 
            'original_signal': time_series_data_np, 
            'template': template_data
        }
        
    return results

session_id = 2112242023  # Choose a session ID
stimulation_id_col = 'stimulation_id'
is_responsive_col = 'is_responsive'
#delta_f_f_full_array_col = 'delta_f_f_full_array'
delta_f_f_full_array_col = 'raw_signal'


cablam_template = calculate_template(responsiveness_df_cablam, stimulation_id_col, is_responsive_col, delta_f_f_full_array_col)

print(list(cablam_filtered_responsive_rois.keys()))

template_matching_results = template_matching(session_id, cablam_filtered_responsive_rois, cablam_template)


plt.plot(cablam_template)
#plt.plot(cablam_template[::-1])

#plot separately the peaks and the filtered signal on two different subplots
plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
plt.plot(template_matching_results['2112242023']['ROI_10']['original_signal'])
plt.plot(template_matching_results['2112242023']['ROI_10']['peaks'], template_matching_results['2112242023']['ROI_10']['original_signal'][template_matching_results['2112242023']['ROI_10']['peaks']], 'ro')


template_matching_results['2112242023']['ROI_10']['filtered_signal']
plt.plot(template_matching_results['2112242023']['ROI_10']['filtered_signal'])


# Plot original signal and peaks
plt.subplot(2, 1, 1)
plt.title('Original Signal and Detected Peaks')
plt.plot(template_matching_results['2112242023']['ROI_10']['original_signal'], label='Original Signal')
plt.plot(template_matching_results['2112242023']['ROI_10']['peaks'],
         template_matching_results['2112242023']['ROI_10']['original_signal'][template_matching_results['2112242023']['ROI_10']['peaks']],
         'ro', label='Detected Peaks')
plt.legend()

# Plot filtered signal and threshold line
plt.subplot(2, 1, 2)
plt.title('Filtered Signal and Detection Threshold')
plt.plot(template_matching_results['2112242023']['ROI_10']['filtered_signal'], label='Filtered Signal')
plt.axhline(y=template_matching_results['2112242023']['ROI_10']['threshold'], color='r', linestyle='--', label='Detection Threshold')

#zoom in on the peaks
plt.legend()

plt.tight_layout()
plt.show()

###################################################################
# null distribution and significance threshold
##################################################################



def create_null_distribution_and_match(dataframe, norm_template, stimulation_id_col, is_responsive_col, delta_f_f_full_array_col, stimulation_id_val=12, num_iterations=8000):
    """
    Create a null distribution by shuffling time-locked responses and then applying template matching with the actual 1AP template.
    
    - dataframe: DataFrame containing the experimental data.
    - norm_template: The original 1AP template array to use for matching normalized.
    - stimulation_id_col: Column name for stimulation IDs.
    - is_responsive_col: Column name indicating if an ROI is responsive.
    - delta_f_f_full_array_col: Column name where time-locked array data is stored.
    - stimulation_id_val: The specific stimulation ID value to filter by.
    - num_iterations: Number of shuffles to perform for the null distribution.
    """

    
    # Filter the dataframe for responsive ROIs with the specific stimulation ID
    template_data = dataframe[(dataframe[stimulation_id_col] == stimulation_id_val) & (dataframe[is_responsive_col] == True)]
    delta_f_f_full_arrays = template_data[delta_f_f_full_array_col].values
    print("Data type and shape of one sample array:", type(delta_f_f_full_arrays[0]), delta_f_f_full_arrays[0].shape)
    
    # Ensure data is in numpy array format and interpolate to handle NaNs
    delta_f_f_full_arrays = np.array([interpolate_signal(np.array(x)) for x in delta_f_f_full_arrays])
    
    scores_null = []
    for _ in range(num_iterations):
        # Shuffle each time-locked response array
        shuffled_data = [np.random.permutation(array) for array in delta_f_f_full_arrays]
        # Stack the shuffled arrays
        shuffled_stacked = np.vstack(shuffled_data)
        
        # Apply template matching using fftconvolve
        filtered_signal = fftconvolve(shuffled_stacked.mean(axis=0), norm_template[::-1], mode='same')
        max_score = np.max(filtered_signal)
        scores_null.append(max_score)
        
    
    threshold = np.percentile(null_scores, 95)  # 95th percentile as threshold
    
    return scores_null, threshold

def template_matching_vs_null(session_id, roi_data, norm_template, significance_threshold):
    session_id = str(session_id)
    session_data = roi_data[session_id]
    results = {}
    results[session_id] = {}


    for roi_id, time_series_data in session_data.items():
        time_series_data_np = np.array(time_series_data.values)
        nan_count = np.sum(np.isnan(time_series_data_np))
        
        print(f'Number of NaNs in the time series data for ROI {roi_id}: {nan_count}')
        print(f'Total number of elements in the time series data for ROI {roi_id}: {len(time_series_data_np)}')
        
        # Interpolate NaNs if they are present
        if nan_count > 0:
            time_series_data_np = interpolate_signal(time_series_data_np)
            # Check if the interpolation was successful
            nan_count = np.sum(np.isnan(time_series_data_np))
            if nan_count > 0:
                print(f'Interpolation failed for ROI {roi_id}: Still contains NaNs after interpolation.')
                continue

        # Apply matched filtering
        filtered_signal = fftconvolve(time_series_data_np, norm_template[::-1], mode='same')
        
        # Calculate the standard deviation of the filtered signal to set a local threshold
        local_threshold = np.nanstd(filtered_signal) * 3

        # Find peaks using the local threshold
        peaks = np.where(filtered_signal > local_threshold)[0]

        # Evaluate against the significance threshold from null distribution
        max_filtered_value = np.max(filtered_signal)
        is_significant = max_filtered_value > significance_threshold
        
        results[session_id][roi_id] = {
            'peaks': peaks, 
            'filtered_signal': filtered_signal, 
            'local_threshold': local_threshold,
            'max_filtered_value': max_filtered_value,
            'is_significant': is_significant,
            'original_signal': time_series_data_np, 
            'template': norm_template
        }
        
    return results
# 
# 
# Example usage
session_id = 2112242023  # Choose a session ID





# Assuming you have a DataFrame `responsiveness_df` structured correctly
null_scores, threshold = create_null_distribution_and_match(responsiveness_df_cablam, cablam_template, stimulation_id_col = 'stimulation_id', is_responsive_col = 'is_responsive', delta_f_f_full_array_col = 'delta_f_f_full_array', stimulation_id_val=12)

#asses the null distribution and significance threshold


print("Null distribution scores:", null_scores)
threshold = np.percentile(null_scores, 95)  # 95th percentile as threshold
print("Significance Threshold:", threshold)


    



In [None]:
# Call the renamed function
#results = template_matching_vs_null(session_id, cablam_filtered_responsive_rois, cablam_template, significance_threshold)

all_results = {}  # Dictionary to store results from all sessions
session_ids = ['2112242023', '2212242023', '2312242023']
for session_id in session_ids:
    print(f"Processing session ID: {session_id}")
    #create the null distribution and calculate the threshold
    null_scores, threshold = create_null_distribution_and_match(responsiveness_df_cablam, cablam_template, stimulation_id_col = 'stimulation_id', is_responsive_col = 'is_responsive', delta_f_f_full_array_col = 'delta_f_f_full_array', stimulation_id_val=12)
    all_results = template_matching_vs_null(session_id, cablam_filtered_responsive_rois, cablam_template, threshold)
    all_results[session_id] = results

In [None]:
# Iterate through each session and each ROI in the results
for session_id, session_results in all_results.items():
    print(f"Session ID: {session_id}")  # Print the session ID
    for roi_id, data in session_results[session_id].items():
        # Print relevant data for each ROI
        print(f"  ROI {roi_id}:")
        #print(f"    Max Filtered Value = {data['max_filtered_value']}")
        print(f"    Significant = {data['is_significant']}")
        #print(f"    Peaks Detected = {len(data['peaks'])} at positions {data['peaks']}")
        #print()  # Add a newline for better readability between ROIs



In [None]:
def debug_peak_alignment(template_matching_results, session_id, window_radius, num_segments_to_plot=5):
    """
    Plots individual normalized segments of the original signal at detected peaks, 
    aligned with the template for debugging.

    Parameters:
    template_matching_results (dict): The dictionary containing the matching results.
    session_id (str): The session ID to use in the template_matching_results.
    window_radius (int): The number of data points to include on either side of the peak.
    num_segments_to_plot (int): The number of individual segments to plot for debugging.
    """
    
    #convert the session_id to a string
    session_id = str(session_id)
    
    # Extract the relevant data from the results
    session_results = template_matching_results[session_id]
    first_roi_id = next(iter(session_results))
    template = session_results[first_roi_id]['template']
    
    # Normalize the template for comparison
    template_norm = (template - np.min(template)) / (np.max(template) - np.min(template))

    # Set up the plot
    plt.figure(figsize=(10, 6))
    
    # Plot the template
    plt.plot(template_norm, label='Template', color='black', linewidth=2, alpha=0.7)
    
    # Loop through each ROI in the session (we'll just take the first one for debugging)
    roi_id = first_roi_id
    roi_results = session_results[roi_id]
    original_signal = roi_results['original_signal']
    peaks = roi_results['peaks']
    
    # Loop over a few peaks to plot individual segments
    for peak_idx, peak in enumerate(peaks[:num_segments_to_plot]):
        # Make sure we don't go out of bounds
        if peak - window_radius < 0 or peak + window_radius > len(original_signal):
            continue
        
        # Extract and normalize the segment
        segment = original_signal[peak - window_radius:peak + window_radius]
        segment_norm = (segment - np.min(segment)) / (np.max(segment) - np.min(segment))
        
        # Overlay the normalized segment on the template
        plt.plot(np.arange(-window_radius, window_radius), segment_norm, alpha=0.5, label=f'Segment at Peak {peak_idx+1}')
    
    plt.title(f'Individual Detected Events vs Template for ROI {roi_id}')
    plt.xlabel('Time Relative to Peak')
    plt.ylabel('Normalized Signal Amplitude')
    plt.legend()
    plt.show()
    
# Usage of the function:
# Assuming template_matching_results is your results dictionary and '2112242023' is your session ID
debug_peak_alignment(template_matching_results, 2112242023, window_radius=5)



In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_peak_matches_from_results(template_matching_results, session_id, window_radius):
    """
    Plots segments of the original signal at detected peaks alongside the template
    for a specific session from the template_matching_results dictionary.

    Parameters:
    template_matching_results (dict): The dictionary containing the matching results.
    session_id (str): The session ID to use in the template_matching_results.
    window_radius (int): The number of data points to include on either side of the peak.
    """
    # Extract the relevant data from the results
    session_results = template_matching_results[session_id]
    
    # Loop through each ROI in the session
    for roi_id, roi_results in session_results.items():
        original_signal = roi_results['original_signal']
        peaks = roi_results['peaks']
        template = roi_results['template']  # Assuming the template is stored in the results

        # Set up the plot
        plt.figure(figsize=(14, 7))
        plt.title(f'ROI {roi_id} Detected Peaks vs Template')
        
        # Plot the template
        plt.plot(template, label='Template', color='black', linewidth=2)

        # Overlay segments of the original signal centered around the peaks
        for peak in peaks:
            # Make sure we don't go out of bounds
            if peak - window_radius < 0 or peak + window_radius > len(original_signal):
                continue

            # Extract the segment
            segment = original_signal[peak - window_radius:peak + window_radius]
            
            # Normalize the segment for better comparison
            segment = (segment - np.min(segment)) / (np.max(segment) - np.min(segment))
            segment *= np.max(template)  # Scale to match the template amplitude

            # Plot the segment
            plt.plot(range(peak - window_radius, peak + window_radius), segment, alpha=0.5)
        
        plt.xlabel('Time')
        plt.ylabel('Signal Amplitude')
        plt.legend()
        plt.show()

# Usage of the function:
# Assuming template_matching_results is your results dictionary and '2112242023' is your session ID
plot_peak_matches_from_results(template_matching_results, '2112242023', window_radius=50)


In [None]:
def plot_responsive_vs_nonresponsive_histogram(dataframe, metric, bins=30, alpha=0.5, title=None):
    """
    Plots overlaid histograms for responsive and non-responsive units based on a specified metric.

    Parameters:
    dataframe (pd.DataFrame): The DataFrame containing the responsiveness data and the metric to plot.
    metric (str): The name of the column in the DataFrame representing the metric to plot.
    bins (int): The number of bins for the histograms.
    alpha (float): The transparency level for the histogram bars.
    title (str): The title for the plot. If None, a default title will be set.

    Returns:
    None: The function plots the histograms but does not return any value.
    """
    
    # Filter responsive and non-responsive units
    responsive_units = dataframe[dataframe['is_responsive'] == True]
    non_responsive_units = dataframe[dataframe['is_responsive'] == False]
    
    # Plot histograms
    plt.figure(figsize=(10, 6))
    
    # Histogram for responsive units
    plt.hist(responsive_units[metric], bins=bins, alpha=alpha, label='Responsive Trials')

    # Histogram for non-responsive units
    plt.hist(non_responsive_units[metric], bins=bins, alpha=alpha, label='Non-responsive Trials')
    

    plt.xlabel(metric)
    plt.ylabel('Count')
    plt.title(title if title else f'Distribution of {metric} for Responsive vs. Non-responsive Trials')
    
    #add a text that states the total number of responsive and non-responsive units at the top right corner of the plot and the percentage of responsive units relative to the total number of units
    total_units = len(dataframe)
    total_responsive = len(responsive_units)
    total_non_responsive = len(non_responsive_units)
    

    #place the text seomwhere else on the plot that is not the top right corner
    plt.text(0.5, 0.5, f'Total Trials: {total_units}\nResponsive: {total_responsive} ({total_responsive / total_units:.1%})',
                horizontalalignment='center', verticalalignment='center', transform=plt.gca().transAxes)
    
    
    
    plt.legend()
    plt.show()
    
def plot_responsive_vs_nonresponsive_scatter(dataframe, x_metric, y_metric, title=None):
    """
    Plots a scatter plot to compare the relationship between two metrics, differentiating
    between responsive and non-responsive units.

    Parameters:
    dataframe (pd.DataFrame): The DataFrame containing the responsiveness data.
    x_metric (str): The name of the column in the DataFrame for the x-axis.
    y_metric (str): The name of the column in the DataFrame for the y-axis.
    title (str): The title for the plot. If None, a default title will be set.

    Returns:
    None: The function plots the scatter plot but does not return any value.
    """
    
    # Separate responsive and non-responsive units
    responsive_units = dataframe[dataframe['is_responsive'] == True]
    non_responsive_units = dataframe[dataframe['is_responsive'] == False]
    
    # Create the scatter plot
    plt.figure(figsize=(10, 6))
    
    # Scatter plot for responsive units
    plt.scatter(responsive_units[x_metric], responsive_units[y_metric], 
                alpha=0.7, label='Responsive Trials', edgecolors='w')
    
    # Scatter plot for non-responsive units
    plt.scatter(non_responsive_units[x_metric], non_responsive_units[y_metric], 
                alpha=0.7, label='Non-responsive Trials', edgecolors='w')
    
    plt.xlabel(x_metric)
    plt.ylabel(y_metric)
    plt.title(title if title else f'Relationship between {x_metric} and {y_metric}')
    plt.legend()
    plt.grid(True)
    plt.show()




# Example usage:
plot_responsive_vs_nonresponsive_histogram(responsiveness_df_gcamp8, 'post_stim_peak', 30)
plot_responsive_vs_nonresponsive_histogram(responsiveness_df_cablam, 'post_stim_peak', 30)
plot_responsive_vs_nonresponsive_histogram(responsiveness_df_cablam1x, 'post_stim_peak', 30)

plot_responsive_vs_nonresponsive_scatter(responsiveness_df_gcamp8, 'pre_stim_mean', 'post_stim_peak')
plot_responsive_vs_nonresponsive_scatter(responsiveness_df_cablam, 'pre_stim_mean', 'post_stim_peak')
plot_responsive_vs_nonresponsive_scatter(responsiveness_df_cablam1x, 'pre_stim_mean', 'post_stim_peak')

plot_responsive_vs_nonresponsive_scatter(responsiveness_df_gcamp8, 'pre_stim_mean', 'peak_delta_f_f_post_stim')
plot_responsive_vs_nonresponsive_scatter(responsiveness_df_cablam, 'pre_stim_mean', 'peak_delta_f_f_post_stim')
plot_responsive_vs_nonresponsive_scatter(responsiveness_df_cablam1x, 'pre_stim_mean', 'peak_delta_f_f_post_stim')


In [None]:
def plot_aggregated_responsive_vs_nonresponsive_scatter(dataframe, x_metric, y_metric, stimulus_id, title=None):
    """
    Plots a scatter plot to compare the relationship between two metrics for a specific stimulus ID,
    differentiating between responsive and non-responsive units.

    Parameters:
    dataframe (pd.DataFrame): The DataFrame containing the responsiveness data.
    x_metric (str): The name of the column in the DataFrame for the x-axis.
    y_metric (str): The name of the column in the DataFrame for the y-axis.
    stimulus_id (int): The specific stimulus ID to filter the responsiveness data by.
    title (str): The title for the plot. If None, a default title will be set.

    Returns:
    None: The function plots the scatter plot but does not return any value.
    """
    
    # Filter for the specific stimulus ID
    specific_stim_data = dataframe[dataframe['stimulation_id'] == stimulus_id]
    
    # Group by session_id and roi to aggregate data
    grouped = specific_stim_data.groupby(['session_id', 'roi'])
    
    # For each group, determine if the ROI is responsive to the specific stimulus
    aggregated_data = grouped.agg({
        'is_responsive': 'max',  # max will be True if the ROI is responsive at least once
        x_metric: 'mean',  # mean of the pre-stimulus metric
        y_metric: 'mean'   # mean of the post-stimulus metric
    }).reset_index()
    
    # Now we have aggregated data with one row per ROI per session
    # We separate responsive and non-responsive ROIs
    responsive_units = aggregated_data[aggregated_data['is_responsive'] == True]
    non_responsive_units = aggregated_data[aggregated_data['is_responsive'] == False]
    
    # Create the scatter plot
    plt.figure(figsize=(10, 6))
    
    # Scatter plot for responsive units
    plt.scatter(responsive_units[x_metric], responsive_units[y_metric],
                alpha=0.7, label='Responsive Units', edgecolors='w', color='blue')
    
    # Scatter plot for non-responsive units
    plt.scatter(non_responsive_units[x_metric], non_responsive_units[y_metric],
                alpha=0.7, label='Non-responsive Units', edgecolors='w', color='red')
    
    #add linear trendlines to the scatter plot for responsive and non-responsive units
    responsive_units_fit = np.polyfit(responsive_units[x_metric], responsive_units[y_metric], 1)
    non_responsive_units_fit = np.polyfit(non_responsive_units[x_metric], non_responsive_units[y_metric], 1)

    plt.plot(responsive_units[x_metric], responsive_units_fit[0] * responsive_units[x_metric] + responsive_units_fit[1], color='blue', linestyle='--')
    plt.plot(non_responsive_units[x_metric], non_responsive_units_fit[0] * non_responsive_units[x_metric] + non_responsive_units_fit[1], color='red', linestyle='--')

    
    plt.xlabel('Mean ' + x_metric)
    plt.ylabel('Mean ' + y_metric)
    plt.title(title if title else f'Relationship between {x_metric} and {y_metric} for Stimulus ID {stimulus_id}')
    plt.legend()
    plt.grid(True)
    plt.show()
    
    
def plot_responsive_vs_nonresponsive_histogram_for_stimulus(dataframe, metric, stimulus_id, bins=30, alpha=0.5, title=None):
    """
    Plots overlaid histograms for responsive and non-responsive units based on a specified metric and stimulus ID.

    Parameters:
    dataframe (pd.DataFrame): The DataFrame containing the responsiveness data and the metric to plot.
    metric (str): The name of the column in the DataFrame representing the metric to plot.
    stimulus_id (int): The specific stimulus ID to filter the responsiveness data by.
    bins (int): The number of bins for the histograms.
    alpha (float): The transparency level for the histogram bars.
    title (str): The title for the plot. If None, a default title will be set.

    Returns:
    None: The function plots the histograms but does not return any value.
    """
    
    # Filter the DataFrame for the specific stimulus ID
    df_filtered = dataframe[dataframe['stimulation_id'] == stimulus_id]
    
    # Filter responsive and non-responsive units
    responsive_units = df_filtered[df_filtered['is_responsive'] == True]
    non_responsive_units = df_filtered[df_filtered['is_responsive'] == False]
    
    # Plot histograms
    plt.figure(figsize=(10, 6))
    
    # Histogram for responsive units
    plt.hist(responsive_units[metric], bins=bins, alpha=alpha, label='Responsive Units')

    # Histogram for non-responsive units
    plt.hist(non_responsive_units[metric], bins=bins, alpha=alpha, label='Non-responsive Units')

    plt.xlabel(metric)
    plt.ylabel('Count')
    plt.title(title if title else f'Distribution of {metric} for Stimulus ID {stimulus_id}')

    # Adding text for total counts and percentages
    total_units = len(df_filtered)
    total_responsive = len(responsive_units)
    total_non_responsive = len(non_responsive_units)
    plt.text(0.5, 0.5, f'Total Units: {total_units}\nResponsive: {total_responsive} ({total_responsive / total_units:.1%})',
             horizontalalignment='center', verticalalignment='center', transform=plt.gca().transAxes)

    plt.legend()
    plt.show()

# Example usage:
def plot_responsive_vs_nonresponsive_histogram_for_stimulus(dataframe, metric, stimulus_id, bins=30, title=None):
    """
    Plots overlaid histograms with transparent bins and colored outlines for responsive and 
    non-responsive units based on a specified metric and stimulus ID.

    Parameters:
    dataframe (pd.DataFrame): The DataFrame containing the responsiveness data and the metric to plot.
    metric (str): The name of the column in the DataFrame representing the metric to plot.
    stimulus_id (int): The specific stimulus ID to filter the responsiveness data by.
    bins (int or sequence): The number of bins for the histograms or the sequence of bin edges.
    title (str): The title for the plot. If None, a default title will be set.

    Returns:
    None: The function plots the histograms but does not return any value.
    """

    # Filter the DataFrame for the specific stimulus ID
    df_filtered = dataframe[dataframe['stimulation_id'] == stimulus_id]
    
    # Determine the range for the histograms
    data_min = df_filtered[metric].min()
    data_max = df_filtered[metric].max()
    bin_edges = np.linspace(data_min, data_max, bins + 1)  # +1 because bin edges are one more than bin count
    
    # Filter responsive and non-responsive units
    responsive_units = df_filtered[df_filtered['is_responsive'] == True]
    non_responsive_units = df_filtered[df_filtered['is_responsive'] == False]
    
    # Plot histograms with transparent bins and colored outlines
    plt.figure(figsize=(10, 6))
    
    # Histogram for responsive units
    plt.hist(responsive_units[metric], bins=bin_edges, edgecolor='blue', linewidth=1.5, facecolor='none', label='Responsive Units')
    
    # Histogram for non-responsive units
    plt.hist(non_responsive_units[metric], bins=bin_edges, edgecolor='orange', linewidth=1.5, facecolor='none', label='Non-responsive Units')
    
    plt.xlabel(metric)
    plt.ylabel('Count')
    plt.title(title if title else f'Distribution of {metric} for Stimulus ID {stimulus_id}')

    # Adding text for total counts and percentages
    total_units = len(df_filtered)
    total_responsive = len(responsive_units)
    total_non_responsive = len(non_responsive_units)
    plt.text(0.7, 0.85, f'Total Units: {total_units}\nResponsive: {total_responsive}\nNon-responsive: {total_non_responsive}\nPercentage Responsive: {total_responsive / total_units:.2%}',
             horizontalalignment='center', verticalalignment='center', transform=plt.gca().transAxes, fontsize=9)

    plt.legend(loc='upper right')




def plot_histogram_with_complete_outline(dataframe, metric, stimulus_id, bins=30, responsive_color='blue', non_responsive_color='orange', title=None):
    """
    Plots histograms with a complete outline around each group for responsive and non-responsive units.
    
    Parameters:
    dataframe (pd.DataFrame): The DataFrame containing the responsiveness data.
    metric (str): The column name of the metric to be plotted.
    stimulus_id (int): The specific stimulus ID to filter the responsiveness data by.
    bins (int): The number of bins for the histograms.
    responsive_color (str): Color for the outline of the responsive units histogram.
    non_responsive_color (str): Color for the outline of the non-responsive units histogram.
    title (str): The title for the plot.

    Returns:
    None: The function plots the histograms with outlined edges.
    """
    
    # Filter the DataFrame for the specific stimulus ID
    df_filtered = dataframe[dataframe['stimulation_id'] == stimulus_id]

    # Separate responsive and non-responsive units
    responsive = df_filtered[df_filtered['is_responsive']]
    non_responsive = df_filtered[~df_filtered['is_responsive']]

    # Determine the bin edges for consistent bin widths
    all_data = df_filtered[metric]
    data_range = (all_data.min(), all_data.max())
    bin_edges = np.linspace(data_range[0], data_range[1], bins+1)

    # Create histograms for the counts
    responsive_counts, res_edges = np.histogram(responsive[metric], bins=bin_edges)
    non_responsive_counts, non_res_edges = np.histogram(non_responsive[metric], bins=bin_edges)

    # Plot histograms
    plt.figure(figsize=(10, 6))
    plt.hist(responsive[metric], bins=bin_edges, color=responsive_color, alpha=0.5, label='Responsive Units')
    plt.hist(non_responsive[metric], bins=bin_edges, color=non_responsive_color, alpha=0.5, label='Non-responsive Units')

    # Draw the complete outline for responsive units histogram
    plt.step(np.concatenate(([res_edges[0]], res_edges)), np.concatenate(([0], responsive_counts, [0])), where='post', color=responsive_color, linewidth=2)

    # Draw the complete outline for non-responsive units histogram
    plt.step(np.concatenate(([non_res_edges[0]], non_res_edges)), np.concatenate(([0], non_responsive_counts, [0])), where='post', color=non_responsive_color, linewidth=2)

    # Set labels and title
    plt.xlabel(metric)
    plt.ylabel('Count')
    if title is not None:
        plt.title(title)
    else:
        plt.title(f'Distribution of {metric} for Stimulus ID {stimulus_id}')

    # Display the legend
    plt.legend(loc='upper right')

    plt.show()

# Example usage:
#plot_aggregated_responsive_vs_nonresponsive_scatter(responsiveness_df_cablam, 'pre_stim_median', 'peak_delta_f_f_post_stim', 36)
#plot_aggregated_responsive_vs_nonresponsive_scatter(responsiveness_df_gcamp8, 'pre_stim_median', 'peak_delta_f_f_post_stim', 36)
#plot_aggregated_responsive_vs_nonresponsive_scatter(responsiveness_df_cablam1x, 'pre_stim_median', 'peak_delta_f_f_post_stim', 36)

plot_responsive_vs_nonresponsive_histogram_for_stimulus(responsiveness_df_cablam, 'peak_delta_f_f_post_stim', 36, bins=20)
plot_responsive_vs_nonresponsive_histogram_for_stimulus(responsiveness_df_gcamp8, 'peak_delta_f_f_post_stim', 36, bins=20)
plot_responsive_vs_nonresponsive_histogram_for_stimulus(responsiveness_df_cablam1x, 'peak_delta_f_f_post_stim', 36, bins=20)


plot_histogram_with_complete_outline(responsiveness_df_gcamp8, 'peak_delta_f_f_post_stim', 24, bins=15)
plot_histogram_with_complete_outline(responsiveness_df_cablam, 'peak_delta_f_f_post_stim', 24, bins=15)
plot_histogram_with_complete_outline(responsiveness_df_cablam1x, 'peak_delta_f_f_post_stim', 24, bins=1)   



In [None]:
import pandas as pd

def get_sample_size_summary(df, stim_ids, label, count_responsive=False):
    """
    Summarizes sample sizes (n) for each stimulation ID in the given DataFrame.

    Parameters:
    df (pd.DataFrame): The input DataFrame with 'session_id', 'roi', 'stimulation_id', and 'is_responsive'.
    stim_ids (list of int): List of stimulation IDs to include.
    label (str): Name of the group (e.g., 'CaBLAM', 'GCaMP8s').
    count_responsive (bool): If True, also reports how many neurons were responsive.

    Returns:
    pd.DataFrame: Summary table with columns: ['Sensor', 'Stim', 'n_total', 'n_responsive', 'percent_responsive']
    """
    results = []

    for stim_id in stim_ids:
        df_stim = df[df['stimulation_id'] == stim_id]

        # Group by unique neuron per session
        grouped = df_stim.groupby(['session_id', 'roi'])

        # Aggregate responsiveness
        aggregated = grouped.agg({
            'is_responsive': 'max'  # True if responsive at least once
        }).reset_index()

        n_total = len(aggregated)
        if count_responsive:
            n_resp = aggregated['is_responsive'].sum()
            percent_resp = (n_resp / n_total) * 100 if n_total > 0 else 0
            results.append({
                'Sensor': label,
                'Stim': stim_id,
                'n_total': n_total,
                'n_responsive': int(n_resp),
                'percent_responsive': round(percent_resp, 1)
            })
        else:
            results.append({
                'Sensor': label,
                'Stim': stim_id,
                'n_total': n_total
            })

    return pd.DataFrame(results)

stim_ids = [12, 120]
summary_cablam = get_sample_size_summary(responsiveness_df_cablam, stim_ids, label='CaBLAM')
summary_cablam1x = get_sample_size_summary(responsiveness_df_cablam1x, stim_ids, label='CaBLAM1x')
summary_gcamp8s = get_sample_size_summary(responsiveness_df_gcamp8, stim_ids, label='GCaMP8s')

# Combine all
summary_all = pd.concat([summary_cablam, summary_cablam1x, summary_gcamp8s], ignore_index=True)

# View
print(summary_all)

The below is to properly process and correct the dark signal/camera noise 

In [None]:
def process_biolumi_calcium_signal(session_id, directory_df):
    processed_dir = 'processed_data/processed_image_analysis_output'
    calcium_csv_suffix = '_calcium_signals.csv'
    directory_df['session_id'] = directory_df['session_id'].astype(int)
    directory_entry = directory_df[directory_df['session_id'] == session_id]

    # Check if the directory_entry is empty
    if directory_entry.empty:
        print(f"No directory entry found for session {session_id}. Please check the session_id.")
        return None
    
    directory_path = directory_entry['directory_path'].values[0]
    csv_path = os.path.join(directory_path, processed_dir, str(session_id) + calcium_csv_suffix)

    
    if not os.path.exists(csv_path):
        print(f"Calcium signals file not found for session {session_id}")
        return None

    calcium_signals_df = pd.read_csv(csv_path) # import the calcium signals csv file
    
    # Correct the "Dark signal" for each ROI
    for roi in calcium_signals_df.columns:
        if 'ROI' in roi:  # Assuming ROI columns are prefixed with 'ROI'
            dark_signal_median = calcium_signals_df[roi][:300].median() # Calculate the median of the first 100 frames
            calcium_signals_df[roi] = calcium_signals_df[roi] - dark_signal_median
            calcium_signals_df.loc[calcium_signals_df[roi] < 0, roi] = np.nan
    
    
    #save the corrected calcium signals to a new csv file in the same directory
    corrected_csv_path = os.path.join(directory_path, processed_dir, str(session_id) + '_corrected' + calcium_csv_suffix)
    calcium_signals_df.to_csv(corrected_csv_path, index=False)
    
            
    return calcium_signals_df

# New function to process all session IDs
def process_all_sessions(directory_df):
    unique_sessions = directory_df['session_id'].unique()
    for session_id in unique_sessions:
        print(f"Processing session ID: {session_id}")
        process_biolumi_calcium_signal(session_id, directory_df)
        print(f"Completed processing for session ID: {session_id}")


# Example usage:
session_id = 2212242023
biolumi_calcium_signals_df = process_biolumi_calcium_signal(session_id, analysis.directory_df)
process_all_sessions(analysis.directory_df)
                


In [None]:
biolumi_calcium_signals_df

In [None]:
def process_all_sessions(self):
    all_data = {}
    # Iterate over all unique session IDs
    for session_id in self.directory_df['session_id'].unique():
        # Adjusted to unpack three values here
        stim_frame_numbers, roi_data, stimulation_ids = self.create_trial_locked_calcium_signals(session_id)
        all_data[session_id] = {
            'stim_frame_numbers': stim_frame_numbers,
            'roi_data': roi_data,
            'stimulation_ids': stimulation_ids  # You can decide whether you need to store this or not
        }
    return all_data

def create_trial_locked_calcium_signals(self, session_id):
    """
        dictionary with session IDs as keys. Each session contains its own dictionary with keys for stim_frame_numbers, roi_data, and stimulation_ids. 
    Within roi_data, the data is keyed by tuples, where each tuple consists of a stimulation_id and a stim_frame_number, 
    and associated with these tuples are NumPy arrays of the recorded signals.
    
    Outline of dictionary structure:
    Level 1: The top-level dictionary contains Session IDs as keys.
    Example: '2312072023', '1112072023'
    
    Level 2: Each Session ID key maps to a dictionary that contains three keys:
    'stim_frame_numbers': List of frame numbers where stimuli were applied.
    'roi_data': Nested dictionary with ROI signal data.
    'stimulation_ids': List of identifiers for each stimulus type.
    
    Level 3: The 'roi_data' dictionary has:
    Keys: Names of the ROIs (e.g., 'ROI_1').
    Values: Another dictionary for each ROI, which I'll describe in the next level.
    
    Level 4 (within 'roi_data'): Here's where the tuple comes into play.
    Keys: Tuples containing (stimulation_id, stim_frame_number).
    stimulation_id: A unique identifier for the type of stimulation.
    stim_frame_number: The frame number when this stimulation occurred.
    Values: NumPy arrays with the calcium signal data corresponding to each ROI following a stimulus event.
    
    all_data : dict
    A dictionary containing processed calcium signal data for multiple sessions.

    Each key in `all_data` represents a unique session ID corresponding to an individual experimental session.

    Keys
    ----
    session_id : str
        A unique identifier for the experimental session. The `session_id` is likely a string that represents the date and additional identifying information of the session.

    Values
    ------
    A dictionary containing the following keys:

    stim_frame_numbers : list of int
        A list of integers representing the frame numbers at which stimuli were applied.

    roi_data : dict of dict
        A nested dictionary where each top-level key is an ROI label (e.g., 'ROI_1') and the value is another dictionary mapping a tuple of `(stimulation_id, stim_frame_number)` to a NumPy array of signal data.

        Keys
        ----
        (stimulation_id, stim_frame_number) : tuple
            `stimulation_id` : int
                An integer representing a unique identifier for a type of stimulus applied during the experimental session.

            `stim_frame_number` : int
                An integer indicating the frame number at which the stimulus was applied.

        Values
        ------
        signal_data : numpy.ndarray
            An array containing the calcium signal values recorded for the ROI after the corresponding stimulus event.

    stimulation_ids : list of int
        A list of unique identifiers for each type of stimulation used in the session.

    Example
    -------
    >>> all_data['2312072023']['stim_frame_numbers']
    [3582, 3784, 3986, 4187, 4389, 4590, 4792, 4994, 5195, 5397]
    >>> all_data['2312072023']['roi_data']['ROI_1'][(60, 3582)]
    array([...signal values...])
    >>> all_data['2312072023']['stimulation_ids']
    [12, 24, 36, 60, 120, 480]
    """
    
    
    processed_dir = 'processed_data/processed_image_analysis_output'
    calcium_csv_suffix = '_calcium_signals.csv'

    directory_entry = self.directory_df[self.directory_df['session_id'] == session_id] #pull out the entry for the given session_id from the directory dataframe
    
    #pull out the list of stimulation frame numbers for the given session_id under the stimulation_frame_number column
    stim_frame_numbers = directory_entry['stimulation_frame_number'].values[0]
    
    #pull out the stimulation label for the given session_id under the stimulation_label column
    stimulation_ids = directory_entry['stimulation_ids'].values[0]
    
    if directory_entry.empty:
        print(f"No directory entry found for session {session_id}")
        return

    directory_path = directory_entry['directory_path'].values[0]
    csv_path = os.path.join(directory_path, processed_dir, session_id + calcium_csv_suffix)

    if not os.path.exists(csv_path):
        print(f"Calcium signals file not found for session {session_id}")
        return

    calcium_signals_df = pd.read_csv(csv_path) #import the calcium signals csv file 
    #convert the values in the calcium_signals dataframe to integers with no decimal points
    calcium_signals_df = calcium_signals_df.astype(int)
    
    # Parameters for alignment
    pre_stim_frames = 10  # Number of frames before stimulation to include
    post_stim_frames = 100  # Number of frames after stimulation to include
    
    # Create a nested dictionary where each key-value pair corresponds to a different ROI. 
    # For each ROI, you have another dictionary where the key is a tuple of (stimulation_id, stim_frame_number), 
    # and the value is a NumPy array containing the calcium signal values for a window around the stimulation frame.

    # Initialize a nested dictionary to hold ROI, stimulation ID and frame number, and data
    roi_data = {roi: {} for roi in calcium_signals_df.columns if 'ROI' in roi}

    # Loop through each stimulation frame number and their corresponding stimulation IDs
    for stim_id, stim_frame in zip(stimulation_ids, stim_frame_numbers):
        # Calculate the index range for frames to extract
        start_idx = max(stim_frame - pre_stim_frames, 0)  # Ensure index is not negative
        end_idx = min(stim_frame + post_stim_frames, len(calcium_signals_df))  # Ensure index is within range

        # Loop through each ROI column
        for roi in roi_data:
            # Extract the relevant section of the calcium signals for the ROI
            trial = calcium_signals_df.loc[start_idx:end_idx, roi]

            # Store the trial data as a NumPy array in the nested dictionary
            # Using a tuple of (stimulation_id, stim_frame_number) as the key
            roi_data[roi][(stim_id, stim_frame)] = trial.to_numpy().astype(int)
    
    return  stim_frame_numbers, roi_data, stimulation_ids

# Example usage:
session_id = 2212242023
stim_frame_numbers, roi_data, stimulation_ids = create_trial_locked_calcium_signals(analysis, session_id)





BELOW HERE IS OLDER FUCNTIONS --must sort through

In [None]:
def plot_mean_response(roi_data, stim_id=12):
    selected_data = {}
    for roi, stim_data in roi_data.items():
        for stim_key, data in stim_data.items():
            if stim_key[0] == stim_id:
                if roi not in selected_data:
                    selected_data[roi] = []
                selected_data[roi].append(data)

    for roi in selected_data.keys():
        selected_data[roi] = np.stack(selected_data[roi], axis=0)

    all_roi_data = np.stack(list(selected_data.values()), axis=0)
    mean_response = np.mean(all_roi_data, axis=0)
    sem_response = np.std(all_roi_data, axis=0, ddof=1) / np.sqrt(all_roi_data.shape[0])

    # Flatten the mean_response if it's 2D
    if mean_response.ndim == 2 and mean_response.shape[0] == 1:
        mean_response = mean_response.flatten()
    if sem_response.ndim == 2 and sem_response.shape[0] == 1:
        sem_response = sem_response.flatten()

    time_points = np.arange(-10, 51)

    plt.figure(figsize=(10, 5))
    plt.plot(time_points, mean_response, label='Mean Response')
    plt.fill_between(time_points, mean_response - sem_response, mean_response + sem_response, alpha=0.3, label='SEM')
    
    #add red dotted lines at time point 0 
    plt.axvline(x=-1, color='red', linestyle='--', linewidth=1)   
    
    plt.xlabel('Frame Number (relative to stimulus)')
    plt.ylabel('Calcium Signal')
    plt.title(f'Mean Calcium Response for stim_id {stim_id}')
    plt.legend()
    plt.show()

    return mean_response, sem_response

for stimulation_id in stimulation_ids:
    plot_mean_response(roi_data, stim_id=stimulation_id)


In [None]:
#plot the calcium signals for all sessions in the directory_df, do not need to re-run if all analysis has been done already 
#results = analysis.plot_all_sessions_calcium_signals()
results = analysis.plot_all_sessions_calcium_signals(use_corrected_data=True)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import fftconvolve

# Parameters
signal_length = 1000  # Length of the synthetic calcium signal
template_length = 50  # Length of the template
noise_level = 0.5  # Level of Gaussian noise to add to the signal

# Create synthetic calcium signal with noise
np.random.seed(0)
calcium_signal = np.zeros(signal_length)
calcium_signal[100:150] = 1  # Simulate a 1AP signal
calcium_signal += noise_level * np.random.randn(signal_length)  # Add noise

# Create template from a segment of the calcium signal
template = calcium_signal[100:150]

# Normalize the template
template_norm = template / np.linalg.norm(template)

# Apply matched filtering
filtered_signal = fftconvolve(calcium_signal, template_norm[::-1], mode='same')

# Detect peaks (threshold at 1.5 times the standard deviation of the filtered signal)
threshold = 1.5 * np.std(filtered_signal)
peaks = np.where(filtered_signal > threshold)[0]

# Plotting
plt.figure(figsize=(12, 6))
plt.subplot(3, 1, 1)
plt.plot(calcium_signal, label='Calcium Signal')
plt.legend()

plt.subplot(3, 1, 2)
plt.plot(template, label='Template')
plt.legend()

plt.subplot(3, 1, 3)
plt.plot(filtered_signal, label='Filtered Signal')
plt.plot(peaks, filtered_signal[peaks], 'ro', label='Detected Peaks')
plt.axhline(y=threshold, color='r', linestyle='--', label='Threshold')
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
# Generate a ground truth (for demonstration purposes, assume 1AP occurs at index 100)
ground_truth_peaks = [100]

# Calculate True Positives, False Positives, and False Negatives
TP = len(set(ground_truth_peaks) & set(peaks))
FP = len(set(peaks) - set(ground_truth_peaks))
FN = len(set(ground_truth_peaks) - set(peaks))

# Calculate True Negatives (assuming all other points are TN)
# Total number of points - (TP + FP + FN)
TN = len(calcium_signal) - (TP + FP + FN)

# Calculate False Positive Rate (FPR) and True Positive Rate (TPR)
FPR = FP / (FP + TN)
TPR = TP / (TP + FN)

print(f"False Positive Rate (FPR): {FPR}")
print(f"True Positive Rate (TPR): {TPR}")

In [None]:
analysis.plot_calcium_signals(session_id)  # Replace with your actual session ID

In [None]:
analysis.save_individual_roi_plots(session_id)  # Replace with your actual session ID

In [None]:
#plot the individual roi signals for all sessions in the directory_df 
#analysis.save_individual_roi_plots_all_sessions()
analysis.save_individual_roi_plots_all_sessions(use_corrected_data=False)
#analysis.save_individual_roi_plots_all_sessions(use_corrected_data=True)

In [None]:
analysis.plot_roi_with_zoomed_stimulations(session_id)

In [None]:
# plot the calcium signals for a specified session in the directory_df with the full and zoomed in traces around the stimulations
analysis.plot_and_save_roi_stimulations(session_id)  # Replace with your actual session ID

In [None]:
# plot the calcium signals for all sessions in the directory_df with the full and zoomed in traces around the stimulations
analysis.plot_and_save_roi_stimulations_all_sessions()

In [None]:
analysis.find_responsive_rois_first_stim_mean(session_id, pre_stim_duration=3, post_stim_duration=3, threshold=2)  

In [None]:
analysis.plot_responsive_rois_around_stim(session_id, pre_stim_duration=3, post_stim_duration=3, threshold=2)  # Replace with your actual session ID

In [None]:
analysis.plot_mean_and_sem_of_responsive_rois(session_id, pre_stim_duration=3, post_stim_duration=3, threshold=2)  # Replace with your actual session ID

In [None]:
analysis.plot_normalized_mean_and_sem_of_responsive_rois(session_id, pre_stim_duration=3, post_stim_duration=3, threshold=2)  # Replace with your actual session ID

In [None]:
analysis.plot_normalized_mean_and_sem_of_all_stims(session_id, pre_stim_duration=3, post_stim_duration=3, threshold=2)  # Replace with your actual session ID

In [None]:
analysis.plot_mean_responses_from_file(session_id)  # Replace with your actual session ID

In [None]:
analysis.plot_overlaid_normalized_responses(session_id, pre_stim_duration=3, post_stim_duration=3, threshold=2)  # Replace with your actual session ID

In [None]:
# Since the folders are directly inside the project_folder, you don't need to append any subdirectory name
data_files = analysis.list_files('') 
print(data_files) #print the list of files in the project folder 

directories = analysis.list_directories()
print(directories) #print the list of directories in the project folder

In [None]:
# Assuming analysis is an instance of ImageAnalysis
first_row = analysis.directory_df.iloc[3]
directory_path = first_row['directory_path']

# Automatically generate file paths based on the directory path
dark_frames_path = os.path.join(directory_path, "dark_frames.tiff")
raw_image_path = os.path.join(directory_path, "raw_image.tiff")

print("dark_frames_path:", dark_frames_path)
print("raw_image_path:", raw_image_path)

tiff_path = '/Volumes/MannySSD/cablam_imaging/raw_data_for_analysis/c11_12232023_estim_10hz_1xfz/c11_12232023_estim_10hz_1xfz_biolumi_combined.tif'

# Generate the dark image
dark_image = analysis.generate_dark_image(tiff_path) #generates a dark image from the first 200 frames of the tiff file 

In [None]:
import matplotlib.pyplot as plt
# Display the dark image
plt.imshow(dark_image, cmap='gray')
plt.axis('off')
plt.show()


In [None]:
# Plot the calcium signals
plt.figure(figsize=(10, 6))
for roi in range(num_rois):
    plt.plot(calcium_signals[roi], label=f'ROI {roi + 1}')

plt.xlabel('Time (frames)')
plt.ylabel('Mean Intensity')
plt.title('Calcium Signals Over Time')
plt.legend()
plt.show()


In [None]:

# Assuming 'tiff_path' contains the path to your time series TIFF file
# and 'labeled_image' is your ROI mask loaded as a numpy array
time_series = io.imread(tiff_path)  # This should be a 3D numpy array (time, y, x)
num_rois = np.max(labeled_image)
num_frames = time_series.shape[0]

# Initialize an array to hold the calcium signal data
calcium_signals = np.zeros((num_rois, num_frames))

# Process each frame to extract ROI signals
for t in range(num_frames):
    frame = time_series[t]
    for roi in range(1, num_rois + 1):  # ROIs are labeled from 1 to num_rois
        roi_mask = labeled_image == roi
        roi_data = frame[roi_mask]
        calcium_signals[roi - 1, t] = np.mean(roi_data) if roi_data.size > 0 else 0

# Plotting the calcium signals
offset = 10  # Change this value to adjust the vertical spacing between ROIs
plt.figure(figsize=(15, 8))
for roi in range(num_rois):
    plt.plot(calcium_signals[roi] + offset * roi, label=f'ROI {roi + 1}')  # Offset each ROI signal

plt.xlabel('Time')
plt.ylabel('ROI')
plt.title('Timeseries of ROIs')
#plt.yticks(ticks=np.arange(num_rois) * offset, labels=np.arange(1, num_rois + 1))  # Set y-ticks to show ROI IDs
plt.grid(True)
plt.show()


In [None]:
# Load the time series TIFF file from the given path
time_series = io.imread(tiff_path)  # 3D numpy array: (time, y, x)
num_rois = np.max(labeled_image)    # Assuming labeled_image is already defined as shown before
num_frames = time_series.shape[0]

# Initialize an array to hold the calcium signal data for each ROI over time
calcium_signals = np.zeros((num_rois, num_frames))

# Process each frame to extract ROI signals
for t in range(num_frames):
    frame = time_series[t]
    for roi in range(1, num_rois + 1):  # ROI labels start from 1
        roi_mask = labeled_image == roi
        roi_data = frame[roi_mask]
        calcium_signals[roi - 1, t] = np.mean(roi_data) if np.any(roi_mask) else np.nan

# Plotting the calcium signals
plt.figure(figsize=(20, 10))  # Adjust the figure size as necessary

# Define vertical offset between lines to ensure clear separation
vertical_offset = 10  # Change as needed to match the plot scale and ROI separation

# Iterate over the ROIs to plot each one with an offset
for roi_idx in range(num_rois):
    plt.plot(calcium_signals[roi_idx] + (vertical_offset * roi_idx), label=f'ROI {roi_idx + 1}')

# Set the y-ticks to correspond to the ROIs
# Here, we create a list of y-tick positions based on the number of ROIs and the vertical offset
plt.yticks(ticks=np.arange(num_rois) * vertical_offset, labels=np.arange(1, num_rois + 1))

plt.xlabel('Time (frames)')
plt.ylabel('ROI')
plt.title('Timeseries of ROIs')
plt.grid(True)  # Include grid for better readability

# Optional: Adjust the limits of the y-axis if needed to fit your data range
plt.ylim(-5, (num_rois - 1) * vertical_offset + 15)

# Optional: If you want to show a legend mapping colors to ROI IDs
# plt.legend(loc='upper right')

plt.show()


In [None]:
# Load and display the first few rows of the CSV file to understand its structure
file_path = '/mnt/data/c11_12232023_estim_10hz_1xfz_biolumi_combined_calcium_signals.csv'
calcium_data = pd.read_csv(file_path)

calcium_data.head()


# Load and display the first few rows of the second CSV file to understand its structure
stimulation_file_path = '/mnt/data/c11_12232023_estim_10hz_1xfz_biolumi.csv'
stimulation_data = pd.read_csv(stimulation_file_path)

stimulation_data.head()

# Re-load the data assuming there is no header and display it again to understand its structure
stimulation_data_no_header = pd.read_csv(stimulation_file_path, header=None)
stimulation_data_no_header.head()

In [None]:
import math

# made modifications to the function to plot the calcium signals for each ROI with a white background and no grid lines and ensure 

def plot_roi_signals_no_grid(calcium_data, stimulation_frames, num_rois=46):
    # Determine the number of rows needed for the subplots (n)
    num_rows = math.ceil(num_rois / 5)

    # Extract the frame numbers for stimulations
    stimulation_points = stimulation_data_no_header.values.flatten()

    # Create the subplot grid and plot data with a white background and no grid lines
    fig, axs = plt.subplots(num_rows, 5, figsize=(25, 5 * num_rows), facecolor='white')
    for i in range(num_rois):
        row = i // 5
        col = i % 5
        # Generate a random color for each ROI
        random_color = np.random.rand(3,)
        axs[row, col].plot(calcium_data['Frame'], calcium_data[f'ROI_{i+1}'], label=f'ROI_{i+1}', color=random_color)
        # Add stimulation markers
        for stim_point in stimulation_points:
            axs[row, col].axvline(x=stim_point, color='red', linestyle='dotted')
        axs[row, col].set_title(f'ROI_{i+1}')
        axs[row, col].set_xlabel('Frame')
        axs[row, col].set_ylabel('Calcium Signal')
        axs[row, col].set_facecolor('white')
        axs[row, col].grid(False)  # Disable grid lines

    # Adjust the layout and display the plot
    plt.tight_layout()
    plt.show()

# Call the function to display the plot without grid lines
plot_roi_signals_no_grid(calcium_data, stimulation_data_no_header)


In [None]:
def plot_all_rois_aligned(calcium_data, stimulation_frames, num_rois=46):
    # Extract the frame numbers for stimulations
    stimulation_points = stimulation_data_no_header.values.flatten()

    # Initialize the figure
    plt.figure(figsize=(20, 15))

    # Define y-ticks and their labels based on the number of ROIs
    y_ticks = []
    y_tick_labels = []

    # Calculate a reasonable fixed offset to visually separate the ROI lines
    fixed_offset = 100  # Adjust if necessary

    # Plot each ROI's calcium signal with a unique random color and apply fixed offset incrementally
    for i in range(num_rois):
        random_color = np.random.rand(3,)
        # Calculate the offset for this ROI's line
        offset = i * fixed_offset
        plt.plot(calcium_data['Frame'], calcium_data[f'ROI_{i+1}'] + offset, color=random_color, label=f'ROI_{i+1}')
        
        # Add y-tick at the median of the offset signal for the label
        y_ticks.append(np.median(calcium_data[f'ROI_{i+1}'] + offset))
        y_tick_labels.append(f'ROI_{i+1}')

    # Add stimulation markers
    for stim_point in stimulation_points:
        plt.axvline(x=stim_point, color='red', linestyle='dotted', linewidth=1)

    plt.xlabel('Frame')
    plt.ylabel('ROI')
    plt.yticks(y_ticks, y_tick_labels)
    plt.title('All ROIs Aligned with Corresponding Data')
    plt.legend(loc='upper left', bbox_to_anchor=(1, 1))  # Move the legend outside of the plot
    plt.tight_layout(rect=[0, 0, 0.85, 1])  # Adjust layout to make room for the legend
    plt.show()

# Call the function to display the plot with correctly aligned ROIs and their data
plot_all_rois_aligned(calcium_data, stimulation_data_no_header)


In [None]:
def plot_roi_aligned_extended_frames(calcium_data, stimulation_frames, num_rois=46, frames_before_stim=1000):
    # Extract the frame numbers for stimulations and find the first stimulation frame
    stimulation_points = stimulation_frames.values.flatten()
    first_stim_frame = np.min(stimulation_points)

    # Set the range of frames to plot: from (first_stim_frame - frames_before_stim) to the end of the data
    start_frame = max(0, first_stim_frame - frames_before_stim)
    end_frame = calcium_data['Frame'].max()

    # Filter the calcium_data to include only the relevant frames
    limited_data = calcium_data[(calcium_data['Frame'] >= start_frame) & (calcium_data['Frame'] <= end_frame)]

    # Initialize the figure
    plt.figure(figsize=(20, 15))

    # Define y-ticks and their labels based on the number of ROIs
    y_ticks = []
    y_tick_labels = []

    fixed_offset = 100  # Adjust if necessary

    for i in range(num_rois):
        random_color = np.random.rand(3,)
        offset = i * fixed_offset
        plt.plot(limited_data['Frame'], limited_data[f'ROI_{i+1}'] + offset, color=random_color, label=f'ROI_{i+1}')
        
        y_ticks.append(np.median(limited_data[f'ROI_{i+1}'] + offset))
        y_tick_labels.append(f'ROI_{i+1}')

    # Add stimulation markers within the range
    for stim_point in stimulation_points:
        if start_frame <= stim_point <= end_frame:
            plt.axvline(x=stim_point, color='red', linestyle='dotted', linewidth=1)

    plt.xlabel('Frame')
    plt.ylabel('ROI')
    plt.yticks(y_ticks, y_tick_labels)
    plt.xlim(start_frame, end_frame)
    plt.title('Aligned ROIs with Extended Frame Range')
    plt.legend(loc='upper left', bbox_to_anchor=(1, 1))  # Move the legend outside of the plot
    plt.tight_layout(rect=[0, 0, 0.85, 1])  # Adjust layout to make room for the legend
    plt.show()

# Call the function with the extended frame range
plot_roi_aligned_extended_frames(calcium_data, stimulation_data_no_header)
