In [1]:
%matplotlib inline
%reload_ext autoreload

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import seaborn as sns
import logging

import logging
import sys

def setup_logging(epoch_key, date_format='%d-%b-%y %H:%M:%S', format='%(asctime)s %(message)s'):
    animal, day, epoch = epoch_key
    log_filename = f"{animal}_{day:02d}_{epoch:02d}_test_speed_cpu.log"

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(fmt=format, datefmt=date_format)

    stdout_handler = logging.StreamHandler(sys.stdout)
    stdout_handler.setLevel(logging.INFO)
    stdout_handler.setFormatter(formatter)

    file_handler = logging.FileHandler(log_filename)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)

    logger.addHandler(file_handler)
    logger.addHandler(stdout_handler)

sns.set_context("talk")

In [3]:
epoch_key = 'Jaq', 3, 12

In [4]:
setup_logging(epoch_key)

In [5]:
from dask.distributed import Client

# for virga
client = Client(n_workers=16, threads_per_worker=6, processes=True)

client

Perhaps you already have a cluster running?
Hosting the HTTP server on port 38443 instead
  http_address["port"], self.http_server.port


0,1
Client  Scheduler: tcp://127.0.0.1:36825  Dashboard: http://127.0.0.1:38443/status,Cluster  Workers: 16  Cores: 96  Memory: 2.16 TB


In [6]:
logging.info(client)

19-Sep-21 21:56:31 <Client: 'tcp://127.0.0.1:36825' processes=16 threads=96, memory=2.16 TB>


In [7]:
from src.load_data import load_data


data = load_data(epoch_key,
                 position_to_linearize=['nose_x', 'nose_y'],
                 max_distance_from_well=5,
                 min_distance_traveled=30)

19-Sep-21 21:56:33 Loading position info...
19-Sep-21 21:57:14 Loading multiunit...
19-Sep-21 21:57:46 Loading spikes...
19-Sep-21 21:59:34 Finding ripple times...


In [8]:
client.dashboard_link

'http://127.0.0.1:38443/status'

#### Dask Integer

In [9]:
from replay_trajectory_classification.misc import NumbaKDE
from replay_trajectory_classification import ClusterlessClassifier
from src.parameters import WTRACK_EDGE_ORDER, WTRACK_EDGE_SPACING
import pprint

continuous_transition_types = (
    [['random_walk', 'uniform'],
     ['uniform',     'uniform']])


clusterless_algorithm = 'multiunit_likelihood_integer'
clusterless_algorithm_params = {
    'mark_std': 20.0,
    'position_std': 8.0,
    'chunks': 100,
}

classifier_parameters = {
    'movement_var': 6.0,
    'replay_speed': 1,
    'place_bin_size': 2.5,
    'continuous_transition_types': continuous_transition_types,
    'discrete_transition_diag': 0.968,
    'clusterless_algorithm': clusterless_algorithm,
    'clusterless_algorithm_params': clusterless_algorithm_params
}

logging.info(pprint.pprint(classifier_parameters))

19-Sep-21 22:00:36 Cupy is not installed. Required if using gpu state space.
{'clusterless_algorithm': 'multiunit_likelihood_integer',
 'clusterless_algorithm_params': {'chunks': 100,
                                  'mark_std': 20.0,
                                  'position_std': 8.0},
 'continuous_transition_types': [['random_walk', 'uniform'],
                                 ['uniform', 'uniform']],
 'discrete_transition_diag': 0.968,
 'movement_var': 6.0,
 'place_bin_size': 2.5,
 'replay_speed': 1}
19-Sep-21 22:00:37 None


In [None]:
logging.info('Dask Integer')
state_names = ['Continuous', 'Fragmented']

classifier = ClusterlessClassifier(**classifier_parameters)
classifier.fit(
    position=data["position_info"].linear_position,
    multiunits=data["multiunits"],
    track_graph=data["track_graph"],
    edge_order=WTRACK_EDGE_ORDER,
    edge_spacing=WTRACK_EDGE_SPACING,
)

results_dask_integer = classifier.predict(
    data["multiunits"],
    time=data["position_info"].index / np.timedelta64(1, "s"),
    state_names=state_names
)
logging.info('Done...')

19-Sep-21 22:00:37 Dask Integer
19-Sep-21 22:00:37 Fitting initial conditions...
19-Sep-21 22:00:38 Fitting state transition...
19-Sep-21 22:00:38 Fitting multiunits...
19-Sep-21 22:00:56 Estimating likelihood...




In [None]:
client.close()