In [None]:
import os 
import sys
import scipy
import random
import h5py
from tqdm import tqdm
import numpy as np
import helpers as h
import spontHelpers as sh
import matplotlib as mpl
import utils as u
from glob import glob
from scipy.io import loadmat
from scipy import signal
from scipy.stats import skew, kurtosis
from sklearn.cluster import KMeans
from mpl_toolkits import mplot3d
from sklearn.model_selection import cross_val_score
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
%pylab inline

In [33]:
def compute_cvPCA(resp0, maxcols=np.inf):
    ss0 = u.shuff_cvPCA(resp0, nshuff=10, maxcols=np.inf)
    ss0 = ss0.mean(axis=0)
    ss0 = ss0 / ss0.sum()

    return ss0

def load_xcorr_lags(f0, f1, **kwargs):
    xcorr = np.load(f0, **kwargs)
    lags =np.load(f1, **kwargs)

    return xcorr, lags


def compute_cross_corr(sig0, sig1):
    xcorr = signal.correlate(sig0, sig1)
    lags = signal.correlation_lags(len(sig0), len(sig1))

    return xcorr, lags



def figure_sanity_checks(d0=None, dlags0=None, dxcorrs=None, idx0=None, nnrois_ids=None, nnrois_corr=None, iter=None):    
    mpl.rcParams['lines.linewidth'] = 0.5

    # headers 
    plt.figure(figsize=(20, 20))
    fig0, axs = plt.subplots(nrows=11, ncols=2, figsize=(15, 15), layout='tight', frameon=False)

    axs[0, 0].plot(d0[:, idx0][80:1880], label=f'ROI {idx0} : avg correlation with nbrs={0}')
    axs[0, 0].tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    axs[0, 0].legend()
    
    
    axs[0, 1].plot(d0[:, idx0][80:1880], label=f'ROI {idx0}: avg correlation with nbrs={0}')
    axs[0, 1].tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    axs[0, 1].legend()
    

    # Subplots for signals
    
    axs[1, 0].plot(d0[:, nnrois_ids[0]][80:1880], label=f' Neuron {nnrois_ids[0]} : xcorr \w ROI av={mean(nnrois_corr[0])}')
    axs[1, 0].tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    axs[1, 0].legend()

    
    axs[2, 0].plot(d0[:, nnrois_ids[1]][80:1880], label= f"Neuron {nnrois_ids[1]} : xcorr \w ROI av={mean(nnrois_corr[1])}")
    axs[2, 0].tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    axs[2, 0].legend()

    
    axs[3, 0].plot(d0[:, nnrois_ids[2]][80:1880], label=f"Neuron {nnrois_ids[2]} : xcorr \w ROI av={mean(nnrois_corr[2])}")
    axs[3, 0].tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    axs[3, 0].legend()

    
    axs[4, 0].plot(d0[:, nnrois_ids[3]][80:1880], label=f"Neuron {nnrois_ids[3]} : xcorr \w ROI av={mean(nnrois_corr[3])}")
    axs[4, 0].tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    axs[4, 0].legend()

    
    axs[5, 0].plot(d0[:, nnrois_ids[4]][80:1880], label=f"Neuron {nnrois_ids[4]} : xcorr \w ROI av={mean(nnrois_corr[4])}")
    axs[5, 0].tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    axs[5, 0].legend()

    
    axs[6, 0].plot(d0[:, nnrois_ids[5]][80:1880], label=f"Neuron {nnrois_ids[5]}: xcorr \w ROI av={mean(nnrois_corr[5])}")
    axs[6, 0].tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    axs[6, 0].legend()

   
    axs[7, 0].plot(d0[:, nnrois_ids[6]][80:1880], label=f"Neuron {nnrois_ids[6]} : xcorr \w ROI av={mean(nnrois_corr[6])}")
    axs[7, 0].tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    axs[7, 0].legend()

    
    axs[8, 0].plot(d0[:, nnrois_ids[7]][80:1880], label=f"Neuron {nnrois_ids[7]} : xcorr \w ROI av={mean(nnrois_corr[7])}")
    axs[8, 0].tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    axs[8, 0].legend()

    
    axs[9, 0].plot(d0[:, nnrois_ids[8]][80:1880], label=f"Neuron {nnrois_ids[8]} : xcorr \w ROI av={mean(nnrois_corr[8])}")
    axs[9, 0].tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    axs[9, 0].legend()

    
    axs[10, 0].plot(d0[:, nnrois_ids[9]][80:1880], label=f"Neuron {nnrois_ids[9]} : xcorr \w ROI av={mean(nnrois_corr[9])}")
    axs[10, 0].set_xlabel('Signal')
    axs[10, 0].set_ylabel('AU')
    axs[10, 0].legend()


    # subplots for lags and xcorrs
    axs[1, 1].plot(dlags0[0], dxcorrs[0], label='X - Correlation with ROI')
    axs[1, 1].tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    axs[1, 1].legend()

    
    axs[2, 1].plot(dlags0[1], dxcorrs[1], label='X - Correlation with ROI')
    axs[2, 1].tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    axs[2, 1].legend()

    
    axs[3, 1].plot(dlags0[2], dxcorrs[2], label='X - Correlation with ROI')
    axs[3, 1].tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    axs[3, 1].legend()

    axs[4, 1].plot(dlags0[3], dxcorrs[3], label='X - Correlation with ROI')
    axs[4, 1].tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    axs[4, 1].legend()

    axs[5, 1].plot(dlags0[4], dxcorrs[4], label='X - Correlation with ROI')
    axs[5, 1].tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    axs[5, 1].legend()


    axs[6, 1].plot(dlags0[5], dxcorrs[5], label='X - Correlation with ROI')
    axs[6, 1].tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    axs[6, 1].legend()

    
    axs[7, 1].plot(dlags0[6], dxcorrs[6], label='X - Correlation with ROI')
    axs[7, 1].tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    axs[7, 1].legend()


    axs[8, 1].plot(dlags0[7], dxcorrs[7], label='X - Correlation with ROI')
    axs[8, 1].tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    axs[8, 1].legend()

    axs[9, 1].plot(dlags0[8], dxcorrs[8], label='X - Correlation with ROI')
    axs[9, 1].tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    axs[9, 1].legend()
    
    
    axs[10, 1].plot(dlags0[9], dxcorrs[9], label='X - Correlation with ROI')
    axs[10, 1].set_xlabel("Lags")
    axs[10, 1].set_ylabel("X - Correlation")
    axs[10, 1].legend()

    fig0.suptitle('Signal          SANITY CHECKS             Correlations and Lags');
    plt.savefig(f"sanity_check_cross_correlation_plots_{idx0}_{iter}.png");
    plt.close();

    return 0



In [None]:
dp0='/Users/duuta/ppp/data/zebf00/TimeSeries.h5'
dp1='/Users/duuta/ppp/data/zebf00/data_full.mat'

In [None]:
d1 = loadmat(dp1, simplify_cells=True)
dholder = h5py.File(dp0, 'r')
d0 = dholder['CellResp'][:]

In [None]:
x, y, z = d1['data']['CellXYZ'].T

In [None]:
d00 = sh.ssplit(d0)

In [None]:
d01 = d00[:, :, :10000]

In [None]:
ss0 =compute_cvPCA(d00)

In [None]:
d01.shape

In [None]:
ss01 = compute_cvPCA(d01)

In [None]:
a, _, _= u.get_powerlaw(ss0, np.arange(11, 500).astype('int'))

In [None]:
a1, _, _ = u.get_powerlaw(ss01, np.arange(11, 500).astype('int'))

In [34]:
nums = list(range(50))
ID= 1408
sig0 = d0[:, ID]
for i in range(5):
    _lags = {}
    _ccorr = {}
    for id in nums[i*10: i*10 +10]: 
        id0 = id % 10
        sig1 = d0[:, id]
        ccorr = signal.correlate(sig0, sig1)
        lags = signal.correlation_lags(len(sig0), len(sig1))
        _lags[id0] =lags
        _ccorr[id0] = ccorr
    nnrois_ids = list(_ccorr.keys())
    nnrois_corr = _ccorr
    figure_sanity_checks(d0=d0, dlags0=_lags, dxcorrs=_ccorr, idx0=ID, nnrois_corr=nnrois_corr, nnrois_ids=nnrois_ids, iter=i)


<Figure size 2000x2000 with 0 Axes>

<Figure size 2000x2000 with 0 Axes>

<Figure size 2000x2000 with 0 Axes>

<Figure size 2000x2000 with 0 Axes>

<Figure size 2000x2000 with 0 Axes>

In [50]:
_lags[3][np.argmax(_ccorr[3])]

-14

In [None]:
import numpy as np

def smooth_foobar(foobar, positions, nn_list):
    # define near
    smooth_version = 0*foobar 
    for idx, val in enumerate(foobar):
        x0 = mean(foobar[nn_list[idx]]) 
        smooth_version[idx] = x0

    return smooth_version

In [42]:
nn= np.array([1, 2, 3, 4])
rn= np.array([2, 2, 2, 2])


In [43]:
nn.sum()/rn.sum()

1.25

In [45]:
nn.mean()/rn.mean()

1.25

In [46]:
nn.sum()

10

In [47]:
nn.mean()

2.5

In [48]:
rn.mean()

2.0

In [None]:
import os
import random
import time
from glob import glob

import h5py
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import TwoSlopeNorm
from scipy.io import loadmat
from scipy.spatial.distance import pdist, squareform


def get_tag_paths(zebfolderpath):
    tag = zebfolderpath.split("/")[8]

    files = os.listdir(zebfolderpath)
    files.sort()
    files.remove("README_LICENSE.rtf")

    path0, path1 = [zebfolderpath + "/" + fname for fname in files]

    return tag, path0, path1


def read_data(path0, path1, num_rois, scells=True):
    dholder = h5py.File(path0, "r")
    d0 = dholder["CellResp"][:]  # responses
    d1 = loadmat(path1, simplify_cells=scells)

    eliminated_rois = d1["data"]["IX_inval_anat"]
    all_rois = d1["data"]["CellXYZ"]

    used_rois_coor = np.array(
        [row for j, row in enumerate(all_rois) if j not in list(eliminated_rois)]
    )

    x, y, z = used_rois_coor[:num_rois, :].T

    return x, y, z, d0


def find_nearest_nbrs(ds, roi_idx, n=10):
    nn_idx = ds[roi_idx,].argsort()[1 : n + 1]

    return nn_idx


def compute_distance_matrix(x, y):
    ds = squareform(pdist(np.array([x, y]).T, metric="euclidean"))

    return ds


def pick_random_nbrs(roi_idx, len0=100, n=10):
    all_idx = list(range(len0))
    all_idx.remove(roi_idx)
    rn_idx = random.sample(all_idx, n)

    return rn_idx


def reject_outliers(data, m=2):
    X = data[abs(data - np.mean(data)) < m * np.std(data)]

    return X


def compute_render_ratio_corr(
    x, y, d0, dS, num_rois=30000, nnpop=10, rnpop=10, seed=None, tag=None, sdir=None
):
    nnidx_dict = {}
    rnidx_dict = {}
    nncorr_dict = {}
    rncorr_dict = {}
    collect_nn_min_max = []
    collect_rn_min_max = []

    for roi_idx in range(
        num_rois
    ):  # need to account for cases where the pass a list of roi indices
        random.seed(seed)
        roi = d0[:, roi_idx]
        nn_idx = find_nearest_nbrs(dS, roi_idx, n=nnpop)
        nn_roi = d0[:, nn_idx]
        rn_idx = pick_random_nbrs(roi_idx, len0=num_rois, n=rnpop)
        rn_roi = d0[:, rn_idx]
        nrcorr = []
        rncorr = []

        for j in range(nn_roi.shape[1]):
            nn_corr = np.corrcoef(roi, nn_roi[:, j])[0, 1]
            rn_corr = np.corrcoef(roi, rn_roi[:, j])[0, 1]
            nrcorr.append(nn_corr)
            rncorr.append(rn_corr)
            collect_nn_min_max.append(nn_corr)
            collect_rn_min_max.append(rn_corr)

        nnidx_dict[roi_idx] = nn_idx
        rnidx_dict[roi_idx] = rn_idx
        nncorr_dict[roi_idx] = nrcorr  # groups of near correlations
        rncorr_dict[roi_idx] = rncorr  # groups of random correlations

    # srnr_arr = np.array(collect_nn_min_max) / np.array(collect_rn_min_max)
    # sPRN = round(np.percentile(srnr_arr, 90), 3)
    # filtered0 = [a for a in srnr_arr if a > 0 and a <= sPRN]

    # outliers for reasonable distribution
    # filtered1 = reject_outliers(srnr_arr)
    # mid = np.median(filtered0)
    # print('mid', mid)
    # vmin = min(filtered0)
    # print('vmin', vmin)
    # vmax = max(filtered0)
    # print('vmax', vmax)

    plt.figure(figsize=(20, 20))
    custom_norm = TwoSlopeNorm(vcenter=1, vmin=0.001, vmax=9)
    ax = plt.axes()

    for roi_idx in range(
        num_rois
    ):  # need to take care of cases where num_rois is of indexes
        plt.scatter(
            x[nnidx_dict[roi_idx]],
            y[nnidx_dict[roi_idx]],
            marker=".",
            norm=custom_norm,
            cmap="rainbow",
            s=0.5,
            c=[
                np.array(nncorr_dict[roi_idx]).sum()
                / np.array(rncorr_dict[roi_idx]).sum()
            ]
            * len(nnidx_dict[roi_idx]) 
        )

    plt.colorbar(shrink=0.5)
    plt.xlabel("ROI X Positions", fontsize=20)
    plt.ylabel("ROI Y Positions", fontsize=20)
    plt.margins(x=0, y=0)
    plt.title(
        f"{tag}:Raw correlation ratios of near ROIs:{nnpop} to random ROIs:{rnpop} seed:{seed}",
        fontsize=20,
    )
    ax.set_facecolor("black")
    plt.tight_layout()

    plt.savefig(
        f"{sdir}testing_Rawratiocorrelations_{tag}_ROIs:{num_rois}_NN:{nnpop}_seed:{seed}_RN:{rnpop}.png",
    )
    plt.close()
    # need to run this for different vmax, vmin, mid=1 (first make a function initerim a class later)


def main():
    num_rois = 2000
    seed = None
    nnpop = 10
    sdir = "/camp/home/duuta/working/duuta/jobs/plots/ratioCorr/"

    for rnpop in [10, 100, 1000]:
        for fpath in glob("/camp/home/duuta/working/duuta/ppp0/data/zebf*")[:1]:
            print(f"reading {fpath}..........")

            # get tag and paths
            tag, path0, path1 = get_tag_paths(fpath)

            # read file paths
            x, y, _, d0 = read_data(path0, path1, num_rois=num_rois, scells=True)
            print("can read the file path... yes.... frantically reading files.....")

            # compute distances between rois
            dS = compute_distance_matrix(x, y)

            print("franticall computing distances......")

            # compute correlation ratio and render plot
            compute_render_ratio_corr(
                x,
                y,
                d0=d0,
                dS=dS,
                num_rois=num_rois,
                nnpop=nnpop,
                rnpop=rnpop,
                seed=seed,
                tag=tag,
                sdir=sdir,
            )


if __name__ == "__main__":
    start_time = time.time()
    main()
    print(f"{time.time() - start_time} ----all done")
