In [1]:
%reload_ext autoreload
%autoreload 2

import random
import numpy as np
from matplotlib import pyplot as plt
from syd import make_viewer

from vrAnalysis.database import get_database
from vrAnalysis.sessions import B2Session
from vrAnalysis.processors import SpkmapProcessor
from dimilibi import Population

sessiondb = get_database("vrSessions")

from dimensionality_manuscript.regression_models.registry import PopulationRegistry

ModuleNotFoundError: No module named 'dimensionality_manuscript'

In [None]:
"""
Let's think about what I need here.

1. All the regression models need to be able to compare the same target neurons, with the same cross-validated timepoints. For the RBF(POS) model, I actually need train/train/val/test splits. 
So, I'd like a "registry" which will create a "Population" object with the correct time splits 
for all sessions and save it somewhere. Then, I can just load it from the registry.
- So I'll save the population data in a designated folder, under a repr of the session name. 
- Each time I try to analyze a session, I'll check the registry to see if the population data exists,
if not I'll make a new one and save it. 
"""

print('bad docstring! no docstrign!')

bad docstring! no docstrign!


In [67]:
from vrAnalysis import files

files.local_data_path()

WindowsPath('D:/localData')

In [None]:
from dataclasses import dataclass, field
from pathlib import Path
from joblib import dump, load

manuscript_path = files.local_data_path() / "dimensionality-manuscript"
cache_path = manuscript_path / "cache"
registry_path = cache_path / "population-registry"

class TimeSplit:
    train_0 = 0
    train_1 = 1
    val = 2
    test = 3
    train = (0, 1)

@dataclass(frozen=True)
class RegistryParameters:
    time_split_groups: int = 4
    time_split_relative_size: tuple[int] = field(default_factory=lambda: (4, 4, 1, 1))
    time_split_chunks_per_group: int = 10
    cell_split_force_even: bool = False

    @property
    def time_split_prms(self):
        return dict(
            num_groups=self.time_split_groups,
            relative_size=self.time_split_relative_size,
            chunks_per_group=self.time_split_chunks_per_group,
        )

    @property
    def cell_split_prms(self):
        return dict(
            force_even=self.cell_split_force_even,
        )

    @property
    def params_hash(self):
        return hash(self)

class PopulationRegistry:
    def __init__(self, registry_path: Path, registry_params: RegistryParameters):
        self.registry_path = registry_path
        self.registry_params = registry_params

    def get_population(self, session: B2Session) -> Population:
        """Get the population object for a session.

        If the population object already exists, it will be loaded from the registry.
        Otherwise, a new population object will be created and saved to the registry.

        Parameters
        ----------
        session : B2Session
            The session to get the population for.

        Returns
        -------
        npop : Population
            The population object for the session.
        """
        if self._check_population_exists(session):
            return self._load_population(session)
        else:
            npop = self._make_population(session)
            self._save_population(session, npop)
            return npop

    def _make_population(self, session: B2Session) -> Population:
        """Create a new population object for a session.

        Parameters
        ----------
        session : B2Session
            The session to create a population for.

        Returns
        -------
        npop : Population
            The population object for the session.
        """
        npop = Population(session.spks.T, time_split_prms=self.registry_params.time_split_prms)
        return npop

    def _save_population(self, session: B2Session, population: Population) -> None:
        """Save a population object to the registry.

        Uses the population indices dictionary to save a lightweight summary of the population.

        Parameters
        ----------
        session : B2Session
            The session to save the population for.
        population : Population
            The population object to save.
        """
        indices_dict = population.get_indices_dict()
        ppath = self._get_population_path(session)
        dump(indices_dict, ppath)

    def _load_population(self, session: B2Session) -> Population:
        """Load a population object from the registry.

        Parameters
        ----------
        session : B2Session
            The session to load the population for.

        Returns
        -------
        population : Population
            The population object loaded from the registry.
        """
        ppath = self._get_population_path(session)
        indices_dict = load(ppath)
        return Population.make_from_indices(indices_dict, session.spks.T)

    def _check_population_exists(self, session: B2Session):
        """Check if a population exists for a session.

        Parameters
        ----------
        session : B2Session
            The session to check for a population.

        Returns
        -------
        bool
            True if the population exists, False otherwise.
        """
        return self._get_population_path(session).exists()

    def _get_population_path(self, session: B2Session):
        """Get the path to the population file for a session.

        Parameters
        ----------
        session : B2Session
            The session to get the population path for.

        Returns
        ------- 
        pathlib.Path
            The path to the population file for the session.
        """
        return self.registry_path / (self._get_unique_id(session) + ".joblib")

    def _get_unique_id(self, session: B2Session):
        """Get a unique identifier for a population.

        This includes the session name and a hash of the registry parameters.
        This way each population that is stored is uniquely identified by the
        session and the way the population timesplits and cellsplits are done.

        Parameters
        ----------
        session : B2Session
            The session to get the unique id for.

        Returns
        -------
        str
            The unique id for the population.
        """
        session_name = ".".join(session.session_name)
        params_hash = hash(RegistryParameters())
        return f"{session_name}_{params_hash}"

3907744549172530631

In [75]:
ses.session_print(joinby=".")

'ATL020.2023-04-06.701'

In [17]:
ises = sessiondb.iter_sessions(imaging=True)
ses = random.choice(ises)
print(ses)

B2Session(mouse_name='ATL020', date='2023-04-06', session_id='701')


In [72]:
time_split_prms = dict(
    num_groups=4,
    relative_size=[4, 4, 1, 1],
    chunks_per_group=10,
)
npop = Population(ses.spks.T, time_split_prms=time_split_prms)

In [43]:
source, target = npop.get_split_data(time_idx=[0, 1])
print(source.shape, target.shape)

torch.Size([9018, 10574]) torch.Size([9017, 10574])


In [44]:
print(npop.data.shape[1])
ss = []
for i in range(4):
    source = npop.get_split_data(i)[0]
    ss.append(source.shape[1])
print(ss)
npop.data.shape[1] - sum(ss)

14206
[5287, 5287, 1322, 1320]


990