In [None]:
import logging

logging.basicConfig(
    level="INFO", format="%(asctime)s %(message)s", datefmt="%d-%b-%y %H:%M:%S"
)


In [None]:
from src.load_data import load_data
from src.parameters import SAMPLING_FREQUENCY

(position_info, spikes, multiunit_firing_rate, multiunit_HSE_times) = load_data(
    position_file_name="../Raw-Data/position4Xulu_1.csv",
    spike_file_name="../Raw-Data/df4Xulu_1.csv")

# cut out first 20 s because animal is being placed on track
start_ind = int(20.0 * SAMPLING_FREQUENCY)
position_info = position_info.iloc[start_ind:]
spikes = spikes.iloc[start_ind:]
multiunit_firing_rate = multiunit_firing_rate.iloc[start_ind:]


In [None]:
from replay_trajectory_classification import (
    SortedSpikesClassifier,
    Environment,
    RandomWalk,
    Uniform,
)


environment = Environment(place_bin_size=2.0)
continuous_transition_types = [
    [RandomWalk(movement_var=12.0), Uniform()],
    [Uniform(), Uniform()],
]


classifier = SortedSpikesClassifier(
    environments=environment,
    continuous_transition_types=continuous_transition_types,
    sorted_spikes_algorithm="spiking_likelihood_kde_gpu",
    sorted_spikes_algorithm_params={
        "position_std": 6.0,
        "use_diffusion": False,
        "block_size": int(2**12),
    },
)

classifier


In [None]:
import cupy as cp
import xarray as xr


state_names = ["continuous", "fragmented"]

GPU_ID = 8

n_time = len(spikes)
n_segments = 1

results = []

# use context manager to specify which GPU (device)
with cp.cuda.Device(GPU_ID):
    # Fit the place fields
    classifier.fit(
        position=position_info[["x", "y"]],
        spikes=spikes,
    )

    for ind in range(n_segments):
        time_slice = slice(ind * n_time // n_segments, (ind + 1) * n_time // n_segments)
        results.append(
            classifier.predict(
                spikes.iloc[time_slice],
                time=spikes.iloc[time_slice].index.to_numpy(),
                state_names=state_names,
                use_gpu=True,
            )
        )
    logging.info("Done!")

results = xr.concat(results, dim="time")
results


In [None]:
results.drop(["likelihood", "causal_posterior"]).to_netcdf(
    "../Processed-Data/results_1.nc"
)


In [None]:
classifier.save_model("../Processed-Data/classifier_1.pkl")


In [None]:
from src.plot_data import create_interactive_2D_decoding_figurl

view = create_interactive_2D_decoding_figurl(
    position_info,
    multiunit_firing_rate,
    results,
    bin_size=environment.place_bin_size,
    position_name=["x", "y"],
    speed_name="speed",
    posterior_type="acausal_posterior",
    view_height=800,
)

view.url(label="2D Decode")
