In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import seaborn as sns
import logging

FORMAT = '%(asctime)s %(message)s'

logging.basicConfig(level='INFO', format=FORMAT, datefmt='%d-%b-%y %H:%M:%S')
sns.set_context("talk")

In [3]:
epoch_key = ('chimi', 5, 2) # animal, day, epoch

In [4]:
from src.load_data import load_data


data = load_data(epoch_key)

17-Jan-21 16:34:13 Loading position information and linearizing...
17-Jan-21 16:34:16 Loading multiunits...
17-Jan-21 16:34:31 Loading theta...


In [8]:
from replay_trajectory_classification import ClusterlessClassifier
from sklearn.model_selection import KFold
from src.parameters import EDGE_ORDER, EDGE_SPACING

continuous_transition_types = (
    [['random_walk', 'uniform'],
     ['uniform',     'uniform']])


classifier_parameters = {
    'movement_var': 6.0,
    'replay_speed': 1,
    'place_bin_size': 2.5,
    'continuous_transition_types': continuous_transition_types,
    'discrete_transition_diag': 0.968,
    'model_kwargs': {
        'bandwidth': np.array([20.0, 20.0, 20.0, 20.0, 8.0])}, 
}


cv = KFold()

for fold_ind, (train, test) in enumerate(cv.split(data["position_info"].index)):
    logging.info(f'Fitting Fold #{fold_ind + 1}...')
    classifier = ClusterlessClassifier(**classifier_parameters)
    classifier.fit(
        position=data["position_info"].iloc[train].linear_position,
        multiunits=data["multiunits"].isel(time=train),
        track_graph=data["track_graph"],
        edge_order=EDGE_ORDER,
        edge_spacing=EDGE_SPACING,
    )

17-Jan-21 16:35:57 Fitting Fold #1...
17-Jan-21 16:35:58 Fitting initial conditions...
17-Jan-21 16:35:59 Fitting state transition...
17-Jan-21 16:35:59 Fitting multiunits...
17-Jan-21 16:36:05 Fitting Fold #2...
17-Jan-21 16:36:05 Fitting initial conditions...
17-Jan-21 16:36:05 Fitting state transition...
17-Jan-21 16:36:06 Fitting multiunits...
17-Jan-21 16:36:11 Fitting Fold #3...
17-Jan-21 16:36:12 Fitting initial conditions...
17-Jan-21 16:36:12 Fitting state transition...
17-Jan-21 16:36:13 Fitting multiunits...
17-Jan-21 16:36:18 Fitting Fold #4...
17-Jan-21 16:36:19 Fitting initial conditions...
17-Jan-21 16:36:19 Fitting state transition...
17-Jan-21 16:36:19 Fitting multiunits...
17-Jan-21 16:36:24 Fitting Fold #5...
17-Jan-21 16:36:25 Fitting initial conditions...
17-Jan-21 16:36:25 Fitting state transition...
17-Jan-21 16:36:26 Fitting multiunits...


In [9]:
classifier.save_model(f"{epoch_key[0]}_{epoch_key[1]:02d}_{epoch_key[2]:02d}_model.pkl")

In [10]:
import xarray as xr


results = xr.open_dataset('chimi_05_02_results.nc')
results

In [13]:
from trajectory_analysis_tools import (get_ahead_behind_distance,
                                       get_trajectory_data)

posterior = results.acausal_posterior.sum("state")

trajectory_data = get_trajectory_data(
    posterior,
    data['track_graph'],
    classifier,
    data['position_info'])

ahead_behind_distance = get_ahead_behind_distance(
    data['track_graph'], *trajectory_data)

In [14]:
from trajectory_analysis_tools import get_highest_posterior_threshold, get_HPD_spatial_coverage

hpd_threshold = get_highest_posterior_threshold(posterior, coverage=0.95)
spatial_coverage = get_HPD_spatial_coverage(posterior, hpd_threshold)