In [1]:
%reload_ext autoreload
%autoreload 2

import numpy as np
from matplotlib import pyplot as plt
from syd import make_viewer, Viewer

from sklearn.decomposition import PCA
from umap import UMAP
from rastermap import Rastermap

from vrAnalysis2.helpers import Timer, tic, toc, nearestpoint
from vrAnalysis2.helpers import get_average_frame_position, get_place_field, crossCorrelation
from vrAnalysis2.sessions import create_b2session, B2Session, B2SessionParams
from vrAnalysis2.processors import spkmaps as SMPs
from vrAnalysis2.processors.support import median_zscore
from vrAnalysis2.tracking import Tracker
from vrAnalysis2.multisession import MultiSessionSpkmaps
from vrAnalysis2.metrics import FractionActive

mouse_name = "ATL027"
date = "2023-07-27"
session_id = "701"
spks_type = "oasis"

session = create_b2session(mouse_name, date, session_id, dict(spks_type=spks_type))
smp = SMPs.SpkmapProcessor(session)

frame_position, frame_speed, frame_environment, frame_trial = smp.get_frame_behavior()
reliability = smp.get_reliability()
idx_reliable = np.where(np.any(np.stack([rval > 0.5 for rval in reliability.values], axis=0), axis=0))[0]
print("Num reliable:", idx_reliable.shape)

Num reliable: (940,)


In [2]:
spks = session.spks[:, session.idx_rois]
spkmap = smp.get_env_maps()
spkmap.pop_nan_positions()

idx_random = np.random.choice(np.arange(spkmap.spkmap[0].shape[0]), size=2000, replace=False)
spkmap.filter_rois(idx_random)
spks = spks[:, idx_random]
print(spkmap)
print([s.shape for s in spkmap.spkmap])
print(spks.shape)

spkmap_samples = np.concatenate(([s.reshape(s.shape[0], -1) for s in spkmap.spkmap]), axis=1).T
environment_index = np.concatenate(([ienv * np.ones(np.prod(s.shape[1:])) for ienv, s in enumerate(spkmap.spkmap)]), axis=0)
position_index = np.concatenate([np.reshape(np.repeat(np.arange(s.shape[2]).reshape(1, -1), s.shape[1], axis=0), -1) for s in spkmap.spkmap], axis=0)
print(spkmap_samples.shape, spks.shape, environment_index.shape, position_index.shape)

n_neighbors = 20
n_components = 2
metric = "correlation"
spread = 0.8
min_dist = 0.2
umap_params = dict(n_neighbors=n_neighbors, n_components=n_components, metric=metric, spread=spread, min_dist=min_dist)
umap_spks = UMAP(**umap_params).fit(spks)
umap_maps = UMAP(**umap_params).fit(spkmap_samples)

pca_spks = PCA(n_components=n_components).fit(spks)
pca_maps = PCA(n_components=n_components).fit(spkmap_samples)

Maps(num_trials={45, 53}, num_positions=195, num_rois=2000, environments={1, 3}, rois_first=True)
[(2000, 45, 195), (2000, 53, 195)]
(12367, 2000)
(19110, 2000) (12367, 2000) (19110,) (19110,)




In [3]:
plt.close('all')

def plot(state):
    if state["redmethod"] == "umap":
        emb_spks = umap_spks.embedding_
        emb_maps = umap_maps.embedding_
    elif state["redmethod"] == "pca":
        emb_spks = pc_spks
        emb_maps = pc_maps

    fig, ax = plt.subplots(2, 2, figsize=(7, 6), layout="constrained")
    ax[0, 0].scatter(emb_spks[:, 0], emb_spks[:, 1], s=5, c=frame_environment, alpha=0.2)
    ax[0, 1].scatter(emb_maps[:, 0], emb_maps[:, 1], s=5, c=environment_index, alpha=0.2)
    ax[1, 0].scatter(emb_spks[:, 0], emb_spks[:, 1], s=5, c=frame_position, alpha=0.2)
    ax[1, 1].scatter(emb_maps[:, 0], emb_maps[:, 1], s=5, c=position_index, alpha=0.2)
    return fig

viewer = make_viewer(plot)
viewer.add_selection("redmethod", options=["umap", "pca"])
viewer.show()

HBox(children=(VBox(children=(VBox(children=(HTML(value='<b>Parameters</b>'), Dropdown(description='redmethod'…

In [5]:
smp = SMPs.SpkmapProcessor(session)
reliability = smp.get_reliability()
idx_reliable = np.where(np.any(np.stack([rval > 0.5 for rval in reliability.values], axis=0), axis=0))[0]

spks = session.spks[:, session.idx_rois]
spkmap = smp.get_env_maps()
spkmap.pop_nan_positions()
# spkmap.filter_rois(idx_reliable)
# spks = spks[:, idx_reliable]
# print("Num reliable:", len(idx_reliable))

idx_random = np.random.choice(np.arange(spkmap.spkmap[0].shape[0]), size=2000, replace=False)
spkmap.filter_rois(idx_random)
spks = spks[:, idx_random]

spkmap_samples = np.concatenate(([s.reshape(s.shape[0], -1) for s in spkmap.spkmap]), axis=1).T
environment_index = np.concatenate(([ienv * np.ones(np.prod(s.shape[1:])) for ienv, s in enumerate(spkmap.spkmap)]), axis=0)
position_index = np.concatenate([np.reshape(np.repeat(np.arange(s.shape[2]).reshape(1, -1), s.shape[1], axis=0), -1) for s in spkmap.spkmap], axis=0)

print(spkmap)
print(spkmap_samples.shape)
print(spks.shape)

viewer = make_viewer()
viewer.add_integer("n_neighbors", value=20, min=5, max=80)
viewer.add_float("min_dist", value=0.2, min=0, max=2, step=0.002)
viewer.add_selection("metric", value="correlation", options=["cosine", "correlation", "manhattan", "euclidean"])
viewer.add_float("spread", value=0.8, min=0, max=2, step=0.1)
viewer.add_integer("n_components", value=2, min=1, max=10)
viewer.add_selection("spk_source", value="spkmap", options=["spks", "spkmap"])

def plot(state):
    source = spks if state["spk_source"] == "spks" else spkmap_samples
    pos = frame_position if state["spk_source"] == "spks" else position_index
    env = frame_environment if state["spk_source"] == "spks" else environment_index

    idx_not_nan = ~np.isnan(pos)
    message = f"Valid Samples: {np.sum(idx_not_nan)} Total Samples: {len(idx_not_nan)}"
    print(message)
    source = source[idx_not_nan]
    pos = pos[idx_not_nan]
    env = env[idx_not_nan]

    umap_maps = UMAP(n_neighbors=state["n_neighbors"], n_components=state["n_components"], min_dist=state["min_dist"], metric=state["metric"], spread=state["spread"]).fit(source)

    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].scatter(umap_maps.embedding_[:, 0], umap_maps.embedding_[:, 1], s=5, c=env, alpha=0.2)
    ax[1].scatter(umap_maps.embedding_[:, 0], umap_maps.embedding_[:, 1], s=5, c=pos, alpha=0.2)
    ax[0].set_title(message)
    return fig

viewer.set_plot(plot)
viewer.show()

Maps(num_trials={45, 53}, num_positions=195, num_rois=2000, environments={1, 3}, rois_first=True)
(19110, 2000)
(12367, 2000)


HBox(children=(VBox(children=(VBox(children=(HTML(value='<b>Parameters</b>'), IntSlider(value=20, continuous_u…

Valid Samples: 19110 Total Samples: 19110


