# Dec 30th, 2022 (AD: full brain)

**Motivation**: SVINET results on full brain AD data. Which one is better: run or session level? <br>

In [1]:
# HIDE CODE


import os
import sys
from pprint import pprint
from copy import deepcopy as dc
from os.path import join as pjoin
from scipy.ndimage import gaussian_filter
from IPython.display import display, IFrame, HTML

# tmp & extras dir
git_dir = pjoin(os.environ['HOME'], 'Dropbox/git')
extras_dir = pjoin(git_dir, 'jb-Ca-fMRI/_extras')
fig_base_dir = pjoin(git_dir, 'jb-Ca-fMRI/figs')
tmp_dir = pjoin(git_dir, 'jb-Ca-fMRI/tmp')

# GitHub
sys.path.insert(0, pjoin(git_dir, '_Ca-fMRI'))
from figures.fighelper import *
from analysis.final import *
from utils.render import *

# warnings, tqdm, & style
warnings.filterwarnings('ignore', category=DeprecationWarning)
from tqdm.notebook import tqdm
%matplotlib inline
set_style()

In [2]:
def decode_key(k):
    ssr = k.split('_')
    if len(ssr) == 3:
        sub, ses, run = ssr
    elif len(ssr) == 2:
        sub, ses = ssr
        run = None
    elif len(ssr) == 1:
        sub = ssr[-1]
        ses = None
        run = None
    if ses is not None:
        ses = int(ses.split('-')[-1])
    if run is not None:
        run = int(run.split('-')[-1])
    sub = sub.split('-')[-1]
    g = re.findall(f"[a-zA-Z]+", sub).pop()
    try:
        s = int(re.findall(f"[^a-zA-Z]+", sub).pop())
    except IndexError:
        s = None
    return g, s, ses, run



def do_group_tmp(data_dict, sessions: List[int], match_metric: str = 'correlation'):
    if not isinstance(sessions, Iterable):
        sessions = [sessions]
    data_list = list(itertools.chain.from_iterable([
        data_dict[ses] for ses in sessions]))
    data = np.concatenate([
        flatten_arr(x) for x in data_list])
    nonan = np.where(np.isnan(data).sum(0) == 0)[0]
    
    centroids = np.zeros((num_k, data.shape[-1]))
    _centroids, _ = fit_kmeans(
        data=data[:, nonan],
        n_clusters=num_k,
        match_metric='euclidean',
        random_state=42,
        kw_kmeans={
            'n_init': 10,
            'max_iter': 300,
            'tol': 0.0001},
    )
    centroids = np.zeros((num_k, data.shape[-1]))
    centroids[:, nonan] = _centroids
    
    global_mapping = bs.align_centroid_to_structs(
        pi_mv=centroids,
        metric='cosine',
        global_order=False,
    )
    centroids = centroids[global_mapping]
    
    pi = []
    for item in data_list:
        good = np.where(np.isnan(avg(item)).sum(0) == 0)[0]
        good = set(good).intersection(nonan)
        good = np.array(sorted(good))
        _x = item[..., good]
        _x = np.where(~np.isnan(_x), _x, 0.0)

        aligned = []
        for i, u in enumerate(_x):
            dist = sp_dist.cdist(
                XA=centroids[:, good],
                XB=u,
                metric=match_metric,
            )
            _, col_ind = sp_optim.linear_sum_assignment(dist)
            aligned.append(np.expand_dims(item[i][col_ind], 0))
        pi.append(np.concatenate(aligned))
    pi = np.concatenate(pi)
    pi /= np.nansum(pi, -2, keepdims=True)
    return pi

## Prepare

In [3]:
mice = Mice(Config(128, resolution=25), load_parcel=True)
mice.set_band(band_bo=(0.008,0.28))

template, _ = mice.al.mcc.get_template_volume()
template = template.astype('uint32')
root = mice.al.get_masks('root')
brn = mice.parcel['brain'][:]
region_idxs = unique_idxs(brn)

### Load node2id

In [4]:
n2i = np.load(pjoin(tmp_dir, 'AD_n2i.npy'), allow_pickle=True).item()
i2n = {i: n for n, i in n2i.items()}

n2l = {}
for n, i in n2i.items():
    info = mice.parcel.get(i)
    if info['acro'] == 'Isocortex':
        n2l[n] = f"{info['hemis']}-{info['region']}-{info['layer']}"
    else:
        n2l[n] = f"{info['hemis']}-{info['region']}"
mice.node_lookup['bold'] = n2l
mice.node_lookup['ca2'] = {
    n: lbl for n, lbl in n2l.items()
    if len(lbl.split('-')) == 3
}
bs = Base(mice, mode='ca2')

### Extract all keys

In [5]:
keys = []
for f in os.listdir(mice.cfg.svinet_dir):
    if f"n-{mice.cfg.nn}*{mice.cfg.ll}" not in f:
        continue
    key = f.split('_')
    try:
        i = next(
            i for i, e in
            enumerate(key)
            if 'sub' in e
        )
    except StopIteration:
        continue
    keys.append('_'.join(key[i:]))
keys = sorted(keys)
len(keys)