In [None]:
from wfield import SVDStack
from wfield import *
from wfield.local_nmf import compute_locaNMF
%matplotlib widget


localdisk = '/home/data/JC111/20230520_164209/wfield/'  # the results folder

U = np.load(pjoin(localdisk,'U.npy'))
SVT = np.load(pjoin(localdisk,'SVTcorr.npy'))
mask = np.load(pjoin(localdisk,'mask.npy'))
lmarksfile = glob(pjoin(localdisk,'*landmarks*.json'))
lmarks = load_allen_landmarks(lmarksfile[0])


def get_U_atlas(U,M):
    U = U.copy()
    U[:,0,:] = 1e-10
    U[0,:,:] = 1e-10
    U[-1,:,:] = 1e-10
    U[:,-1,:] = 1e-10

    # transpose U
    return np.stack(runpar(im_apply_transform, U.transpose([2,0,1]),
                           M = M)).transpose([1,2,0]).astype(np.float32)
    
Uatlas = get_U_atlas(U,M = lmarks['transform'])
# load the mask 
mask = im_apply_transform(mask.astype('int8'), M = lmarks['transform'])

# load the atlas
atlas, areanames, brain_mask = atlas_from_landmarks_file(lmarksfile[0],do_transform=False)

mask = (mask>0) & (atlas != 0)  
atlas[mask==0] = 0   # this will discard the masked areas from the atlas


In [None]:
def im_apply_affine(im,transform):
    W,H = im.shape
    M = transform.params[:2,:]
    return cv2.warpAffine(im, M, (H, W),cv2.WARP_INVERSE_MAP)
def get_U_atlas(U,M):
    U = U.copy()
    U[:,0,:] = 1e-10
    U[0,:,:] = 1e-10
    U[-1,:,:] = 1e-10
    U[:,-1,:] = 1e-10

    # transpose U
    return np.stack(runpar(im_apply_affine, U.transpose([2,0,1]),
                           transform = M)).transpose([1,2,0]).astype(np.float32)
    
Uatlas = get_U_atlas(U,M = lmarks['transform_inverse'])

In [None]:
# play the stack
plt.figure()
nb_play_movie(SVDStack(Uatlas,SVT),clim = [-0.1,0.1])

In [None]:
# Run semiNMF, seed region is the entire brain mask
Asemi,Csemi,regions_semi = compute_locaNMF(Uatlas,
                                           SVT,
                                           mask.astype('int8'), # this is the seed, in this case the whole dorsal cortex
                                           mask,
                                           minrank = 1, 
                                           maxrank = 200, 
                                           min_pixels = 100,
                                           loc_thresh = 1, 
                                           r2_thresh = 0.99)

# # Run locaNMF, seed region is the CCF atlas
A,C,regions = compute_locaNMF(Uatlas,
                              SVT,
                              atlas, # seeds the atlas
                              mask,
                              minrank = 1, 
                              maxrank = 20, 
                              min_pixels = 100,
                              loc_thresh = 60, 
                              r2_thresh = 0.99)


In [None]:
plt.figure()
nmf = SVDStack(A,C)
nb_play_movie(nmf,clim=[-0.06,0.06],cmap='inferno')

In [None]:
import torch # clear the GPU
torch.cuda.empty_cache()

In [None]:
# show the spatial components. Note that there should be a metric to ditch components that are small
plt.figure()
nb_play_movie(A.transpose(2,0,1),clim=[0,1],cmap='hot')

In [None]:
# Incomplete example of how to plot the correlations
# Preprocess C to remove nans
areas = regions
areainds = np.unique(regions)
from sklearn.cross_decomposition import CCA

keepinds=np.nonzero(np.sum(np.isfinite(C),axis=0))[0]
C=C[:,keepinds]
corrmat=np.zeros((len(areainds),len(areainds)))
skipinds=[]
for i,area_i in enumerate(areainds):
    for j,area_j in enumerate(areainds):
        if i==0 and area_j not in areas:
            skipinds.append(j)
        C_i=C[np.where(areas==area_i)[0],:].T
        C_j=C[np.where(areas==area_j)[0],:].T
        if i not in skipinds and j not in skipinds:
            cca=CCA(n_components=1)
            cca.fit(C_i,C_j)
            C_i_cca,C_j_cca=cca.transform(C_i,C_j)
            try:
                C_i_cca=C_i_cca[:,0]
            except:
                pass
            try:
                C_j_cca=C_j_cca[:,0]
            except:
                pass               
            corrmat[i,j]=np.corrcoef(C_i_cca,C_j_cca)[0,1]
corrmat=np.delete(corrmat,skipinds,axis=0); 
corrmat=np.delete(corrmat,skipinds,axis=1);
corr_areanames=np.delete(areanames,skipinds)
print('plotting correlations',flush=True)
fig=plt.figure(figsize=(3,3))
plt.imshow(corrmat,cmap=plt.cm.get_cmap('jet')); plt.colorbar(shrink=0.8)
plt.get_cmap('jet')
# this map needs to be adjusted to display the area names
# plt.xticks(ticks=np.arange(len(areainds)-len(skipinds)),labels=corr_areanames,rotation=90); 
# plt.yticks(ticks=np.arange(len(areainds)),labels=corr_areanames); 
plt.title('CCA between all regions',fontsize=12)
plt.xlabel('Region i',fontsize=10)
plt.ylabel('Region j',fontsize=10)


In [None]:
# show the areas that have components
tt = A.transpose(2,0,1).copy()
tt[~np.isfinite(tt)] = 0

T = im_argmax_hsv(tt)
plt.figure()
plt.imshow(T)

In [None]:
# show activity overlayed in the atlas.. (not very useful)
plt.figure()
stack = SVDStack(Uatlas,SVT)
nb_play_movie(stack,clim=[-0.1,0.1])
plt.imshow(atlas,alpha=0.4)