In [16]:
# Import the functions from the analyze_sessions_ids.py script
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import pickle



with open('/nrs/spruston/Gaby_imaging/raw/M54/multi_day_demix/registration_data/match.pkl','rb') as f:  
    [matched_cells, matched_im, template_masks, template_im] = pickle.load(f)


In [17]:
len(template_masks)

5454

In [18]:
template_masks[0]

{'id': 1,
 'ipix': array([63238, 63239, 63240, 65093, 65094, 65095, 65096, 65097, 65098,
        65099, 66948, 66949, 66950, 66951, 66952, 66953, 66954, 66955,
        66956, 68804, 68805, 68806, 68807, 68808, 68809, 68810, 68811,
        68812, 70661, 70662, 70663, 70664, 70665, 70666, 70667, 70668,
        72518, 72519, 72520, 72521, 72522, 72523, 72524, 74376, 74377,
        74378, 74379]),
 'xpix': array([134, 135, 136, 133, 134, 135, 136, 137, 138, 139, 132, 133, 134,
        135, 136, 137, 138, 139, 140, 132, 133, 134, 135, 136, 137, 138,
        139, 140, 133, 134, 135, 136, 137, 138, 139, 140, 134, 135, 136,
        137, 138, 139, 140, 136, 137, 138, 139]),
 'ypix': array([34, 34, 34, 35, 35, 35, 35, 35, 35, 35, 36, 36, 36, 36, 36, 36, 36,
        36, 36, 37, 37, 37, 37, 37, 37, 37, 37, 37, 38, 38, 38, 38, 38, 38,
        38, 38, 39, 39, 39, 39, 39, 39, 39, 40, 40, 40, 40]),
 'med': [37.0, 136.0],
 'lam': array([ 4.00372271,  4.61244836,  4.10467189,  4.86417759,  6.13918126,
 

In [12]:
# Function to analyze putative_cells structure and find cells missing specific sessions
def analyze_putative_cells(putative_cells, target_missing_sessions=[7, 8, 9], min_sessions=20):
    """
    Analyzes the putative_cells structure to find cell clusters missing specific sessions.
    
    Args:
        putative_cells: List of clusters, where each cluster is a list of cell dictionaries
        target_missing_sessions: List of session numbers to check if missing
        min_sessions: Minimum number of sessions a cell should be present in to be considered
        
    Returns:
        List of results, each containing cluster index, cell ID, and session information
    """
    results = []
    
    # Process each cluster
    for cluster_idx, cluster in enumerate(putative_cells):
        if cluster_idx % 100 == 0:
            print(f"Processing cluster {cluster_idx}/{len(putative_cells)}...")
        
        # Skip empty clusters
        if not cluster:
            continue
        
        # Get the cell ID (should be the same across the cluster)
        cell_id = cluster[0].get('id')
        if cell_id is None:
            continue
            
        # Get all sessions this cell appears in
        sessions = set()
        for cell_data in cluster:
            session = cell_data.get('session')
            if session is not None:
                sessions.add(str(session))
        
        # Skip cells that don't appear in enough sessions
        if len(sessions) < min_sessions:
            continue
            
        # Check if any target sessions are missing
        missing_sessions = [str(s) for s in target_missing_sessions if str(s) not in sessions]
        
        if missing_sessions:
            # This cell is missing at least one of our target sessions
            results.append({
                'cluster_idx': cluster_idx,
                'cell_id': cell_id,
                'sessions': sorted(sessions, key=lambda x: int(x) if x.isdigit() else x),
                'missing_targets': missing_sessions,
                'total_sessions': len(sessions),
                'all_missing': [str(s) for s in range(35) if str(s) not in sessions]  # Assuming sessions are numbered 0-34
            })
    
    # Sort by number of sessions (most first)
    return sorted(results, key=lambda x: x['total_sessions'], reverse=True)

# Function to load a pickle file
def load_pickle(file_path):
    import pickle
    with open(file_path, 'rb') as f:
        return pickle.load(f)

# Function to display putative cells analysis results
def display_putative_cells_results(results, max_display=10):
    print(f"Found {len(results)} cell clusters missing target sessions")
    
    if not results:
        return
        
    print("\nTop results (cells with most sessions but missing target sessions):")
    for i, result in enumerate(results[:max_display]):
        print(f"\nResult {i+1}:")
        print(f"Cluster index: {result['cluster_idx']}")
        print(f"Cell ID: {result['cell_id']}")
        print(f"Total sessions: {result['total_sessions']}")
        print(f"Missing target sessions: {result['missing_targets']}")
        print(f"All missing sessions: {result['all_missing']}")
        print(f"Present sessions: {result['sessions']}")
    
    # Summary of missing sessions
    missing_counts = defaultdict(int)
    for result in results:
        for session in result['missing_targets']:
            missing_counts[session] += 1
    
    print("\nSummary of missing target sessions:")
    for session, count in sorted(missing_counts.items(), key=lambda x: int(x[0]) if x[0].isdigit() else x[0]):
        print(f"Session {session}: missing in {count} results")

# To use these functions:
# 1. If you need to load the match.pkl file:

# 2. If you already have the putative_cells loaded (from info or elsewhere):
results = analyze_putative_cells(putative_cells, 
                                target_missing_sessions=[7, 8, 9], 
                                min_sessions=20)
display_putative_cells_results(results)

# 3. To directly access a specific cluster:
# if results:
#     cluster_idx = results[0]['cluster_idx']
#     cluster = putative_cells[cluster_idx]
#     print(f"First result cluster has {len(cluster)} cells across different sessions")

Processing cluster 0/5972...
Processing cluster 100/5972...
Processing cluster 200/5972...
Processing cluster 300/5972...
Processing cluster 400/5972...
Processing cluster 500/5972...
Processing cluster 600/5972...
Processing cluster 700/5972...
Processing cluster 800/5972...
Processing cluster 900/5972...
Processing cluster 1000/5972...
Processing cluster 1100/5972...
Processing cluster 1200/5972...
Processing cluster 1300/5972...
Processing cluster 1400/5972...
Processing cluster 1500/5972...
Processing cluster 1600/5972...
Processing cluster 1700/5972...
Processing cluster 1800/5972...
Processing cluster 1900/5972...
Processing cluster 2000/5972...
Processing cluster 2100/5972...
Processing cluster 2200/5972...
Processing cluster 2300/5972...
Processing cluster 2400/5972...
Processing cluster 2500/5972...
Processing cluster 2600/5972...
Processing cluster 2700/5972...
Processing cluster 2800/5972...
Processing cluster 2900/5972...
Processing cluster 3000/5972...
Processing cluster 3

In [None]:
# Import the zarr library and other necessary packages
import zarr
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

# Open the zarr file
zarr_path = '/nrs/spruston/Gaby_imaging/raw/M54/multi_day_demix/vr2p.zarr'
z = zarr.open(zarr_path, mode='r')

# Define a function to process one entry of info at a time
def analyze_id_session_in_item(item, item_index=None):
    """Process a single item to find ID-session relationships"""
    print(f"\nAnalyzing item {item_index if item_index is not None else ''}")
    
    # Create a mapping of IDs to sessions
    id_to_sessions = defaultdict(set)
    
    # Direct approach: Look for patterns where both id and session are in the same dictionary
    def find_id_session_pairs(obj, path=[]):
        if isinstance(obj, dict):
            # Check if this dictionary has both 'id' and 'session' keys
            id_val = obj.get('id')
            session_val = obj.get('session')
            
            if id_val is not None and session_val is not None:
                id_to_sessions[str(id_val)].add(str(session_val))
            
            # Recursively check all values
            for k, v in obj.items():
                find_id_session_pairs(v, path + [k])
                
        elif isinstance(obj, list):
            # Recursively check all items in the list
            for i, item in enumerate(obj):
                find_id_session_pairs(item, path + [f'[{i}]'])
    
    # Process the item
    find_id_session_pairs(item)
    
    # Display results
    if id_to_sessions:
        print(f"Found {len(id_to_sessions)} unique IDs across sessions")
        
        # Show IDs with multiple sessions
        multi_session_ids = {id_val: sessions for id_val, sessions in id_to_sessions.items() if len(sessions) > 1}
        
        if multi_session_ids:
            print(f"\n{len(multi_session_ids)} IDs appear in multiple sessions:")
            
            # Sort by number of sessions (most first)
            for id_val, sessions in sorted(multi_session_ids.items(), key=lambda x: len(x[1]), reverse=True)[:10]:
                print(f"ID {id_val}: {len(sessions)} sessions - {sorted(sessions)}")
        else:
            print("No IDs appear in multiple sessions")
    else:
        print("No ID-session relationships found in this item")
    
    return id_to_sessions

# Function to process items one at a time
def process_info_step_by_step(info):
    all_mappings = []
    
    if isinstance(info, list):
        print(f"Info contains {len(info)} items")
        
        # Process each item
        for i, item in enumerate(info):
            print(f"\n===== Processing Item {i} =====")
            mapping = analyze_id_session_in_item(item, i)
            
            if mapping:
                all_mappings.append(mapping)
                
                # Display first few IDs and their sessions
                print("\nExample IDs and their sessions:")
                for id_val, sessions in list(mapping.items())[:5]:
                    print(f"ID {id_val}: {sorted(sessions)}")
            
            # Simple progress indicator
            if i < len(info) - 1:
                print(f"\nProgress: {i+1}/{len(info)} items processed")
            
            # Uncomment to pause between items:
            # if i < len(info) - 1:
            #     input("Press Enter to continue to next item...")
    
    elif isinstance(info, dict):
        print("Info is a dictionary (single item)")
        mapping = analyze_id_session_in_item(info)
        
        if mapping:
            all_mappings.append(mapping)
    
    else:
        print(f"Info has unexpected type: {type(info)}")
    
    return all_mappings

# Use this function to analyze the loaded info variable
# First make sure info is loaded from the zarr file
# For example: info = z['path/to/info'][:]

# Then run:
# all_mappings = process_info_step_by_step(info)

# Print summary information from all mappings
def summarize_all_mappings(all_mappings):
    if not all_mappings:
        print("No mappings to summarize")
        return
    
    # Combine all mappings
    combined = defaultdict(set)
    for mapping in all_mappings:
        for id_val, sessions in mapping.items():
            combined[id_val].update(sessions)
    
    # Count sessions per ID
    session_counts = {id_val: len(sessions) for id_val, sessions in combined.items()}
    
    # Find IDs with most sessions
    top_ids = sorted(session_counts.items(), key=lambda x: x[1], reverse=True)
    
    print(f"\nSummary across all {len(all_mappings)} processed items:")
    print(f"Total unique IDs: {len(combined)}")
    
    # Count unique sessions
    all_sessions = set()
    for sessions in combined.values():
        all_sessions.update(sessions)
    print(f"Total unique sessions: {len(all_sessions)}")
    
    # Top IDs with most sessions
    print("\nTop 10 IDs by number of sessions:")
    for id_val, count in top_ids[:10]:
        sessions = combined[id_val]
        print(f"ID {id_val}: {count} sessions - {sorted(sessions)}")
    
    # Distribution
    count_distribution = defaultdict(int)
    for _, count in session_counts.items():
        count_distribution[count] += 1
    
    print("\nDistribution of session counts:")
    for count, num_ids in sorted(count_distribution.items()):
        print(f"{count} session(s): {num_ids} IDs")
    
    # Visualize
    plt.figure(figsize=(10, 6))
    counts = list(session_counts.values())
    plt.hist(counts, bins=range(1, max(counts) + 2), alpha=0.7, color='skyblue', edgecolor='black')
    plt.xlabel('Number of Sessions')
    plt.ylabel('Number of IDs')
    plt.title('Distribution of Number of Sessions per ID')
    plt.grid(axis='y', alpha=0.75)
    plt.show()

# To summarize after processing all items:
# summarize_all_mappings(all_mappings)

In [None]:
# Import the functions from the analyze_sessions_ids.py script
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import zarr

# Add the current directory to sys.path to ensure the script can be imported
current_dir = os.path.dirname(os.path.abspath(''))
if current_dir not in sys.path:
    sys.path.append(current_dir)

# Import functions from the script
from analyze_sessions_ids import find_keys, process_item, analyze_id_session_counts, visualize_results

# Open the zarr file
zarr_path = '/nrs/spruston/Gaby_imaging/raw/M54/multi_day_demix/vr2p.zarr'
z = zarr.open(zarr_path, mode='r')
info_orig = np.load('/nrs/spruston/Gaby_imaging/raw/M54/multi_day_demix/registration_data/match.pkl', allow_pickle=True)
# Example of how to use the functions:
# Assuming info is already loaded (as shown in your original message)
# If not, you can load it from the zarr file depending on its location:

# Option 1: If info is at the root level
# info = z['info'][:]

# Option 2: If info is nested inside another group
# info = z['some_group']['info'][:]
matched_cells = info_orig[0]
# Process a single item from info
if isinstance(info, list):
    print(f"Info contains {len(info)} items")
    
    # Process the first item as an example
    item_index = 1  # Change this to process different items
    if len(info) > item_index:
        print(f"Processing item {item_index}...")
        item = info[item_index]
        
        # Find ID-session relationships in this item
        mapping = process_item(item, item_index)
        
        # Analyze and display the results
        id_session_map = analyze_id_session_counts(mapping)
        
        # If you want to visualize (this would normally be done after processing all items)
        # visualize_results([id_session_map])
        
        # You can also access the mapping directly to inspect specific IDs
        if id_session_map:
            print("\nExample of direct access to the mapping:")
            
            # Get an example ID (first one in the mapping)
            example_id = next(iter(id_session_map))
            sessions = id_session_map[example_id]
            
            print(f"ID {example_id} appears in {len(sessions)} sessions: {sorted(sessions)}")
    else:
        print(f"Index {item_index} is out of range (info has {len(info)} items)")
        
elif isinstance(info, dict):
    print("Info is a dictionary (single item)")
    mapping = process_item(info)
    id_session_map = analyze_id_session_counts(mapping)
    
    # Example of direct access
    if id_session_map:
        example_id = next(iter(id_session_map))
        sessions = id_session_map[example_id]
        print(f"ID {example_id} appears in {len(sessions)} sessions: {sorted(sessions)}")
        
else:
    print(f"Info has unexpected type: {type(info)}")

Info contains 5972 items
Processing item 1...

Processing item 1
Found 28 'session' keys and 28 'id' keys

Found 1 unique IDs across multiple sessions

1 IDs appear in multiple sessions:
ID 2: 28 sessions - ['1', '10', '12', '13', '14', '15', '19', '2', '21', '22', '24', '25', '26', '27', '28', '29', '3', '30', '31', '32', '33', '34', '4', '5', '6', '7', '8', '9']

Example of direct access to the mapping:
ID 2 appears in 28 sessions: ['1', '10', '12', '13', '14', '15', '19', '2', '21', '22', '24', '25', '26', '27', '28', '29', '3', '30', '31', '32', '33', '34', '4', '5', '6', '7', '8', '9']


In [None]:
# Comprehensive analysis of sessions and IDs in the nested info structure
from collections import defaultdict
import json

# First, let's define functions to analyze the structure
def find_keys(obj, target_key, path=None, results=None):
    """Recursively search for keys in a nested dictionary/list structure"""
    if path is None:
        path = []
    if results is None:
        results = []
    
    if isinstance(obj, dict):
        for k, v in obj.items():
            if k == target_key:
                results.append((path + [k], v))
            find_keys(v, target_key, path + [k], results)
    elif isinstance(obj, list):
        for i, item in enumerate(obj):
            find_keys(item, target_key, path + [f"[{i}]"], results)
            
    return results

def find_id_session_relationships(data):
    """Find all possible relationships between IDs and sessions"""
    # Create mappings for IDs to sessions
    id_to_sessions = defaultdict(set)
    
    # Method 1: Find keys and path relationships
    session_results = find_keys(data, 'session')
    id_results = find_keys(data, 'id')
    
    print(f"Found {len(session_results)} 'session' keys")
    print(f"Found {len(id_results)} 'id' keys")
    
    # Try to find relationships based on path proximity
    for id_path, id_value in id_results:
        id_path_str = '.'.join(map(str, id_path[:-1]))
        
        # Convert to string for consistent handling
        id_value_str = str(id_value)
        
        for session_path, session_value in session_results:
            session_path_str = '.'.join(map(str, session_path[:-1]))
            session_value_str = str(session_value)
            
            # Check if they might be related (in same object or parent/child relationship)
            common_prefix = os.path.commonprefix([id_path_str, session_path_str])
            if common_prefix and len(common_prefix) > 1:  # They share a common parent
                id_to_sessions[id_value_str].add(session_value_str)
    
    # Method 2: If data has specific patterns we can directly extract
    # For example, if it's organized by session with IDs inside
    if isinstance(data, dict):
        for key, value in data.items():
            # Pattern: {'session1': {'ids': [1, 2, 3]}, 'session2': {'ids': [2, 4, 5]}}
            if isinstance(value, dict) and 'ids' in value and isinstance(value['ids'], list):
                for id_val in value['ids']:
                    id_to_sessions[str(id_val)].add(str(key))
                    
            # Pattern: {'cells': [{'id': 1, 'session': 'A'}, {'id': 2, 'session': 'B'}]}
            if key == 'cells' and isinstance(value, list):
                for cell in value:
                    if isinstance(cell, dict) and 'id' in cell and 'session' in cell:
                        id_to_sessions[str(cell['id'])].add(str(cell['session']))
    
    # Method 3: If data has components with both session and id
    components = []
    if isinstance(data, dict):
        # Try to find all individual components/items that might have both id and session
        if 'components' in data and isinstance(data['components'], list):
            components = data['components']
        elif 'cells' in data and isinstance(data['cells'], list):
            components = data['cells']
        elif 'items' in data and isinstance(data['items'], list):
            components = data['items']
        
        # Check each component for id and session
        for comp in components:
            if isinstance(comp, dict):
                id_val = None
                session_val = None
                
                # Look for id and session keys or their variations
                for key in comp:
                    if key.lower() == 'id' or key.lower().endswith('_id'):
                        id_val = comp[key]
                    if key.lower() == 'session' or key.lower().endswith('_session'):
                        session_val = comp[key]
                
                if id_val is not None and session_val is not None:
                    id_to_sessions[str(id_val)].add(str(session_val))
    
    return id_to_sessions

# Function to analyze and display the results
def analyze_id_session_counts(mapping):
    """Analyze and display counts of sessions per ID"""
    if not mapping:
        print("No ID-session relationships found.")
        return
    
    print(f"\nFound {len(mapping)} unique IDs across multiple sessions")
    
    # Count sessions per ID
    session_counts = {id_val: len(sessions) for id_val, sessions in mapping.items()}
    
    # Sort IDs by number of sessions (descending)
    sorted_ids = sorted(session_counts.items(), key=lambda x: x[1], reverse=True)
    
    # Display IDs with most sessions
    print("\nTop 10 IDs by number of sessions:")
    for id_val, count in sorted_ids[:10]:
        sessions = mapping[id_val]
        print(f"ID {id_val}: {count} sessions - {sorted(sessions)}")
    
    # Display distribution
    count_distribution = defaultdict(int)
    for _, count in session_counts.items():
        count_distribution[count] += 1
    
    print("\nDistribution of session counts:")
    for count, num_ids in sorted(count_distribution.items()):
        print(f"{count} session(s): {num_ids} IDs")
    
    # Calculate statistics
    all_counts = list(session_counts.values())
    avg_count = sum(all_counts) / len(all_counts)
    max_count = max(all_counts)
    min_count = min(all_counts)
    
    print("\nStatistics:")
    print(f"Average sessions per ID: {avg_count:.2f}")
    print(f"Maximum sessions per ID: {max_count}")
    print(f"Minimum sessions per ID: {min_count}")
    
    # Count unique sessions
    all_sessions = set()
    for sessions in mapping.values():
        all_sessions.update(sessions)
    
    print(f"Total unique sessions: {len(all_sessions)}")
    
    return sorted_ids, all_sessions

# Now let's analyze the info variable
import os  # Needed for common prefix calculations
print("Analyzing the info structure...")
id_session_mapping = find_id_session_relationships(info)
top_ids, all_sessions = analyze_id_session_counts(id_session_mapping)

# Visualize the distribution of sessions per ID
import matplotlib.pyplot as plt

if id_session_mapping:
    # Count sessions per ID
    session_counts = [len(sessions) for sessions in id_session_mapping.values()]
    
    # Create a histogram
    plt.figure(figsize=(10, 6))
    plt.hist(session_counts, bins=range(1, max(session_counts) + 2), alpha=0.7, color='skyblue', edgecolor='black')
    plt.xlabel('Number of Sessions')
    plt.ylabel('Number of IDs')
    plt.title('Distribution of Number of Sessions per ID')
    plt.grid(axis='y', alpha=0.75)
    plt.xticks(range(1, max(session_counts) + 1))
    plt.show()
    
    # Create a pie chart for distribution
    count_distribution = defaultdict(int)
    for count in session_counts:
        count_distribution[count] += 1
    
    # For pie chart, group small slices
    threshold = 0.03  # Minimum percentage to show as separate slice
    total_ids = len(session_counts)
    
    pie_data = {}
    other_count = 0
    
    for count, num_ids in count_distribution.items():
        percentage = num_ids / total_ids
        if percentage >= threshold:
            pie_data[f"{count} session(s)"] = num_ids
        else:
            other_count += num_ids
    
    if other_count > 0:
        pie_data["Other"] = other_count
    
    plt.figure(figsize=(10, 8))
    plt.pie(pie_data.values(), labels=pie_data.keys(), autopct='%1.1f%%', 
            startangle=90, shadow=True, explode=[0.05] * len(pie_data))
    plt.axis('equal')
    plt.title('Distribution of IDs by Number of Sessions')
    plt.show()

In [None]:
# Alternative approach: Directly mapping IDs to sessions
# This is especially useful if your data structure has a clear pattern

def extract_id_session_mapping(data):
    """
    Extract a mapping from IDs to sessions based on the specific structure of your data.
    This function needs to be customized based on your actual data structure.
    """
    id_to_sessions = defaultdict(set)
    
    # Assuming the structure might be something like:
    # data = {
    #    'sessions': [
    #        {'id': 123, 'session': 'A'},
    #        {'id': 123, 'session': 'B'},
    #        {'id': 456, 'session': 'A'},
    #        ...
    #    ]
    # }
    
    # Option 1: If data has a list of dictionaries with both 'id' and 'session' keys
    if isinstance(data, dict) and 'sessions' in data and isinstance(data['sessions'], list):
        for item in data['sessions']:
            if 'id' in item and 'session' in item:
                id_to_sessions[item['id']].add(item['session'])
    
    # Option 2: If data is structured with sessions as keys and lists of IDs as values
    # e.g., {'session1': [id1, id2], 'session2': [id2, id3]}
    elif isinstance(data, dict):
        for session, ids in data.items():
            if isinstance(ids, list):
                for id_val in ids:
                    id_to_sessions[id_val].add(session)
    
    # Option 3: For more complex structures, we might need to navigate through the hierarchy
    # Printing the first few levels of your data structure might help identify patterns
    else:
        print("Structure not recognized. Printing data structure:")
        if isinstance(data, dict):
            for key, value in list(data.items())[:5]:  # Print first 5 items
                print(f"{key}: {type(value)}")
                if isinstance(value, (dict, list)) and len(str(value)) > 100:
                    print(f"  Sample: {str(value)[:100]}...")
                else:
                    print(f"  Value: {value}")
        elif isinstance(data, list):
            for i, item in enumerate(data[:5]):  # Print first 5 items
                print(f"[{i}]: {type(item)}")
                if isinstance(item, (dict, list)) and len(str(item)) > 100:
                    print(f"  Sample: {str(item)[:100]}...")
                else:
                    print(f"  Value: {item}")
        else:
            print(f"Data type: {type(data)}")
    
    return id_to_sessions

# Try the direct mapping approach
print("Trying direct mapping approach...")
direct_mapping = extract_id_session_mapping(info)

if direct_mapping:
    print(f"\nFound mapping with {len(direct_mapping)} unique IDs")
    
    # Display IDs with multiple sessions
    multi_session_ids = {id_val: sessions for id_val, sessions in direct_mapping.items() if len(sessions) > 1}
    print(f"IDs in multiple sessions: {len(multi_session_ids)}")
    
    if multi_session_ids:
        print("\nTop 10 IDs with most sessions:")
        for id_val, sessions in sorted(multi_session_ids.items(), key=lambda x: len(x[1]), reverse=True)[:10]:
            print(f"ID {id_val}: {len(sessions)} sessions - {sorted(sessions)}")
else:
    print("No mapping found with direct approach. The data structure might need more specific handling.")
    
    # Let's print more detailed information about the structure to help customize the code
    print("\nAttempting to identify the structure pattern...")
    
    # Recursive function to show the structure pattern with a limited depth
    def show_structure_pattern(obj, depth=0, max_depth=3):
        prefix = "  " * depth
        if depth >= max_depth:
            return f"{prefix}..."
        
        if isinstance(obj, dict):
            result = f"{prefix}dict with keys: {list(obj.keys())}"
            if depth < max_depth - 1 and obj:
                sample_key = next(iter(obj))
                result += f"\n{prefix}Sample for key '{sample_key}':\n"
                result += show_structure_pattern(obj[sample_key], depth + 1, max_depth)
            return result
        elif isinstance(obj, list):
            result = f"{prefix}list with {len(obj)} items"
            if depth < max_depth - 1 and obj:
                result += f"\n{prefix}Sample for first item:\n"
                result += show_structure_pattern(obj[0], depth + 1, max_depth)
            return result
        else:
            return f"{prefix}value of type {type(obj).__name__}"
    
    # Show the structure pattern of the info variable
    print(show_structure_pattern(info))

In [None]:
# Analyze the nested info structure to find sessions and IDs
from collections import defaultdict

# Function to recursively search for keys in a nested dictionary
def find_keys(obj, target_key, path=None, results=None):
    if path is None:
        path = []
    if results is None:
        results = []
    
    if isinstance(obj, dict):
        for k, v in obj.items():
            if k == target_key:
                results.append((path + [k], v))
            find_keys(v, target_key, path + [k], results)
    elif isinstance(obj, list):
        for i, item in enumerate(obj):
            find_keys(item, target_key, path + [f"[{i}]"], results)
            
    return results

# Find all 'session' keys and values
session_results = find_keys(info, 'session')
print(f"Found {len(session_results)} 'session' keys")

# Find all 'id' keys and values
id_results = find_keys(info, 'id')
print(f"Found {len(id_results)} 'id' keys")

# Create a mapping from ID to sessions
id_to_sessions = defaultdict(set)

# Extract session identifiers and IDs
# This approach depends on how your data is structured
# We'll need to extract the session information and match it with IDs
# This is a generalized approach that might need adjustment based on your actual data structure

# Option 1: If IDs and sessions are at the same level in the hierarchy
for path, session_value in session_results:
    # Try to find an ID in the same parent dictionary
    parent_path = path[:-1]  # Get the parent path
    parent_str = '.'.join(map(str, parent_path))
    
    for id_path, id_value in id_results:
        id_parent_path = id_path[:-1]
        id_parent_str = '.'.join(map(str, id_parent_path))
        
        # If they share the same parent, link them
        if parent_str == id_parent_str:
            id_to_sessions[id_value].add(session_value)

# Option 2: If the structure is more complex, we might need a different approach
# For example, if sessions contain IDs or vice versa
for path, session_value in session_results:
    # Try to extract session identifier - this depends on your data structure
    session_id = str(session_value)  # Adjust based on your data
    
    # Look through all paths leading to this session
    session_parent = '.'.join(map(str, path[:-1]))
    
    # Find IDs that might be related to this session
    for id_path, id_value in id_results:
        id_parent = '.'.join(map(str, id_path[:-1]))
        
        # Check if the ID is within the session structure or vice versa
        # This is a simplistic check - you might need to adjust based on your data
        if session_parent in id_parent or id_parent in session_parent:
            id_to_sessions[id_value].add(session_id)

# Display the results
print("\nNumber of sessions per ID:")
for id_val, sessions in id_to_sessions.items():
    print(f"ID {id_val}: {len(sessions)} sessions - {sorted(sessions)}")

# Show IDs with multiple sessions
print("\nIDs appearing in multiple sessions:")
multi_session_ids = {id_val: sessions for id_val, sessions in id_to_sessions.items() if len(sessions) > 1}
for id_val, sessions in sorted(multi_session_ids.items(), key=lambda x: len(x[1]), reverse=True):
    print(f"ID {id_val}: {len(sessions)} sessions - {sorted(sessions)}")

# Calculate statistics
if id_to_sessions:
    session_counts = [len(sessions) for sessions in id_to_sessions.values()]
    avg_sessions = sum(session_counts) / len(session_counts)
    max_sessions = max(session_counts)
    min_sessions = min(session_counts)
    
    print(f"\nStatistics:")
    print(f"Average sessions per ID: {avg_sessions:.2f}")
    print(f"Maximum sessions per ID: {max_sessions}")
    print(f"Minimum sessions per ID: {min_sessions}")
    print(f"Total unique IDs: {len(id_to_sessions)}")
    
    # Count IDs by number of sessions
    session_count_distribution = defaultdict(int)
    for count in session_counts:
        session_count_distribution[count] += 1
    
    print("\nDistribution of IDs by number of sessions:")
    for count, num_ids in sorted(session_count_distribution.items()):
        print(f"{count} session(s): {num_ids} IDs")
else:
    print("No ID-session relationships found with the current approach.")
    print("You may need to adjust the code based on the specific structure of your data.")

# Exploring Zarr File Contents

This notebook will open and explore the zarr file at `/nrs/spruston/Gaby_imaging/raw/M54/multi_day_demix/vr2p.zarr`

In [None]:
# Import necessary libraries
import zarr
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# Open the zarr file
zarr_path = '/nrs/spruston/Gaby_imaging/raw/M54/multi_day_demix/vr2p.zarr'
z = zarr.open(zarr_path, mode='r')

# Display the root group
print("Root group structure:")
print(z)

In [None]:
# List all groups and arrays in the zarr file
def explore_zarr(group, path=''):
    """Recursively explore zarr group structure"""
    for key in group.keys():
        item_path = f"{path}/{key}"
        item = group[key]
        
        if isinstance(item, zarr.Group):
            print(f"Group: {item_path}")
            explore_zarr(item, item_path)
        elif isinstance(item, zarr.Array):
            print(f"Array: {item_path}, Shape: {item.shape}, Dtype: {item.dtype}, Chunks: {item.chunks}")

print("Full zarr structure:")
explore_zarr(z)

In [None]:
# Get attributes if available
print("\nRoot attributes:")
try:
    for key, value in z.attrs.items():
        print(f"{key}: {value}")
except AttributeError:
    print("No attributes found at root level")

In [None]:
# Additional example: Examining a specific chunk of data
"""
# If you have a large dataset, you might want to examine specific chunks
# array_path = 'example_array_path'  # Replace with actual path
# array = z[array_path]

# # Determine chunk size and shape
# print(f"Chunk shape: {array.chunks}")

# # Access a specific chunk (for a 2D array)
# if len(array.shape) == 2:
#     # Get a specific chunk - adjust indices based on your data
#     chunk_data = array[0:array.chunks[0], 0:array.chunks[1]]
#     
#     plt.figure(figsize=(10, 8))
#     plt.imshow(chunk_data)
#     plt.colorbar()
#     plt.title(f'First chunk of {array_path}')
#     plt.show()
"""