# (03) Marija

**Motivation**: notebook to generate results.

### Note

Run the ```Imports``` and ```Setup config``` cells below everytime, before running anything else.
- Say you want to run ```Fit network```:
    - Run ```Imports``` >> ```Setup config``` >> ```Fit network```
- Or maybe you want to run ```Fit SVINET```:
    - Run ```Imports``` >> ```Setup config``` >> ```Fit SVINET```
    
...and so on.
<br>

## Imports

In [1]:
# HIDE CODE


import os, sys
from IPython.display import display

# Code path + imports
code_path = '/home/mm4347/Desktop/multicell/MarijaHadi/_Ca-fMRI-Aug03'
sys.path.insert(0, code_path)
from analysis.bootstrap import *
from figures.fighelper import *
from utils.render import *

# warnings, tqdm, & style
warnings.filterwarnings('ignore', category=DeprecationWarning) #
from tqdm.notebook import tqdm
%matplotlib inline
set_style()

## Setup config

In [2]:
def get_config(**kwargs):
    _cfg, cfg = Config(**kwargs), Config(**kwargs)
    _replacement = '/ca2data3/lake_lab_analyses/MarijaHadiAnalysis/Ca-fMRI'
    dirs = [
        k for k in dir(cfg)
        if '_dir' in k and
        isinstance(getattr(cfg, k), str)
    ]
    for item in dirs:
        new = getattr(_cfg, item).replace(
            _cfg.base_dir, _replacement)
        setattr(cfg, item, new)
    cfg.svinet_dir = '/ca2data3/lake_lab_analyses/MarijaHadiAnalysis/svinet'
    cfg.download_dir, cfg.raw_dir, cfg.tx_dir, cfg.warped_dir = None, None, None, None
    cfg._get_all_dirs()
    return cfg

In [None]:
cfg = get_config(nn=512,makedirs=False) #512 no of ROIs per hemisphere in cortex 
cfg.all_dirs

In [None]:
mice = Mice(cfg)

## Fit network

In [None]:
task = 'rest'
bands = [None, (0.01, 0.5)]
# frequency bands - none no bandpassed at all the (0.01, 0.5) is frequency band from the paper

In [None]:
# props are details that go into network 
props = {
    'mice': mice,
    'mode': 'ca2',
    'metric': 'pearson',
    'percentiles': np.linspace(25, 10, 4).astype(int),
    'prep_data': True,
    'binarize': False,
    'verbose': False,
}
kws = {
    'task': task,
    'desc_ca2': 'preproc',
    'runs_only': True,
    'exclude': True,
}

In [None]:
#pbar is progress bar: for each frequency band run for each animal a net.fit.network 

pbar1 = tqdm(
    bands,
    leave=True,
    total=len(bands),
)
for b in pbar1:
    kws['band_ca2'] = b
    mice.setup_func_data(**kws)
    pbar2 = tqdm(
        mice.get_data_containers('ca2')[0],
        leave=False,
    )
    for key in pbar2:
        msg = f"running net (task-{task}, "
        msg += f"b-{str(b).replace(' ', '')})"
        msg += f": {key}"
        pbar1.set_description(msg)
        # fit network
        net = Network(key=key, **props)
        net.fit_network( force=True, full=False, save=True)
        
#key = is mouse ID i.e., SLC...

## Apply SVINET (bash)

- Open a terminal
- Navigate to where the code is located and cd to the ```scripts``` directory, i.e. go here:
    - ```.../_Ca-fMRI-Aug02/scripts```
- Run the following code:

```bash
time ./loop_svinet.sh 512 3 7 p15-sample 500 ca2 preproc [0-9.,]+ rest
```

## Fit SVINET

In [None]:
num_k = 7
perc = 'p15-sample'
bands = [None, (0.01, 0.5)]
task = 'rest'

In [None]:
props = {
    'mice': mice,
    'num_k': num_k,
    'perc': perc,
    'mode': 'ca2',
    'metric': 'pearson',
    'match_metric': 'euclidean',
    'match_using': 'gam',
    'graph_type': 'real',
    'verbose': False,
}
kws = {
    'task': task,
    'desc_ca2': 'preproc',
    'runs_only': True,
    'exclude': True,
}

In [None]:
pbar1 = tqdm(bands)
for b in pbar1:
    kws['band_ca2'] = b
    mice.setup_func_data(**kws)
    pbar2 = tqdm(
        mice.get_data_containers('ca2')[0],
        leave=False,
    )
    for key in pbar2:
        msg = f"running svinet (task-{task}, "
        msg += f"b-{str(b).replace(' ', '')})"
        msg += f": {key}"
        pbar1.set_description(msg)
        # fit svinet        
        sv = SVINET(key=key, **props)
        sv.fit_svinet()

## Fit Group results

In [None]:
num_k = 7
perc = 'p15-sample'
bands = [None, (0.01, 0.5)]
task = 'rest'

In [None]:
props = {
    'mice': mice,
    'num_k': num_k,
    'perc': perc,
    'mode': 'ca2',
    'metric': 'pearson',
    'sv_props': {
        'match_metric': 'euclidean',
        'match_using': 'gam'},
    'graph_type': 'real',
    'dist_metric': 'cosine',
    'ordering': 'infer',
    'verbose': False,
}
kws = {
    'task': task,
    'desc_ca2': 'preproc',
    'runs_only': True,
    'exclude': True,
}

In [None]:
pbar = tqdm(bands)
for b in pbar:
    msg = f"running group (task-{task}, "
    msg += f"b-{str(b).replace(' ', '')}, kk-{num_k})"
    pbar.set_description(msg)
    # fit group
    mice.setup_func_data(band_ca2=b)
    gr = Group(**props)
    if gr.svinets_exist():
        gr.fit_group(force=False)

## Plot results

In [None]:
task = 'rest'
desc = 'preproc'
bands = {
    (0.0, 5.0): 'Unfiltered',
    (0.01, 0.5): CAS,
}
bs = Base(mice, mode='ca2')

### Load results

In [None]:
results = {}
for b in bands:
    mice.set_kws(band_ca2=b, desc_ca2=desc, task=task)
    gr = Group(mice=mice, mode='ca2', perc='p15-sample', num_k=7)
    results[f"{bands[b]}\n{b}"] = gr.avg(ndim_start=1)
results_avg = {
    k: avg(v) for k, v
    in results.items()
}

### Align

In [None]:
src_key = (0.01, 0.5)
src_key = '\n'.join([
    bands[src_key],
    str(src_key),
])
global_order = 'infer'
match_metric = 'correlation'
final_perm = {4: 5, 5: 6, 6: 4}
final_perm = {
    i: final_perm.get(i, i) for
    i in range(props['num_k'])
}

# get gmap
gmap = find_best_mappings(
    data=results_avg,
    centroids=results_avg[src_key],
    match_metric=match_metric,
    global_order=global_order[0],
)
if final_perm:
    for k, v in gmap.items():
        gmap[k] = {
            s: v[t] for s, t in
            final_perm.items()
        }
    gmap = {
        k: gmap.get(k, final_perm)
        for k in results_avg
    }
    
# apply gmap
pi_sorted = {
    k: v[:, list(gmap[k].values()), :]
    for k, v in results.items()
}
pi_sorted_avg = {
    k: avg(v) for k, v
    in pi_sorted.items()
}

### Show

In [None]:
fig, axes = bs.show(pi_sorted_avg, **show_kws(num_k))