In [1]:
import logging

logging.basicConfig(
    level="INFO", format="%(asctime)s %(message)s", datefmt="%d-%b-%y %H:%M:%S"
)


In [2]:
from src.load_data import load_data


(position_info, spikes, multiunit_firing_rate, multiunit_HSE_times) = load_data()


20-Feb-23 14:47:16 Note: detected 96 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
20-Feb-23 14:47:16 Note: NumExpr detected 96 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
20-Feb-23 14:47:16 NumExpr defaulting to 8 threads.


In [3]:
from replay_trajectory_classification import (
    SortedSpikesClassifier,
    Environment,
    RandomWalk,
    Uniform,
)


environment = Environment(place_bin_size=3.0)
continuous_transition_types = [
    [RandomWalk(movement_var=12.0), Uniform()],
    [Uniform(), Uniform()],
]


classifier = SortedSpikesClassifier(
    environments=environment,
    continuous_transition_types=continuous_transition_types,
    sorted_spikes_algorithm="spiking_likelihood_kde_gpu",
    sorted_spikes_algorithm_params={
        "position_std": 6.0,
        "use_diffusion": False,
        "block_size": int(2**12),
    },
)

classifier


  from tqdm.autonotebook import tqdm


In [4]:
import cupy as cp
import xarray as xr


state_names = ["continuous", "fragmented"]

GPU_ID = 9

n_time = len(spikes)
n_segments = 1

results = []

# use context manager to specify which GPU (device)
with cp.cuda.Device(GPU_ID):
    # Fit the place fields
    classifier.fit(
        position=position_info[["x", "y"]],
        spikes=spikes,
    )

    for ind in range(n_segments):
        time_slice = slice(ind * n_time // n_segments, (ind + 1) * n_time // n_segments)
        results.append(
            classifier.predict(
                spikes.iloc[time_slice],
                time=spikes.iloc[time_slice].index.to_numpy(),
                state_names=state_names,
                use_gpu=True,
            )
        )
    logging.info("Done!")

results = xr.concat(results, dim="time")
results


20-Feb-23 14:47:38 Fitting initial conditions...
20-Feb-23 14:47:38 Fitting continuous state transition...
  x /= x.sum(axis=1, keepdims=True)
20-Feb-23 14:47:39 Fitting discrete state transition
20-Feb-23 14:47:39 Fitting place fields...


  0%|          | 0/225 [00:00<?, ?it/s]

  np.log(mean_rate) + np.log(marginal_density) - np.log(occupancy)
  np.log(mean_rate) + np.log(marginal_density) - np.log(occupancy)
20-Feb-23 14:48:38 Estimating likelihood...


  0%|          | 0/225 [00:00<?, ?it/s]

20-Feb-23 14:51:03 Estimating causal posterior...
20-Feb-23 14:56:05 Estimating acausal posterior...
20-Feb-23 15:09:00 Done!


In [5]:
results.drop(["likelihood", "causal_posterior"]).to_netcdf(
    "../Processed-Data/results.nc"
)


In [6]:
classifier.save_model("../Processed-Data/classifier.pkl")


In [7]:
from src.plot_data import create_interactive_2D_decoding_figurl

view = create_interactive_2D_decoding_figurl(
    position_info,
    multiunit_firing_rate,
    results,
    bin_size=environment.place_bin_size,
    position_name=["x", "y"],
    speed_name="speed",
    posterior_type="acausal_posterior",
    view_height=800,
)

view.url(label="2D Decode")


Computing sha1 of /stelmo/nwb/.kachery-cloud/tmp_XvGaLvpu/file.dat


'https://figurl.org/f?v=gs://figurl/spikesortingview-10&d=sha1://ad8ed49f4678b68ec1c109fbdce77565bde46ed3&label=2D%20Decode&zone=franklab.default'