In [1]:
%matplotlib inline
%reload_ext autoreload

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import seaborn as sns
import logging

import logging
import sys

def setup_logging(epoch_key, date_format='%d-%b-%y %H:%M:%S', format='%(asctime)s %(message)s'):
    animal, day, epoch = epoch_key
    log_filename = f"{animal}_{day:02d}_{epoch:02d}_test_speed_cpu.log"

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(fmt=format, datefmt=date_format)

    stdout_handler = logging.StreamHandler(sys.stdout)
    stdout_handler.setLevel(logging.INFO)
    stdout_handler.setFormatter(formatter)

    file_handler = logging.FileHandler(log_filename)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)

    logger.addHandler(file_handler)
    logger.addHandler(stdout_handler)

sns.set_context("talk")

In [3]:
epoch_key = 'Jaq', 3, 12

In [4]:
setup_logging(epoch_key)
logging.info('Numba KDE')

In [5]:
from dask.distributed import Client

# for virga
client = Client(n_workers=16, threads_per_worker=6, processes=True)

client

0,1
Client  Scheduler: tcp://127.0.0.1:35589  Dashboard: http://127.0.0.1:8787/status,Cluster  Workers: 16  Cores: 96  Memory: 2.16 TB


In [None]:
logging.info(client)

In [6]:
from src.load_data import load_data


data = load_data(epoch_key,
                 position_to_linearize=['nose_x', 'nose_y'],
                 max_distance_from_well=5,
                 min_distance_traveled=30)

19-Sep-21 16:20:57 Loading position info...
19-Sep-21 16:21:34 Loading multiunit...
19-Sep-21 16:22:06 Loading spikes...
19-Sep-21 16:23:55 Finding ripple times...


In [7]:
client.dashboard_link

'http://127.0.0.1:8787/status'

### Numba KDE

In [8]:
from replay_trajectory_classification.misc import NumbaKDE
from replay_trajectory_classification import ClusterlessClassifier
from src.parameters import WTRACK_EDGE_ORDER, WTRACK_EDGE_SPACING

import pprint

continuous_transition_types = (
    [['random_walk', 'uniform'],
     ['uniform',     'uniform']])


clusterless_algorithm = 'multiunit_likelihood'
clusterless_algorithm_params = {
    'model': NumbaKDE,
    'model_kwargs': {
         'bandwidth': np.array([20.0, 20.0, 20.0, 20.0, 8.0]) # amplitude 1, amplitude 2, amplitude 3, amplitude 4, position       
    }
}

classifier_parameters = {
    'movement_var': 6.0,
    'replay_speed': 1,
    'place_bin_size': 2.5,
    'continuous_transition_types': continuous_transition_types,
    'discrete_transition_diag': 0.968,
    'clusterless_algorithm': clusterless_algorithm,
    'clusterless_algorithm_params': clusterless_algorithm_params
}

logging.info(pprint.pprint(classifier_parameters))

19-Sep-21 16:24:52 Cupy is not installed. Required if using gpu state space.


In [None]:

state_names = ['Continuous', 'Fragmented']

classifier = ClusterlessClassifier(**classifier_parameters)
classifier.fit(
    position=data["position_info"].linear_position,
    multiunits=data["multiunits"],
    track_graph=data["track_graph"],
    edge_order=WTRACK_EDGE_ORDER,
    edge_spacing=WTRACK_EDGE_SPACING,
)

results_numba_kde = classifier.predict(
    data["multiunits"],
    time=data["position_info"].index / np.timedelta64(1, "s"),
    state_names=state_names
)
logging.info('Done...')

19-Sep-21 16:24:53 Numba KDE
19-Sep-21 16:24:54 Fitting initial conditions...
19-Sep-21 16:24:54 Fitting state transition...
19-Sep-21 16:24:55 Fitting multiunits...
19-Sep-21 16:24:59 Estimating likelihood...


In [None]:
client.close()