In [1]:
import matplotlib
#matplotlib.use('Agg')
%matplotlib tk
%autosave 180

%load_ext autoreload
%reload_ext autoreload
%autoreload 2

# 
import sys
sys.path.append("/home/cat/code/widefield/") # Adds higher directory to python modules path.

import numpy, scipy.optimize

import h5py
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
from sklearn.cross_decomposition import CCA
import torch
from tqdm import tqdm
import pandas as pd
import scipy

import os
import time
import warnings
warnings.filterwarnings("ignore")

from scipy import signal
from scipy.fft import fftshift
from scipy.signal import savgol_filter
from numpy.polynomial import polynomial as P
from scipy.signal import argrelmax


from locanmf import LocaNMF, postprocess
from locanmf import analysis_fig4 




Autosaving every 180 seconds


In [2]:
######################################################
######################################################
######################################################

def load_raw_data(spatial_fname, temporal_fname):
    # GRAB AND RECONSTRUCT DATA
    spatial = np.load(spatial_fname) 
    temporal = np.load(temporal_fname)
    temporal = temporal.transpose(0,2,1)

    #
    print (spatial.shape)
    print (temporal.shape)

    #
    print ("reconstructing data: ")
    data = np.matmul( temporal, spatial)
    print (data.shape)

    #
    print ("getting mean of data: ")
    data_mean = data.mean(0)
    print ("data_mean: ", data_mean.shape)
    # compute variance in raw data- not used
    # var = np.var(data2d, axis=0)
    # print ("var: ", var.shape)

    ######################################
    ###### COMPUTE RAW ROI ACTIVITY ######
    ######################################
    data2D = data_mean.reshape(data_mean.shape[0], 128,128)
    print ("Data mean 2D: ", data2D.shape)

    # 
    means = []
    ctr=0
    for id_ in ordered_names:
        area_id = areas[id_]
        idx = np.where(atlas==area_id)
        print (ctr, "areaId: ", area_id, names[id_], idx[0].shape)
        mask = np.zeros((128,128),'float32') + np.nan
        mask[idx] = 1

        temp = data2D*mask
        roi = np.nanmean(temp, axis=1)
        roi = np.nanmean(roi, axis=1)
        means.append(roi)

        ctr+=1   

    #
    raw_means = np.array(means)
    print ("Raw data means: ", raw_means.shape)

    return raw_means



def load_locaNMF_data(fname_locaNMF):
    # order locaNMF components by plot color ORDER in Fig 4A
    ordered_names = np.array([15,0,14,1,   # retrosplenial areas
                          13,2,
                          12,3,
                          11,4,
                          10,5,
                          9,6,
                          8,7])[::-1]


    # load raw data
    try:
        d = np.load(fname_locaNMF)
    except:
        #print ("file missing", fname_locaNMF)
        return None, None, None, None, None
        
    locaNMF_temporal = d['temporal_trial']
    locaNMF_temporal_random = d['temporal_random']
    locaNMF_temporal = locaNMF_temporal[:,ordered_names]
    locaNMF_temporal_random = locaNMF_temporal_random[:,ordered_names]
    # print ("locanmf data: ", locaNMF_temporal.shape)

    # 
    areas = d['areas'][ordered_names]
    names = d['names'][ordered_names]
    #print ("original names: ", names.shape)

    #
    atlas = np.load('/home/cat/code/widefield/locanmf/atlas_fixed_pixel.npy')
    #print ("atlas: ",atlas.shape)
    # print (areas)
    # print (names)

    #print ("  # of ordered_names: ", ordered_names.shape)
    #print ("ORDERED NAMES: ", names[ordered_names])

    
    return atlas, areas, names, locaNMF_temporal, locaNMF_temporal_random
  
        
# #
def compute_variance_locaNMF(locaNMF_temporal):
    # 
    t = np.arange(locaNMF_temporal.shape[2])/30 - 30
    means = []
    var = []
    #for k in ordered_names:
    for k in range(locaNMF_temporal.shape[1]):
        temp = locaNMF_temporal[:,k].mean(0)
        means.append(temp)

        # 
        temp = np.var(locaNMF_temporal[:,k],axis=0)
        var.append(temp)

    #
    means = np.array(means)#[:,:900]
    var = np.array(var)#[:,:900]
    #print (means.shape, var.shape)

    return means, var


def plot_variance_locaNMF(var):
    colors = plt.cm.jet(np.linspace(0,1,var.shape[0]))

    scale1 = 1
    scale2 = .005
    linewidth = 2
    # scale3 = np.nan
    t = np.arange(var.shape[1])/30-30
    #print ("t: ", t)
    fig=plt.figure(figsize=(10,10))
    for k in range(var.shape[0]):
        ax=plt.subplot(121)
        plt.xlim(-15,0)
        plt.xticks([])
        plt.yticks([])

        # 
        temp = var[k]
        temp = temp*scale1+k*scale2

        # 
        #plt.plot(t,temp,'--', c=colors[k],
        plt.plot(t,temp, c=colors[k],
                linewidth=linewidth)

        plt.plot([-15,0], [k*scale2,k*scale2],'--',c=colors[k],alpha=.5)

    if False:
        plt.savefig('/home/cat/variance.png',dpi=300)
        plt.close()
    else:
        plt.show()
  

def load_variances(animal_id):
    session_name = 'all'
    root_dir = '/media/cat/4TBSSD/yuki/'

    # 
    loca = analysis_fig4.LocaNMFClass(root_dir, animal_id, session_name)

    #
    loca.get_sessions(session_name)
    print ("# of sessions: ", loca.sessions.shape)

    # 
    all_means = []
    saved_names = []
    all_means_random = []
    good_sessions = []
    all_vars = []
    all_vars_random = []
    all_n_trials = []
    all_trials = []
    all_random = []
    for session in tqdm(loca.sessions):

        # load data
        fname_locaNMF = os.path.join(root_dir, animal_id, 'tif_files',session,
                                     session + '_locanmf.npz')
        
        # 
        atlas, areas, names, locaNMF_temporal, random = load_locaNMF_data(fname_locaNMF)

        # 
        if atlas is not None:
            means, var = compute_variance_locaNMF(locaNMF_temporal)
            all_n_trials.append(locaNMF_temporal.shape[0])
            all_means.append(means)
            all_vars.append(var)
            # 
            means_r, var_r = compute_variance_locaNMF(random)
            all_means_random.append(means_r)
            all_vars_random.append(var_r) 

            saved_names = names
            good_sessions.append(session)
            
            #
            all_trials.append(locaNMF_temporal)
            all_random.append(random)

    #print ("# good sessions: ", len(good_sessions))
    
    return all_means, all_vars, saved_names, all_n_trials, all_trials, all_random

############################################################
def compute_and_plot_variance_peaks(var, 
                                   means,
                                   n_trials,
                                   saved_names,
                                   session_id,
                                   plotting=False,
                                   smoothing=False):
    
    # 
    colors = plt.cm.jet(np.linspace(0,1,var.shape[0]))

    scale1 = 1.5E-4
    scale2 = 1E-6
    
    linewidth = 2

    # 
    if plotting:
        fig = plt.figure(figsize=(10,10))
        ax1 = plt.subplot(3,1,1)
        ax2 = plt.subplot(3,1,2)
        plt.xticks(np.arange(16), saved_names, 
                   rotation='vertical',
                   fontsize=14)
        #
        ax3 = plt.subplot(3,1,3)
        
        ax1.set_xlim(-30,0)
        
    # 
    ratios = []
    times = []
    # loop over all ROIs
    for k in range(var.shape[0]):

        # 
        temp = var[k]
        
        if True:
            temp3 = temp                    
        else:
            temp2 = means[k]

            temp2 += 100

            temp3 = temp/temp2
            temp3 = temp3*scale1+k*scale2

            # 
        if smoothing:
            window_length = 31
            polyorder = 3
            temp3 = scipy.signal.savgol_filter(temp3, 
                                               window_length, polyorder)
                                       
        # 
        if plotting:
            t = np.arange(temp3.shape[0])/30-30

            ax1.plot(t,temp3*scale1+k*scale2, c=colors[k],
                     linewidth=linewidth)

            ax1.plot([-30,0], [k*scale2,k*scale2],'--',c='black',alpha=.5)
        
        # compute mean and plot it
        trace_control = np.mean(temp3[:temp3.shape[0]//2])
        if plotting:
            ax1.plot([-30,0], [trace_control*scale1+k*scale2,
                           trace_control*scale1+k*scale2],'--',c=colors[k], alpha=.5)

        
        #
        min_val = np.min(temp3[temp3.shape[0]//2:])
        ratio = min_val/trace_control
        
        # NOTE WE ONLY CONSIDER last 5 sec of data and exlcuding last 500ms
        #arg_min = np.argmin(temp3[temp3.shape[0]//2:-15])
        # print ('temp3 shape: ', temp3.shape)
        #arg_min = np.argmin(temp3[750:900-15])
        
        
        
        #arg_min = np.argmin(temp3[750:900-15])
        arg_min = np.argmin(temp3[450:900])
        arg_min = arg_min/30.-15

        
        
        # 
        if plotting:
            arg_min_int = int(arg_min*30+15)
            
            ax1.scatter(t[arg_min_int+temp3.shape[0]//2],
                        temp3[arg_min_int+temp3.shape[0]//2]*scale1+k*scale2, 
                        #color=colors[k],
                        color='black',
                        s=100)
        
        ratios.append(ratio)
        times.append(arg_min)

        if plotting:
            ax2.scatter(k, ratio,
                        s=300,
                       #color=colors[k]
                        color='black'
                       )
            
            ax2.set_ylim(0,1.0)
            ax2.set_xlim(-0.5,15.5)

            if k==0:
                for p in range(0,100,25):
                    ax2.plot([-0.5,16], [p/100.,p/100.],'--', c='black',alpha=.5)
                
            ax3.scatter(k, arg_min,
                        s=300,
                        #color=colors[k]
                        color='black',
                        )
            
            ax3.set_ylim(-5.0,0)
            ax3.set_xlim(-0.5,15.5)

            if k==0:
                for p in range(-5,1,1):
                    ax3.plot([-0.5,16],[p,p],'--', c='black',alpha=.5)
            #ax3.set_xticks([])
            #ax3.set_yticks([])
        
        # 
        if False:
            print (k, " area ", saved_names[k], 
               " ratio change: ", round(ratio,2),
               " time: ", round(arg_min,2))
        
        #return 
# 
    if plotting:
        plt.suptitle(animal_id+"  session: "+str(session_id) + 
             "    n trials: "+str(n_trials))

        if True:
            fig.savefig('/home/cat/variance.svg',dpi=300)
            #fig2.savefig('/home/cat/ratios.svg',dpi=300)
            #fig3.savefig('/home/cat/times.svg',dpi=300)
            plt.close('all')
        else:
            plt.show()

    return ratios, times

# 
#    
def plot_histograms_ratios(all_ratios, all_times, color,
                          min_ratio_threshold = 0.15):
    min_ratio_threshold = 0.15
    bin_width = .25
    bins = np.arange(-5,0.5,bin_width)
    good_vals = []

    plotting = False
    for k in range(all_ratios.shape[0]):
        idx = np.where(all_ratios[k]<min_ratio_threshold)[0]

        good_times = all_times[k][idx]
        good_vals.append(good_times)

        if plotting:
            y = np.histogram(good_times,bins = bins)

            plt.plot(y[1][1:],y[0],
                    c = colors[k])

    good_vals = np.hstack(good_vals)
    y = np.histogram(good_vals,bins = bins)

    plt.plot(y[1][1:],y[0],
            c = color)
    

    
def plot_variance_peaks_3_plots(all_vars, all_means, all_n_trials,
                                saved_names, plotting,
                                session_id = None,
                                smoothing = False):
    
    
    if session_id is None:
        session_ids = np.arange(len(all_vars))
        if plotting==True:
            print (" CAN't plot all sesssions, exiting")
            return
    else:
        session_ids = [session_id]

    
    all_ratios = []
    all_times = []
    for session_id in session_ids:
        # session_id = 3
        # print (session_id, "n_trials: ", all_n_trials[session_id])

        # 
        ratios, times = compute_and_plot_variance_peaks(all_vars[session_id], 
                                               all_means[session_id],
                                               all_n_trials[session_id],
                                               saved_names,
                                               session_id,
                                               plotting,
                                               smoothing)
        all_ratios.append(ratios)
        all_times.append(times)

        # if plotting input sessoin_id manually and use break
        # break

    all_ratios = np.array(all_ratios)
    all_times = np.array(all_times)
    print (all_times.shape)
    print ("DONE")
    
    return all_ratios, all_times


def plot_box_plots_times(all_ratios, all_times, 
                         min_ratio_threshold=0.15):
    
    
    codes = ['Retrosplenial', 'barrel', 'limb', 'visual','motor']
    #codes = ['limb, layer 1 - right', 'limb, layer 1 - left']
    clrs_local = ['black','blue','red','magenta', 'pink','brown']
    
    # 
    bin_width = .25
    bins = np.arange(-5,0.5,bin_width)

    # 
    edts = []
    for a in range(len(all_ratios)):
        good_vals = []
        for k in range(all_ratios[a].shape[0]):
            idx = np.where(all_ratios[a][k]<min_ratio_threshold)[0]

            if idx.shape[0]>0:
                good_times = all_times[a][k][idx]
                good_vals.append(good_times)

            # 
        good_vals = np.hstack(good_vals)

        edts.append(good_vals)
        
    print ("M2: ", edts[1])
    
    # 
    my_dict = dict(M1 = edts[0], 
                   M2 = edts[1],
                   M3 = edts[2],
                   M4 = edts[3],
                   M5 = edts[4],                      
                   M6 = edts[5]
                     )

    data = pd.DataFrame.from_dict(my_dict, orient='index')
    data = data.transpose()

    # 
    flierprops = dict(marker='o', 
                      #markerfacecolor='g', 
                      #markersize=10000,
                      linestyle='none', 
                      markeredgecolor='r')

    #
    data.boxplot(showfliers=False,
                flierprops=flierprops)


    for i,d in enumerate(data):
        y = data[d]
        x = np.random.normal(i+1, 0.04, len(y))
#         plt.plot(x, y, 
#                  mfc = clrs_local[i], 
#                  mec='k', 
#                  ms=7, 
#                  marker="o", 
#                  linestyle="None",
#                     )
        #
        colors = plt.cm.viridis(np.linspace(0,1,len(edts[i])))
        x = np.random.normal(i+1, 0.04, len(edts[i]))
        print (i,d, ' y shape: ', y.shape)
        plt.scatter(x, edts[i], 
                   #c=clrs_local[i],
                   c=colors,
                   edgecolor='black',
                   s=200,
                   #alpha=np.linspace(.2, 1.0, x.shape[0])
                   alpha=.5
                   )

        
        
        res = scipy.stats.normaltest(edts[i])
        print ("i: ", i, '  res: ', res)
        
#     pvals = [0.05,0.01,0.001,0.0001,0.00001]
#     patches = []
#     if len(codes)!=2:
#         for p in [0,2,4]:
#             res = stats.ks_2samp(edts_saved[p], edts_saved[3])
#             label_ = ''
#             for k in range(len(pvals)):
#                 if res[1]<pvals[k]:
#                     label_ = label_ + "*"
#                 else:
#                     break

#                 patches.append(mpatches.Patch(color='blue', label=label_))





def load_locaNMF_temporal(animal_id, session_name, root_dir,
                         session_lid):

    loca = analysis_fig4.LocaNMFClass(root_dir, animal_id, session_name)

    #
    loca.get_sessions(session_name)
    print ("sessions: ", loca.sessions.shape)
    print ("selected session: ", loca.sessions[session_id])

    session = loca.sessions[session_id]

    # load data
    fname_locaNMF = os.path.join(root_dir, animal_id, 'tif_files',session,
                                 session + '_locanmf.npz')


    atlas, areas, names, locaNMF_temporal, locaNMF_temporal_random = load_locaNMF_data(fname_locaNMF)

    return atlas, areas, names, locaNMF_temporal, locaNMF_temporal_random 



def fit_sin(tt, yy):
    
    '''Fit sin to the input time sequence, 
        and return fitting parameters 
        "amp", "omega", "phase", "offset", "freq", "period" and "fitfunc"
        
        
    '''
    
    # 
    tt = numpy.array(tt)
    yy = numpy.array(yy)
    ff = numpy.fft.fftfreq(len(tt), (tt[1]-tt[0]))   # assume uniform spacing
    Fyy = abs(numpy.fft.fft(yy))
    guess_freq = abs(ff[numpy.argmax(Fyy[1:])+1])   # excluding the zero frequency "peak", which is related to offset
    guess_amp = numpy.std(yy) * 2.**0.5
    guess_offset = numpy.mean(yy)
    guess = numpy.array([guess_amp, 2.*numpy.pi*guess_freq, 0., guess_offset])

    def sinfunc(t, A, w, p, c):  
        return A * numpy.sin(w*t + p) + c
    
    popt, pcov = scipy.optimize.curve_fit(sinfunc, tt, yy, p0=guess)
    A, w, p, c = popt
    f = w/(2.*numpy.pi)
    fitfunc = lambda t: A * numpy.sin(w*t + p) + c
    
    # 
    return {"amp": A, 
            "omega": w, 
            "phase": p, 
            "offset": c, 
            "freq": f, 
            "period": 1./f, 
            "fitfunc": fitfunc, 
            "maxcov": numpy.max(pcov), 
            "rawres": (guess,popt,pcov)
           }


def plot_phases(animal_id,
               session_name,
               root_dir,
               session_id,
               random_flag,
               areas_to_plot_phases,
               start,
               end,
               show_area_id,
               codes,
               clrs_local,
               plotting):
    
    (atlas, 
     areas, 
     names, 
     locaNMF_temporal, 
     locaNMF_temporal_random) = load_locaNMF_temporal(animal_id, 
                                                      session_name, 
                                                      root_dir,
                                                      session_id)
#     for ctr, name in enumerate(names):
#         print (ctr, name)
        
    if atlas is None:
        print ("session is empty ")
        return atlas
    #
    if plotting:
        fig=plt.figure(figsize=(15,15))
        colors = plt.cm.viridis(np.linspace(0,1,len(locaNMF_temporal)))

    # loop over areas 
    t0_phases=[]

    for ctr_area, area_sel in enumerate(areas_to_plot_phases):
        
        t0_phases.append([])

        #
        areas_selected = []
        for k in range(len(names)):
            if area_sel in names[k]:
                areas_selected.append(k)

        # 
        if random_flag:
            locaNMF_temporal = locaNMF_temporal_random

        #
        areas_selected = np.array(areas_selected)
        print ("areas_selected", areas_selected)
        locaNMF_temporal2 = locaNMF_temporal[:,areas_selected].squeeze()
        print ("locaNMF_temporal: ", locaNMF_temporal2.shape)

        # 
        missed_fit = 0
        t = np.arange(locaNMF_temporal2.shape[1])/30-30
        for k in range(locaNMF_temporal2.shape[0]):
            if ctr_area==show_area_id:
                if plotting:
                    ax=plt.subplot(3,1,1)
                    plt.plot(t,locaNMF_temporal2[k], 
                         c='black',
                         linewidth=3,
                         alpha=.1)

            # fit sinusoid to the single trial data
            tt = t[start:end]
            yy = locaNMF_temporal2[k][start:end]

            # 
            try:
                tt2 = np.arange(0, locaNMF_temporal2[k].shape[0],1)/30.-30
                res = fit_sin(tt, yy)
                curve = res["fitfunc"](tt2)
                #print ("CURVE: ", curve.shape)
                # if curve fit, extrapolate all the way to -30sec
                
                if plotting:
                    if ctr_area==show_area_id:
                        ax=plt.subplot(3,1,2)
                        plt.plot(tt2, curve,  
                             linewidth=2,
                             c=clrs_local[show_area_id],
                             alpha=.5
                            )

                # get peaks of curves
                t0_amp = curve[900]
                phase = (t0_amp-np.min(curve))/(np.max(curve)-np.min(curve))

                ##############################################################
                # detect phase of curve by comparing t0 and t0-1 timestep
                if (curve[899] - t0_amp)<0:
                    #pass
                    phase+=1

                phase *= 3.14159
                #
                t0_phases[ctr_area].append(phase)

            #
            except:
                missed_fit+=1
                pass


        ####################################
        if ctr_area==show_area_id:
            mean = np.mean(locaNMF_temporal2,axis=0)
            if plotting:
                ax=plt.subplot(3,1,1)
                plt.plot(t,mean, 
                         c='black',
                         linewidth=8,
                         alpha=1)

                # 
                plt.plot([-30,0],[0,0],'--',c='grey')
                plt.xlim(-15,0)
                plt.ylim(-0.15,0.15)


            # fit sinusoid to trial average
            try:
                tt = t[start:end]
                tt2 = np.arange(0,mean.shape[0],1)/30.-30
                yy = mean[start:end]

                res = fit_sin(tt, yy)
                #print( "Amplitude=%(amp)s, Angular freq.=%(omega)s, phase=%(phase)s, offset=%(offset)s, Max. Cov.=%(maxcov)s" % res )
                
                if plotting:
                    ax=plt.subplot(3,1,2)
                    #plt.plot(tt, yy, "-k", label="y", linewidth=2)
                    plt.plot(tt2, res["fitfunc"](tt2), linewidth=8,
                            c=clrs_local[show_area_id]
                            )

                    # 
                    plt.plot([-30,0],[0,0],'--',c='grey')
                    plt.xlim(-15,0)
                    plt.ylim(-0.15,0.15)
            except:
                pass
                

    ###############################################################
    ###############################################################
    ###############################################################

    my_dict = dict(restrosplenial = t0_phases[0],
                   barrel = t0_phases[1],
                   limb = t0_phases[2],
                   visual = t0_phases[3],
                   motor = t0_phases[4])

    data = pd.DataFrame.from_dict(my_dict, orient='index')
    data = data.transpose()

    # 
    flierprops = dict(marker='o', 
                      #markerfacecolor='g', 
                      #markersize=10000,
                      linestyle='none', 
                      markeredgecolor='r')

    #
    if plotting:
        ax = fig.add_subplot(313)
        data.boxplot(showfliers=False,
                flierprops=flierprops)

        # 
        for i,d in enumerate(data):
            colors = plt.cm.viridis(np.linspace(0,1,len(t0_phases[i])))
            y = data[d]
            x = np.random.normal(i+1, 0.04, len(y))
            if False:
                plt.plot(x, y, 
                     mfc = 'black', 
                     mec='k', 
                     ms=7, 
                     marker="o", 
                     linestyle="None",
                        )
            #
            else:
                x = np.random.normal(i+1, 0.04, len(t0_phases[i]))
                print (i,d, ' y shape: ', y.shape)
                plt.scatter(x, t0_phases[i], 
                           #c=clrs_local[i],
                           c=clrs_local[i],
                           edgecolor='black',
                           s=200,
                           #alpha=np.linspace(.2, 1.0, x.shape[0])
                           alpha=.2
                           )

        plt.ylim(0,2*np.pi)
        
    print ("# of non-fit trias: ", missed_fit)
    
    return t0_phases


def plot_polar_plots(t0_phases):
    
    fig2=plt.figure(figsize=(15,5))

    # 
    for k in range(len(t0_phases)):
        ax = fig2.add_subplot(1,5,k+1, projection='polar')

        N = 16

        #
        conversion= 1

        # 
        phases1 = np.array(t0_phases[k])*conversion
        y = np.histogram(phases1, 
                        bins = np.linspace(0,
                                           2*np.pi*conversion, N)
                        #bins = np.arange(0,360+45,45),
                        )
        #
        theta = y[0]#[:-1]
        theta = theta/np.max(theta)  # not necessary
        radii = y[1][1:]

        #
        width = (2*np.pi) / (N)
        ax.bar(radii, theta, 
                      width=width,
               color=clrs_local[k])
        
        # 
        ax.set_yticklabels([])
        plt.title(codes[k])
        ax.yaxis.grid(False)

    plt.suptitle(animal_id+ " "+str(session_id))

    if True:
        plt.savefig('/home/cat/polar.svg')
        plt.close()
    else:
        plt.show()

In [3]:
######################################################
###### LOAD LOCANMF DATA AND CORRECT ORDER MAPS ######
######################################################

# 
animal_id = 'IA3'
all_means, all_vars, saved_names, all_n_trials, all_trials, all_random = load_variances(animal_id)

  5%|▍         | 2/44 [00:00<00:02, 15.08it/s]

# of sessions:  (44,)


100%|██████████| 44/44 [00:03<00:00, 13.76it/s]


In [240]:
####################################################################
##### PLOT LONGITUDINAL MEANS AND VARIANCES FOR ALL SESSIONS #######
####################################################################

# def find_first_variance_decrease_point(data_in, s1, e1, std_factor, ctr):


#     if False:
#     #if ctr!=3:
#         data_in = savgol_filter(data_in, 31, 2)

#     # find std of up to 10 sec prior to pull
#     std = np.std(data_in[s1:e1], axis=0)

#     # find mean up to 10 sec prior to pull
#     mean2 = np.mean(data_in[s1:e1], axis=0)
#     #mean2 = np.mean(data_in[0:e1], axis=0)

#     # 
#     idx = np.where(data_in[700:900]<(mean2-std*std_factor[ctr]))[0]
#     #idx = np.where(np.abs(data_in-mean2)>=std*std_factor[ctr])[0]
    
#     # 
#     if idx.shape[0]==0:
#         return np.nan
#     else:
#         return idx[0]

def find_first_variance_decrease_point2(data_in, s1, e1, std_factor, ctr,
                                       animal_id,
                                       n_vals_below_thresh=30):
    
    # 
    if False:
    #if ctr!=3:
        data_in = savgol_filter(data_in, 31, 2)

    # find std of up to 10 sec prior to pull
    std = np.std(data_in[s1:e1], axis=0)

    # find mean up to 10 sec prior to pull
    mean2 = np.mean(data_in[0:e1], axis=0)

    # do rolling evalution to find location when next N vals are belw threhsold
    idx_out = np.nan
    #n_vals_below_thresh = 30
    window = [20*30,30*30 ]
    for k in range(window[0], 
                   window[1], 1):
        # ensure that several time points in a row are below
        temp = data_in[k:k+n_vals_below_thresh]
        #print ("TEMP: ", temp)
        #print ("mean2: ", mean2)
        #print ("std*std_factor[ctr]: ", std*std_factor[ctr])
        if animal_id !='IA2':
            if np.all(temp<=(mean2-std*std_factor[ctr])):
                idx_out = k 
                break
        else:
            if np.all(temp>=(mean2+std*std_factor[ctr])):
                idx_out = k 
                break

    #
    if idx_out>(900-30//2):
        idx_out=np.nan
        
    return idx_out
        
# 
animal_id = 'AQ2'
all_means, all_vars, saved_names, all_n_trials, all_trials, all_random = load_variances(animal_id)
        

# Fig 5B
# 
print (len(all_vars))
session_ids = np.arange(len(all_vars))

# 
roi_ids = [15,5,9,11,1]

# for k in range(len(saved_names)):
#     print (k, saved_names[k])
# print ("plotitng: ", saved_names[roi_ids])

colors = plt.cm.viridis(np.linspace(0,1,len(all_vars)))


###############################################
###############################################
###############################################
plot_vars = True
fig = plt.figure(figsize=(20,12))
std_factor = [2,2,2,4,2]
n_vals = [10,15,20,20,20,30]
first_decay = []

# plot first decay point
s1 = 400
e1 = 600
#
for ctr, roi_id in enumerate(roi_ids):
    
    first_decay.append([])
    
    # 
    ax = plt.subplot(2,3,ctr+1)
    all_ = []
    for ctr_sess, session_id in enumerate(session_ids):
        if plot_vars:
            temp = all_vars[session_id][roi_id].copy()
        else:
            temp = all_means[session_id][roi_id].copy()
        t = np.arange(temp.shape[0])/30.-30

        # 
        if False:
            temp = savgol_filter(temp, 31, 2)

        elif True:
            fs = 30
            fc = 5  # Cut-off frequency of the filter

            w = fc / (fs / 2) # Normalize the frequency
            b, a = signal.butter(5, w, 'low')
            temp = signal.filtfilt(b, a, temp)

        # 
        #temp[:400]=0
        #temp[1350:]=0
        plt.plot(t, temp, 
                 color=colors[ctr_sess], 
                 alpha=.5,
                 linewidth=3)

        idx = find_first_variance_decrease_point2(temp, s1, e1, std_factor, ctr,
                                                 animal_id,
                                                 n_vals[ctr])

        first_decay[ctr].append(idx)
        
        all_.append(temp)

    # 
    a = np.array(all_)
    
    if False:
        a = savgol_filter(a, 31, 2)
    
    a_mean = np.mean(a,axis=0)
    plt.plot(t, a_mean,c='red',
             linewidth=5,
             alpha=1)
    
    # 
    idx = find_first_variance_decrease_point2(a_mean, s1, e1,
                                             std_factor, ctr,
                                             animal_id, 
                                             n_vals[ctr])
    
    #
    if False:
        plt.plot([t[idx+600], t[idx+601]],
             [a_mean[idx+600],a_mean[idx+601]],
               c='black',
               linewidth=10,
           alpha=1)
    
    plt.xlim(-15,5)
    
    # 
    if plot_vars:
        plt.ylim(bottom=0)
        #plt.ylim(bottom=0)
        plt.plot([-30,30],[0,0],'--',c='grey')
        plt.plot([0,0],[-.2,.2],'--',c='grey')
    else:
        plt.ylim(-.075,.15)
        plt.plot([-30,30],[0,0],'--',c='grey')
        plt.plot([0,0],[-.2,.2],'--',c='grey')
    plt.title(saved_names[roi_id])
    
plt.suptitle(animal_id)
if True:
    plt.savefig('/home/cat/variance_analysis_var'+str(plot_vars)+'.svg')
    plt.close()
else:
    plt.show()

 14%|█▎        | 15/110 [00:00<00:00, 136.31it/s]

# of sessions:  (110,)


100%|██████████| 110/110 [00:07<00:00, 15.36it/s]


77


In [243]:
####################################################################
####### SHOW DISTRIBUTIONS OF EARLIEST VARIANCE DECAY TIMES ########
####################################################################
    
codes = ['Retrosplenial', 'barrel', 'limb', 'visual','motor']
#codes = ['limb, layer 1 - right', 'limb, layer 1 - left']
clrs_local = ['black','blue','red','magenta', 'pink','brown']
clrs_local = ['magenta','brown','pink','lightblue','darkblue', 'blue']

#
d1 = []
for k in range(5):
    #temp = (np.array(first_decay[k])+700)/30-30
    temp = (np.array(first_decay[k]))/30-30
    idx = np.where(np.isnan(temp)==False)[0]
    temp = temp[idx]
    
    print (k, "Times: ", temp)
    d1.append(temp)
    

fig = plt.figure(figsize=(10,10))
my_dict = dict(
               r = d1[0], 
               b = d1[1],
               s = d1[2],
               v = d1[3],
               m = d1[4]
           )

data = pd.DataFrame.from_dict(my_dict, orient='index')
data = data.transpose()

# 
flierprops = dict(marker='o', 
                  #markerfacecolor='g', 
                  #markersize=10000,
                  linestyle='none', 
                  markeredgecolor='r')

#
data.boxplot(showfliers=False,
            flierprops=flierprops)

# 
for i,d in enumerate(d1):
    y = d1[i]
    x = np.random.normal(i+1, 0.04, len(y))

    #
    #colors = plt.cm.viridis(np.linspace(0,1,len(data[i])))
    print (d1[i])
    x = np.random.normal(i+1, 0.04, len(d1[i]))
    #print (i,d, ' y shape: ', y.shape)
    plt.scatter(x, 
                d1[i], 
                c=clrs_local[i],
                #c=colors,
                edgecolor='black',
                s=200,
                #alpha=np.linspace(.2, 1.0, x.shape[0])
                alpha=.5
               )
    
plt.ylim(-15,0)
plt.plot([0,6],[-3,-3],'--')
plt.plot([0,6],[-5,-5],'--')
plt.xticks([])
#plt.yticks([])
if True:
    plt.savefig('/home/cat/earliest_variance_decay'+animal_id+'.svg')
    plt.close()
else:
    plt.show()

0 Times:  [-10.          -1.63333333  -4.56666667  -3.66666667  -4.16666667
  -3.46666667  -4.23333333  -9.63333333  -5.16666667  -3.96666667
  -3.73333333  -3.5         -3.66666667  -3.53333333  -3.73333333
  -3.66666667  -7.86666667  -3.86666667  -3.66666667  -4.2
  -4.2         -3.96666667  -3.7         -3.          -3.26666667
  -3.36666667  -3.5         -2.86666667  -3.16666667  -3.66666667
  -3.83333333  -3.63333333  -3.83333333]
1 Times:  [-10.          -4.23333333  -3.7         -4.5         -3.6
  -4.23333333  -3.73333333  -3.76666667  -4.5         -8.46666667
  -3.63333333  -4.1         -4.03333333  -4.1         -4.06666667
  -4.16666667  -4.06666667  -4.06666667  -3.7         -3.43333333
  -3.5         -3.6         -9.26666667  -4.43333333  -7.96666667]
2 Times:  [-10.          -4.06666667  -4.13333333  -4.5         -1.86666667
  -3.6         -3.53333333  -3.06666667  -3.76666667  -4.5
  -4.03333333  -4.43333333  -4.2         -4.33333333  -3.8
  -3.43333333  -4.2         -8.4

In [51]:
###########################################################
##### PLOT ALL TRIALS - PCA PLOTS - SAVE DATA FIRST #######
###########################################################

# Fig 6B
animal_ids = ['IA1','IA2','IA3','IJ1','IJ2','AQ2']    
    
##########################
for animal_id in animal_ids:
    (all_means, 
     all_vars, 
     saved_names, 
     all_n_trials, 
     all_trials, 
     all_random) = load_variances(animal_id)

    session_ids = np.arange(len(all_vars))

    # colors
    roi_ids = [15,5,9,11,1]

    ##############################################################
    ##############################################################
    ##############################################################
    all_traces = []
    all_means = []
    all_vars = []
    
    # loop over sessions
    for ctr_sess, session_id in enumerate(session_ids):

        print ("all trials: ", all_trials[session_id].shape)
        
        #
        temp_means_roi = np.mean(all_trials[session_id],0)[roi_ids][:,300:1350]
        all_means.append(np.hstack(temp_means_roi))
         
        #
        temp_vars_roi = np.std(all_trials[session_id],0)[roi_ids][:,300:1350]
        all_vars.append(np.hstack(temp_vars_roi))
        
        # loop over trials
        for t in range(all_trials[session_id].shape[0]):
            
            # loop over rois
            traces_trial = []
            for roi_id in roi_ids:
                temp = all_trials[session_id][t][roi_id]
                
                traces_trial.append(temp[300:1350])
       
            all_traces.append(np.hstack(traces_trial))

    np.save('/home/cat/all_traces_stacked_'+animal_id+'.npy', all_traces)
    np.save('/home/cat/all_means_stacked_'+animal_id+'.npy', all_means)
    np.save('/home/cat/all_vars_stacked_'+animal_id+'.npy', all_vars)
        
    print (" # traces: ", len(all_traces))
    

  8%|▊         | 6/71 [00:00<00:01, 37.74it/s]

# of sessions:  (71,)


100%|██████████| 71/71 [00:01<00:00, 41.29it/s]


all trials:  (23, 16, 1801)
all trials:  (58, 16, 1801)
all trials:  (51, 16, 1801)
all trials:  (13, 16, 1801)
all trials:  (38, 16, 1801)
all trials:  (18, 16, 1801)
all trials:  (40, 16, 1801)
all trials:  (36, 16, 1801)
all trials:  (14, 16, 1801)
all trials:  (72, 16, 1801)
all trials:  (76, 16, 1801)
all trials:  (11, 16, 1801)
all trials:  (29, 16, 1801)
all trials:  (36, 16, 1801)
all trials:  (24, 16, 1801)
all trials:  (32, 16, 1801)
all trials:  (20, 16, 1801)
all trials:  (23, 16, 1801)
all trials:  (24, 16, 1801)
all trials:  (23, 16, 1801)
all trials:  (44, 16, 1801)
all trials:  (25, 16, 1801)
all trials:  (13, 16, 1801)
all trials:  (16, 16, 1801)
all trials:  (13, 16, 1801)
all trials:  (12, 16, 1801)
all trials:  (11, 16, 1801)
all trials:  (20, 16, 1801)
all trials:  (13, 16, 1801)
all trials:  (12, 16, 1801)
all trials:  (11, 16, 1801)
all trials:  (17, 16, 1801)
all trials:  (10, 16, 1801)
all trials:  (25, 16, 1801)
all trials:  (20, 16, 1801)
all trials:  (14, 16

  0%|          | 0/44 [00:00<?, ?it/s]

all trials:  (39, 16, 1801)
all trials:  (37, 16, 1801)
all trials:  (23, 16, 1801)
all trials:  (17, 16, 1801)
all trials:  (13, 16, 1801)
all trials:  (13, 16, 1801)
all trials:  (28, 16, 1801)
all trials:  (11, 16, 1801)
all trials:  (22, 16, 1801)
all trials:  (23, 16, 1801)
all trials:  (22, 16, 1801)
all trials:  (39, 16, 1801)
all trials:  (29, 16, 1801)
all trials:  (16, 16, 1801)
all trials:  (10, 16, 1801)
 # traces:  1313
# of sessions:  (44,)


100%|██████████| 44/44 [00:01<00:00, 23.69it/s]


all trials:  (37, 16, 1801)
all trials:  (88, 16, 1801)
all trials:  (67, 16, 1801)
all trials:  (19, 16, 1801)
all trials:  (48, 16, 1801)
all trials:  (24, 16, 1801)
all trials:  (45, 16, 1801)
all trials:  (67, 16, 1801)
all trials:  (42, 16, 1801)
all trials:  (31, 16, 1801)
all trials:  (59, 16, 1801)
all trials:  (21, 16, 1801)
all trials:  (63, 16, 1801)
all trials:  (32, 16, 1801)
all trials:  (13, 16, 1801)
all trials:  (12, 16, 1801)
all trials:  (79, 16, 1801)
all trials:  (87, 16, 1801)
all trials:  (63, 16, 1801)
all trials:  (98, 16, 1801)
all trials:  (82, 16, 1801)


  0%|          | 0/44 [00:00<?, ?it/s]

all trials:  (61, 16, 1801)
all trials:  (49, 16, 1801)
all trials:  (67, 16, 1801)
all trials:  (45, 16, 1801)
all trials:  (35, 16, 1801)
all trials:  (62, 16, 1801)
all trials:  (85, 16, 1801)
all trials:  (28, 16, 1801)
all trials:  (22, 16, 1801)
all trials:  (62, 16, 1801)
 # traces:  1593
# of sessions:  (44,)


100%|██████████| 44/44 [00:02<00:00, 16.12it/s]


all trials:  (28, 16, 1801)
all trials:  (65, 16, 1801)
all trials:  (91, 16, 1801)
all trials:  (53, 16, 1801)
all trials:  (11, 16, 1801)
all trials:  (13, 16, 1801)
all trials:  (82, 16, 1801)
all trials:  (48, 16, 1801)
all trials:  (66, 16, 1801)
all trials:  (49, 16, 1801)
all trials:  (38, 16, 1801)
all trials:  (30, 16, 1801)
all trials:  (24, 16, 1801)
all trials:  (11, 16, 1801)
all trials:  (65, 16, 1801)
all trials:  (46, 16, 1801)
all trials:  (23, 16, 1801)
all trials:  (33, 16, 1801)
all trials:  (51, 16, 1801)
all trials:  (38, 16, 1801)
all trials:  (34, 16, 1801)
all trials:  (40, 16, 1801)
all trials:  (65, 16, 1801)
all trials:  (55, 16, 1801)
all trials:  (22, 16, 1801)
all trials:  (34, 16, 1801)
all trials:  (47, 16, 1801)
all trials:  (56, 16, 1801)
all trials:  (80, 16, 1801)
all trials:  (106, 16, 1801)
all trials:  (99, 16, 1801)
all trials:  (108, 16, 1801)
all trials:  (130, 16, 1801)
all trials:  (88, 16, 1801)
all trials:  (48, 16, 1801)
all trials:  (95,

  0%|          | 0/44 [00:00<?, ?it/s]

all trials:  (67, 16, 1801)
all trials:  (104, 16, 1801)
 # traces:  2398
# of sessions:  (44,)


100%|██████████| 44/44 [00:01<00:00, 22.70it/s]


all trials:  (29, 16, 1801)
all trials:  (31, 16, 1801)
all trials:  (43, 16, 1801)
all trials:  (38, 16, 1801)
all trials:  (44, 16, 1801)
all trials:  (43, 16, 1801)
all trials:  (54, 16, 1801)
all trials:  (49, 16, 1801)
all trials:  (56, 16, 1801)
all trials:  (31, 16, 1801)
all trials:  (56, 16, 1801)
all trials:  (27, 16, 1801)
all trials:  (44, 16, 1801)
all trials:  (26, 16, 1801)
all trials:  (82, 16, 1801)
all trials:  (46, 16, 1801)
all trials:  (73, 16, 1801)
all trials:  (16, 16, 1801)
all trials:  (61, 16, 1801)
all trials:  (132, 16, 1801)
all trials:  (114, 16, 1801)
all trials:  (42, 16, 1801)
all trials:  (37, 16, 1801)


  0%|          | 0/44 [00:00<?, ?it/s]

all trials:  (18, 16, 1801)
all trials:  (10, 16, 1801)
all trials:  (44, 16, 1801)
all trials:  (32, 16, 1801)
all trials:  (21, 16, 1801)
all trials:  (42, 16, 1801)
all trials:  (12, 16, 1801)
all trials:  (19, 16, 1801)
all trials:  (20, 16, 1801)
all trials:  (21, 16, 1801)
all trials:  (19, 16, 1801)
all trials:  (41, 16, 1801)
all trials:  (47, 16, 1801)
all trials:  (16, 16, 1801)
all trials:  (27, 16, 1801)
all trials:  (37, 16, 1801)
all trials:  (26, 16, 1801)
 # traces:  1626
# of sessions:  (44,)


100%|██████████| 44/44 [00:02<00:00, 20.64it/s]


all trials:  (10, 16, 1801)
all trials:  (16, 16, 1801)
all trials:  (43, 16, 1801)
all trials:  (47, 16, 1801)
all trials:  (52, 16, 1801)
all trials:  (50, 16, 1801)
all trials:  (39, 16, 1801)
all trials:  (26, 16, 1801)
all trials:  (56, 16, 1801)
all trials:  (12, 16, 1801)
all trials:  (29, 16, 1801)
all trials:  (38, 16, 1801)
all trials:  (64, 16, 1801)
all trials:  (94, 16, 1801)
all trials:  (61, 16, 1801)
all trials:  (106, 16, 1801)
all trials:  (48, 16, 1801)
all trials:  (34, 16, 1801)
all trials:  (21, 16, 1801)
all trials:  (50, 16, 1801)
all trials:  (61, 16, 1801)
all trials:  (59, 16, 1801)
all trials:  (30, 16, 1801)
all trials:  (28, 16, 1801)
all trials:  (43, 16, 1801)
all trials:  (22, 16, 1801)
all trials:  (44, 16, 1801)
all trials:  (39, 16, 1801)
all trials:  (74, 16, 1801)
all trials:  (33, 16, 1801)
all trials:  (43, 16, 1801)
all trials:  (50, 16, 1801)
all trials:  (64, 16, 1801)
all trials:  (33, 16, 1801)
all trials:  (65, 16, 1801)
all trials:  (82, 1

 14%|█▎        | 15/110 [00:00<00:00, 136.24it/s]

 # traces:  1821
# of sessions:  (110,)


100%|██████████| 110/110 [00:07<00:00, 15.48it/s]


all trials:  (9, 16, 1801)
all trials:  (26, 16, 1801)
all trials:  (21, 16, 1801)
all trials:  (24, 16, 1801)
all trials:  (33, 16, 1801)
all trials:  (65, 16, 1801)
all trials:  (52, 16, 1801)
all trials:  (51, 16, 1801)
all trials:  (59, 16, 1801)
all trials:  (27, 16, 1801)
all trials:  (17, 16, 1801)
all trials:  (37, 16, 1801)
all trials:  (54, 16, 1801)
all trials:  (37, 16, 1801)
all trials:  (42, 16, 1801)
all trials:  (26, 16, 1801)
all trials:  (70, 16, 1801)
all trials:  (12, 16, 1801)
all trials:  (55, 16, 1801)
all trials:  (46, 16, 1801)
all trials:  (60, 16, 1801)
all trials:  (41, 16, 1801)
all trials:  (50, 16, 1801)
all trials:  (67, 16, 1801)
all trials:  (78, 16, 1801)
all trials:  (21, 16, 1801)
all trials:  (74, 16, 1801)
all trials:  (91, 16, 1801)
all trials:  (93, 16, 1801)
all trials:  (131, 16, 1801)
all trials:  (184, 16, 1801)
all trials:  (157, 16, 1801)
all trials:  (145, 16, 1801)
all trials:  (104, 16, 1801)
all trials:  (96, 16, 1801)
all trials:  (14

In [112]:
################################################
###### VARIANCE PCA PLOTS FOR ALL ANIMALS ######
################################################
from sklearn.decomposition import PCA
from scipy.spatial import ConvexHull
from scipy.spatial import cKDTree
from sklearn import preprocessing

def knn_triage(th, pca_wf):

    tree = cKDTree(pca_wf)
    dist, ind = tree.query(pca_wf, k=6)
    dist = np.sum(dist, 1)

    idx_keep1 = dist <= np.percentile(dist, th)
    return idx_keep1

X = []
clrs = []
for k in range(6):
    #data = np.load('/home/cat/all_traces_stacked_'+animal_ids[k]+'.npy', allow_pickle=True)
    data = np.load('/home/cat/all_vars_stacked_'+animal_ids[k]+'.npy', allow_pickle=True)
    print (data.shape)
    
    if True:
        for p in range(data.shape[0]):
            data[p] = (data[p]-np.min(data[p]))/(np.max(data[p])-np.min(data[p]))
    
    clrs.append(np.zeros(data.shape[0])+k)
    X.append(data)

X = np.vstack(X)
print (X.shape)

scaler = preprocessing.StandardScaler().fit(X)
X = scaler.transform(X)

# 
pca = PCA(n_components=2)
X_fit = pca.fit_transform(X)
print ("X fit: ", X_fit.shape)

#
print (pca.explained_variance_ratio_)
print(pca.singular_values_)

# 
#clrs = np.hstack(clrs)
colors= plt.cm.tab10(np.linspace(0,1,len(clrs)))
sums = 0
fig=plt.figure(figsize=(10,10))
for k in range(len(clrs)):
    
   # if True:

    # use knn triage to remove most outlier points
    triage_value = 0.01
    
    knn_triage_threshold = 100*(1-triage_value)

    # apply knn to all points
    points_in = X_fit[sums:sums+clrs[k].shape[0]]
    idx_keep = knn_triage(knn_triage_threshold, points_in)
    idx_keep = np.where(idx_keep==1)[0]
    points_out = points_in[idx_keep]
    
    #              
    plt.scatter(points_out[:,0],
                points_out[:,1],
                color=colors[k],
                alpha=np.arange(0,len(colors[k]),1)/(1.2*len(colors[k]))+0.2,
                edgecolor='black',
                s=200,
                label=animal_ids[k])
    
    
    # single frame convex hull    
    points = points_out
    hull = ConvexHull(points)

    # Get the indices of the hull points.
    hull_indices = hull.vertices

    # These are the actual points.
    hull_pts = points[hull_indices, :]

    #
    print ("hull pts: ", hull_pts.shape)
    for p in range(hull_pts.shape[0]-1):
        plt.plot([hull_pts[p,0], hull_pts[p+1,0]],
                 [hull_pts[p,1], hull_pts[p+1,1]],
                 color=colors[k],
                 linewidth=8
                )
        
    #
    plt.plot([hull_pts[-1,0], hull_pts[0,0]],
             [hull_pts[-1,1], hull_pts[0,1]],
                 color=colors[k],
                 linewidth=8
                )
    
    #    
    sums+=clrs[k].shape[0]
        
plt.legend() 

if True:
    plt.savefig('/home/cat/variance_pca.svg')
    plt.close()
else:
    plt.show()

(53, 5250)
(31, 5250)
(41, 5250)
(40, 5250)
(39, 5250)
(77, 5250)
(281, 5250)
X fit:  (281, 2)
[0.25841708 0.18719909]
[617.43808596 525.51446526]
hull pts:  (8, 2)
hull pts:  (7, 2)
hull pts:  (7, 2)
hull pts:  (7, 2)
hull pts:  (8, 2)
hull pts:  (10, 2)


In [43]:
print (pca.explained_variance_)


[0.28939488 0.27125434]


In [165]:
###############################
###### VARIANCE ANALYSIS ######
###############################

# FIG 6 H variance curves + times + ratios

plotting = True  # note you need to run plotting=True for a single sessoin not all!!!
session_id = -1

saved_names = [
'M1-R',
'M1-L',
'M2-R',
'M2-L',
'SomB-R',
'SomB-L',
'SomH-R',
'SomH-L',
'SomF-R',
'SomF-L',
'V1-R',
'V1-L',
'RL-R',
'RL-L',
'RD-R',
'RD-L'    
]

smoothing = False
all_ratios, all_times = plot_variance_peaks_3_plots(all_vars, all_means, all_n_trials,
                                                    saved_names, plotting,
                                                    session_id,
                                                    smoothing)

(1, 16)
DONE


In [172]:
#############################################
###### ALL ANIMAL LOWEST VARIANCE TIME ######
#############################################

animal_ids = ['IA1','IA2','IA3','IJ1','IJ2','AQ2']    
#animal_ids = ['IA3']
session_id = None

# 
ratios = []
times= []
smoothing = True
for ctr,animal_id in enumerate(animal_ids):
    all_means, all_vars, saved_names, all_n_trials= load_variances(animal_id)

    plotting = False  # note you need to run plotting=True for a single sessoin not all!!!
    all_ratios, all_times = plot_variance_peaks_3_plots(all_vars, all_means, all_n_trials,
                                                        saved_names, 
                                                        plotting,
                                                        session_id,
                                                        smoothing)
    ratios.append(all_ratios)
    times.append(all_times)
    

fig = plt.figure(figsize=(10,10))
min_ratio_threshold = 0.25
plot_box_plots_times(ratios, times, 
                     min_ratio_threshold)
plt.ylim(-15,0)

if True:
    plt.savefig('/home/cat/min_var_times.svg')
    plt.close()
else:
    plt.show()



  4%|▍         | 3/71 [00:00<00:02, 27.62it/s]

# of sessions:  (71,)


100%|██████████| 71/71 [00:01<00:00, 40.33it/s]
  9%|▉         | 4/44 [00:00<00:01, 26.42it/s]

(53, 16)
DONE
# of sessions:  (44,)


100%|██████████| 44/44 [00:01<00:00, 22.41it/s]
  5%|▍         | 2/44 [00:00<00:02, 18.35it/s]

(31, 16)
DONE
# of sessions:  (44,)


100%|██████████| 44/44 [00:02<00:00, 15.47it/s]
  7%|▋         | 3/44 [00:00<00:01, 21.68it/s]

(41, 16)
DONE
# of sessions:  (44,)


100%|██████████| 44/44 [00:02<00:00, 21.28it/s]
  9%|▉         | 4/44 [00:00<00:01, 26.91it/s]

(40, 16)
DONE
# of sessions:  (44,)


100%|██████████| 44/44 [00:02<00:00, 19.33it/s]
 14%|█▎        | 15/110 [00:00<00:00, 130.33it/s]

(39, 16)
DONE
# of sessions:  (110,)


100%|██████████| 110/110 [00:07<00:00, 14.33it/s]


(77, 16)
DONE
M2:  [ -0.03333333  -0.03333333  -0.2         -0.23333333  -0.16666667
  -0.13333333  -0.13333333  -0.13333333 -11.86666667 -11.86666667
 -10.7        -10.7         -0.03333333  -0.03333333  -0.03333333
  -0.03333333  -0.03333333  -0.03333333  -0.03333333  -3.96666667
  -0.03333333  -0.03333333  -0.03333333  -0.03333333 -12.6
  -1.23333333  -1.26666667  -1.1         -1.06666667  -0.03333333
  -2.13333333  -0.03333333  -0.03333333  -1.96666667  -1.56666667
  -3.86666667  -0.2         -0.03333333  -0.03333333  -0.03333333
  -0.03333333  -0.03333333  -7.86666667  -7.83333333]
0 M1  y shape:  (258,)
i:  0   res:  NormaltestResult(statistic=65.04669956359484, pvalue=7.503927972521363e-15)
1 M2  y shape:  (258,)
i:  1   res:  NormaltestResult(statistic=21.733955008667895, pvalue=1.9077946866075673e-05)
2 M3  y shape:  (258,)
i:  2   res:  NormaltestResult(statistic=1.8937663104486255, pvalue=0.3879483156868674)
3 M4  y shape:  (258,)
i:  3   res:  NormaltestResult(statistic=238

In [37]:

data = np.load('/media/cat/4TBSSD/yuki/IA3/tif_files/IA3am_Mar15_30Hz/IA3am_Mar15_30Hz_locanmf.npz',
              allow_pickle=True)

A_reshape = data['A_reshape']
temp = data['temporal_trial']
names = data['areanames_area']

print (temp.shape)
mean = temp.mean(0)


fig=plt.figure()
for i in range(A_reshape.shape[2]):
    plt.subplot(4,4,i+1)
    plt.imshow(A_reshape[:,:,i])
    plt.xticks([])
    plt.yticks([])
    plt.title(names[i],fontsize=6)
plt.tight_layout(h_pad=0.5,w_pad=0.5)
plt.show()




(34, 16, 1801)


In [62]:
############################################


def get_sessions(main_dir,
                 animal_id,
                 session_id):
     # load ordered sessions from file
    sessions = np.load(os.path.join(main_dir,
                                         animal_id,
                                         'tif_files.npy'))
    # grab session names from saved .npy files
    data = []
    for k in range(len(sessions)):
        data.append(os.path.split(sessions[k])[1].replace('.tif',''))
    sessions = data

    #
    if session_id != 'all':
        final_session = []
        for k in range(len(sessions)):
            if session_id in sessions[k]:
                final_session = [sessions[k]]
                break
        sessions = final_session

    # fix binary string files issues; remove 'b and ' from file names
    for k in range(len(sessions)):
        sessions[k] = str(sessions[k]).replace("'b",'').replace("'","")
        if sessions[k][0]=='b':
            sessions[k] = sessions[k][1:]

    sessions = np.array(sessions)

    return sessions

main_dir = '/media/cat/4TBSSD/yuki/'
session_id = 'all'


    
animal_ids = ['AQ2']
fig =plt.figure(figsize=(10,10))
ax=plt.subplot()
xmax = 21
for animal_id in animal_ids: 
    sessions = get_sessions(main_dir,
                             animal_id,
                             session_id)
    # 
    colors = plt.cm.plasma(np.linspace(0,1,len(sessions)))
    for ctr,session in enumerate(sessions):
        #print ("session: ", session)

        # 
        fname_var = os.path.join(main_dir,animal_id,'tif_files',session,
                                             session+ '_whole_stack_trial_ROItimeCourses_15sec_pca_var_explained.npy')

        if os.path.exists(fname_var)==True:
        
            var = np.load(fname_var)

            var = np.cumsum(var)
            x=np.arange(var.shape[0])+1
            plt.plot(x[:xmax], 
                     var[:xmax],
                    color=colors[ctr],
                    linewidth=2,
                    alpha=.9)
            
            
            
        
#
plt.plot([0,xmax],[0.95,0.95],
        '--',c='black')
#plt.semilogy()
plt.ylim(.6,1)
plt.xlim(0,xmax-1)

cmap = plt.cm.plasma  # define the colormap

cax = fig.add_axes([0.95, 0.2, 0.02, 0.6])
cb = matplotlib.colorbar.ColorbarBase(cax, cmap=cmap, spacing='proportional')
cb.set_label('gdp_md_est')




if True:
    plt.savefig('/home/cat/variance.svg', dpi=1200)
    plt.close()
else:
    plt.show()
        
    
    
    

In [28]:
d = np.load('/media/cat/4TBSSD/yuki/IJ2/tif_files/IJ2pm_Mar2_30Hz/IJ2pm_Mar2_30Hz_code_04_trial_ROItimeCourses_30sec_pca_0.95.npy')
print (d.shape)

(106, 6, 1801)


In [38]:
data = np.load('/media/cat/4TBSSD/yuki/IJ1/tif_files/IJ1pm_Mar2_30Hz/IJ1pm_Mar2_30Hz_whole_stack_trial_ROItimeCourses_15sec_pca30components_spatial.npy')
print (data.shape)

mask = np.int32(np.loadtxt('/media/cat/4TBSSD/yuki/IJ1/genericmask.txt'))

mask1 = np.ones((128,128),'float32')

for k in range(mask.shape[0]):
    mask1[mask[k][0],
          mask[k][1]] = np.nan
    
print (mask1.shape)

(30, 16384)
(128, 128)


In [41]:
for k in range(6):
    ax=plt.subplot(2,3,k+1)
    temp = data[k].reshape(128,128)*mask1
    plt.imshow(temp)
    
plt.savefig('/home/cat/pcas.svg',dpi=1200)
plt.close()