In [1]:
import matplotlib
%matplotlib tk
%autosave 180
%load_ext autoreload
%autoreload 2

import nest_asyncio
%config Completer.use_jedi = False

import os
os.chdir('/home/cat/code/manifolds/')

#
import matplotlib.pyplot as plt
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import scipy
import numpy as np
import pandas as pd

from calcium import calcium
from wheel import wheel
from visualize import visualize
from tqdm import trange, tqdm

from scipy.io import loadmat

import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from scipy.spatial import ConvexHull, convex_hull_plot_2d

# 
np.set_printoptions(suppress=True)


from utils import load_UMAP, load_binarized_traces, find_ensemble_order, load_data, HMM, get_footprint_contour, load_footprints, computing_ensemble_loadings



############### SSM FUNCTIONS ##########################
import autograd.numpy as np
import autograd.numpy.random as npr
npr.seed(0)

import ssm
from ssm.util import find_permutation
from ssm.plots import gradient_cmap, white_to_color_cmap




Autosaving every 180 seconds


In [4]:
############################################
################ LOAD DATA #################
############################################
root_dir = '/media/cat/4TB/donato/'
animal_id = 'DON-006084'
session = '20210519'
dim_type = 'pca'
#bin_type = 'F_onphase'  # ['upphase','onphase','spikes','spikes_smooth']
bin_type = 'F_upphase'  # ['upphase','onphase','spikes','spikes_smooth']

use_pca = True
use_rasters = False

pca_n_dim = 50
num_states = 100

#
X_pca, rasters = load_data(root_dir, 
                           animal_id, 
                             session,
                             dim_type,
                             bin_type)
rasters = rasters.T
print ("X_pca: ", X_pca.shape)
print ("rasters: ", rasters.shape)


X_pca:  (55740, 50)
rasters:  (55740, 531)


In [None]:
#########################################
################ RUN HMM ################
#########################################

#
if use_pca: 
    data = X_pca[:,:pca_n_dim]
elif use_rasters: 
    data = rasters #[:,:10]


obs_dim = data.shape[1]
print ("Data into hmm: ", data.shape)

#
hmm = ssm.HMM(num_states, 
              obs_dim, 
              observations="gaussian")

hmm_z = hmm.most_likely_states(data)
unique_states = np.unique(hmm_z)
print ("# of dis covered states: ", unique_states.shape[0])

# 
ctr=0
hmm_z2 = hmm_z.copy()*0
for id_ in unique_states:
    idx = np.where(hmm_z==id_)[0]
    hmm_z2[idx]=ctr
    
    ctr+=1
    
hmm_z = hmm_z2.copy()
    

In [56]:
np.save('/home/cat/hmm_z_pca.npy',hmm_z)

In [2]:
########################################################
########### PLOT ALL STATES TOGETHER ###################
########################################################

hmm_z = np.load('/home/cat/hmm_z_pca.npy')

# 
fig=plt.figure()
plt.subplot(111)
ctr=0
split = 20
img = []      # stack all times
yticks = []   # compute time windows as strings
for k in range(0,hmm_z.shape[0],hmm_z.shape[0]//split):
    
    #
    img.append(hmm_z.copy()[k:k+hmm_z.shape[0]//split][None,:])
    yticks.append(str(k//30)+" - "+str((k+hmm_z.shape[0]//split)/30))

img = np.vstack(img)
print ("IMG: ", img.shape)

#################################################
#################################################
#################################################
# change the background / most frequent state to darker color
if False:
    idx0 = np.where(img==0)
    idx11 = np.where(img==11)
    img[idx0]=11
    img[idx11]=0
    
    idx0 = np.where(hmm_z==0)[0]
    idx11 = np.where(hmm_z==11)[0]
    print (idx0.shape, idx11.shape, hmm_z.shape)
    hmm_z[idx0] = 11
    hmm_z[idx11] = 0

##############################################
##############################################
##############################################
ax=plt.subplot(1,1,1)

cmap = matplotlib.cm.get_cmap('gist_ncar', img.shape[0])

cax = ax.imshow(img, 
       aspect="auto", 
       extent=[k//30,(k+hmm_z.shape[0]//split)/30.,0,1],
       cmap=cmap,
       interpolation='none')

##############################################################
h = HMM()
h.get_hmm_stats(hmm_z)
ids = h.ids
n_occurance = h.n_occurance
total_durations = h.total_durations

print ("ids: ", ids)
ticks2 = []
for k in range(np.unique(ids).shape[0]):
    temp = str(k) + ",  #"+str(n_occurance[k])+",  "+str(round(total_durations[k],2))+"sec"
    ticks2.append(temp)

cbar = fig.colorbar(cax, ticks=np.arange(np.unique(ids).shape[0]))
cbar.ax.set_yticklabels(ticks2)  # vertically oriented colorbar





print (img.shape, yticks)
plt.yticks(np.arange(img.shape[0])/img.shape[0]+1/2./img.shape[0], 
           yticks[::-1])
           #rotation='vertical')

ax.yaxis.labelpad = 50
plt.xticks([])
    
plt.ylabel("time")
plt.show()

IMG:  (20, 2787)
ids:  [24 18  4 ... 30 28 24]
(20, 2787) ['0 - 92.9', '92 - 185.8', '185 - 278.7', '278 - 371.6', '371 - 464.5', '464 - 557.4', '557 - 650.3', '650 - 743.2', '743 - 836.1', '836 - 929.0', '929 - 1021.9', '1021 - 1114.8', '1114 - 1207.7', '1207 - 1300.6', '1300 - 1393.5', '1393 - 1486.4', '1486 - 1579.3', '1579 - 1672.2', '1672 - 1765.1', '1765 - 1858.0']


  ax=plt.subplot(1,1,1)


In [25]:
########################################################
########### PLOT EACH STATES VS TIME ###################
########################################################
fig=plt.figure()
unique_states = np.unique(hmm_z)
print ("unique states: ", unique_states)
img_temp = hmm_z.copy()

    
# 
ctr=0
for id_ in unique_states:
    ax=plt.subplot(unique_states.shape[0]//2+1,2,ctr+1)
    
    #
    
    if True:
        idx = np.where(img_temp==id_)[0]
        img2 = img_temp.copy()*0
        for id2 in idx:
            img2[id2:id2+10]=1
        
    # 
    ax.imshow(img2[None,:], 
           aspect="auto", 
           extent=[0,img2.shape[0]//30.,0,1],
           cmap='Greys')
                             
   
    plt.ylabel(str(id_),
               fontsize=10, rotation=0)
    #plt.xticks(fontsize=10)
    #plt.yticks(rotation=180)
    ax.yaxis.labelpad = 10
    plt.yticks([])
    plt.xticks([])
    ctr+=1
    
#plt.xlim(0, time_bins)
#plt.ylabel("$z_{\\mathrm{inferred}}$")
plt.yticks([])
plt.xlabel("time")

#plt.tight_layout()
plt.show()

unique states:  [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30]


In [26]:
########################################
####### STATISTICS OF STATES ###########
########################################

#
h = HMM()
h.get_hmm_stats(hmm_z)

#lens, ids, lens_per, n_occurance, total_durations = get_hmm_stats(hmm_z)

###########################################
fig=plt.figure()

# plot all distributions
ax=plt.subplot(221)
y = np.histogram(h.lens, bins=np.arange(0,30,1))
plt.plot(y[1][:-1], y[0])
plt.semilogy()
plt.xlabel("Ensemble duration (sec)")
plt.ylabel("# of ensembles ")
plt.title("# of ensemble transitions "+str(ids.shape[0]))


# plot all distributions
ax=plt.subplot(222)
y=np.histogram(h.ids,bins=np.arange(0,np.unique(h.ids).shape[0],1))
plt.bar(y[1][:-1], y[0], .9)
plt.semilogy()
plt.ylabel("# of occurances of ensemble")
plt.xlabel("ensemble ID")
plt.title("# unique (non-zero) ensembles: "+str(np.unique(h.ids).shape[0]))


# plot all distributions
ax=plt.subplot(223)
plt.scatter(h.lens_per, h.n_occurance)
plt.semilogy()
plt.semilogx()
plt.plot([np.min(h.lens_per)*.5,np.max(h.lens_per)*1.5],
         [np.min(h.n_occurance)*.5,np.max(h.n_occurance)*1.5])
plt.xlabel("Median duration of ensemble (sec)")
plt.ylabel("# of occurances")

plt.show()
    
    

In [27]:
##############################################
####### VISUALIZE STATES AS RASTERS ##########
##############################################
from tqdm import tqdm

hmm_z = np.load('/home/cat/hmm_z_pca.npy')
#hmm_z = np.load('/home/cat/hmm_z_rasters.npy')

#
state_id = 1
print ("Total duration of state: ", total_durations[state_id], 'sec')

# get hmm-based 
h = HMM()
h.get_hmm_stats(hmm_z)
#
idx = np.where(h.ids==state_id)

# grab segments for each type
segs = h.windows[idx]

# test the 3 different types of data
data_in = rasters.copy()
#data_in = X_pca.copy()
#data_in = hmm_z.copy()[:,None]
print ("Data_in: ", data_in.shape)

############ STACK IMAGE ##############
img = np.zeros((0,data_in.shape[1]))
blank = np.zeros((3,data_in.shape[1]))+np.nan
img_out = []
for s in tqdm(segs):
    temp = data_in[s[0]:s[1]+1]
    img_out.append(temp)
    img_out.append(blank)
    
img = np.vstack(img_out)
img = img.T

############# SHOW IMAGE ###############
fig=plt.figure()
plt.imshow(img,
          aspect='auto',
          interpolation='none',
          extent=[0,img.shape[1]/30.,0,img.shape[0]])
plt.xlabel("time (sec)")
plt.title("State: "+str(state_id),fontsize=20)
plt.show()


100%|██████████| 38/38 [00:00<00:00, 60281.22it/s]

Total duration of state:  11.133333333333333 sec
Data_in:  (55740, 531)





In [42]:
############################################################
####### VISUALIZE SINGLE STATE AS CELL ASSEMBLY ############
############################################################
root_dir = '/media/cat/4TB/donato/'
animal_id = 'DON-006084'
session = '20210519'

c = calcium.Calcium()
c.root_dir = root_dir
c.animal_id = animal_id
c.session = session
c.load_suite2p()

# 
imgs, imgs_all, imgs_bin, contours = load_footprints(c)

#
state_id = 2

#
#hmm_z = np.load('/home/cat/hmm_z_pca.npy')
hmm_z = np.load('/home/cat/hmm_z_rasters.npy')

h=HMM()
h.get_hmm_stats(hmm_z)
cell_sums = computing_ensemble_loadings(state_id, 
                                        h,
                                        rasters)

# compute assembly blueprint
state_blueprint = imgs_bin.transpose(1,2,0)*cell_sums
state_blueprint = state_blueprint.sum(axis=2)
cell_sums_norm = cell_sums/np.max(cell_sums)

###################################################
###################################################
###################################################
fig=plt.figure(figsize=(25,8))
ax=plt.subplot(131)
plt.imshow(imgs_all,
          aspect='auto',
          cmap='jet',
          interpolation='none')
plt.title("P19")

# 
ax=plt.subplot(132)
plt.imshow(state_blueprint,
          aspect='auto',
          cmap='jet',
          interpolation='none')

# 
ax=plt.subplot(133)
for k in range(len(c.stat)):
    contour = contours[k]
    
    plt.plot(contour[:,1],contour[:,0], 
             c='black',
            alpha=cell_sums_norm[k])
    
plt.xlim(0,512)
plt.ylim(512,0)
plt.suptitle("State: " +str(state_id))

#ax.set_facecolor('xkcd:white')
if False:
    plt.savefig('/home/cat/img.svg')
    plt.close()
else:
    plt.show()







100%|██████████| 352/352 [00:00<00:00, 388729.60it/s]


raster_state:  (531, 2157)
(531,)


In [24]:
############################################################
####### VISUALIZE ALL STATE AS CELL ASSEMBLIES #############
############################################################
root_dir = '/media/cat/4TB/donato/'
animal_id = 'DON-006084'
session = '20210519'

#
dim_type = 'pca'
#bin_type = 'F_onphase'  # ['upphase','onphase','spikes','spikes_smooth']
bin_type = 'F_upphase'  # ['upphase','onphase','spikes','spikes_smooth']
X_pca, rasters = load_data(root_dir, 
                           animal_id, 
                             session,
                             dim_type,
                             bin_type)

rasters = rasters.T
print ("RASTERS: ", rasters.shape)

spike_rates = rasters.sum(0)
print ("spike rates", spike_rates.shape)

    
# 
c = calcium.Calcium()
c.root_dir = root_dir
c.animal_id = animal_id
c.session = session
c.load_suite2p()

imgs, imgs_all, imgs_bin, contours = load_footprints(c)


#
h=HMM()
#hmm_z = np.load('/home/cat/hmm_z_pca.npy')
hmm_z = np.load('/home/cat/hmm_z_rasters.npy')
h.get_hmm_stats(hmm_z)

#
state_ids = np.argsort(h.total_durations)[::-1]
print ("state ids: ", state_ids)


# 
fig=plt.figure(figsize=(12,12))
ctr=0
for state_id in state_ids:
    ax=plt.subplot(6,6,ctr+1)

    # get cel sums for each ensemble
    cell_sums = computing_ensemble_loadings(state_id,
                                            h,
                                            rasters)
    
    # normalize to the spiking rate of each cell:
    cell_sums = cell_sums/spike_rates
    
    # 
    cell_sums_norm = cell_sums/np.max(cell_sums)

    ###################################################
    ###################################################
    ###################################################
    for k in range(len(c.stat)):
        contour = contours[k]

        plt.plot(contour[:,1],contour[:,0], 
                 c='black',
                 alpha=cell_sums_norm[k])

    plt.xlim(0,512)
    plt.ylim(512,0)
    plt.suptitle("State: " +str(state_id))
    plt.xticks([])
    plt.yticks([])
    plt.title("# "+str(h.n_occurance[state_id])+
              ", "+str(round(h.total_durations[state_id],1))+"sec")
    
    ctr+=1
    
plt.suptitle("State partition using PCA")
if False:
    #plt.savefig('/home/cat/img.png',dpi=100)
    plt.savefig('/home/cat/ensembles_rasters.svg')
    plt.close()
else:
    plt.show()


RASTERS:  (55740, 531)
spike rates (531,)




100%|██████████| 1538/1538 [00:00<00:00, 399259.74it/s]

state ids:  [11  5  2 10 14  6  4 13  9  0  1  3 12  7  8]
state id:  11
raster_state:  (531, 39815)



100%|██████████| 1480/1480 [00:00<00:00, 409794.69it/s]

state id:  5
raster_state:  (531, 12352)



100%|██████████| 352/352 [00:00<00:00, 386490.84it/s]

state id:  2
raster_state:  (531, 2157)



100%|██████████| 153/153 [00:00<00:00, 337929.71it/s]

state id:  10
raster_state:  (531, 756)



100%|██████████| 38/38 [00:00<00:00, 244078.95it/s]

state id:  14
raster_state:  (531, 190)



100%|██████████| 42/42 [00:00<00:00, 251658.24it/s]

state id:  6
raster_state:  (531, 169)



100%|██████████| 26/26 [00:00<00:00, 176745.39it/s]

state id:  4
raster_state:  (531, 116)



100%|██████████| 35/35 [00:00<00:00, 210617.85it/s]

state id:  13
raster_state:  (531, 76)



100%|██████████| 23/23 [00:00<00:00, 191027.71it/s]

state id:  9
raster_state:  (531, 70)



100%|██████████| 5/5 [00:00<00:00, 65331.84it/s]

state id:  0
raster_state:  (531, 19)



100%|██████████| 6/6 [00:00<00:00, 74455.10it/s]

state id:  1
raster_state:  (531, 7)



100%|██████████| 3/3 [00:00<00:00, 41527.76it/s]

state id:  3
raster_state:  (531, 6)



100%|██████████| 3/3 [00:00<00:00, 42799.02it/s]

state id:  12
raster_state:  (531, 4)



100%|██████████| 1/1 [00:00<00:00, 15420.24it/s]

state id:  7
raster_state:  (531, 2)



100%|██████████| 1/1 [00:00<00:00, 15252.01it/s]

state id:  8
raster_state:  (531, 1)





In [25]:
###################################################################
####### VISUALIZE TIME PREOGRESSION WITH SINGLE STATE #############
###################################################################
root_dir = '/media/cat/4TB/donato/'
animal_id = 'DON-006084'
session = '20210519'


#
dim_type = 'pca'
#bin_type = 'F_onphase'  # ['upphase','onphase','spikes','spikes_smooth']
bin_type = 'F_upphase'  # ['upphase','onphase','spikes','spikes_smooth']
X_pca, rasters = load_data(root_dir, 
                           animal_id, 
                             session,
                             dim_type,
                             bin_type)

rasters = rasters.T

# 
#
c = calcium.Calcium()
c.root_dir = root_dir
c.animal_id = animal_id
c.session = session
c.load_suite2p()

imgs, imgs_all, imgs_bin, contours = load_footprints(c)


#
h=HMM()
#hmm_z = np.load('/home/cat/hmm_z_pca.npy')
hmm_z = np.load('/home/cat/hmm_z_rasters.npy')
h.get_hmm_stats(hmm_z)

#
idx = np.argsort(h.total_durations)[::-1]
state_ids = idx


# 
#state_id = state_ids[0]
state_id = 2
cell_sums_longitudinal, state_durations = h.computing_ensemble_loadings_per_occurance(state_id, 
                                                                                    rasters)

# 
state_duration = 2 # seconds to bin


#
fig=plt.figure(figsize=(12,12))
ctr=0
start_t = 0
duration = 0
for t in trange(state_durations.shape[0]):

    #print (t, " /", state_durations.shape[0], "duration: ", duration)

    duration += state_durations[t]
    if duration>state_duration:
        duration -= state_durations[t]
        ctr = ctr+1

        # 
        ax=plt.subplot(6,6,ctr)

        # compute assembly blueprint
        cell_sums = cell_sums_longitudinal[start_t:t-1].sum(0)
        
            
        # normalize to the spiking rate of each cell:
        cell_sums = cell_sums/spike_rates

        # 
        cell_sums_norm = cell_sums/np.max(cell_sums)

        ###################################################
        ###################################################
        ###################################################
        for k in range(len(c.stat)):
            contour = contours[k]

            ax.plot(contour[:,1],contour[:,0], 
                     c='black',
                    alpha=cell_sums_norm[k])

        plt.xlim(0,512)
        plt.ylim(512,0)
        plt.suptitle("State: " +str(state_id))
        plt.xticks([])
        plt.yticks([])

        #
        plt.title(str(round(duration,1))+"sec")

        # 
        duration = 0
        start_t = t
        
    
##########################################
if False:
    #plt.savefig('/home/cat/img.png',dpi=100)
    plt.savefig('/home/cat/ensembles_rasters.svg')
    plt.close()
else:
    plt.show()






100%|██████████| 352/352 [00:00<00:00, 18072.70it/s]
  0%|          | 0/352 [00:00<?, ?it/s]

state id:  2
cell sums longitudinal:  (352, 531)


100%|██████████| 352/352 [00:13<00:00, 26.89it/s]


In [40]:
import sklearn.metrics
times = []
for k in range(len(ensemble_times)):
    temp = ensemble_times[k]/30
    diff = temp[1:]-temp[:-1]
    
    mx = np.max(diff)
    #sklearn.metrics.pairwise_distances(temp)
    
    
    #
    times.append(mx)
  
times = np.hstack(times)
y = np.histogram(times, bins = np.arange(0,50000/30,1))

fig=plt.figure()
plt.plot(y[1][1:],y[0])
plt.semilogy()
plt.semilogx()
plt.ylim(bottom=1)
plt.xlabel("Maximum time between ensemble repeat (sec)")
plt.ylabel("# of ensembles")
plt.show()

In [41]:
##############################################################################
############ REORDER RASTERS BY CLUSTER - FROM GPU SAVED DATA ################
##############################################################################

root_dir = '/media/cat/4TB/donato/'
animal_id = 'DON-006084'
session = '20210519'
dim_type = 'pca'

# 
data = np.load(os.path.join(root_dir,
                     animal_id,
                     session,
                     'suite2p',
                     'plane0',
                     'res_dbscan_'+dim_type+'.npz'))



X_pca = data['X_pca']
X_clean = data['X_clean']
clusters = data['db_clean']
times = data['times']

print ("clusters: ", np.unique(clusters).shape, times.shape, times)


#########################################################
#########################################################
#########################################################
print (rasters.shape)
rasters_reordered = np.zeros(rasters.shape).T
print ("rasters reordered: ", rasters_reordered.shape)
ctr=0
for c in np.unique(clusters):
    idx = np.where(clusters==c)[0]
    times_original = times[idx]
    
    times_new = np.arange(ctr,ctr+idx.shape[0],1)

    #print (c, idx.shape, idx[:10])
    
    # 
    rasters_reordered[times_new]= rasters[:,times_original].T
    
    rasters_reordered[times_new[-4]:times_new[-1]]= 1
    #
    ctr+=idx.shape[0]
    
# 
rasters_reordered = rasters_reordered.T

#
idx = np.argsort(rasters_reordered.sum(axis=1))[::-1]

rasters_reordered = rasters_reordered[idx]

# 
fig=plt.figure()
plt.imshow(rasters_reordered,
           cmap='Greys',
           aspect='auto',
          interpolation=None)
plt.show()





clusters:  [  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35
  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53
  54  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71
  72  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89
  90  91  92  93  94  95  96  97  98  99 100 101 102 103 104 105 106 107
 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
 234 235 236 237 238 239 240 241 242 243

In [None]:
# ###############################################################
# ########## VISUALIZE DISTRIBUTION OF CLUSTERS SIZES ###########
# ###############################################################
# lens = []
# for k in np.unique(clusters):
#     idx = np.where(clusters==k)[0]
#     lens.append(idx.shape[0])

# y = np.histogram(lens, bins = np.arange(0,60000,10))
# plt.plot(y[1][1:]/30.,
#         y[0])
# plt.xlabel("Duration of cluster (sec)")
# plt.ylabel("# of ensembles/clusters")
# plt.xlim(0.1,y[1][-1]/30.)
# plt.ylim(bottom=0.9)
# plt.semilogy()
# plt.semilogx()

# plt.show()