In [1]:
%matplotlib inline
# %matplotlib widget

# Analysis of subjects

In [2]:
import csv
import os
import numpy as np
from tqdm import tqdm

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

sns.set(rc={'figure.facecolor':'white'})

In [3]:
from demapper.code.utils.analyze import extract_dataset, plot_results, plot_limits
from demapper.code.utils.config_cme import DATASETS, FILTERS, get_plot_columns, extract_params_f



### Run for one dataset

In [5]:
PICKED_DATASET = 'cmev4euc'

datadir = DATASETS[PICKED_DATASET]
stats_path = os.path.join(datadir, 'compute_stats-averaged.csv')
results_path = os.path.join(datadir, 'compute_degrees_from_TCM/combined-compute_degrees_from_TCM.csv')

filter_by = FILTERS[PICKED_DATASET][0]
df, max_values = extract_dataset(stats_path, results_path, filter_by, extract_params_f, has_sbj=False)

resdir = datadir
if len(FILTERS[PICKED_DATASET]) > 1:
    resdir = os.path.join(datadir, filter_by)
    os.makedirs(resdir, exist_ok=True)

print(max_values)
df.sample(5)

len(stats):  539
len(results):  539
{}


Unnamed: 0,Mapper,K,R,G,ChangePointsIndicesError,ChangePointsCount,assortativity-mean,assortativity-std,coverage_TRs-mean,coverage_TRs-std,...,degree_TRs_avg-mean,degree_TRs_avg-std,degree_TRs_entropy-mean,degree_TRs_entropy-std,distances_entropy-mean,distances_entropy-std,distances_max-mean,distances_max-std,hrfdur_stat-mean,hrfdur_stat-std
219,EucNeuMapper_28_220_34,28,220,34,9.857143,8.0,-0.036726,0.048154,0.967825,0.03865,...,83.401378,37.156836,6.908817,0.507282,3.024135,0.494365,23.111111,11.390857,0.09128,0.039381
328,EucNeuMapper_36_260_50,36,260,50,8.571429,8.0,-0.064762,0.124836,0.999399,0.000962,...,1777.849444,451.534446,9.562615,0.163344,1.659228,0.331847,6.5,2.007339,0.346187,0.067882
457,EucNeuMapper_48_180_34,48,180,34,10.714286,8.0,-0.084566,0.085587,0.990003,0.014266,...,208.012889,79.247691,8.103008,0.352344,2.236421,0.324564,11.944444,4.683792,0.13808,0.056013
227,EucNeuMapper_28_260_38,28,260,38,9.428571,8.0,-0.036798,0.060692,0.988747,0.015338,...,120.377011,28.67783,7.515664,0.367054,2.759768,0.489873,18.388889,8.374604,0.094544,0.02891
299,EucNeuMapper_36_100_46,36,100,46,5.857143,8.0,-0.042472,0.019234,0.999891,0.000318,...,2215.538889,393.982985,9.653462,0.077854,1.048827,0.202728,2.944444,0.416176,0.848275,0.100217


In [6]:

target_metrics = ['ChangePointsIndicesErrorRev', 'ChangePointsCount', 'coverage_nodes-mean', 'coverage_nodes-std', 'hrfdur_stat-mean', 'hrfdur_stat-std', 'distances_entropy-mean', 'distances_entropy-std']
log_metrics = ['ChangePointsIndicesErrorRev']

df['ChangePointsIndicesErrorRev'] = df.apply(
    lambda x: 1.0 / x['ChangePointsIndicesError'] if x['ChangePointsIndicesError'] > 0 else 100,
    axis=1)
# df['TransitionBetweenessRev'] = df.apply(lambda x: 1.0 / (x['TransitionBetweeness'] + 1), axis=1)


plot_results(df, 'SBJs_RKG', [], 'R', 'K', 'G', target_metrics, log_metrics, resdir=resdir)
plot_results(df, 'SBJs_KRG', [], 'K', 'R', 'G', target_metrics, log_metrics, resdir=resdir)

# for sbj_group_name, sbj_group in tqdm(sbjs_map.items()):
#     plot_results(df, sbj_group_name, sbj_group, fixedV, indexV, colV, target_metrics, log_metrics, resdir=resdir)
    

In [6]:
# Plot limits

target_metrics = {
    'ChangePointsIndicesError': [0, 12],
    'coverage_nodes-mean': [0.7, 1.0],
    'hrfdur_stat-mean': [0.15, 1.0],
    'distances_entropy-mean': [2.0, 10000.0]
}

fixedV, indexV, colV = get_plot_columns(filter_by)

plot_limits(df, 'SBJs_RKG', [], 'R', 'K', 'G', target_metrics, resdir=resdir)


## Recompute for all analysis that we have

In [None]:

target_metrics_detailed = ['ChangePointsIndicesErrorRev', 'ChangePointsCount', 'coverage_nodes-mean', 'coverage_nodes-std', 'hrfdur_stat-mean', 'hrfdur_stat-std', 'distances_entropy-mean', 'distances_entropy-std']
target_metrics = ['ChangePointsIndicesErrorRev', 'coverage_nodes-mean', 'hrfdur_stat-mean', 'distances_entropy-mean']
log_metrics = ['ChangePointsIndicesErrorRev']

target_metrics_limits = {
    'ChangePointsIndicesError': [0, 12],
    'coverage_nodes-mean': [0.7, 1.0],
    'hrfdur_stat-mean': [0.15, 1.0],
    'distances_entropy-mean': [2.0, 10000.0]
}

for dataset_name in DATASETS.keys():
    if '_fast' not in dataset_name:
        continue
    print('======= Processing', dataset_name)
    datadir = DATASETS[dataset_name]
    stats_path = os.path.join(datadir, 'compute_stats-averaged.csv')
    results_path = os.path.join(datadir, 'compute_degrees_from_TCM/combined-compute_degrees_from_TCM.csv')

    try:
        for filter_by in FILTERS[dataset_name]:
            resdir = datadir
            if len(FILTERS[dataset_name]) > 1:
                resdir = os.path.join(datadir, filter_by)
                os.makedirs(resdir, exist_ok=True)
            df, max_values = extract_dataset(stats_path, results_path, filter_by, extract_params_f, has_sbj=False)

            fixedV, indexV, colV = get_plot_columns(filter_by)
            df['ChangePointsIndicesErrorRev'] = df.apply(lambda x: 1.0 / x['ChangePointsIndicesError'] if x['ChangePointsIndicesError'] > 0 else 10, axis=1)

            plot_results(df, 'SBJs', [], fixedV, indexV, colV, target_metrics, log_metrics, resdir=resdir)
            plot_results(df, 'SBJs_detailed', [], fixedV, indexV, colV, target_metrics_detailed, log_metrics, resdir=resdir)
            plot_limits(df, 'SBJs', [], fixedV, indexV, colV, target_metrics_limits, resdir=resdir)
    except Exception as err:
        print("Warning! Didn't process '{}' because:".format(dataset_name))
        print(err)
        raise err

len(stats):  640
len(results):  640


### Create links to download and then create a grid

In [5]:

REMOTE_DATASET_PATHS = {
    'ss_w3cv1': '/scratch/groups/saggar/demapper-w3c/mappers_w3cv1.json',
    'ss_w3cv2': '/scratch/groups/saggar/demapper-w3c/mappers_w3cv2.json',
    'ss_w3cv3': '/scratch/groups/saggar/demapper-w3c/mappers_w3cv3.json',
    'ss_w3cv4': '/scratch/groups/saggar/demapper-w3c/mappers_w3cv4.json',
    'ss_w3cv5': '/scratch/groups/saggar/demapper-w3c/mappers_w3cv5dist.json',
    'ss_w3cv6': '/scratch/groups/saggar/demapper-w3c/mappers_w3cv6dist.json'
}

dataset_id = 'ss_w3cv3'
local_dataset_path = DATASETS[dataset_id]
os.makedirs(os.path.join(local_dataset_path, 'mappers_plots'), exist_ok = True)
os.makedirs(os.path.join(local_dataset_path, 'mappers_combined'), exist_ok = True)
remote_dataset_path = REMOTE_DATASET_PATHS[dataset_id]
remote_analysis_path = REMOTE_DATASET_PATHS[dataset_id].replace('mappers_w3c', 'analysis/mappers_w3c')



KeyError: 'ss_w3cv3'

In [130]:
mappers = []
k = 20
for r in range(10, 50, 10):
    for g in range(50, 100, 10):
        mappers.append({
            'mapper': 'BDLMapper_{k}_{r}_{g}'.format(k=k, r=r, g=g),
            'R': r,
            'G': g,
        })
        

sbjs = ['SBJ99', 'SBJ99-83.0', 'SBJ99-75.0', 'SBJ99-50.0']

files = ['plot_task-G.png']


In [131]:
print("###### Copy the files to mappers_plots")

print('mkdir {}/mappers_plots'.format(remote_analysis_path))

table = []
for s in sbjs:
    for m in mappers:
        for f in files:
            item = m.copy()
            item['sbj'] = s
            item['file'] = f
            table.append(item)
            print('cp {ds_path}/{s}/{m}/{f} {an_path}/mappers_plots/{s}_{m}_{f}'.format(
                ds_path=remote_dataset_path,
                an_path=remote_analysis_path,
                s=s, m=m['mapper'], f=f))
            
print("###### Copy the files locally")
print('scp -r "hasegan@login.sherlock.stanford.edu:{}/mappers_plots/*" {}'.format(
    remote_analysis_path,
    os.path.join(local_dataset_path, 'mappers_plots/')))

###### Copy the files to mappers_plots
mkdir /scratch/groups/saggar/demapper-w3c/analysis/mappers_w3cv3.json/mappers_plots
cp /scratch/groups/saggar/demapper-w3c/mappers_w3cv3.json/SBJ99/BDLMapper_20_10_50/plot_task-G.png /scratch/groups/saggar/demapper-w3c/analysis/mappers_w3cv3.json/mappers_plots/SBJ99_BDLMapper_20_10_50_plot_task-G.png
cp /scratch/groups/saggar/demapper-w3c/mappers_w3cv3.json/SBJ99/BDLMapper_20_10_60/plot_task-G.png /scratch/groups/saggar/demapper-w3c/analysis/mappers_w3cv3.json/mappers_plots/SBJ99_BDLMapper_20_10_60_plot_task-G.png
cp /scratch/groups/saggar/demapper-w3c/mappers_w3cv3.json/SBJ99/BDLMapper_20_10_70/plot_task-G.png /scratch/groups/saggar/demapper-w3c/analysis/mappers_w3cv3.json/mappers_plots/SBJ99_BDLMapper_20_10_70_plot_task-G.png
cp /scratch/groups/saggar/demapper-w3c/mappers_w3cv3.json/SBJ99/BDLMapper_20_10_80/plot_task-G.png /scratch/groups/saggar/demapper-w3c/analysis/mappers_w3cv3.json/mappers_plots/SBJ99_BDLMapper_20_10_80_plot_task-G.png
cp /s

In [1]:
from PIL import Image
from matplotlib import pyplot as plt
from tqdm import tqdm

sns.set(style = "whitegrid")

def create_axis(df, axis_hps, idx=0):
    if idx == len(axis_hps):
        return [{}]
    hp_name = axis_hps[idx]
    hp_vals = df[hp_name].drop_duplicates().to_numpy()
    vals = []
    hp_vals = sorted(sorted(hp_vals), key=lambda x: x != 'none')
    for hpv in hp_vals:
        df_hp = df[df[hp_name] == hpv]
        val = create_axis(df_hp, axis_hps, idx+1)
        if len(hp_vals) > 1:
            for v in val:
                v[hp_name] = hpv
        vals.extend(val)
    return vals


def get_item_path(basedir, item):
    return os.path.join(
        basedir, '{s}_{m}_{f}'.format(
            s=item['sbj'].tolist()[0],
            m=item['mapper'].tolist()[0],
            f=item['file'].tolist()[0]))


def plot_image(img_path, ax):
    im = Image.open(img_path)
    img = np.array(im)
    ax.imshow(img)
    del img
    del im
    return ax

def plot_grid(df, sbj, hps_x, hps_y, hps_fixed_vals, analysis_path, regen=False, text_f=None, show_plot=False):
    ax_x = create_axis(df, hps_x)
    ax_y = create_axis(df, hps_y)
    
    x_keys = sorted(list(set([k for x in ax_x for k,v in x.items()])))
    y_keys = sorted(list(set([k for y in ax_y for k,v in y.items()])))
    fixed_toks = sorted(['{}_{}'.format(k,v) for k,v in hps_fixed_vals.items()])
    savefig_path = os.path.join(
        os.path.join(analysis_path, 'mappers_combined'),
        'comb{}_{}.png'.format(
            sbj,
            '_'.join(['x'] + x_keys + ['y'] + y_keys + fixed_toks)))
    
    if os.path.isfile(savefig_path) and not regen:
        return

    ncols = len(ax_x)
    nrows = len(ax_y)
    fsize = 20
    fig, axr = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fsize * (ncols/nrows)*1.1,fsize))

    for r,(y,axc) in enumerate(zip(ax_y,axr)):
        for c,(x,ax) in enumerate(zip(ax_x,axc)):
            K = {'sbj':sbj, **x, **y, **hps_fixed_vals}
            df_filtered = df
            for k,v in K.items():
                df_filtered = df_filtered[df_filtered[k] == v]
            if len(df_filtered) == 0:
                ax.axis('off')
                continue
            assert len(df_filtered) == 1, 'Too many ({}) items found for {}, {}, {}'.format(
                len(df_filtered), x, y, hps_fixed_vals)
            item = df_filtered
            sbj = item['sbj'].to_numpy()[0]

            ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['bottom'].set_visible(False)
            ax.spines['left'].set_visible(False)
            ax.grid(False)
            if r == len(ax_y)-1:
                ax.set_xlabel(' '.join([str(x[k]) for k in hps_x if k in x]))
                ax.xaxis.set_label_position("bottom")
            if c == 0:
                ax.set_ylabel(' '.join([str(y[k]) for k in hps_y if k in y]))

            img_path = get_item_path(os.path.join(analysis_path, 'mappers_plots'), item)
            if not os.path.isfile(img_path):
    #             ax.axis('off')
                print('Skipping as plot not found at {}'.format(img_path))
                continue
        
#             modd = df_filtered[['preembed_mod_full', 'preembed_mod_wout_ac', 'embed_modd_bin', 'embed_modd_wout_ac']].to_numpy()[0]
#             modd = [str(round(m,3)) for m in modd]
#             title = ' -> '.join([', '.join(modd[0:2]), ', '.join(modd[2:4])])
#             ax.set_title(title)
            if text_f:
                ax = text_f(ax, item)
            ax = plot_image(img_path, ax)
    plt.tight_layout()
    
    plt.savefig(savefig_path, dpi=150)
    if show_plot:
        plt.show()
    plt.close()
    
# plot_grid(df, 'SBJ01', hps_x, hps_y, {'knnparam_gains': 1.0})


NameError: name 'sns' is not defined

In [138]:
df = pd.DataFrame(data=table)

for sbj in tqdm(sbjs):
    plot_grid(df, sbj, ['R'], ['G'], {}, local_dataset_path, regen=True)


100%|█████████████████████████████████████████████████████████████████████████████████| 4/4 [00:11<00:00,  3.00s/it]


In [139]:
local_dataset_path

'/Users/dh/workspace/BDL/demapper/results/w3c_ss/analysis/mappers_w3cv3.json/'