In [262]:
import re
import h5py
from functools import lru_cache
from pathlib import Path
from typing import Tuple, Optional
import logging
from scipy.io import loadmat
import numpy as np
from typing import Callable
from scipy.spatial.distance import pdist, squareform


# logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s %(levelname)s %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    handlers=[
        logging.StreamHandler()
    ]
)

logger = logging.getLogger(__name__)


class DataLoader:
    def __init__(self, 
    cells_datapath: str=None | Path,
    cells_positionsdataKey: str='data',
    cells_positionsKey: str=None, 
    cells_dholderkey: str='CellResp', 
    cells_positions_path: str = None | Path,
    eliminateNeurons: bool =True,
    simplify_cells: bool =True,
    number_of_cells: Optional[int] = None,
    sessions: Optional[int] = None,
    ):

        self.cells_datapath = Path(cells_datapath)
        self.cells_positionsKey = cells_positionsKey or 'CellXYZ'
        self.cells_dholderkey = cells_dholderkey
        self.cells_positions_path = Path(cells_positions_path) if cells_positions_path else None
        self.number_of_cells = number_of_cells
        self.sessions = sessions
        self.cells_positionsdataKey = cells_positionsdataKey

        # Validation 
        if not self.cells_datapath.exists():
            raise FileNotFoundError(f"Cell data file not found: {self.cells_datapath}")
        self.eliminateNeurons=eliminateNeurons
        self.eliminateNeuronsKey = 'IX_inval_anat' if eliminateNeurons else None
        self.simplify_cells = simplify_cells

        # cache loaded data 
        self._cached_positons = None
        self._cached_responses = None
        self._cache_size = None

        self._response_cache = {}
        self._max_response_cache_size = 5


    @lru_cache(maxsize=1)
    def _load_cell_positions_cached(self) -> np.ndarray:

        if not self.cells_positions_path:
            raise ValueError("cells_positions_path must be provided")
        
        if not self.cells_positions_path.exists():
            raise FileNotFoundError(f"Positions files not foundf")
        
        try:
            tmpdata = loadmat(self.cells_positions_path, 
                                simplify_cells=self.simplify_cells)
            logging.debug(f"Loaded cell positions from {tmpdata}")
        except Exception as e:
            raise IOError(f"Failed to load cell position file {self.cells_positions_path}: {e}")
        if self.cells_positionsdataKey not in tmpdata:
            raise KeyError(f"Key '{self.cells_positionsdataKey}' not found in positions file")
        
        if self.eliminateNeuronsKey not in tmpdata[self.cells_positionsdataKey]:
            raise KeyError(f"Key '{self.eliminateNeuronsKey}' not found in position data")

        roi_positions_to_discard = tmpdata[self.cells_positionsdataKey][self.eliminateNeuronsKey]
        all_roi_positions =  tmpdata[self.cells_positionsdataKey][self.cells_positionsKey]
        
        if self.eliminateNeurons:
            mask = np.ones(len(all_roi_positions), dtype=bool)
            mask[roi_positions_to_discard] = False
            used_rois_positions = all_roi_positions[mask]
        
        else:
            used_rois_positions = all_roi_positions
        
        logger.info(f"Loaded {len(used_rois_positions)} cell postions (cached)")

        return used_rois_positions

        
    def _load_cell_responses(self) -> np.ndarray:
        """
        Load cell responses from h5 file

        Args:
            number_of_cells: Number of cells to load. If None, loads all the cells
            sessions: length of time window of responses to load. If None, loads all the sessions

        Returns:
            np.ndarry: Cell responses for data

        """

        cache_key = self.number_of_cells

        # check cache
        if cache_key in self._response_cache:
            logger.debug(f" Using cached responses for {self.number_of_cells, self.sessions} cells")
            return self._response_cache[cache_key]
        
        with h5py.File(str(self.cells_datapath), 'r') as tmpdata:
            responses = tmpdata[self.cells_dholderkey][:self.number_of_cells, :self.sessions]
        
        if len(self._response_cache) >= self._max_response_cache_size:
            # Remove oldest entry (FIFO)
            first_key = next(iter(self._response_cache))
            del self._response_cache[first_key]
            logger.debug(f"Evicted cache entry for {first_key} cells")
        
        self._response_cache[cache_key] = responses
        logger.info(f"Loaded and cached responses for {len(responses)}: {self.number_of_cells} cells, {self.sessions} sessions")

        return responses
        


    def clear_response_cache(self):
       """ Clear response cache """
       self._response_cache.clear()
       logger.info("Response cache cleared") 
    

    def get_cache_stats(self):
        """ Get cache stats """
        return {
            "position_cache": self.cache_info(),
            "response_cache": len(self._response_cache),
            "response_cache_keys": list(self._response_cache.keys())
        }


    def _load_cell_positions(self) -> np.ndarray:
        """Load 3D cell positions from MATLAB file"""
        logger.info(f"Loading cell positions for {self.number_of_cells or 'all'} cells")

        positions = self._load_cell_positions_cached()[:self.number_of_cells, :].T

        logger.info(f"Returning positions for {self.number_of_cells or 'all'} cells")
        return positions


    def load_cell_responses_and_positions(self) -> tuple[np.ndarray, np.ndarray]:
        responses = self._load_cell_responses()
        positions = self._load_cell_positions()
        return responses, positions 


    def clear_cache(self):

        self._load_cell_positions_cached.cache_clear()
        logger.info("Positions cache cleared")


    def cache_info(self):
        """
            Get cache stats
        """
        return self._load_cell_postions_cached.cache_info()



    @classmethod
    def from_zebrafish_folder(cls,
                        zebfolderpath: str = None | Path,
                        zfolder: str = None | Path ,
                        pattern: str = 'z.*f\d+',
                        rmfname : str = 'README_LICENSE.rtf',
                        sortkey = None,
                        reverse: bool = False,
                    
        ) -> Tuple[str, Path, Path]:
        """ Constructor for loading from zebrafish data folders

        Args:
          zebfolderpath: Path to parent folder containing zebrafish data subfolders.
                         Will search for a subfolder matching the pattern.
          zfolder: Direct path to a specific zebrafish data folder (bypasses pattern matching)
          pattern: Regex pattern to match subfolder names (default: 'z.*f\d+')
          rmfname: Filename to exclude from the folder (default: 'README_LICENSE.rtf')
          sortkey: Optional key function for sorting files
          reverse: Reverse sort order if True (default: False)
          
        Returns:
            Tuple of (tag, file1_path, file2_path)
            
        Raises:
            ValueError: If neither zebfolderpath nor zfolder is provided,
                        if no matching folders found, or if folder doesn't contain exactly 2 files
        """

        # Validate inputs 
        if not zebfolderpath and not zfolder:
            raise ValueError("Must provide either zebfolderpath or zfolder")
        
        if zfolder: 
            matched_dir = Path(zfolder)
            if not matched_dir.exists():
                raise FileNotFoundError(f"Folder {matched_dir} does not exist")
            if not matched_dir.is_dir():
                raise NotADirectoryError(f"Path is not a directory: {zfolder}")

        else:
            datapath_list = Path(zebfolderpath)
            if not datapath_list.exists():
                raise FileNotFoundError(f"Folder {datapath_list} does not exist")
            if not datapath_list.is_dir():
                raise NotADirectoryError(f"Path is not a directory: {zebfolderpath}")
            matches = [path for path in datapath_list.iterdir() if path.is_dir() and  re.search(pattern, path.name)]
            if not matches:
                raise ValueError(f"No matches found for pattern {pattern} in {zebfolderpath}")

            if len(matches) > 1:
                match_names = [m.name for m in matches]
                raise ValueError(f"Multiple matches found for pattern '{pattern}':{match_names}." 
                                f"Use zfolder to specify which folder")
                
            # set matched dir to match
            matched_dir = Path(matches[0])

        tag = matched_dir.name

        files = [f for f in matched_dir.iterdir() if f.is_file() and f.name != rmfname]
        if sortkey:
            files.sort(key=sortkey)
        else:
            files.sort(reverse=reverse)
        
        if len(files) != 2:
            file_names = [f.name for f in files]
            raise ValueError(f"Expected 2 files in {zebfolderpath}, found {len(file_names)}")
        
        return tag, files[0], files[1]
    


In [263]:
# testing for robustSigs

tag, cellpath, cellpositionspath = DataLoader.from_zebrafish_folder(zfolder='/home/duuta/ppp/data/zebf00')


In [264]:
dataloader= DataLoader(cells_datapath= str(cellpath),
           cells_positions_path=str(cellpositionspath),
           number_of_cells=100,
           sessions = 100,
           )

In [265]:
cell_responses, cell_positions = dataloader.load_cell_responses_and_positions()

2025-12-16 12:48:20 INFO Loaded and cached responses for 100: 100 cells, 100 sessions
2025-12-16 12:48:20 INFO Loading cell positions for 100 cells
2025-12-16 12:48:21 INFO Loaded 83205 cell postions (cached)
2025-12-16 12:48:21 INFO Returning positions for 100 cells


In [266]:
cell_positions.shape

(3, 100)

In [267]:
# cleared bugs, code works and loads data as desired
# TODO:
# 1.  create a class for measures
# 2.  class needs to take different measures to compute the local information
# 3.  does it take one roi or all rois? which methods should take all rois and which should take one roi?
# 4.  


In [268]:
from functools import lru_cache
import numpy as np
from typing import Callable


In [None]:
class ComputeLocalSigs:
    def __init__(self, 
                cell_positions: np.ndarray, 
                cell_responses: np.ndarray,
                metric: str = 'euclidean',  
                measure: Callable[[np.ndarray], np.ndarray] | None=None, 
            ) -> None:
        self.measure = measure
        self.positions = cell_positions
        self.responses = cell_responses
        self.metric = metric
        self.n_cells = cell_positions.shape[1]
        self._distance_matrix = self._compute_cell_distances()

    def _compute_cell_distances(self) -> np.ndarray:
        """ Compute pairwise cell distances"""
        ds = squareform(pdist(self.positions.T, metric=self.metric)) 
        return ds

    
    def get_nearest_neighbor_indices(self, cell_idx: int, n_neighbors: int) -> np.ndarray:
        """ Get indices of nearest neighbors for a cell (excluding itself)"""
        if not 0 <= cell_idx < self.n_cells:
            raise ValueError(f"Cell index {cell_idx} out of range")
        if n_neighbors >= self.n_cells:
            raise ValueError(f" Nearest Neighbours must be < {self.n_cells}")

        # argsort and take [1:n_neighbors+1] to exclude self at index 0 
        neighbors_indices = np.argsort(self._distance_matrix[cell_idx])[:n_neighbors]
        
        return neighbors_indices
    
    def get_nearest_neighbors_distances(self, cell_idx: int, n_neighbors: int) -> np.ndarray:
        """ Get distances of nearest neighbors for a cell (excluding itself)"""
        if not 0 <= cell_idx < self.n_cells:
            raise ValueError(f"Cell index {cell_idx} out of range [0, {self.n_cells}]")
        
        sorted_distances = np.sort(self._distance_matrix[cell_idx]) # [1: neighbors+1] excludes self (distance=0)
        return sorted_distances[1:n_neighbors+1] 
    
    def get_random_neighbors_indices(self, cell_idx: int, n_random: int, exclude_self: bool = True) -> np.ndarray:
        """ Get random cell indices (optionally excluding self)"""
        if not 0 <= cell_idx < self.n_cells:
            raise ValueError(f"Cell index {cell_idx} out of range [0, {self.n_cells}]")

        available_indices = np.arange(self.n_cells)  

        if exclude_self:
            available_indices = available_indices[available_indices != cell_idx]

        if n_random > len(available_indices):
           raise ValueError(f"Number random {n_random} exceeds available cells {len(available_indices)} ")

        return  np.random.choice(available_indices, n_random, replace=False)

    def compute_local_measure(self, cell_idx: int, n_neighbors: int, n_random:int) -> float:
        """ Compute local measure for a cell using its neighbors """
        neighbors_indices = self.get_nearest_neighbor_indices(cell_idx, n_neighbors) 

        # Include the cell itself with its neighbors for the measure 
        local_indices = np.concatenate((np.array([cell_idx]), neighbors_indices))
        local_responses = self.responses[local_indices, :]

        return self.measure(local_responses)


    def compute_measure_for_nearest_neighbors(
        self, 
        cell_idx: int,
        n_neighbors: int,
        measure: Callable[[np.ndarray], np.ndarray],
    ) -> float | np.ndarray:
        """ Compute measure for a cell using its nearest spatial neighbors 

        Args:
            cell_idx (int): Index of the cell to analyze
            n_neighbors (int): Number of nearest neighbors to consider
            measure (Callable[[np.ndarray], np.ndarray]): Measure to compute
        
        Returns:
            float | np.ndarray: Local measure for the cell
        """
        # Combine reference cell with neighbors: reference first, then neighbors 
        neighbors_indices = self.get_nearest_neighbor_indices(cell_idx, n_neighbors)
        local_indices = np.concatenate((np.array([cell_idx]), neighbors_indices))
        # Note responses sha[e is (n_timepoints, n_cells), transpose to (n_cells, n_timepoints)]
        local_responses = self.responses[local_indices, :]
        return measure(local_responses)


    def compute_measure_for_random_neighbors(
        self, 
        cell_idx: int,
        n_random: int,
        measure: Callable[[np.ndarray], float | np.ndarray],
        exclude_self: bool = True,
    ) -> float | np.ndarray:
        """ Compute measure for a cell using random neighbors.

        Args:
            cell_idx (int): Index of the cell to analyze
            n_random (int): Number of random neighbors to sample
            measure (Callable): Function that takes (n_cells, n_timepoints) responses
                               and returns a scale ro array measure
            exclude_self: Whether to exclude the cell itself from random sampling
            
        Returns:
             Result of measure function    float | np.ndarray: Local measure for the cell
        """
        random_indices = self.get_random_neighbors_indices(cell_idx, n_random, exclude_self)
        logger.info(f"number of random indices {random_indices.shape}...")
        # Combine reference cell with random neighbors
        local_indices = np.concatenate((np.array([cell_idx]), random_indices))
        local_responses = self.responses[local_indices, :]
        logger.info(f"the shape of the local_responses {local_responses.shape}")
        return measure(local_responses)
        

        
    def compute_measure_for_all_cells(
        self,
        n_neighbors: int,
        measure: Callable[[np.ndarray], float | np.ndarray],
        neighbor_type: str = 'nearest',

    ) -> np.ndarray:
        """ Compute measure for all cells.

        Args:
            n_neighbors (int): Number of neighbors to consider 
            measure : Measure functipon to apply
            neighbor_type: "nearest" or "random"

        Returns:
             Array of measure results fro each cell
        """
        results = []
        for cell_idx in range(self.n_cells):
            if neighbor_type == 'nearest':
                result = self.compute_measure_for_nearest_neighbors(
                    cell_idx, n_neighbors, measure
                )
            elif neighbor_type == 'random':
                result = self.compute_measure_for_random_neighbors(
                    cell_idx, n_neighbors, measure
                )

            else:
                raise ValueError(f"neighbor type must be 'nearest' or 'random' got type {neighbor_type}")

            results.append(result)

        return np.array(results)

In [270]:
# Create generic correlation measure fucntions

def compute_neighbor_correlations(responses: np.ndarray) -> np.ndarray:
    """
        Compute pairwise correlations between first neuron and all others.

        Args:
             responses: (n_neurons, n_timepoints(sessions)) array where first row is the
                        reference neuron and remaining rows are neighbors

        Returns:
            np.ndarray: Array of shape (n_neurons,) containing the pairwise correlations 
                        between the first neuron and neighbors.
    """
    reference_response = responses[0, :]  # First neuron is roi
    neighbor_responses = responses[1:, :]  # Remaining neighbor neurons

    n_neighbors = neighbor_responses.shape[0]
    correlations = np.empty(n_neighbors, dtype=float)

    for i in range(n_neighbors):
        correlations[i] = np.corrcoef(reference_response, neighbor_responses[i, :])[0, 1]
    
    return correlations


def compute_mean_neighbor_correlations(responses: np.ndarray) -> float:
    """
        Compute the correlation between reference neuron and neighbors.

        Args:
            responses: (n_neurons, n_timepoints) array
        
        Returns:
           Mean correlation value
    """
    correlations = compute_neighbor_correlations(responses)
    return np.mean(correlations)


In [271]:
tag, cellpath, cellpositionspath = DataLoader.from_zebrafish_folder(zfolder='/home/duuta/ppp/data/zebf00')

In [272]:
dataloader = DataLoader(
                        cells_datapath=cellpath, 
                        cells_positions_path=cellpositionspath,
                        number_of_cells=100,
                        sessions = 1000,
            )

In [273]:
cell_responses, cell_positions = dataloader.load_cell_responses_and_positions()

2025-12-16 12:48:21 INFO Loaded and cached responses for 100: 100 cells, 1000 sessions
2025-12-16 12:48:21 INFO Loading cell positions for 100 cells
2025-12-16 12:48:21 INFO Loaded 83205 cell postions (cached)
2025-12-16 12:48:21 INFO Returning positions for 100 cells


In [274]:
local_sigs = ComputeLocalSigs(
    cell_positions=cell_positions,
    cell_responses=cell_responses,
    metric='euclidean',
)

In [280]:
cell_idx = 0
n_neighbors = 10

nn_correlations = local_sigs.compute_measure_for_nearest_neighbors(
    cell_idx= cell_idx,
    n_neighbors=n_neighbors,
    measure=compute_neighbor_correlations,
)

In [281]:
nn_correlations
#  need to check that the computation is correct...

array([1.        , 0.27722458, 0.28741727, 0.19400512, 0.20721784,
       0.26599003, 0.13600898, 0.17352471, 0.13088984, 0.11748192])

In [277]:
rn_correlations= local_sigs.compute_measure_for_random_neighbors(
    cell_idx = cell_idx,
    n_random = n_neighbors,
    measure= compute_neighbor_correlations,
)

2025-12-16 12:48:21 INFO number of random indices (10,)...
2025-12-16 12:48:21 INFO the shape of the local_responses (11, 1000)


In [278]:
rn_correlations.shape

(10,)

In [284]:
all_nn_mean_corrs = local_sigs.compute_measure_for_all_cells(
    n_neighbors=10,
    measure=compute_mean_neighbor_correlations,
    neighbor_type='nearest',
)

all_rn_mean_corrs= local_sigs.compute_measure_for_all_cells(
    n_neighbors=10,
    measure=compute_mean_neighbor_correlations,
    neighbor_type='random',
)

AttributeError: 'ComputeLocalSigs' object has no attribute 'computer_measure_nearest_neighbors'

# Load data
# Usage example..
tag, cellpath, cellpositionspath = DataLoader.from_zebrafish_folder(
    zfolder='/Users/duuta/ppp/data/zebf00'
)
dataloader = DataLoader(
    cells_datapath=str(cellpath),
    cells_positions_path=str(cellpositionspath),
    number_of_cells=100,
    sessions=100,
)
cell_responses, cell_positions = dataloader.load_cell_responses_and_positions()

# Create local sigs computer
local_sigs = ComputeLocalSigs(
    cell_positions=cell_positions,
    cell_responses=cell_responses,
    metric='euclidean'
)

# Example 1: Compute nearest neighbor correlations for a single cell
cell_idx = 0
n_neighbors = 10
nn_correlations = local_sigs.compute_measure_for_nearest_neighbors(
    cell_idx=cell_idx,
    n_neighbors=n_neighbors,
    measure=compute_neighbor_correlations  # Returns array of correlations
)
print(f"Nearest neighbor correlations: {nn_correlations}")

# Example 2: Compute random neighbor correlations for a single cell
rn_correlations = local_sigs.compute_measure_for_random_neighbors(
    cell_idx=cell_idx,
    n_random=n_neighbors,
    measure=compute_neighbor_correlations
)
print(f"Random neighbor correlations: {rn_correlations}")

# Example 3: Compute mean correlation for nearest neighbors
mean_nn_corr = local_sigs.compute_measure_for_nearest_neighbors(
    cell_idx=cell_idx,
    n_neighbors=n_neighbors,
    measure=compute_mean_neighbor_correlation  # Returns scalar
)
print(f"Mean nearest neighbor correlation: {mean_nn_corr}")

# Example 4: Compute for all cells
all_nn_mean_corrs = local_sigs.compute_measure_for_all_cells(
    n_neighbors=10,
    measure=compute_mean_neighbor_correlation,
    neighbor_type='nearest'
)
all_rn_mean_corrs = local_sigs.compute_measure_for_all_cells(
    n_neighbors=10,
    measure=compute_mean_neighbor_correlation,
    neighbor_type='random'
)

print(f"Mean NN correlation across all cells: {np.mean(all_nn_mean_corrs)}")
print(f"Mean RN correlation across all cells: {np.mean(all_rn_mean_corrs)}")
Key Benefits
Single Measure Function: compute_neighbor_correlations replaces both compute_nearest_neighbor_corr_roi and compute_random_neighbor_corr_roi
Flexibility: Easy to add new measures (mean, median, max, custom functions)
Clean Separation: The class handles neighbor selection, measure functions handle computation
Reusable: Measure functions can be used independently or with the class
Type Safety: Clear callable signatures with proper type hints
Would you like me to help you implement this refactoring in your notebook?
