In [None]:
import h5py as h5
import os
import json
import sys
import re
import importlib
from scipy.spatial.transform import Rotation
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import ListedColormap
import numpy as np
from opm_img.util import load_from_h5
import wholebrain as wb
from matplotlib import pyplot as plt
from types import SimpleNamespace  
from scipy.optimize import minimize
%matplotlib widget
bwrblack = LinearSegmentedColormap.from_list('map0', [(.25, .25, 1), (0, 0, 0), (1, .25, .25)], N=256)
bwrblack_a = bwrblack(np.arange(bwrblack.N))
bwrblack_a[:,-1] = np.abs(np.linspace(-1, 1, len(bwrblack_a)))
bwrblack_a = ListedColormap(bwrblack_a)

magma = plt.get_cmap('magma')
magma_a = magma(np.arange(magma.N))
magma_a[:,-1] = np.linspace(0, 1, magma.N)
magma_a = ListedColormap(magma_a)


In [None]:
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
i2m=25.4
c_voxel="#3B7BBF"
c_random='#F9A91C'
c_pca='#ED1B50'

In [None]:
out='/home/hoffmmax/project/max/analysis/all_2/'
names = [ '20211105_1332_no_stimulus',"20220610_1532_No_stimulus", "20220610_1428_No_stimulus"]

all_dsets={}
n_s=[]
for name in names:
    print(name)
    
    res_large=h5.File(os.path.join(out,name,'res_large.h5'),'r')
    res=load_from_h5(os.path.join(out,name,'res.h5'))
    pars=load_from_h5(os.path.join(out,name,'pars.h5'))
    all_dsets[name]={'res':res,'pars':pars,'res_large':res_large}
    n_s.append(res_large['cc'].shape[0])



In [None]:
def binfcn(x):
    return np.nanmean(x), np.quantile(x,0.05), np.quantile(x,0.95)
n_step=100
dist_bins = np.r_[10:1000:5]
binstats_list=[]
for k,v in all_dsets.items():
        binstats = wb.legacy.apply_to_bins(v['res_large']['pd'][:].ravel()[::n_step], dist_bins, v['res_large']['cc'][:].ravel()[::n_step], binfcn)
        binstats_list.append(binstats)

In [None]:
fig,ax=plt.subplots()
for name, dset in all_dsets.items():
    res=dset['res']
    ax.plot(res['alphas'],res['r2s'])
    print(res['alpha'])
ax.set_xscale('log')

In [None]:
fig, ax = plt.subplots( figsize=(60/i2m, 50/i2m))
for binstats in binstats_list:
    ax.plot(dist_bins[1: - 1], binstats[1:, 0], 'k', label='average')
    ax.fill_between(dist_bins[1: - 1], binstats[1:, 1], binstats[1:, 2], label='5%-95%',alpha=0.2)
#ax.set_ylim(0.01, 1)
#ax.set_xlim(None, 1500)
#ax.set_xticks((0,0.1,0.2,0.4))
#ax.set_xscale('log')
#ax.set_yscale('log')
ax.set_xlabel('Distance (µm)')
ax.set_ylabel("Pearson's r")
plt.tight_layout()
#plt.savefig(os.path.join(out,  'correaltions.pdf'))
#ax.legend()

In [None]:
fig, axa  = plt.subplots(1,3 ,figsize=(220/i2m, 50/i2m))

for ii,binstats in enumerate(binstats_list):
    ax=axa[ii]
    xx=dist_bins[: - 1]
    yy=binstats[:, 0]

    powlaw= lambda x,a,k: a*xx**(k)
    mse= lambda a: np.sum((powlaw(xx,a[0],a[1])-yy)**2)
    a=minimize(mse,[yy.max(),-1])
    ax.plot(xx, yy, label='average',lw=0.5)
    ax.plot(xx, powlaw(xx,*a.x).squeeze(), label='average',lw=0.5)
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlabel('Distance (µm)')
    ax.set_ylabel("Pearson's r")
    ax.set_title(f"{a.x[1]:.2}")
plt.tight_layout()

In [None]:
fig, ax = plt.subplots( figsize=(80/i2m, 50/i2m))
for binstats in binstats_list:
    ax.plot(dist_bins[1: - 1], binstats[1:, 0], label='average',lw=0.5)
    #ax.fill_between(dist_bins[1: - 1], binstats[1:, 1], binstats[1:, 2], label='5%-95%')
ax.set_ylim(0.01, 0.4)
#ax.set_xlim(None, 1500)
#ax.set_xticks((0,0.1,0.2,0.4))
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlabel('Distance (µm)')
ax.set_ylabel("Pearson's r")
plt.tight_layout()
plt.savefig(os.path.join(out,  'correaltions.pdf'))
#ax.legend()

In [None]:
fig, ax = plt.subplots( figsize=(60/i2m, 50/i2m))
for ii,binstats in enumerate(binstats_list):
    ax.plot(dist_bins[1: - 1], binstats[1:, 0],label=names[ii])
    #ax.fill_between(dist_bins[1: - 1], binstats[1:, 1], binstats[1:, 2], label='5%-95%')
ax.set_ylim(0.01, 1)
#ax.set_xlim(None, 1500)
#ax.set_xticks((0,0.1,0.2,0.4))
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlabel('Distance (µm)')
ax.set_ylabel("Pearson's r")
plt.tight_layout()
#plt.savefig(os.path.join(out,  'correaltions.pdf'))
#plt.legend()
#ax.legend()

In [None]:
fig, ax = plt.subplots( figsize=(60/i2m, 50/i2m))
for binstats in binstats_list:
    ax.plot(dist_bins[1: - 1], binstats[1:, 0], 'k', label='average')
    #ax.fill_between(dist_bins[1: - 1], binstats[1:, 1], binstats[1:, 2], label='5%-95%')
ax.set_ylim(0.01, 1)
#ax.set_xlim(None, 1500)
#ax.set_xticks((0,0.1,0.2,0.4))
#ax.set_xscale('log')
#ax.set_yscale('log')
ax.set_xlabel('Distance (µm)')
ax.set_ylabel("Pearson's r")
plt.tight_layout()
plt.savefig(os.path.join(out,  'correaltions.pdf'))
#ax.legend()

In [None]:
for name,dset in all_dsets.items():
    res=dset['res']
    pars=dset['pars']
    print("Estimated Dim:" ,pars["dims_list"][np.argmax(res["r2_bcvpca"].mean(0))])
    print("Explained R2:",np.max(res["r2_bcvpca"].mean(0)))

In [None]:
##
name='20211105_1332_no_stimulus'
dset=all_dsets[name]
res=dset['res']
pars=dset['pars']

## Plot all in one
fig, ax = plt.subplots(1, 1, figsize=(80/i2m, 46/i2m))
print(pars['stripe_period'])
ax.plot(res['nnz'].mean(0), res["r2_voxelate"].mean(0), '.-', label='# nonzero voxels',c=c_voxel)
ax.fill_between(res['nnz'].mean(0), np.quantile(res['r2_voxelate'], 0.05, axis=0), np.quantile(res["r2_voxelate"], 0.95, axis=0),
                alpha=0.3)

ax.plot(np.array(pars["n_pred_list"]), np.array(res["r2_rand_pred"].mean(0)), '.-', label='# cells',c=c_random)
ax.fill_between(np.array(pars["n_pred_list"]), np.quantile(res["r2_rand_pred"], 0.05, axis=0),
                np.quantile(res["r2_rand_pred"], 0.95, axis=0), alpha=0.3)
# ax.plot(n_preds_list, r2_rand_pred, '.-', label='# predictor cells')

ax.plot(pars["dims_list"], np.array(res["r2_bcvpca"].mean(0)), '.-', label='# PCA dimensions',c=c_pca)
ax.fill_between(pars["dims_list"], np.quantile(res["r2_bcvpca"], 0.05, axis=0), np.quantile(res["r2_bcvpca"], 0.95, axis=0),color=c_pca, alpha=0.3)

ax.set(xlabel='# predictors', ylabel='R$^2$', xscale='log')
ax.set_xticks(10**np.arange(5))
plt.tight_layout()
plt.savefig(os.path.join(out,  'PCA_vs_cells_voxels.pdf'))
#plt.savefig(os.path.join(p_out, 'PCA_vs_cells_voxels.pdf'))
print(pars["dims_list"][np.argmax(res["r2_bcvpca"].mean(0))])
print(np.max(res["r2_bcvpca"].mean(0)))

In [None]:
fig, ax = plt.subplots(figsize=(86/i2m, 50/i2m))
fig.suptitle('regression of cellular activity against voxelized activity')
for name,dset in all_dsets.items():
    dset=all_dsets[name]
    res=dset['res']
    pars=dset['pars']
    print(pars['stripe_period'])
    ax.plot(pars["s_bins_dim"], res["r2_voxelate"][:].squeeze().mean(0), '.-',label=f"{name}",color='k')
    #plt.legend()
ax.set_xlim(0,None)
ax.set(xlabel='voxel side length (µm)', ylabel='R$^2$')
plt.tight_layout()
plt.savefig(os.path.join(out, 'voxelsize_R2_all.pdf'))
print(np.max(res["r2_voxelate"][:].squeeze().mean(0)))

In [None]:
fig, ax = plt.subplots(figsize=(86/i2m, 50/i2m))
#fig.suptitle('regression of cellular activity against voxelized activity')
#for name,dset in all_dsets.items():
name='20211105_1332_no_stimulus'
dset=all_dsets[name]
res=dset['res']
pars=dset['pars']
print(pars['stripe_period'])
ax.plot(pars["s_bins_dim"], res["r2_voxelate"][:].squeeze().mean(0), '.-',label=f"{name}",color='k')
#plt.legend()
ax.set_xlim(0,None)
ax.set(xlabel='voxel side length (µm)', ylabel='R$^2$')
plt.tight_layout()
plt.savefig(os.path.join(out, 'voxelsize_R2.pdf'))
print(np.max(res["r2_voxelate"][:].squeeze().mean(0)))

In [None]:
for name,dset in all_dsets.items():
    res=dset['res']
    pars=dset['pars']
    fig, ax =plt.subplots(figsize=(106/i2m, 106/i2m))
    #plt.title(f'randomly selected non-empty predictor voxels')
    res=dset['res']
    pars=dset['pars']

    for i in range(len(pars["n_pred_list"])):
        color = next(ax._get_lines.prop_cycler)['color']
        ax.fill_between(pars["s_bins"], np.quantile(res["R2s"][i], 0.05, axis=0), np.quantile(res["R2s"][i], 0.95, axis=0), color=color,
                        alpha=0.3, label='5%-95%' if i == 0 else None)
        ax.fill_between(pars["s_bins"], np.quantile(res["R2s_shuffled"][i], 0.05, axis=0), np.quantile(res["R2s_shuffled"][i], 0.95, axis=0),
                  color=color, alpha=0.1)
        ax.plot(pars["s_bins"], res["R2s_shuffled"][i].mean(0), ':', color=color, label='cells shuffled' if i == 0 else None, alpha=0.5)
        ax.plot(pars["s_bins"], res["R2s"][i].mean(0), label=f'{pars["n_pred_list"][i]} voxels', color=color)
    ax.set(ylim=(0, 0.5), xlim=(0, None), xlabel='voxel side length (µm)', ylabel='R$^2$')
    plt.legend()
    plt.savefig(os.path.join(out, f'{name}_constant_R2.pdf'))
    
    r0=res["R2s"].mean(1)
    print(np.nanmax(np.abs((r0-r0[:,0][:,None]))))


In [None]:
res.coords=res.coords*[1,1,-1]

In [None]:
name="20211105_1332_no_stimulus"

In [None]:
all_dsets[name]['res'].keys()

In [None]:
wb.visualize.scatter_brain(all_dsets[name]['res']['coords'][:], c=all_dsets[name]['res']['ipsi_cc'][:], vmin=0, vmax=0.4, cmap=magma_a,off=[-800,1500])
plt.title('correlation with neighborhood')
plt.savefig(os.path.join(out, 'ipsi.pdf'))

In [None]:
fig,ax=wb.visualize.scatter_brain(all_dsets[name]['res']['coords'][:], c=all_dsets[name]['res']['contra_cc'][:], vmin=0, vmax=0.5, cmap=magma_a,off=[-800,1500],plot_midline=True)

plt.title('correlation with neighborhood contra')
plt.savefig(os.path.join(out, 'conta.pdf'))

In [None]:
wb.scatter_brain(all_dsets[name]['res']['coords'][:], c=all_dsets[name]['res']['ipsi_cc'][:]-all_dsets[name]['res']['contra_cc'][:], vmin=0, vmax=0.1, cmap=magma_a,off=[-800,1500])
plt.title('correlation with neighborhood')
plt.savefig(os.path.join(out, 'ipsi_contra.pdf'))