### Imports

In [7]:
import os
import sys

import caiman as cm
from pathlib import Path
from caiman.source_extraction.cnmf.cnmf import load_CNMF
import pickle
import numpy as np
import matplotlib.pyplot as plt
import glob
from ipywidgets import interact, widgets, interact_manual, HBox, Label

### Loading data

In [9]:
name = "" # Insert name of folder
data_path = Path(f'/vol/battaglialab/imaging1/{name}/preprocessing')
sessions = { '456225_Freddy': ['20240501']}

# Path to file with neuron location data
data_path= Path(f'/vol/battaglialab/imaging1/{name}/preprocessing/456225_Freddy/20240501')
cnmf_file = data_path.joinpath('cnmf.hdf5')
cnmf = load_CNMF(cnmf_file)
ests = cnmf.estimates


metrics_file = data_path.joinpath('metrics.pickle')
file = open(metrics_file,'rb') 
metrics= pickle.load(file)

### Multiple neuron locator

In [10]:
from matplotlib.patches import Circle
from sklearn.cluster import KMeans

def plot_neuron_location_with_clusters(neuron_input, n_clusters=3):
    # Parse the neuron input
    neuron_input = neuron_input.split(',')
    neuron_input = [int(x.strip()) for x in neuron_input]
    
    good_footprints = ests.A[:, ests.idx_components].toarray()
    combined_footprints = good_footprints.sum(axis=1).reshape(ests.dims)
    
    plt.figure(figsize=(8, 8))
    cm = plt.get_cmap('tab20')

    # Highlight neuron outlines
    handles = []
    centroids = []
    for i, neuron in enumerate(neuron_input):
        highlight_footprint = good_footprints[:, neuron].reshape(ests.dims)
        # Add contour for neuron
        plt.contour(highlight_footprint, levels=1, colors=cm(1. * i / len(neuron_input)), linewidths=2, alpha=1)
        # Create a legend handle for the neuron
        handle = plt.Line2D([], [], color=cm(1. * i / len(neuron_input)), lw=2, label=f'Neuron {neuron}')
        handles.append(handle)

        # Calculate centroid of the neuron
        y_coords, x_coords = np.indices(highlight_footprint.shape)
        total_intensity = highlight_footprint.sum()
        if total_intensity > 0:
            centroid_x = (x_coords * highlight_footprint).sum() / total_intensity
            centroid_y = (y_coords * highlight_footprint).sum() / total_intensity
            centroids.append((centroid_x, centroid_y))
    
    # Convert centroids to an array for clustering
    centroids_array = np.array(centroids)
    
    if (len(centroids_array) > 0) & (n_clusters<len(centroids_array)):
        # Perform K-means clustering
        kmeans = KMeans(n_clusters=n_clusters)
        kmeans.fit(centroids_array)
        cluster_labels = kmeans.labels_
        cluster_centers = kmeans.cluster_centers_

        # Plot clusters and their circles
        for i, center in enumerate(cluster_centers):
            # Calculate spread (radius) for each cluster
            cluster_distances = np.linalg.norm(centroids_array[cluster_labels == i] - center, axis=1)
            cluster_radius = cluster_distances.max()  # Or np.std(cluster_distances)
            circle = Circle(center, cluster_radius, color=plt.cm.tab20(i / n_clusters), fill=False, linestyle='--', lw=2, label=f'Cluster {i+1}')
            plt.gca().add_patch(circle)
            plt.scatter(*center, color=plt.cm.tab20(i / n_clusters), s=100, marker='x', label=f'Cluster {i+1} Center')
    else:
        print("For clusters, enter more neuron samples than the selected amount of clusters")

    # Display combined footprint and legend
    plt.imshow(combined_footprints, cmap=plt.cm.gnuplot2, alpha=0.9)
    plt.legend(handles=handles, loc='upper right', bbox_to_anchor=(1.25, 1.01))
    plt.title(f'Neuron Locations with {n_clusters} clusters')
    plt.tight_layout()
    plt.show()

# Interactive widget to select neurons and specify number of clusters
max_neuron_index = len(ests.idx_components) - 1

neuron_input = widgets.Text(
    value='',
    placeholder='Enter the neurons to locate',
    description='Neurons: ',
    disabled=False
)

widgets.interact_manual.opts['manual_name'] = 'Locate neurons'
interact_manual(
    plot_neuron_location_with_clusters, 
    neuron_input=neuron_input, 
    n_clusters=widgets.IntSlider(min=1, max=10, value=3)
);


interactive(children=(Text(value='', continuous_update=False, description='Neurons: ', placeholder='Enter the …