In [3]:
%reload_ext autoreload
%autoreload 2

import joblib
from tqdm import tqdm
import numpy as np
from matplotlib import pyplot as plt

from vrAnalysis2.database import get_database
from vrAnalysis2.helpers import Timer, print_all_keys
from vrAnalysis2.sessions import create_b2session
from vrAnalysis2.processors.spkmaps import SpkmapProcessor
from vrAnalysis2.tracking import Tracker
from vrAnalysis2 import files
from vrAnalysis2 import helpers
from syd import Viewer

sessiondb = get_database("vrSessions")
mousedb = get_database("vrMice")
tracked_mice = mousedb.get_table(tracked=True)["mouseName"].unique()

In [None]:
class TrackerViewer(Viewer):
    def __init__(self, tracked_mice: list[str]):
        self.tracked_mice = list(tracked_mice)
        self.add_selection("mouse", value=self.tracked_mice[0], options=self.tracked_mice)
        self.add_boolean("distribution", value=True)
        self._trackers = {mouse: None for mouse in self.tracked_mice}
        self._mouse_data = {mouse: None for mouse in self.tracked_mice}

    def get_tracker(self, mouse: str) -> Tracker:
        if self._trackers[mouse] is None:
            self._trackers[mouse] = Tracker(mouse)
        return self._trackers[mouse]
    
    def get_mouse_data(self, mouse: str):
        if self._mouse_data[mouse] is not None:
            return self._mouse_data[mouse]
        
        tracker = self.get_tracker(mouse)
        num_clusters = tracker.cluster_silhouettes.shape[0]
        num_sessions = len(tracker.sessions)
        cluster_in_session = np.zeros((num_clusters, num_sessions), dtype=bool)
        red_assignment = np.full((num_clusters, num_sessions), np.nan)
        cluster_silhouettes = tracker.cluster_silhouettes

        for isession, session in enumerate(tracker.sessions):
            idx_valid_labels = tracker.labels[isession] != -1
            valid_labels = tracker.labels[isession][idx_valid_labels]
            cluster_in_session[valid_labels, isession] = True
            red_values = session.get_red_idx()
            red_assignment[valid_labels, isession] = red_values[idx_valid_labels]

        self._mouse_data[mouse] = (cluster_in_session, red_assignment, cluster_silhouettes)
        return self._mouse_data[mouse]

    def plot(self, state):
        cluster_in_session, red_assignment, cluster_silhouettes = self.get_mouse_data(state["mouse"])

        num_sessions = np.sum(cluster_in_session, axis=1)

        any_red = np.any(red_assignment == 1, axis=1)
        red_assignment = red_assignment[any_red]
        num_sessions_red = num_sessions[any_red]
        cluster_silhouettes = cluster_silhouettes[any_red]
        fraction_assignment = []
        for inumses in range(2, num_sessions_red.max() + 1):
            itracked = num_sessions_red == inumses
            fraction = np.sum(red_assignment[itracked] == 1, axis=1) / inumses
            fraction_assignment.append(fraction)

        fraction_sessions_red = np.sum(red_assignment == 1, axis=1) / num_sessions_red

        fig, ax = plt.subplots(1, 3, figsize=(8, 3), layout="constrained")
        ax[0].hist(num_sessions, bins=np.arange(0, num_sessions.max() + 1) + 0.5, rwidth=0.9)
        ax[0].set_xlabel("Number of sessions per cluster")
        ax[0].set_ylabel("Number of clusters")

        if state["distribution"]:
            ax[1].violinplot(fraction_assignment, range(2, num_sessions_red.max() + 1), showmeans=False, showextrema=False)
        else:
            ax[1].plot(range(2, num_sessions_red.max() + 1), [np.mean(frac) for frac in fraction_assignment])
        ax[1].set_xlabel("Number of sessions per cluster")
        ax[1].set_ylabel("Fraction of sessions per cluster")


        fraction_bins = np.linspace(0, 1, 10)
        idx_fraction = np.digitize(fraction_sessions_red, fraction_bins, right=True)
        bins = np.linspace(-1, 1, 9)
        counts = []
        for ifraction in range(1, len(fraction_bins)):
            idx = np.where(idx_fraction == ifraction)[0]
            ccounts = np.histogram(cluster_silhouettes[idx], bins=bins)[0]
            counts.append(ccounts / np.sum(ccounts))
        ax[2].imshow(np.stack(counts, axis=0), cmap="plasma", aspect="auto", interpolation="none", origin="lower", extent=[bins[0], bins[-1], 0, 1])
        ax[2].set_xlabel("Cluster silhouette")
        ax[2].set_ylabel("Fraction of sessions per cluster")

        return fig
    
viewer = TrackerViewer(tracked_mice).deploy(env="notebook")


HBox(children=(VBox(children=(VBox(children=(HTML(value='<b>Parameters</b>'), Dropdown(description='mouse', la…

  counts.append(ccounts / np.sum(ccounts))


In [8]:
# NOTES FOR TOMORROW
# Gotta include the red cell in the tracking (and look at cluster quality stuffs etc)
# Maybe consider only some clusters and also ignore the red cell inclusion based on sample silhouette?

In [None]:
from vrAnalysis2.tracking import Tracker
tracker = Tracker("ATL022")

In [6]:
print(print_all_keys(tracker.rundata[0], ignore_keys=["_system_info"]))

data : dict
    _verbose : bool
    type : dict
        __module__ : str
        __doc__ : str
        __init__ : dict
        import_FOV_images : dict
        import_spatialFootprints : dict
        import_neuropil_masks : dict
        _make_shifts : dict
        _transform_statFile_to_spatialFootprints : dict
            __module__ : str
            __name__ : str
            __qualname__ : str
            __doc__ : str
            __annotations__ : dict
                frame_height_width : dict
                    __repr__ : str
                stat : dict
                    __new__ : builtin_function_or_method
                    __repr__ : wrapper_descriptor
                    __str__ : wrapper_descriptor
                    __lt__ : wrapper_descriptor
                    __le__ : wrapper_descriptor
                    __eq__ : wrapper_descriptor
                    __ne__ : wrapper_descriptor
                    __gt__ : wrapper_descriptor
                    __ge__ : wrapper_d