In [None]:
!wget https://owncloud-ext.charite.de/owncloud/index.php/s/0fe625BKTbPy5V5/download -O data.h5 -q

In [3]:
import importlib 

In [1]:
import h5py as h5
import numpy as np
from tqdm.auto import tqdm
from wholebrain import util,cluster,spatial,regression
from functools import partialmethod
tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
importlib.reload(util)

<module 'wholebrain.util' from '/home/hoffmmax/repos/hoffmann_et_al_2023/src/wholebrain/util.py'>

In [5]:
def load_dset(name):
    with h5.File('./data.h5','r') as fh5:
        return fh5[name]['coords'][:],fh5[name]['dff'][:],fh5[name]['traces'][:]

In [6]:
data={}
with h5.File('./data.h5','r') as fh5:
    names=[n for n in fh5.keys()]
    print('Datasets:', names)

Datasets: ['20211027_1238_no_stimulus', '20211105_1332_no_stimulus', '20211209_1205_no_stimulus', '20220610_1532_No_stimulus', '20221124_1003_No_stimulus', '20221127_1643_no_stimulus']


In [7]:
n_repeats=1
alpha_repeats=1

pars={"global": {"stripe_period":300,
                 "val_fraction":0.2,
                 "test_fraction":0.2,
                 "target_fraction":0.1,
                 "alphas":np.geomspace(1,10000,20).tolist(),
                "alpha_repeats":alpha_repeats
                 },
      "distortion": 
     { "do": True, "split":2},
     "correlation": 
     { "do": True, "ringlims":[25,50]},
      "reconstruct": 
     { "do": True,"num":500,"s_bins":[1,100]},
     "pca_reg": 
     {"do":True, "dims_list":np.geomspace(1, 1500, 200).tolist(),"n_repeats":n_repeats},
     "random_predictors": 
     {"do":True, "n_pred_list": [100,200,400,1000,2000,5000,10000,20000],"n_repeats":n_repeats},
     "voxelate":
     {"do":True, "s_bins_dim": [5,10,25,50,100,150,200,500],"n_pred_list": [100,200,400,1000,2000,5000,10000,20000],"n_repeats":n_repeats},
     "r2_scan":
     {"do":True, "s_bins":[5,10,25,50,100,150,200,500],"n_pred_list": [100,200,400,1000,2000,5000,10000,20000],"alphas":np.geomspace(1,10000,20).tolist(),"n_repeats":n_repeats},
     "snr":
     {"do":False, "snr_level": [20, 30, 50, 60, 70, 80] ,"alphas":np.geomspace(50,1000,20).tolist(),"n_repeats":n_repeats},
     "random_projections":
     {"do":False, "alphas":np.geomspace(0.1, 1e9, 10).tolist(),"n_repeats":n_repeats},
      "corr_fc":
     {"do":True, "min_dist":400, "radius_list": [5,10,15,25,50,100,150,200], "save_full_at":[5,50,100,200],"dec_factor":20}
     } 

array([False, False, False, ...,  True,  True,  True])

In [8]:
coords,dff,traces=load_dset(names[0])
dff_mu=(dff-np.nanmean(dff,0,keepdims=True))[~np.isnan(dff.sum(1))] # Centered Data without NANs ( motion)

cv_train,cv_test,cv_val=util.create_crossvalidation_mask(dff_mu)
n_targets = int(dff_mu.shape[1] * pars["global"]["target_fraction"])
max_npred=(1-pars['global']['target_fraction'])*dff_mu.shape[1]

### Bi-crossvalidated PCA

In [9]:
r2_bcvpca = util.pca_run(dff_mu, cv_test + cv_train,pars["pca_reg"]["dims_list"],n_targets=n_targets, n_repeats=pars["pca_reg"]["n_repeats"],)


#### Determination of regularization parameter for ridge regression

In [None]:
print("Alpha Scan for regression")
   
alphas=pars["global"]["alphas"]
r2_l=[]
s_bin=pars["voxelate"]["s_bins_dim"]

for alpha in alphas:
        nnz, r2_voxelate = util.voxelate_regression(
            [s_bin[len(s_bin)//2]],
            dff_mu,
            coords,
            cv_test,
            n_targets=n_targets,
            n_repeats=pars["global"]["alpha_repeats"],
            alpha=alpha,
        )
        r2_l.append(r2_voxelate.mean()) 
        
#Update global alpha parameter
pars["global"]["alpha"]=alphas[np.argmax(r2_l)]

#### Voxelized regression with determined alpha parameter

In [56]:
nnz, r2_voxelate = util.voxelate_regression(
    pars["voxelate"]["s_bins_dim"],
    dff_mu,
    coords,
    cv_test + cv_train,
    n_targets=n_targets,
    n_repeats=pars["voxelate"]["n_repeats"],
    alpha=pars["global"]["alpha"],
)
       


#### Random Predictors

In [58]:

r2_rand_pred = util.ridge_random(
    dff_mu,
    pars["random_predictors"]["n_pred_list"],
    cv_test + cv_train,
    n_repeats=pars["random_predictors"]["n_repeats"],
    alpha=pars["global"]["alpha"],
    n_targets=n_targets,
)


#### De-coupling of voxel size and number of predictors.
- Determine alpha for each combination of #pred and voxel size
- Determine R2 on validation set

In [None]:
R2s = util.voxelate_alpha_scan(
        pars["r2_scan"]["s_bins"],
        pars["r2_scan"]["n_pred_list"],
        coords,
        dff_mu,
        cv_test,
        cv_train,
        pars["r2_scan"]["alphas"],
        n_repeats=pars["global"]["alpha_repeats"],
        n_targets=n_targets,
    )
alphas_max = R2s.mean("repeats").idxmax("alphas")

R2s_val, batch_id =  util.voxelate_all(
    pars["r2_scan"]["s_bins"],
    pars["r2_scan"]["n_pred_list"],
    coords,
    dff_mu,
    cv_train + cv_test,
    alphas_max,
    n_repeats=pars["r2_scan"]["n_repeats"],
    n_targets=n_targets,
)

#R2s.to_netcdf(os.path.join(p_out, "R2s.netcdf"))
#batch_id.to_netcdf(os.path.join(p_out, "batch_id.netcdf"))
#R2s_val.to_netcdf(os.path.join(p_out, "R2s_val.netcdf"))



#### Reconstruct traces from limited number of predictors

In [None]:

cix = np.arange(dff_mu.shape[1])
target_cells = np.random.choice(cix, size=n_targets, replace=False)
predictor_cells = np.setdiff1d(cix, target_cells)
offset = np.random.rand()
res.clix=cluster.embed1D.gpu(dff_mu)
n= pars["reconstruct"]["num"]
res.traces_recon=[]
for s in pars["reconstruct"]["s_bins"]:
    X = spatial.voxelate(res.coords[predictor_cells] + offset * s, dff_traces_m[:, predictor_cells], s)[0]
    indx=np.random.choice(X.shape[1],n)
    X_sub=X[:,indx]
    W,r2=regression.ridgeCV(X_sub, dff_traces_m, cv_train+cv_test, alpha=pars["global"]["alpha"])
    res.traces_recon.append(X_sub@W)

#### Compute distortion measures

In [None]:
sim_val = wb.legacy.distortion(
    pars["r2_scan"]["s_bins"],
    pars["r2_scan"]["n_pred_list"],
    res.coords,
    dff_traces_m,
    n_repeats=pars["r2_scan"]["n_repeats"],
    n_targets=dff_traces_m.shape[1]//pars["distortion"]["split"],
)
sim_val.to_netcdf(os.path.join(p_out, "similarity.netcdf"))