In [1]:
%reload_ext autoreload
%autoreload 2

from tqdm import tqdm
import numpy as np
from matplotlib import pyplot as plt
from syd import Viewer, make_viewer

from vrAnalysis2.database import get_database
from vrAnalysis2.helpers import Timer, color_violins
from vrAnalysis2.sessions import create_b2session
from vrAnalysis2.processors.spkmaps import SpkmapProcessor
from vrAnalysis2.tracking import Tracker

def make_processor(mouse_name, date, session_id, spks_type="significant"):
    session = create_b2session(mouse_name, date, session_id, dict(spks_type=spks_type))
    spkmap_processor = SpkmapProcessor(session)
    return spkmap_processor

sessiondb = get_database('vrSessions')
mousedb = get_database('vrMice')

tracked_mice = mousedb.get_table(trackerExists=True)["mouseName"].unique()

In [2]:
from typing import Optional

class ReliabilityViewer(Viewer):
    def __init__(self, tracked_mice: list[str]):
        self.tracked_mice = list(tracked_mice)
        self._reliability = {mouse: None for mouse in self.tracked_mice}
        self.trackers = {mouse: None for mouse in self.tracked_mice}
        self.add_selection("mouse", value=self.tracked_mice[0], options=self.tracked_mice)
        self.add_boolean("use_session_filters", value=False)

    def get_tracker(self, mouse: str) -> Tracker:
        if self.trackers[mouse] is None:
            self.trackers[mouse] = Tracker(mouse)
        return self.trackers[mouse]
    
    def get_environments(self, track: Tracker) -> np.ndarray:
        """Get all environments represented in tracked sessions"""
        environments = np.unique(np.concatenate([session.environments for session in track.sessions]))
        return environments

    def get_reliability(self, mouse: str, exclude_environments: Optional[list[int] | int] = [-1]) -> dict:
        """Get reliability data for all tracked sessions"""
        if self._reliability[mouse] is not None:
            return self._reliability[mouse]
        
        track = self.get_tracker(mouse)
        environments = list(self.get_environments(track))
        reliability_ctl = {env: [] for env in environments}
        reliability_red = {env: [] for env in environments}
        reliability_ctl_all = {env: [] for env in environments}
        reliability_red_all = {env: [] for env in environments}
        sessions = {env: [] for env in environments}

        for isession, session in enumerate(tqdm(track.sessions)):
            envs = session.environments
            idx_rois = session.idx_rois
            idx_red_all = session.loadone("mpciROIs.redCellIdx")
            idx_red = idx_red_all[idx_rois]

            smp = SpkmapProcessor(session)
            reliability_all = smp.get_reliability(use_session_filters=False, params=dict(smooth_width=5.0))
            reliability_selected = reliability_all[:, idx_rois]

            for ienv, env in enumerate(envs):
                reliability_red_all[env].append(reliability_all[ienv, idx_red_all])
                reliability_ctl_all[env].append(reliability_all[ienv, ~idx_red_all])
                reliability_red[env].append(reliability_selected[ienv, idx_red])
                reliability_ctl[env].append(reliability_selected[ienv, ~idx_red])
                sessions[env].append(isession)

        results = dict(
            environments=environments,
            reliability_ctl=reliability_ctl,
            reliability_red=reliability_red,
            reliability_ctl_all=reliability_ctl_all,
            reliability_red_all=reliability_red_all,
            sessions=sessions,
        )

        if exclude_environments:
            if not isinstance(exclude_environments, list):
                exclude_environments = [exclude_environments]
            for env in exclude_environments:
                for key in results:
                    if isinstance(results[key], dict):
                        results[key].pop(env)
                    else:
                        results[key] = [r for r in results[key] if r != env]
        return results
    
    def plot(self, state):
        # Gather data to plot
        use_session_filters = state["use_session_filters"]
        reliability = self.get_reliability(state["mouse"])
        environments = reliability["environments"]
        sessions = reliability["sessions"]
        if use_session_filters:
            reliability_ctl = reliability["reliability_ctl"]
            reliability_red = reliability["reliability_red"]
        else:
            reliability_ctl = reliability["reliability_ctl_all"]
            reliability_red = reliability["reliability_red_all"]
        
        figwidth = 3
        figheight = 3
        fig, ax = plt.subplots(2, len(environments), figsize=(len(environments)*figwidth, 2*figheight), layout="constrained", sharex=True, sharey=True)
        for ienv, env in enumerate(environments):
            if state["show_distribution"]:
                for ises, (ctl, red) in enumerate(zip(reliability_ctl[env], reliability_red[env])):
                    parts = ax[0, ienv].violinplot(ctl, positions=[ises], showextrema=True, side="low")
                    color_violins(parts, facecolor=("k", 0.1))
                    parts = ax[0, ienv].violinplot(red, positions=[ises], showextrema=True, side="high")
                    color_violins(parts, facecolor=("r", 0.1))
            ctl_mean = [np.mean(r) for r in reliability_ctl[env]]
            red_mean = [np.mean(r) for r in reliability_red[env]]
            ax[0, ienv].plot(sessions[env], ctl_mean, color="black", label="CTL", marker="o")
            ax[0, ienv].plot(sessions[env], red_mean, color="red", label="RED", marker="o")
            ax[0, ienv].legend()
            ax[0, ienv].set_title(env)

            ax[1, ienv].plot(sessions[env], red_mean - ctl_mean, color="blue", label="DIFF", marker="o")
            ax[1, ienv].legend()
            ax[1, ienv].set_title(f"{env} (all)")
            ax[1, ienv].set_xlabel("Session")

        ax[0, 0].set_ylabel("Reliability")
        ax[1, 0].set_ylabel("RED - CTL")
        
        return fig

In [3]:
rv = ReliabilityViewer(tracked_mice)