In [1]:
%matplotlib inline
%reload_ext autoreload

In [2]:
%conda list replay_trajectory_classification

# packages in environment at /home/edeno/miniconda3/envs/pose_analysis:
#
# Name                    Version                   Build  Channel
replay_trajectory_classification 0.9.11.dev0        pyh7b7c402_0    edeno

Note: you may need to restart the kernel to use updated packages.


In [3]:
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import seaborn as sns
import logging

import logging
import sys

def setup_logging(epoch_key, date_format='%d-%b-%y %H:%M:%S', format='%(asctime)s %(message)s'):
    animal, day, epoch = epoch_key
    log_filename = f"{animal}_{day:02d}_{epoch:02d}_test_speed_cpu.log"

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(fmt=format, datefmt=date_format)

    stdout_handler = logging.StreamHandler(sys.stdout)
    stdout_handler.setLevel(logging.INFO)
    stdout_handler.setFormatter(formatter)

    file_handler = logging.FileHandler(log_filename)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)

    logger.addHandler(file_handler)
    logger.addHandler(stdout_handler)

sns.set_context("talk")

In [4]:
epoch_key = 'Jaq', 3, 12

In [5]:
setup_logging(epoch_key)

In [6]:
from src.load_data import load_data

logging.info('CPU Dask')
data = load_data(epoch_key,
                 position_to_linearize=['nose_x', 'nose_y'],
                 max_distance_from_well=5,
                 min_distance_traveled=30)

03-Oct-21 11:10:32 CPU Dask
03-Oct-21 11:10:32 Loading position info...
03-Oct-21 11:11:05 Loading multiunit...
03-Oct-21 11:11:30 Loading spikes...
03-Oct-21 11:12:57 Finding ripple times...


#### CPU

In [18]:
from replay_trajectory_classification import ClusterlessClassifier
from src.parameters import WTRACK_EDGE_ORDER, WTRACK_EDGE_SPACING
import pprint

continuous_transition_types = (
    [['random_walk', 'uniform'],
     ['uniform',     'uniform']])


clusterless_algorithm = 'multiunit_likelihood_integer'
clusterless_algorithm_params = {
    'mark_std': 20.0,
    'position_std': 8.0,
    'chunks': (100, 4),
}

classifier_parameters = {
    'movement_var': 6.0,
    'replay_speed': 1,
    'place_bin_size': 2.5,
    'continuous_transition_types': continuous_transition_types,
    'discrete_transition_diag': 0.968,
    'clusterless_algorithm': clusterless_algorithm,
    'clusterless_algorithm_params': clusterless_algorithm_params
}

logging.info(pprint.pprint(classifier_parameters))

{'clusterless_algorithm': 'multiunit_likelihood_integer',
 'clusterless_algorithm_params': {'chunks': (100, 4),
                                  'mark_std': 20.0,
                                  'position_std': 8.0},
 'continuous_transition_types': [['random_walk', 'uniform'],
                                 ['uniform', 'uniform']],
 'discrete_transition_diag': 0.968,
 'movement_var': 6.0,
 'place_bin_size': 2.5,
 'replay_speed': 1}
03-Oct-21 20:54:03 None


In [19]:
from dask.distributed import Client

# for virga
client = Client(n_workers=8, threads_per_worker=8, processes=True)

client

0,1
Client  Scheduler: tcp://127.0.0.1:40981  Dashboard: http://127.0.0.1:37159/status,Cluster  Workers: 8  Cores: 64  Memory: 1.62 TB


In [21]:
logging.info(client)

03-Oct-21 20:54:39 <Client: 'tcp://127.0.0.1:40981' processes=8 threads=64, memory=1.62 TB>


In [None]:
state_names = ['Continuous', 'Fragmented']

classifier = ClusterlessClassifier(**classifier_parameters)
classifier.fit(
    position=data["position_info"].linear_position,
    multiunits=data["multiunits"],
    track_graph=data["track_graph"],
    edge_order=WTRACK_EDGE_ORDER,
    edge_spacing=WTRACK_EDGE_SPACING,
)

results = classifier.predict(
    data["multiunits"],
    time=data["position_info"].index / np.timedelta64(1, "s"),
    state_names=state_names,
    use_gpu=False
)
logging.info('Done...')

03-Oct-21 20:54:48 Fitting initial conditions...
03-Oct-21 20:54:48 Fitting state transition...
03-Oct-21 20:54:49 Fitting multiunits...
03-Oct-21 20:55:01 Estimating likelihood...


In [None]:
import matplotlib.pyplot as plot

time_slice = slice(100_000, 150_000)

(results
 .acausal_posterior
 .sum('state')
 .isel(time=time_slice)
 .plot(x='time', y='position', robust=True, size=10, aspect=2, cmap='bone_r'))

plt.scatter(data['position_info'].iloc[time_slice].index / np.timedelta64(1, 's'),
         data['position_info'].iloc[time_slice].linear_position,
         color='magenta', s=1)

In [11]:
classifier

ClusterlessClassifier(clusterless_algorithm='multiunit_likelihood_integer',
                      clusterless_algorithm_params={'chunks': 1000,
                                                    'mark_std': 20.0,
                                                    'position_std': 8.0},
                      continuous_transition_types=[['random_walk', 'uniform'],
                                                   ['uniform', 'uniform']],
                      discrete_transition_diag=0.968,
                      discrete_transition_type='strong_diagonal',
                      infer_track_interior=True,
                      initial_conditions_type='uniform_on_track',
                      movement_var=6.0, place_bin_size=2.5, position_range=None,
                      replay_speed=1)

In [12]:
np.arange(10, dtype=np.int64).nbytes

80