In [None]:
from IPython.core.getipython import get_ipython
from matplotlib import pyplot as plt
import numpy as np
import sys
import h5py
import os
import pandas as pd
import seaborn as sns
import plotly.graph_objects as go
sys.path.append("..")
from placecode import utils as ut
from placecode.from_caiman import *

try:
    if __IPYTHON__:
        get_ipython().run_line_magic('load_ext', 'autoreload')
        get_ipython().run_line_magic('autoreload', '2')
except NameError:
    pass

import scipy

In [None]:
# TODO: sankey diagram to plot how cells behave over recordings, e.g. place cell, no place cell, still place cell, still no place cell
# TODO: do a heatmap plotting all persistent cells (that made through analysis): each column is condition, each row contains same cell spatial component.
# TODO: work on the % of cells analysis

## Open (hdf5) files

In [None]:
files_list = []
while True:
    fpath = ut.open_file("Open hdf5 file, or press Cancel to finish")
    if fpath == ".":  # user pressed cancel
        break
    else:
        files_list.append(fpath)

In [None]:
Y_list = []
A_list = []
dims_list = []  # Cn entry in workspace # TODO: A_sparse always have lower resolution, probably from cropping... should I define that as dims?
templates = []
p_vals = []
conditions = []
tv_angles = []
tv_lengths = []
for fpath in files_list:
    with h5py.File(fpath, "r") as hf:
        resolution = hf.attrs["resolution"][()]
        n_components = hf.attrs["n_units"]
        condition = hf.attrs["condition"]
        ps = hf["p_values_tuned"][()]
        A_data = hf["A_data"][()]
        A_indices = hf["A_indices"][()]
        A_indptr = hf["A_indptr"][()]
        A_shape = hf["A_shape"][()]
        tv_a = hf["tuned_vector_angles"][()]
        tv_l = hf["tuned_vector_lengths"][()]
        #spatial = ut.read_spatial(A_data, A_indices, A_indptr, A_shape, n_components, resolution, unflatten=False)
        spatial = scipy.sparse.csc_matrix((A_data, A_indices, A_indptr), shape=A_shape)
        dims_list.append(resolution)
        A_list.append(spatial)  # need to swap: (n_units, n_pixels) -> (n_pixels, n_units)
        p_vals.append(ps)
        conditions.append(condition)
        tv_angles.append(tv_a)
        tv_lengths.append(tv_l)

In [None]:
# convert tuned vector data into numpy array. To deal with varying number of units per recording (condition), pad each column to the longest with np.nan
max_len = max(len(lst) for lst in tv_angles) 
def convert_to_np(list_of_arrs):
    """
    given a list of 1D arrays, convert to a 2D array, add padding with np.nans to achieve equal column sizes 
    """
    return np.array([np.concatenate([lst, [np.nan]*(max_len - len(lst))]) for lst in list_of_arrs]).T

tv_angles_padded = convert_to_np(tv_angles)
tv_lengths_padded = convert_to_np(tv_lengths)
p_vals_padded = convert_to_np(p_vals)


In [None]:
if len(templates) > 0:
    templates_cropped = []
    for template in templates:
        FOV_shape = template.shape
        cropped_shape = dims_list[0]
        
        x_crop_onesided = (FOV_shape[0] - cropped_shape[0])//2
        assert 2*x_crop_onesided == FOV_shape[0] - cropped_shape[0]

        y_crop_onesided = (FOV_shape[1] - cropped_shape[1])//2
        assert 2*y_crop_onesided == FOV_shape[1] - cropped_shape[1]
        template_cropped = template[y_crop_onesided:-y_crop_onesided,x_crop_onesided:-x_crop_onesided]  # TODO: x and y swapped?
        templates_cropped.append(template_cropped)

## Use `register_multisession()`

The function `register_multisession()` requires 3 arguments:
- `A`: A list of ndarrays or scipy.sparse.csc matrices with (# pixels X # component ROIs) for each session
- `dims`: Dimensions of the FOV, needed to restore spatial components to a 2D image
- `templates`: List of ndarray matrices of size `dims`, template image of each session

In [None]:
spatial_union, assignments, matchings = register_multisession(A=A_list, dims=dims_list[0])

The function returns 3 variables for further analysis:
- `spatial_union`: csc_matrix (# pixels X # total distinct components), the union of all ROIs across all sessions aligned to the FOV of the last session.
- `assignments`: ndarray (# total distinct components X # sessions). `assignments[i,j]=k` means that component `k` from session `j` has been identified as component `i` from the union of all components, otherwise it takes a `NaN` value. Note that for each `i` there is at least one session index `j` where `assignments[i,j]!=NaN`.
- `matchings`: list of (# sessions) lists. Saves `spatial_union` indices of individual components in each session. `matchings[j][k] = i` means that component `k` from session `j` is represented by component `i` in the union of all components `spatial_union`. In other words `assignments[matchings[j][k], j] = j`.

## Create various subgroups

### Filter conditions

In [None]:
# create indices for conditions used
assert "bl" in conditions  # TODO: not all bl are called bl1! Some bl_d1_1, bl_d1_2, bl_d2...
conditions_to_use = ["bl", "30min", "60min"]  
i_cond_filtered = np.isin(np.array(conditions), np.array(conditions_to_use))

assignments_filtered = assignments[:, i_cond_filtered]
assignments_filtered = assignments_filtered[~np.isnan(assignments_filtered).all(axis=1)]  # filter out rows full of np.nan

# do not throw away any rows in tuned vector data, but throw away conditions we do not use.
# The vector lengths of one unit over different conditions can be accessed with tv_lengths_filtered[]
tv_lengths_filtered = tv_lengths_padded[:, i_cond_filtered]
tv_angles_filtered = tv_angles_padded[:, i_cond_filtered]
p_vals_filtered = p_vals_padded[:, i_cond_filtered]

### Take only omnipresent cells
(omnipresent cell = cell that could be identified in all recordings)

In [None]:
assignments_omnipresent = assignments_filtered[~np.isnan(assignments_filtered).any(axis=1)].astype(np.int16)

### Match (pair) values for same cell from different conditions (recordings) 

In [None]:
# for each omnipresent unit, get the vector length for each included condition
tv_lengths_paired = np.zeros(assignments_omnipresent.shape)
tv_angles_paired = np.zeros(assignments_omnipresent.shape)
p_vals_paired = np.zeros(assignments_omnipresent.shape)
for i_cond in range(len(conditions_to_use)):
    tv_lengths_paired[:, i_cond] = tv_lengths_filtered[ assignments_omnipresent.T[i_cond],i_cond]
    tv_angles_paired[:, i_cond] = tv_angles_filtered[ assignments_omnipresent.T[i_cond],i_cond]
    p_vals_paired[:, i_cond] = p_vals_filtered[assignments_omnipresent.T[i_cond], i_cond]

# check that np.nans (coming from analysis where cells did not fulfill criteria to be included) match for all variables
assert (~np.isnan(p_vals_paired).any(axis=1) == ~np.isnan(tv_angles_paired).any(axis=1) ).all()
assert (~np.isnan(p_vals_paired).any(axis=1) == ~np.isnan(tv_lengths_paired).any(axis=1) ).all()

### Get persistent cels
(persistent cell = omnipresent cell that fulfilled requirement for getting included in place coding analysis for each condition)

In [None]:
i_persistent = ~np.isnan(p_vals_paired).any(axis=1)

In [None]:
tv_lengths_persistent = tv_lengths_paired[i_persistent]
tv_angles_persistent = tv_angles_paired[i_persistent]
p_vals_persistent = p_vals_paired[i_persistent]
assignments_persistent = assignments_omnipresent[i_persistent]

assert tv_lengths_persistent.shape == tv_angles_persistent.shape
assert tv_angles_persistent.shape == p_vals_persistent.shape
assert p_vals_persistent.shape == assignments_persistent.shape

* `assignments_persistent` contains one row per persistent cell where it fulfilled analysis criteria (minimum number of events...) for all included conditions. For each row, each column contains the original cell index in the recording of the corresponding `conditions_to_use` condition (i.e. `assignments_persistent[0][0]==8` means the first persistent cell is cell 8 (with indexing starting at 0) in the baseline recording. The same cell might be cell 253 in the second condition (`assignments_persistent[0][1]==253`) )
* `tv_lengths_persistent`, `tv_angles_persistent`, `p_vals_persistent` contain the tuning vector lengths, angles, and the p value, each row one neuron tracked over the conditions (that fulfilled analysis criteria). The rows and columns match those of `assignments_persistent` (i.e. the same cell, same condition is in the same row and column)

### Get persistent cells that are initially place coding (ipc) and not initially place coding (nipc)

In [None]:
i_ipc = np.where(p_vals_persistent[:,0] <= 0.05)[0]

In [None]:
tv_lengths_ipc = tv_lengths_persistent[i_ipc]
tv_angles_ipc = tv_angles_persistent[i_ipc]
p_vals_ipc = p_vals_persistent[i_ipc]
assignments_ipc = assignments_persistent[i_ipc]

tv_lengths_nipc = tv_lengths_persistent[~i_ipc]
tv_angles_nipc = tv_angles_persistent[~i_ipc]
p_vals_nipc = p_vals_persistent[~i_ipc]
assignments_nipc = assignments_persistent[~i_ipc]

# Analysis

## Check movement between place-coding, non-place-coding, "quiet" cells
Quiet cells: cells that were not included in PC analysis (minimum event number criterion not fulfilled)

In [None]:
# silent cells appear as np.nan in p_vals. Make sure they return FALSE for both PC and nPC conditions
assert not(np.nan > 0.05)
assert not(np.nan <= 0.05)
assert np.isnan(np.nan)

labels = [] 
colors = []
for condition in conditions_to_use:
  labels.extend([f"PC {condition}", f"nPC {condition}", f"Q {condition}"])  # for each condition, check categories PC, not-PC and quiet
  colors.extend(["red", "blue", "black"])  # 255, 0, 0;  0, 255, 0; 0, 0, 0
# in each condition, we have PC and nPC categories, each have PC and nPC targets in the next category
sources = []  # should be 0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5, ...
targets = []  # should be 3, 4, 5, 3, 4, 5, 3, 4, 5, 6, 7, 8, 6, 7, 8, 6, 7, 8, ...
values = []
link_colors = []
for i_condition in range(len(conditions_to_use)-1):  # last condition does not have output
  # PC, nPC and S sources flow to PC in target
  # i. e. PC[i_condition] -> PC[i_condition+1], nPC[i_condition] -> PC[i_condition+1], Q[i_condition] -> PC[i_condition+1]
  n_PC_to_PC = np.sum(np.logical_and(p_vals_paired[:,i_condition] <= 0.05,  p_vals_paired[:,i_condition+1] <= 0.05))
  n_nPC_to_PC = np.sum(np.logical_and(p_vals_paired[:,i_condition] > 0.05,  p_vals_paired[:,i_condition+1] <= 0.05))
  n_quiet_to_PC = np.sum(np.logical_and(np.isnan(p_vals_paired[:, i_condition]), p_vals_paired[:,i_condition+1] <= 0.05))
  sources.extend([3*i_condition,3*i_condition+1, 3*i_condition+2])
  targets.extend([3*(i_condition+1), 3*(i_condition+1), 3*(i_condition+1)])
  values.extend([n_PC_to_PC, n_nPC_to_PC, n_quiet_to_PC])
  link_colors.extend(["rgba(255, 0, 0, 0.4)", "rgba(0, 0, 255, 0.4)", "rgba(0, 0, 0, 0.4)"])  # PC -> x is light blue, nPC -> x is light red, Q -> y is "light black"

  # PC, nPC and Q sources flow to nPC in target
  # i. e. PC[i_condition] -> nPC[i_condition+1], nPC[i_condition] -> nPC[i_condition+1], Q[i_condition] -> nPC[i_condition+1]
  n_PC_to_nPC = np.sum(np.logical_and(p_vals_paired[:,i_condition] <= 0.05,  p_vals_paired[:,i_condition+1] > 0.05))
  n_nPC_to_nPC = np.sum(np.logical_and(p_vals_paired[:,i_condition] > 0.05,  p_vals_paired[:,i_condition+1] > 0.05))
  n_quiet_to_nPC = np.sum(np.logical_and(np.isnan(p_vals_paired[:, i_condition]), p_vals_paired[:,i_condition+1] > 0.05))
  sources.extend([3*i_condition,3*i_condition+1, 3*i_condition+2])
  targets.extend([3*(i_condition+1)+1, 3*(i_condition+1)+1, 3*(i_condition+1)+1])
  values.extend([n_PC_to_nPC, n_nPC_to_nPC, n_quiet_to_nPC])
  link_colors.extend(["rgba(255, 0, 0, 0.4)", "rgba(0, 0, 255, 0.4)", "rgba(0, 0, 0, 0.4)"])  # PC -> x is light blue, nPC -> x is light red, Q -> y is "light black"


  # PC, nPC and Q sources flow to S in target
  # i. e. PC[i_condition] -> Q[i_condition+1], nPC[i_condition] -> Q[i_condition+1], Q[i_condition] -> Q[i_condition+1]
  n_PC_to_Q = np.sum(np.logical_and(p_vals_paired[:,i_condition] <= 0.05,  np.isnan(p_vals_paired[:,i_condition+1]) ))
  n_nPC_to_Q = np.sum(np.logical_and(p_vals_paired[:,i_condition] > 0.05,  np.isnan(p_vals_paired[:,i_condition+1])  ))
  n_quiet_to_Q = np.sum(np.logical_and(np.isnan(p_vals_paired[:, i_condition]), np.isnan(p_vals_paired[:,i_condition+1])  ))
  sources.extend([3*i_condition,3*i_condition+1, 3*i_condition+2])
  targets.extend([3*(i_condition+1)+2, 3*(i_condition+1)+2, 3*(i_condition+1)+2])
  values.extend([n_PC_to_Q, n_nPC_to_Q, n_quiet_to_Q])
  link_colors.extend(["rgba(255, 0, 0, 0.4)", "rgba(0, 0, 255, 0.4)", "rgba(0, 0, 0, 0.4)"])  # PC -> x is light blue, nPC -> x is light red, Q -> y is "light black"


fig = go.Figure(data=[go.Sankey(
    node = dict(
      pad = 15,
      thickness = 20,
      line = dict(color = "black", width = 0.5),
      label = labels,
      color = colors
    ),
    link = dict(
      source = sources, # indices correspond to labels, eg A1, A2, A1, B1, ...
      target = targets,
      value = values,
      color = link_colors
  ))])

fig.update_layout(title_text="Place coding (PC) - non-place coding (nPC) - quiet (Q)", font_size=10)
fig.write_html("D:\\Downloads\\pc_npc_s.html")
fig.show()

## % of place cells at each time point

In [None]:
place_cell_ratio = np.zeros(len(p_vals))  # the % of place cells for each condition
for i_condition in range(len(p_vals)):
    place_cell_ratio[i_condition] = np.sum(p_vals[i_condition] <= 0.05)/len(p_vals[i_condition])


In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
axs[0].plot(place_cell_ratio*100.)
axs[0].set_xticks(range(len(conditions)), conditions)
axs[0].set_ylabel("% place cells")

axs[1].plot([len(p_vals[i_cond]) for i_cond in range(len(p_vals))])
axs[1].set_xticks(range(len(conditions)), conditions)
axs[1].set_ylabel("# cells (total)")

plt.show()

## Check persistent cells behaviour

In [None]:
labels = [] 
colors= []
for condition in conditions_to_use:
  labels.extend([f"PC {condition}", f"nPC {condition}"])  # for each condition, check categories PC and not-PC
  colors.extend(["red", "blue"])
# in each condition, we have PC and nPC categories, each have PC and nPC targets in the next category
sources = []  # should be 0, 1, 0, 1, 2, 3, 2, 3, ...
targets = []  # should be 2, 3, 2, 3, 4, 5, 4, 5, ...
values = []
link_colors = []
for i_condition in range(len(conditions_to_use)-1):  # last condition does not have output
  # PC and nPC sources flow to PC in target
  # i. e. PC[i_condition] -> PC[i_condition+1], nPC[i_condition] -> PC[i_condition+1]
  n_PC_to_PC = np.sum(np.logical_and(p_vals_persistent[:,i_condition] <= 0.05,  p_vals_persistent[:,i_condition+1] <= 0.05))
  n_nPC_to_PC = np.sum(np.logical_and(p_vals_persistent[:,i_condition] > 0.05,  p_vals_persistent[:,i_condition+1] <= 0.05))
  sources.extend([2*i_condition,2*i_condition+1])
  targets.extend([2*(i_condition+1), 2*(i_condition+1)])
  values.extend([n_PC_to_PC, n_nPC_to_PC])
  link_colors.extend(["rgba(255, 0, 0, 0.4)", "rgba(0, 0, 255, 0.4)"])  # PC -> x is light blue, nPC -> x is light red
  
  # PC and nPC sources flow to nPC in target
  # i. e. PC[i_condition] -> nPC[i_condition+1], nPC[i_condition] -> nPC[i_condition+1]
  n_PC_to_nPC = np.sum(np.logical_and(p_vals_persistent[:,i_condition] <= 0.05,  p_vals_persistent[:,i_condition+1] > 0.05))
  n_nPC_to_nPC = np.sum(np.logical_and(p_vals_persistent[:,i_condition] > 0.05,  p_vals_persistent[:,i_condition+1] > 0.05))
  sources.extend([2*i_condition,2*i_condition+1])
  targets.extend([2*(i_condition+1)+1, 2*(i_condition+1)+1])
  values.extend([n_PC_to_nPC, n_nPC_to_nPC])
  link_colors.extend(["rgba(255, 0, 0, 0.4)", "rgba(0, 0, 255, 0.4)"])  # PC -> x is light blue, nPC -> x is light red

fig = go.Figure(data=[go.Sankey(
    node = dict(
      pad = 15,
      thickness = 20,
      line = dict(color = "black", width = 0.5),
      label = labels,
      color = colors
    ),
    link = dict(
      source = sources, # indices correspond to labels, eg A1, A2, A1, B1, ...
      target = targets,
      value = values,
      color=link_colors
  ))])

fig.update_layout(title_text="Place coding (PC) - non-place coding (nPC) of persistent cells", font_size=10)
fig.write_html("D:\\Downloads\\pc_npc.html")
fig.show()

### Check tuning vector direction change/stability for initially place-coding cells

In [None]:
fig, ax = plt.subplots(subplot_kw={'projection': 'polar'}, figsize=(6,6))
#ax.set_yscale('log')
for angles in tv_angles_ipc:
    radii = [i+1 for i in range(len(angles))]#tv_vector_lengths_paired[i_unit]
    ax.plot(angles, radii, linewidth=0.3, marker='o')  # -pi to pi
plt.show()

# Save multisession registration results

In [None]:
#spatial_union, assignments, matchings