In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import seaborn as sns
import logging

FORMAT = '%(asctime)s %(message)s'

logging.basicConfig(level='INFO', format=FORMAT, datefmt='%d-%b-%y %H:%M:%S')
sns.set_context("talk")

In [3]:
from src.load_data import load_data

epoch_key = ('Jaq', 1, 2) # animal, day, epoch

data = load_data(epoch_key)

TypeError: iteration over a 0-d array

In [None]:
#save linearised position 

data['position_info'].to_xarray().to_netcdf(
    f"{epoch_key[0]}_{epoch_key[1]:02d}_{epoch_key[2]:02d}_linearised_position_nose.nc"
)

In [None]:
fig, ax = plt.subplots(figsize=(30, 10))

for edge_label, df in data['position_info'].groupby('track_segment_id'):
    ax.scatter(df.index / np.timedelta64(1, 's'), df.linear_position, s=1)
    
ax.set_ylabel('Position [cm]')
ax.set_xlabel('Time [s]');

In [None]:
from src.parameters import EDGE_ORDER, EDGE_SPACING, ANIMALS
from src.load_data import make_track_graph

track_graph, center_well_id = make_track_graph(epoch_key, ANIMALS)
is_running = np.abs(data["position_info"].tailBase_vel) > 4
#is_running = np.abs(data["position_info"].forepawR_vel) > 4
is_outbound = data["position_info"].task == "Outbound"

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(30, 10))
ax.scatter(
        data["position_info"].index / np.timedelta64(1, "s"), data["position_info"].linear_position, s=10, color='lightgrey',
    )
for task, df in data["position_info"].groupby("task"):
    ax.scatter(
        df.index / np.timedelta64(1, "s"), df.linear_position, s=10, label=task,
    )

plt.legend()
sns.despine()

### Continuous vs. Jump

In [None]:
from replay_trajectory_classification import ClusterlessClassifier
from src.parameters import classifier_parameters, discrete_state_transition

from sklearn.model_selection import KFold
from tqdm.auto import tqdm

cv = KFold()
cv_classifier_clusterless_results = []

for fold_ind, (train, test) in tqdm(enumerate(cv.split(data["position_info"].index))):
    
    #train = train[is_outbound[train].values]
    
    cv_classifier = ClusterlessClassifier(**classifier_parameters)

    cv_classifier.fit(
        position=data["position_info"].iloc[train].linear_position,
        multiunits=data["multiunits"].isel(time=train),
        is_training=is_running.iloc[train],
        track_graph=track_graph,
        center_well_id=center_well_id,
        edge_order=EDGE_ORDER,
        edge_spacing=EDGE_SPACING,
    )
    cv_classifier.discrete_state_transition_ = discrete_state_transition
    logging.info('Predicting posterior...')
    cv_classifier_clusterless_results.append(
        cv_classifier.predict(
            data["multiunits"].isel(time=test),
            time=data["position_info"].iloc[test].index / np.timedelta64(1, "s"),
        )
    )

In [None]:
# concatenate cv classifier results 
cv_classifier_clusterless_results = xr.concat(
    cv_classifier_clusterless_results, dim="time"
)
cv_classifier_clusterless_results

In [None]:
# save the results as .nc format. ncread matlab can read these
cv_classifier_clusterless_results.to_netcdf(
   f"{epoch_key[0]}_{epoch_key[1]:02d}_{epoch_key[2]:02d}_cv_classifier_clusterless_vel_4_nose_alltime_results.nc"
)

In [None]:
cv_classifier_clusterless_results.time.max()

In [None]:
from src.visualization import plot_classifier_time_slice

time_slice = slice(32500, 34500)
# time_slice = slice(2180,2190)

plot_classifier_time_slice(
    time_slice,
    cv_classifier,
    cv_classifier_clusterless_results,
    data,
    posterior_type="acausal_posterior",
    figsize=(30, 15),
)

In [None]:
ClusterlessClassifier.load_model(f"{epoch_key[0]}_{epoch_key[1]:02d}_{epoch_key[2]:02d}_cv_classifier.pkl")

In [None]:
cv_classifier_clusterless_results = xr.open_dataset(f"{epoch_key[0]}_{epoch_key[1]:02d}_{epoch_key[2]:02d}_cv_classifier_clusterless_vel_4_nose_alltime_results.nc")

In [None]:
from src.analysis import calculate_replay_distance

replay_distance_from_animal_position = calculate_replay_distance(
    posterior=cv_classifier_clusterless_results.causal_posterior.sum('state'),
    track_graph=track_graph,
    decoder=cv_classifier,
    position_2D=data['position_info'].loc[:, ["nose_x", "nose_y"]],
    track_segment_id=data['position_info'].loc[:, ["track_segment_id"]],
)
replay_distance_from_animal_position

In [None]:
plt.figure(figsize=(30, 5))
plt.plot(replay_distance_from_animal_position)

In [None]:
#save linearised position 
cv_classifier_clusterless_results['replay_distance_from_animal_position'] = (('time'), replay_distance_from_animal_position)

In [None]:
# save the results as .nc format. ncread matlab can read these
cv_classifier_clusterless_results.to_netcdf(
   f"{epoch_key[0]}_{epoch_key[1]:02d}_{epoch_key[2]:02d}_cv_classifier_clusterless_vel_4_nose_alltimedist_results.nc"
)

### Local vs. Non-Local

In [None]:
from replay_identification import ReplayDetector
from src.parameters import detector_parameters
from sklearn.model_selection import KFold
from tqdm.auto import tqdm

cv = KFold()
cv_clusterless_results = []

for train, test in tqdm(cv.split(data["position_info"].index)):

    cv_detector = ReplayDetector(**detector_parameters)

    cv_detector.fit(
        is_ripple=data["is_ripple"].iloc[train],
        speed=data["position_info"].iloc[train].speed,
        position=data["position_info"].iloc[train].linear_position,
        multiunit=data["multiunit"].isel(time=train),
        track_graph=track_graph,
        center_well_id=center_well_id,
        edge_order=EDGE_ORDER,
        edge_spacing=EDGE_SPACING,
    )
    
    logging.info('Predicting posterior...')
    cv_clusterless_results.append(
        cv_detector.predict(
            speed=data["position_info"].iloc[test].speed,
            position=data["position_info"].iloc[test].linear_position,
            multiunit=data["multiunit"].isel(time=test),
            use_likelihoods=["multiunit"],
            time=data["position_info"].iloc[test].index / np.timedelta64(1, "s"),
        )
    )

In [None]:
cv_clusterless_results = xr.concat(cv_clusterless_results, dim="time")
cv_clusterless_results

In [None]:
cv_clusterless_results.to_netcdf(
    f"{epoch_key[0]}_{epoch_key[1]:02d}_{epoch_key[2]:02d}_cv_clusterless_results.nc"
)

In [None]:
from src.visualization import plot_local_non_local_time_slice

time_slice = slice(32500, 34500)

plot_local_non_local_time_slice(
    time_slice,
    cv_detector,
    cv_clusterless_results,
    data,
    posterior_type="acausal_posterior",
    figsize=(30, 15),
)