In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import logging

FORMAT = '%(asctime)s %(message)s'
logging.basicConfig(level='INFO', format=FORMAT, datefmt='%d-%b-%y %H:%M:%S')

sns.set_context("talk")

In [3]:
from src.parameters import ANIMALS
from loren_frank_data_processing import make_epochs_dataframe


epoch_info = make_epochs_dataframe(ANIMALS)
epoch_info = epoch_info.loc[(epoch_info.type == 'run') &
                            (epoch_info.environment == 'wtrack')]
epoch_info

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,environment,type,exposure
animal,day,epoch,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Jaq,3,2,wtrack,run,1.0
Jaq,3,4,wtrack,run,2.0
Jaq,3,6,wtrack,run,3.0
Jaq,3,8,wtrack,run,4.0
Jaq,3,10,wtrack,run,5.0
...,...,...,...,...,...
Roqui,5,6,wtrack,run,25.0
Roqui,5,8,wtrack,run,26.0
Roqui,6,2,wtrack,run,27.0
Roqui,6,4,wtrack,run,28.0


In [4]:
import os
from src.load_data import load_data
import xarray as xr
from src.parameters import PROCESSED_DATA_DIR

epoch_key = epoch_info.index[0]

epoch_identifier = f"{epoch_key[0]}_{epoch_key[1]:02d}_{epoch_key[2]:02d}"
results_filename = os.path.join(
    PROCESSED_DATA_DIR,
    f"{epoch_identifier}_clusterless_forward_reverse_results.nc"
)

data = load_data(epoch_key,
                 position_to_linearize=['nose_x', 'nose_y'],
                 max_distance_from_well=30,
                 min_distance_traveled=50,
                 )
results = xr.open_dataset(results_filename)

24-Oct-21 11:52:23 Loading position info...
24-Oct-21 11:53:05 Loading multiunit...
24-Oct-21 11:53:26 Loading spikes...
24-Oct-21 11:54:26 Finding ripple times...


In [5]:
results.isel(time=slice(0, 300), state=0).acausal_posterior

In [6]:
import logging
import os
import sys
from pprint import pprint

import numpy as np
import pandas as pd
import xarray as xr
from loren_frank_data_processing import make_epochs_dataframe
from replay_trajectory_classification import ClusterlessClassifier
from sklearn.model_selection import KFold
from src.load_data import load_data
from src.parameters import (ANIMALS, PROCESSED_DATA_DIR, WTRACK_EDGE_ORDER,
                            WTRACK_EDGE_SPACING)

In [7]:
epoch_identifier = f"{epoch_key[0]}_{epoch_key[1]:02d}_{epoch_key[2]:02d}"
results_filename = os.path.join(
    PROCESSED_DATA_DIR,
    f"{epoch_identifier}_clusterless_forward_reverse_results.nc"
)

In [8]:
data = load_data(epoch_key,
                 position_to_linearize=['nose_x', 'nose_y'],
                 max_distance_from_well=30,
                 min_distance_traveled=50,
                 )

continuous_transition_types = [['random_walk_direction2', 'random_walk',            'uniform', 'random_walk',            'random_walk',            'uniform'],  # noqa
                               ['random_walk',            'random_walk_direction1', 'uniform', 'random_walk',            'random_walk',            'uniform'],  # noqa
                               ['uniform',                'uniform',                'uniform', 'uniform',                'uniform',                'uniform'],  # noqa
                               ['random_walk',            'random_walk',            'uniform', 'random_walk_direction1', 'random_walk',            'uniform'],  # noqa
                               ['random_walk',            'random_walk',            'uniform', 'random_walk',            'random_walk_direction2', 'uniform'],  # noqa
                               ['uniform',                'uniform',                'uniform', 'uniform',                'uniform',                'uniform'],  # noqa
                               ]
encoding_group_to_state = ['Inbound', 'Inbound', 'Inbound',
                           'Outbound', 'Outbound', 'Outbound']

clusterless_algorithm = 'multiunit_likelihood_gpu'
clusterless_algorithm_params = {
    'mark_std': 20.0,
    'position_std': 8.0,
}

classifier_parameters = {
    'movement_var': 6.0,
    'replay_speed': 1,
    'place_bin_size': 2.0,
    'continuous_transition_types': continuous_transition_types,
    'discrete_transition_diag': 0.968,
    'clusterless_algorithm': clusterless_algorithm,
    'clusterless_algorithm_params': clusterless_algorithm_params
}

inbound_outbound_labels = np.asarray(
    data["position_info"].task).astype(str)

notnull = pd.notnull(data["position_info"].task)

state_names = [
    'Inbound-Forward', 'Inbound-Reverse', 'Inbound-Fragmented',
    'Outbound-Forward', 'Outbound-Reverse', 'Outbound-Fragmented']

cv = KFold()
results = []

24-Oct-21 11:55:04 Loading position info...
24-Oct-21 11:55:43 Loading multiunit...
24-Oct-21 11:56:03 Loading spikes...
24-Oct-21 11:57:04 Finding ripple times...


In [31]:
fold_ind = 0

train, test = next(cv.split(data["position_info"].index))

In [10]:
classifier = ClusterlessClassifier(**classifier_parameters)
logging.info("Fitting model...")
classifier.fit(
    position=data["position_info"].iloc[train].linear_position,
    multiunits=data["multiunits"].isel(time=train),
    is_training=notnull.iloc[train],
    track_graph=data["track_graph"],
    edge_order=WTRACK_EDGE_ORDER,
    edge_spacing=WTRACK_EDGE_SPACING,
    encoding_group_labels=inbound_outbound_labels[train],
    encoding_group_to_state=encoding_group_to_state
)

24-Oct-21 11:57:35 Fitting model...
24-Oct-21 11:57:35 Fitting initial conditions...
24-Oct-21 11:57:36 Fitting state transition...
24-Oct-21 11:57:41 Fitting multiunits...
24-Oct-21 11:57:41 init
24-Oct-21 11:57:42 add pending dealloc: cuMemFree_v2 796 bytes
24-Oct-21 11:57:42 add pending dealloc: cuMemFree_v2 385052 bytes
24-Oct-21 11:57:42 add pending dealloc: cuMemFree_v2 4 bytes
24-Oct-21 11:57:42 add pending dealloc: cuMemFree_v2 796 bytes
24-Oct-21 11:57:42 add pending dealloc: cuMemFree_v2 796 bytes
24-Oct-21 11:57:42 add pending dealloc: cuMemFree_v2 40116 bytes
24-Oct-21 11:57:42 add pending dealloc: cuMemFree_v2 4 bytes
24-Oct-21 11:57:42 add pending dealloc: cuMemFree_v2 796 bytes
24-Oct-21 11:57:42 add pending dealloc: cuMemFree_v2 796 bytes
24-Oct-21 11:57:42 add pending dealloc: cuMemFree_v2 99856 bytes
24-Oct-21 11:57:42 add pending dealloc: cuMemFree_v2 4 bytes
24-Oct-21 11:57:42 dealloc: cuMemFree_v2 796 bytes
24-Oct-21 11:57:42 dealloc: cuMemFree_v2 385052 bytes
24-O

ClusterlessClassifier(clusterless_algorithm='multiunit_likelihood_gpu',
                      clusterless_algorithm_params={'mark_std': 20.0,
                                                    'position_std': 8.0},
                      continuous_transition_types=[['random_walk_direction2',
                                                    'random_walk', 'uniform',
                                                    'random_walk',
                                                    'random_walk', 'uniform'],
                                                   ['random_walk',
                                                    'random_walk_direction1',
                                                    'uniform', 'random_walk',
                                                    'random_walk', 'uniform'],
                                                   ['unif...
                                                    'random_walk', 'uniform',
                                         

In [14]:
test = slice(0, 300)

r = classifier.predict(
    data["multiunits"].isel(time=test),
    time=data["position_info"].iloc[test].index /
    np.timedelta64(1, "s"),
    state_names=state_names,
    use_gpu=True,
)

24-Oct-21 11:59:12 Estimating likelihood...
24-Oct-21 11:59:12 add pending dealloc: cuMemFree_v2 115420 bytes
24-Oct-21 11:59:12 add pending dealloc: cuMemFree_v2 200580 bytes
24-Oct-21 11:59:12 add pending dealloc: cuMemFree_v2 20 bytes
24-Oct-21 11:59:12 add pending dealloc: cuMemFree_v2 362180 bytes
24-Oct-21 11:59:12 add pending dealloc: cuMemFree_v2 499280 bytes
24-Oct-21 11:59:12 add pending dealloc: cuMemFree_v2 20 bytes
24-Oct-21 11:59:12 add pending dealloc: cuMemFree_v2 59700 bytes
24-Oct-21 11:59:12 dealloc: cuMemFree_v2 2388 bytes
24-Oct-21 11:59:13 dealloc: cuMemFree_v2 3184 bytes
24-Oct-21 11:59:13 dealloc: cuMemFree_v2 19104 bytes
24-Oct-21 11:59:13 dealloc: cuMemFree_v2 3980 bytes
24-Oct-21 11:59:13 dealloc: cuMemFree_v2 115420 bytes
24-Oct-21 11:59:13 dealloc: cuMemFree_v2 200580 bytes
24-Oct-21 11:59:13 dealloc: cuMemFree_v2 20 bytes
24-Oct-21 11:59:13 dealloc: cuMemFree_v2 362180 bytes
24-Oct-21 11:59:13 dealloc: cuMemFree_v2 499280 bytes
24-Oct-21 11:59:13 dealloc: 

In [15]:
r.acausal_posterior

In [60]:
results = xr.open_dataset(results_filename)
np.isnan(results.isel(time=slice(31_460, 31462)).causal_posterior)

In [61]:
results.isel(time=slice(31_460, 31462)).causal_posterior

In [63]:
results.isel(time=slice(31_460, 31463)).likelihood

In [65]:
results.isel(time=31461).likelihood

In [68]:
data['multiunits'].isel(time=31461).T

In [82]:
import math

import numpy as np
from numba import cuda
from numba.types import float32
from replay_trajectory_classification.bins import atleast_2d
from replay_trajectory_classification.multiunit_likelihood_gpu import estimate_pdf, estimate_log_intensity


multiunits = np.asarray(data['multiunits'].isel(time=[31461]))
encoding_marks = classifier.encoding_model_['Inbound']['encoding_marks']
mark_std = 20
place_bin_centers = classifier.place_bin_centers_
encoding_positions = classifier.encoding_model_['Inbound']['encoding_positions']
position_std = 8
occupancy = classifier.encoding_model_['Inbound']['occupancy']
mean_rates = classifier.encoding_model_['Inbound']['mean_rates']
summed_ground_process_intensity = classifier.encoding_model_['Inbound']['summed_ground_process_intensity']
is_track_interior = classifier.is_track_interior_
time_bin_size = 1
n_streams = 2

n_time = multiunits.shape[0]
log_likelihood = (-time_bin_size * summed_ground_process_intensity *
                  np.ones((n_time, 1)))
n_electrodes = multiunits.shape[-1]
multiunits = np.moveaxis(multiunits, -1, 0)
streams = [cuda.stream() for _ in range(min(n_streams, n_electrodes))]
pdfs = []
is_spikes = []

for elec_ind, (multiunit, enc_marks, enc_pos) in enumerate(zip(
        multiunits, encoding_marks, encoding_positions)):
    is_spike = np.any(~np.isnan(multiunit), axis=1)
    is_spikes.append(is_spike)
    n_spikes = is_spike.sum()
    if n_spikes > 0:
        pdfs.append(estimate_pdf(
            multiunit[is_spike],
            enc_marks,
            mark_std,
            place_bin_centers[is_track_interior],
            enc_pos,
            position_std,
            stream=streams[elec_ind % n_streams]
        ))
    else:
        pdfs.append([])

24-Oct-21 12:14:36 add pending dealloc: cuMemFree_v2 796 bytes
24-Oct-21 12:14:36 add pending dealloc: cuMemFree_v2 796 bytes
24-Oct-21 12:14:36 add pending dealloc: cuMemFree_v2 796 bytes
24-Oct-21 12:14:36 add pending dealloc: cuMemFree_v2 796 bytes
24-Oct-21 12:14:36 add pending dealloc: cuMemFree_v2 796 bytes
24-Oct-21 12:14:36 add pending dealloc: cuMemFree_v2 796 bytes
24-Oct-21 12:14:36 add pending dealloc: cuMemFree_v2 796 bytes
24-Oct-21 12:14:36 dealloc: cuMemFree_v2 20 bytes
24-Oct-21 12:14:36 dealloc: cuMemFree_v2 3980 bytes
24-Oct-21 12:14:36 dealloc: cuMemFree_v2 182400 bytes
24-Oct-21 12:14:36 dealloc: cuMemFree_v2 20 bytes
24-Oct-21 12:14:36 dealloc: cuMemFree_v2 796 bytes
24-Oct-21 12:14:36 dealloc: cuMemFree_v2 796 bytes
24-Oct-21 12:14:36 dealloc: cuMemFree_v2 796 bytes
24-Oct-21 12:14:36 dealloc: cuMemFree_v2 796 bytes
24-Oct-21 12:14:36 dealloc: cuMemFree_v2 796 bytes
24-Oct-21 12:14:36 dealloc: cuMemFree_v2 796 bytes
24-Oct-21 12:14:36 dealloc: cuMemFree_v2 796 by

In [94]:
p = []
li = []

n_interior_place_bins = is_track_interior.sum()
for elec_ind, (pdf, mean_rate, is_spike) in enumerate(
        zip(pdfs, mean_rates, is_spikes)):
    n_spikes = is_spike.sum()
    if n_spikes > 0:
        # Copy results from GPU to CPU and
        # reshape to (n_decoding_spikes, n_interior_place_bins)
        pdf = (pdf
               .copy_to_host(stream=streams[elec_ind % n_streams])
               .reshape((n_spikes, n_interior_place_bins), order='F'))
        p.append(pdf)
        log_intensity = estimate_log_intensity(
            pdf,
            occupancy[is_track_interior],
            mean_rate)
        li.append(log_intensity)
        log_likelihood[np.ix_(
            is_spike, is_track_interior)] += log_intensity
        
log_likelihood[:, ~is_track_interior] = np.nan

In [96]:
p

[array([[1.36014192e-33, 3.16328983e-33, 6.91347166e-33, 1.41983305e-32,
         2.73997830e-32, 4.96836735e-32, 8.46501518e-32, 1.35513210e-31,
         2.03830180e-31, 2.88059945e-31, 3.82490416e-31, 4.77175287e-31,
         5.59311262e-31, 6.15955231e-31, 6.37347442e-31, 6.19687708e-31,
         5.66308698e-31, 4.86799529e-31, 3.94477990e-31, 3.03248132e-31,
         2.24996531e-31, 1.68261249e-31, 1.38228849e-31, 1.37574581e-31,
         1.67551227e-31, 2.29175954e-31, 3.25262061e-31, 4.65165684e-31,
         6.75112502e-31, 1.01737908e-30, 1.62085926e-30, 2.72316394e-30,
         4.72014171e-30, 8.21305903e-30, 1.40377347e-29, 2.32550790e-29,
         3.70800767e-29, 5.67269725e-29, 8.31550647e-29, 1.16717975e-28,
         1.56757011e-28, 2.01240263e-28, 2.46604472e-28, 2.87977732e-28,
         3.19885384e-28, 3.37368326e-28, 3.37229491e-28, 3.18982580e-28,
         2.85119002e-28, 2.40543161e-28, 5.02667640e-30, 2.43749531e-30,
         1.11493363e-30, 4.81021081e-31, 1.95735161

In [95]:
li

[array([[-71.01998 , -70.214386, -69.43183 , -68.6785  , -67.96174 ,
         -67.28991 , -66.67176 , -66.11562 , -65.62839 , -65.214714,
         -64.87688 , -64.615   , -64.427765, -64.31304 , -64.268394,
         -64.29141 , -64.3798  , -64.531166, -64.74224 , -65.00619 ,
         -65.305595, -65.59734 , -65.795685, -65.80298 , -65.6093  ,
         -65.30027 , -64.95473 , -64.60172 , -64.23412 , -63.829407,
         -63.370438, -62.861107, -62.32517 , -61.792244, -61.286503,
         -60.82367 , -60.412346, -60.05587 , -59.753536, -59.50167 ,
         -59.294624, -59.12596 , -58.98948 , -58.87986 , -58.79297 ,
         -58.725807, -58.676357, -58.643433, -58.62666 , -58.626747,
         -61.281776, -62.119564, -63.01321 , -63.95061 , -64.92863 ,
         -65.94947 , -67.01846 , -68.142525, -69.32887 , -70.58361 ,
         -71.910515, -73.31011 , -74.77967 , -76.31133 , -77.86795 ,
         -79.13243 , -79.02678 , -77.97812 , -76.81432 , -75.67079 ,
         -74.56459 , -73.49682 , -