In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'
%qtconsole

In [2]:
import pandas as pd
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt

import logging
logging.basicConfig(level=logging.INFO)

In [3]:
from src.parameters import ANIMALS, SAMPLING_FREQUENCY, BRAIN_AREAS

epoch_key = ('bon', 3, 2)

In [4]:
from src.load_data import load_data

use_likelihoods = ['spikes', 'lfp_power']

data = load_data(epoch_key, ANIMALS, SAMPLING_FREQUENCY, use_likelihoods, BRAIN_AREAS)

INFO:src.load_data:Loading Data...
INFO:spectral_connectivity.transforms:Multitaper(sampling_frequency=1500, time_halfbandwidth_product=1,
           time_window_duration=0.02, time_window_step=0.02,
           detrend_type='constant', start_time=3729.0257, n_tapers=1)
INFO:src.load_data:Finding multiunit high synchrony events...
INFO:src.load_data:Finding ripple times...


In [7]:
from src.analysis import identify_replays

results, detector = identify_replays(data, use_likelihoods)

INFO:replay_identification.decoders:Fitting speed model...
INFO:replay_identification.decoders:Fitting LFP power model...
INFO:replay_identification.decoders:Fitting spiking model...


HBox(children=(IntProgress(value=0, description='neurons', max=19), HTML(value='')))

INFO:replay_identification.decoders:Fitting movement state transition...
INFO:replay_identification.decoders:Fitting replay state transition...





INFO:replay_identification.decoders:Predicting lfp_power likelihood...
INFO:replay_identification.decoders:Predicting spikes likelihood...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))




HBox(children=(IntProgress(value=0, max=19), HTML(value='')))




INFO:replay_identification.decoders:Predicting replay probability and density...


In [306]:
from src.analysis import get_replay_times

replay_times, labels = get_replay_times(results)

In [307]:
replay_times.head()

Unnamed: 0_level_0,start_time,end_time,duration,max_probability
replay_number,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1,01:02:22.825700,01:02:22.985033,00:00:00.159333,0.999903
2,01:02:27.105700,01:02:27.265033,00:00:00.159333,0.999999
3,01:02:35.965700,01:02:36.045033,00:00:00.079333,0.994431
4,01:02:38.145700,01:02:38.245033,00:00:00.099333,0.996903
5,01:03:19.745700,01:03:19.837033,00:00:00.091333,0.996957


In [308]:
results

<xarray.Dataset>
Dimensions:             (position: 187, time: 1396490)
Coordinates:
  * time                (time) timedelta64[ns] 01:02:09.025700 ... 01:17:40.018366
  * position            (position) float64 0.4979 1.494 2.489 ... 184.7 185.7
Data variables:
    replay_probability  (time) float64 0.0 1.233e-06 ... 1.306e-06 1.306e-06
    replay_posterior    (time, position) float64 0.0 0.0 ... 1.491e-06 1.491e-06
    prior               (time, position) float64 0.0 0.0 ... 1.478e-06 1.479e-06
    likelihood          (time, position) float64 1.743 1.744 ... 1.008 1.008

In [11]:
results.replay_posterior

<xarray.DataArray 'replay_posterior' (time: 1396490, position: 187)>
array([[0.000000e+00, 0.000000e+00, 0.000000e+00, ..., 0.000000e+00,
        0.000000e+00, 0.000000e+00],
       [1.246875e-06, 1.247435e-06, 1.247771e-06, ..., 1.255199e-06,
        1.255388e-06, 1.255522e-06],
       [2.493327e-06, 2.494431e-06, 2.495045e-06, ..., 2.517551e-06,
        2.518078e-06, 2.518424e-06],
       ...,
       [1.381562e-06, 1.381475e-06, 1.380409e-06, ..., 1.487029e-06,
        1.488281e-06, 1.488965e-06],
       [1.382364e-06, 1.382276e-06, 1.381207e-06, ..., 1.488225e-06,
        1.489480e-06, 1.490165e-06],
       [1.383086e-06, 1.382998e-06, 1.381926e-06, ..., 1.489313e-06,
        1.490571e-06, 1.491257e-06]])
Coordinates:
  * time      (time) timedelta64[ns] 01:02:09.025700 ... 01:17:40.018366
  * position  (position) float64 0.4979 1.494 2.489 3.485 ... 183.7 184.7 185.7

In [310]:
replay_position_info = (data['position_info'].loc[replay_times.start_time]
                        .set_index(replay_times.index))
replay_info = pd.concat((replay_times, replay_position_info), axis=1)


In [309]:
replay_posterior = xr.concat(
    [(results.replay_posterior
      .sel(time=slice(r.start_time, r.end_time))
      .assign_coords(time=lambda da: da.time - r.start_time))
     for r in replay_times.itertuples()], dim=replay_times.index)

In [338]:
from replay_classification import SortedSpikeDecoder


def _get_test_spikes(data, labels, replay_number, sampling_frequency):
    test_spikes = data['spikes'][labels.replay_number == replay_number].T
    n_time = test_spikes.shape[1]
    time = np.arange(0, n_time) / sampling_frequency
    return test_spikes, time
    
def decode_replays(data, detector, labels, replay_times, sampling_frequency):
    is_training = (data['position_info'].linear_speed > 4) & data['position_info'].is_correct
    train_position_info = data['position_info'].loc[is_training]
    train_spikes = data['spikes'][is_training]
    
    decoder = SortedSpikeDecoder(
        position=train_position_info.linear_distance.values,
        lagged_position=train_position_info.lagged_linear_distance.values,
        trajectory_direction=train_position_info.task.values,
        spikes=train_spikes.T,
        replay_speedup_factor=detector.replay_speed,
        n_position_bins=detector.place_bin_centers.size,
        confidence_threshold=0.8,
        knot_spacing=detector.spike_model_knot_spacing,
    ).fit()
    
    decoder_results = [
        decoder.predict(*_get_test_spikes(data, labels, replay_number, sampling_frequency))
         for replay_number in replay_times.index]
    
    return decoder_results, decoder

decoder_results, decoder = decode_replays(data, detector, labels, replay_times, SAMPLING_FREQUENCY)

INFO:replay_classification.decoders:Fitting state transition model...
INFO:replay_classification.decoders:Fitting observation model...


In [386]:
replay_info, detector_posterior, decoder_posterior = summarize_replays(
    replay_times, results, decoder_results, data)

In [390]:
from ipywidgets import interact

n_replays = replay_info.shape[0] - 1

@interact(index=(0, n_replays), continuous_update=False)
def plot_posterior(index):
    fig, ax = plt.subplots(1, 1, figsize=(10, 7))
    p = (replay_posterior
     .assign_coords(time=replay_posterior.time.to_index().total_seconds())
     .isel(replay_number=index)
     .dropna('time'))
    g = (p
     .plot(x='time', y='position', robust=True));

    g.axes.axhline(replay_info.iloc[index].linear_distance,
                   color='white', linestyle='--', linewidth=5,
                  label='animal position')
    
    replay_position = p.position.values[
        np.log(p).argmax('position').values]
    plt.plot(p.time, replay_position, label='MAP')
    
    plt.legend()
    
    d = decoder_results[index]
    
    plt.figure()
    d.plot_posterior_density()
    g = d.plot_state_probability()
    g.axhline(decoder.confidence_threshold,
              linestyle='--', color='black')
    g.set_ylim((0, 1))
    g.set_ylabel('State Probability')
    
    plt.figure();
    pd = d.results.sum('state').posterior_density
    pd.plot(
        x='time', y='position', robust=True)
    replay_position = pd.position.values[
        np.log(pd).argmax('position').values]
    plt.plot(p.time, replay_position, label='MAP')
    
    COLUMNS = ['linear_distance', 'linear_speed', 'replay_type',
               'replay_type_confidence', 'replay_motion_slope',
               'replay_motion_type', 'replay_movement_distance',
               'credible_interval_size', 'duration']
    display(replay_info.loc[index + 1, COLUMNS])


interactive(children=(IntSlider(value=159, description='index', max=318), Output()), _dom_classes=('widget-int…