In [1]:
%matplotlib inline
# %matplotlib widget

In [2]:
import csv
import os
import numpy as np

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

sns.set(rc={'figure.facecolor':'white'})

### Helper Functions

In [3]:
# Given a dataset stats and results path, load in the DataFrame, merge and filter

def extract_dataset(stats_path, results_path, filter_by, has_sbj=True):
    if type(filter_by) == list:
        assert len(filter_by) == 1, "Can only extract 1 type of mapper"
        filter_by = filter_by[0]
    stats = []
    with open(stats_path) as f:
        for row in csv.DictReader(f):
            obj = {}
            for k,v in row.items():
                if k == 'mapper':
                    obj[k] = v
                elif k.startswith('id'):
                    pass
                elif v == '':
                    obj[k] = 0
                else:
                    obj[k] = float(v)
            if has_sbj:
                obj['SBJ'] = row['id0'] # TODO! This only works for this type of data (fix!)
            stats.append(obj)
    print('len(stats): ', len(stats))


    MAX_INT = 100000
    results = []
    with open(results_path) as f:
        for row in csv.DictReader(f):
            obj = {}
            for k,v in row.items():
                if k == 'Mapper' or k == 'subject':
                    obj[k] = v
                else:
                    if v == 'Inf' or v == 'NaN':
                        obj[k] = MAX_INT
                    else:
                        obj[k] = float(v)
            results.append(obj)
    print('len(results): ', len(results))

    assert len(stats) == len(results)

    dfs = pd.DataFrame(data=stats)
    dfr = pd.DataFrame(data=results)
    
    if has_sbj:
        df = pd.merge(dfr, dfs,  how='left', left_on=['Mapper','subject'], right_on = ['mapper','SBJ'])
    else:
        df = pd.merge(dfr, dfs,  how='left', left_on=['Mapper'], right_on = ['mapper'])
    # df = dfr.join(dfs.set_index('Mapper'), on='Mapper')
    
    # Use filter_by
    df = df[df['Mapper'].str.startswith(filter_by)]

    param_cols = None
    if filter_by == 'BDLMapper' or filter_by == 'NeuMapper':
        df['K'] = df.apply(lambda x: int(x['Mapper'].split('_')[1]), axis=1)
        df['R'] = df.apply(lambda x: int(x['Mapper'].split('_')[2]), axis=1)
        df['G'] = df.apply(lambda x: int(x['Mapper'].split('_')[3]), axis=1)
        param_cols = ['K', 'R', 'G']
    elif filter_by == 'CustomBDLMapper' or filter_by == 'CustomNeuMapper':
        df['preptype'] = df.apply(lambda x: x['Mapper'].split('_')[1], axis=1)
        df['dist'] = df.apply(lambda x: x['Mapper'].split('_')[2], axis=1)
        df['K'] = df.apply(lambda x: int(x['Mapper'].split('_')[3]), axis=1)
        df['R'] = df.apply(lambda x: int(x['Mapper'].split('_')[4]), axis=1)
        df['G'] = df.apply(lambda x: int(x['Mapper'].split('_')[5]), axis=1)
        df['linkbins'] = df.apply(lambda x: int(x['Mapper'].split('_')[6]), axis=1)
        param_cols = ['preptype', 'dist', 'K', 'R', 'G', 'linkbins']
    elif filter_by == 'CMDSMapperNoKNN':
        df['preptype'] = df.apply(lambda x: x['Mapper'].split('_')[1], axis=1)
        df['dist'] = df.apply(lambda x: x['Mapper'].split('_')[2], axis=1)
        df['R'] = df.apply(lambda x: int(x['Mapper'].split('_')[3]), axis=1)
        df['G'] = df.apply(lambda x: int(x['Mapper'].split('_')[4]), axis=1)
        df['linkbins'] = df.apply(lambda x: int(x['Mapper'].split('_')[5]), axis=1)
        param_cols = ['preptype', 'dist', 'R', 'G', 'linkbins']
    else:
        raise Exception('Mapper type not recognized')
        

    # fix CircleLoss and TransitionBetweeness
    max_values = {}
    for colname in ["CircleLoss", "TransitionBetweeness"]:
        if colname in df.columns:
            if len(df[df[colname] < MAX_INT][colname]) == 0:
                raise Exception('There are no valid values for {}'.format(colname))
            new_max_loss = max(df[df[colname] < MAX_INT][colname]) * 1.5
            max_values[colname] = new_max_loss
            df[colname] = df.apply(lambda x: x[colname] if x[colname] != MAX_INT else new_max_loss, axis=1)

    main_cols = ['Mapper'] + (['SBJ'] if has_sbj else []) + param_cols
    other_cols = [c for c in df.columns.tolist() if c not in main_cols and c.lower() != 'mapper' and c != 'subject']
    df = df[main_cols + other_cols]
        
    return df, max_values

In [4]:
# Extract the subjects for each datset type and the combination that we should compute.
# for example: all SBJ2* would be SBJ20 and SBJ21
# For new datasets, this has to be changed

def get_all_parameters(df, dataset_name, silent=False):
    all_sbjs = df['SBJ'].unique().tolist()
    if not silent:
        print('Total {} subjects:'.format(len(all_sbjs)))
        for sbj in all_sbjs:
            print(sbj)


    sbjs_map = {}

    if dataset_name.startswith('ss_'):
        # This is for subsampled data
        for sbj in all_sbjs:
            sbjs_map[sbj] = [sbj]

        sbjs_map['SBJ2x'] = ['SBJ20', 'SBJ21']
        sbjs_map['SBJ4x'] = ['SBJ40', 'SBJ41', 'SBJ42', 'SBJ43']
        sbjs_map['SBJxx-50'] = [sbj for sbj in all_sbjs if sbj.endswith('-50.0')]
        sbjs_map['SBJxx-75'] = [sbj for sbj in all_sbjs if sbj.endswith('-75.0')]
        sbjs_map['SBJxx-83'] = [sbj for sbj in all_sbjs if sbj.endswith('-83.0')]
        sbjs_map['SBJxx-99'] = ['SBJ20', 'SBJ40', 'SBJ99']

    elif dataset_name.startswith('wnoise_'):
        # This is for wnoise data
        for sbj in all_sbjs:
            sbjs_map[sbj] = [sbj]

    elif dataset_name.startswith('hightr_'):
        # This is for subsampled data hightr
        for sbj in all_sbjs:
            sbjs_map[sbj] = [sbj]

        for i in [2,3,4]:
            sbjs_map['SBJe{}'.format(i)] = [s for s in all_sbjs if 'e{}v'.format(i) in s]

    if not silent:
        print('Extra combinations:')
        for sbjname, sbjs_list in sbjs_map.items():
            if len(sbjs_list) > 1:
                print(sbjname, ':', sbjs_list)
            
    return all_sbjs, sbjs_map

In [12]:
from matplotlib.colors import LogNorm, Normalize

def _handle_list_cols(df, col):
    if type(col) == list:
        newcol = '_'.join(col)
        df[newcol] = df.apply(lambda x: '_'.join([str(x[c]) for c in col]), axis=1)
        col = newcol
    return df, col
    

# For a DataFrame, compute a big figure with multiple subplots
# Each row would be a different metric (some metrics are in log scale `log_metrics`)
# Each column is a different value of the fixedV column (usually `R`)
# For each subplot, x-axis is colV column (usually `G`) and y-axis is indexV column (usually `K`)
# The `sbj_group_name` is the name of the group of subjects
def plot_results(df, sbj_group_name, sbj_group, fixedV, indexV, colV, target_metrics, log_metrics, resdir):
    df, fixedV = _handle_list_cols(df, fixedV)
    df, indexV = _handle_list_cols(df, indexV)
    df, colV = _handle_list_cols(df, colV)
    
    if len(sbj_group):
        df_filter = df['SBJ'] == sbj_group[0]
        for idx in range(1,len(sbj_group)):
            df_filter = df_filter | (df['SBJ'] == sbj_group[idx])
    else:
        df_filter = df['K'] > 0
    
    newtypes = {}
    allcols = []
    for col in [fixedV, indexV, colV]:
        col = col if type(col) == list else [col]
        allcols.extend(col)
        for c in col:
            if c in ['K', 'G', 'R', 'linkbins']:
                newtypes[c] = 'int'
                
    # Filter and group by subject/subjects
    dff = df[df_filter]
    dff = dff.groupby(['Mapper'] + allcols).mean()
    dff = dff.reset_index().astype(newtypes)
    # Deprecated below
    # Don't recompute CircleLossRev as next line, average over the CircleLossRev!
    # dff['CircleLossRev'] = dff.apply(lambda x: 1.0 / x['CircleLoss'] if x['CircleLoss'] > 0 else 100, axis=1)

    fixed_vals = sorted(list(set(df[fixedV].to_list())))
    f, axr = plt.subplots(len(target_metrics), len(fixed_vals), figsize=(4 * len(fixed_vals), 4 * len(target_metrics)))

    for axc, target in zip(axr, target_metrics):
        vmin, vmax = min(df[target]), max(df[target]) # get vmin and vmax based on all results not only for the sbj group
        for col_idx,(K,ax) in enumerate(zip(fixed_vals,axc)):
            df_p = dff[dff[fixedV] == K].pivot(index=indexV, columns=colV, values=target)
            
            last_col = col_idx == len(axc) - 1
            if target in log_metrics:
                ax = sns.heatmap(df_p, norm=LogNorm(vmin=vmin, vmax=vmax), ax=ax, cbar=not last_col)
            else:
                ax = sns.heatmap(df_p, vmin=vmin, vmax=vmax, ax=ax, cbar=not last_col)
            ax.set_title('{} == {}'.format(fixedV, K))

            if last_col:
                ax1 = ax.twinx()
                ax1.set_ylabel(target)
                ax1.set_yticks([])

    plt.tight_layout()
    plt.savefig(os.path.join(resdir,'plot_results_{}.png'.format(sbj_group_name)))
    plt.close()
    
    
# Similar to `plot_results`, this function has a map of target_metrics to an interval.
# If the picked metric inside the interval, then the value is 1, otherwise its 0
# This figure also has a row of TOTAL where all metrics are combined to yield the combination of "AND" on all metrics
def plot_limits(df, sbj_group_name, sbj_group, fixedV, indexV, colV, target_metrics, resdir):
    df, fixedV = _handle_list_cols(df, fixedV)
    df, indexV = _handle_list_cols(df, indexV)
    df, colV = _handle_list_cols(df, colV)

    newtypes = {}
    allcols = []
    for col in [fixedV, indexV, colV]:
        col = col if type(col) == list else [col]
        allcols.extend(col)
        for c in col:
            if c in ['K', 'G', 'R', 'linkbins']:
                newtypes[c] = 'int'
    
    df_filter = df['SBJ'] == sbj_group[0]
    for idx in range(1,len(sbj_group)):
        df_filter = df_filter | (df['SBJ'] == sbj_group[idx])
    
    dff = df[df_filter]
    dff = dff.groupby(['Mapper'] + allcols).mean()
    dff = dff.reset_index().astype(newtypes)

    fixed_vals = sorted(list(set(df[fixedV].to_list())))
    f, axr = plt.subplots(len(target_metrics)+1, len(fixed_vals), figsize=(4 * len(fixed_vals), 4 * len(target_metrics) + 4))
    
    for axc, (target, lims) in zip(axr, target_metrics.items()):
        vmin, vmax = min(df[target]), max(df[target]) # get vmin and vmax based on all results not only for the sbj group
        for col_idx,(K,ax) in enumerate(zip(fixed_vals,axc)):
            df_p = dff[dff[fixedV] == K].pivot(index=indexV, columns=colV, values=target)
            df_wl = (df_p >= lims[0]) & (df_p <= lims[1]) # within limits
            
            last_col = col_idx == len(axc)-1
            ax = sns.heatmap(df_wl, vmin=0.0, vmax=1.0, ax=ax, cbar=not last_col)
            ax.set_title('{} == {}'.format(fixedV, K))
            if last_col:
                ax1 = ax.twinx()
                ax1.set_ylabel(target)
                ax1.set_yticks([])
                
    
    # plot the combined plot
    axc = axr[len(target_metrics)]

    for col_idx,(K,ax) in enumerate(zip(fixed_vals,axc)):
        comb_isset = False
        df_comb = None
        for target, lims in target_metrics.items():
            df_p = dff[dff[fixedV] == K].pivot(index=indexV, columns=colV, values=target)
            df_wl = (df_p >= lims[0]) & (df_p <= lims[1]) # within limits
            if not comb_isset:
                df_comb = df_wl
                comb_isset = True
            else:
                df_comb = df_comb & df_wl

        last_col = col_idx == len(axc)-1
        ax = sns.heatmap(df_comb, vmin=0.0, vmax=1.0, ax=ax, cbar=not last_col)
        ax.set_title('{} == {}'.format(fixedV, K))

        if last_col:
            ax1 = ax.twinx()
            ax1.set_ylabel('TOTAL')
            ax1.set_yticks([])


    plt.tight_layout()
    plt.savefig(os.path.join(resdir,'plot_limits_{}.png'.format(sbj_group_name)))
    plt.close()


# Analysis of subjects

In [17]:
import os
import csv
from tqdm import tqdm

sns.set(rc={'figure.facecolor':'white'})

DATASETS = {
    'cmev3': '/Users/dh/workspace/BDL/demapper/results/cme/mappers_cmev3.json',
    'cmev4': '/Users/dh/workspace/BDL/demapper/results/cme/mappers_cmev4.json',
}

FILTERS = {}
for k in DATASETS.keys():
    if k.endswith('cmev1') or k.endswith('cmev3'):
        FILTERS[k] = ['BDLMapper']
    elif k.endswith('cmev2') or k.endswith('cmev4'):
        FILTERS[k] = ['NeuMapper']

def get_plot_columns(mapper_name):
    if mapper_name == 'NeuMapper' or mapper_name == 'BDLMapper':
        fixedV, indexV, colV = 'R', 'K', 'G' # Most informative
    elif mapper_name == 'CustomBDLMapper' or mapper_name == 'CustomNeuMapper':
        fixedV, indexV, colV = 'dist', ['G', 'K'], ['preptype', 'R', 'linkbins']
    elif mapper_name == 'CMDSMapperNoKNN':
        fixedV, indexV, colV = 'dist', ['G', 'linkbins'], ['preptype', 'R']
    else:
        raise Exception('Cannot find plot vars (for columns) for Mapper {}'.format(mapper_name))
    return fixedV, indexV, colV


### Run for one dataset

In [19]:
PICKED_DATASET = 'cmev3'

datadir = DATASETS[PICKED_DATASET]
stats_path = os.path.join(datadir, 'compute_stats-averaged.csv')
results_path = os.path.join(datadir, 'degrees_TRs/combined-degrees_TRs.csv')

filter_by = FILTERS[PICKED_DATASET][0]
df, max_values = extract_dataset(stats_path, results_path, filter_by, has_sbj=False)

resdir = datadir
if len(FILTERS[PICKED_DATASET]) > 1:
    resdir = os.path.join(datadir, filter_by)
    os.makedirs(resdir, exist_ok=True)

print(max_values)
df.sample(5)

len(stats):  462
len(results):  462
{}


Unnamed: 0,Mapper,K,R,G,ChangePointsIndicesError,ChangePointsResiduals,assortativity-mean,assortativity-std,coverage_TRs-mean,coverage_TRs-std,...,degree_TRs_avg-mean,degree_TRs_avg-std,degree_TRs_entropy-mean,degree_TRs_entropy-std,distances_entropy-mean,distances_entropy-std,distances_max-mean,distances_max-std,hrfdur_stat-mean,hrfdur_stat-std
159,BDLMapper_24_35_74,24,35,74,54.571429,0.11569,-0.02987,0.125166,0.989566,0.007328,...,677.301944,84.626307,8.298002,0.441748,3.458824,0.180909,20.777778,2.839958,0.221856,0.062032
10,BDLMapper_12_15_82,12,15,82,36.571429,0.029788,0.223989,0.129486,0.999945,0.000232,...,2462.012222,579.722964,7.082568,0.620325,1.770731,0.182131,5.277778,0.95828,0.612937,0.08412
71,BDLMapper_16_30_90,16,30,90,67.0,0.064864,0.32725,0.12673,1.0,0.0,...,23566.45,4030.63546,7.750886,0.601468,1.79811,0.091175,6.0,1.188177,0.656484,0.095661
96,BDLMapper_20_20_50,20,20,50,66.285714,0.054134,-0.138158,0.057809,0.918387,0.057781,...,42.037744,6.210546,5.245928,0.496352,4.058637,0.246051,27.277778,6.42427,0.143414,0.048187
454,BDLMapper_8_35_82,8,35,82,66.714286,0.094767,0.208931,0.115144,0.999618,0.000686,...,3084.927778,264.335606,8.560674,0.363695,2.901683,0.238333,12.722222,2.696524,0.372912,0.078565


In [20]:

target_metrics = ['ChangePointsIndicesErrorRev', 'ChangePointsResidualsRev', 'coverage_nodes-mean', 'coverage_nodes-std', 'hrfdur_stat-mean', 'hrfdur_stat-std', 'distances_entropy-mean', 'distances_entropy-std']
log_metrics = ['ChangePointsIndicesErrorRev', 'ChangePointsResidualsRev']

df['ChangePointsIndicesErrorRev'] = df.apply(
    lambda x: 1.0 / x['ChangePointsIndicesError'] if x['ChangePointsIndicesError'] > 0 else 100,
    axis=1)
df['ChangePointsResidualsRev'] = df.apply(
    lambda x: 1.0 / x['ChangePointsResiduals'] if x['ChangePointsResiduals'] > 0 else 100,
    axis=1)
# df['TransitionBetweenessRev'] = df.apply(lambda x: 1.0 / (x['TransitionBetweeness'] + 1), axis=1)


plot_results(df, 'SBJs_RKG', [], 'R', 'K', 'G', target_metrics, log_metrics, resdir=resdir)
plot_results(df, 'SBJs_KRG', [], 'K', 'R', 'G', target_metrics, log_metrics, resdir=resdir)

# for sbj_group_name, sbj_group in tqdm(sbjs_map.items()):
#     plot_results(df, sbj_group_name, sbj_group, fixedV, indexV, colV, target_metrics, log_metrics, resdir=resdir)
    

In [23]:
# Plot limits

target_metrics = {
    'CircleLoss': [0, circle_loss_threshold],
    'TransitionBetweeness': [0, max_values['TransitionBetweeness'] * 0.99],
    'coverage_nodes': [0.7, 1.0],
#     'hrfdur_stat': [0.15, 1.0],
    'distances_entropy': [2.0, 10000.0]
}

fixedV, indexV, colV = get_plot_columns(filter_by)

for sbj_group_name, sbj_group in tqdm(sbjs_map.items()):
    plot_limits(df, sbj_group_name, sbj_group, fixedV, indexV, colV, target_metrics, resdir=resdir)
    


100%|█████████████████████████████████████████████████████████████████████████████████████| 20/20 [03:01<00:00,  9.10s/it]


## Recompute for all analysis that we have

In [8]:

target_metrics = ['CircleLossRev', 'TransitionBetweenessRev', 'coverage_nodes', 'hrfdur_stat', 'distances_entropy']
log_metrics = ['CircleLossRev', 'TransitionBetweenessRev']

circle_loss_threshold = 2.0

target_metrics_limits = {
#     'CircleLoss': [0, circle_loss_threshold],
#     'TransitionBetweeness': [0, max_values['TransitionBetweeness'] * 0.99], # Need the data first, add later
    'coverage_nodes': [0.7, 1.0],
    'distances_entropy': [2.0, 10000.0]
}

#     'hrfdur_stat': [0.15, 1.0], # SKip using hrfdurstat since it doesn't have an impact at >= 15%

# hightr_w3cv2  ss_w3cv5

for dataset_name in DATASETS.keys():
#     if dataset_name not in ['wnoise_w3cv4', 'wnoise_w3cv5', 'wnoise_w3cv6']:
    if dataset_name not in ['ss_w3cv3']:
        continue
    print('======= Processing', dataset_name)
    datadir = DATASETS[dataset_name]
    stats_path = os.path.join(datadir, 'compute_stats-combined.csv')
    results_path = os.path.join(datadir, 'scores-all.csv')

    try:
        for filter_by in FILTERS[dataset_name]:
            resdir = datadir
            if len(FILTERS[dataset_name]) > 1:
                resdir = os.path.join(datadir, filter_by)
                os.makedirs(resdir, exist_ok=True)
            df, max_values = extract_dataset(stats_path, results_path, filter_by)
            all_sbjs, sbjs_map = get_all_parameters(df, dataset_name, silent=True)

            fixedV, indexV, colV = get_plot_columns(filter_by)

            df['CircleLossRev'] = df.apply(lambda x: 1.0 / x['CircleLoss'] if x['CircleLoss'] > 0 else 100, axis=1)
            df['TransitionBetweenessRev'] = df.apply(lambda x: 1.0 / (x['TransitionBetweeness'] + 1), axis=1)    
#             for sbj_group_name, sbj_group in tqdm(sbjs_map.items(), desc='plot_results'):
#                 plot_results(df, sbj_group_name, sbj_group, fixedV, indexV, colV, target_metrics, log_metrics, resdir=resdir)

            target_metrics_limits['TransitionBetweeness'] = [0.0, max_values['TransitionBetweeness'] * 0.99]
            for sbj_group_name, sbj_group in tqdm(sbjs_map.items(), desc='plot_limits'):
                plot_limits(df, sbj_group_name, sbj_group, fixedV, indexV, colV, target_metrics_limits, resdir=resdir)
    except Exception as err:
        print("Warning! Didn't process '{}' because:".format(dataset_name))
        print(err)

len(stats):  10080
len(results):  10080


plot_limits: 100%|█████████████████████████████████████████████████████████████████████████| 20/20 [03:43<00:00, 11.15s/it]


### Create links to download and then create a grid

In [43]:

REMOTE_DATASET_PATHS = {
    'ss_w3cv1': '/scratch/groups/saggar/demapper-w3c/mappers_w3cv1.json',
    'ss_w3cv2': '/scratch/groups/saggar/demapper-w3c/mappers_w3cv2.json',
    'ss_w3cv3': '/scratch/groups/saggar/demapper-w3c/mappers_w3cv3.json',
    'ss_w3cv4': '/scratch/groups/saggar/demapper-w3c/mappers_w3cv4.json',
    'ss_w3cv5': '/scratch/groups/saggar/demapper-w3c/mappers_w3cv5dist.json',
    'ss_w3cv6': '/scratch/groups/saggar/demapper-w3c/mappers_w3cv6dist.json'
}

dataset_id = 'ss_w3cv3'
local_dataset_path = DATASETS[dataset_id]
os.makedirs(os.path.join(local_dataset_path, 'mappers_plots'), exist_ok = True)
os.makedirs(os.path.join(local_dataset_path, 'mappers_combined'), exist_ok = True)
remote_dataset_path = REMOTE_DATASET_PATHS[dataset_id]
remote_analysis_path = REMOTE_DATASET_PATHS[dataset_id].replace('mappers_w3c', 'analysis/mappers_w3c')



In [130]:
mappers = []
k = 20
for r in range(10, 50, 10):
    for g in range(50, 100, 10):
        mappers.append({
            'mapper': 'BDLMapper_{k}_{r}_{g}'.format(k=k, r=r, g=g),
            'R': r,
            'G': g,
        })
        

sbjs = ['SBJ99', 'SBJ99-83.0', 'SBJ99-75.0', 'SBJ99-50.0']

files = ['plot_task-G.png']


In [131]:
print("###### Copy the files to mappers_plots")

print('mkdir {}/mappers_plots'.format(remote_analysis_path))

table = []
for s in sbjs:
    for m in mappers:
        for f in files:
            item = m.copy()
            item['sbj'] = s
            item['file'] = f
            table.append(item)
            print('cp {ds_path}/{s}/{m}/{f} {an_path}/mappers_plots/{s}_{m}_{f}'.format(
                ds_path=remote_dataset_path,
                an_path=remote_analysis_path,
                s=s, m=m['mapper'], f=f))
            
print("###### Copy the files locally")
print('scp -r "hasegan@login.sherlock.stanford.edu:{}/mappers_plots/*" {}'.format(
    remote_analysis_path,
    os.path.join(local_dataset_path, 'mappers_plots/')))

###### Copy the files to mappers_plots
mkdir /scratch/groups/saggar/demapper-w3c/analysis/mappers_w3cv3.json/mappers_plots
cp /scratch/groups/saggar/demapper-w3c/mappers_w3cv3.json/SBJ99/BDLMapper_20_10_50/plot_task-G.png /scratch/groups/saggar/demapper-w3c/analysis/mappers_w3cv3.json/mappers_plots/SBJ99_BDLMapper_20_10_50_plot_task-G.png
cp /scratch/groups/saggar/demapper-w3c/mappers_w3cv3.json/SBJ99/BDLMapper_20_10_60/plot_task-G.png /scratch/groups/saggar/demapper-w3c/analysis/mappers_w3cv3.json/mappers_plots/SBJ99_BDLMapper_20_10_60_plot_task-G.png
cp /scratch/groups/saggar/demapper-w3c/mappers_w3cv3.json/SBJ99/BDLMapper_20_10_70/plot_task-G.png /scratch/groups/saggar/demapper-w3c/analysis/mappers_w3cv3.json/mappers_plots/SBJ99_BDLMapper_20_10_70_plot_task-G.png
cp /scratch/groups/saggar/demapper-w3c/mappers_w3cv3.json/SBJ99/BDLMapper_20_10_80/plot_task-G.png /scratch/groups/saggar/demapper-w3c/analysis/mappers_w3cv3.json/mappers_plots/SBJ99_BDLMapper_20_10_80_plot_task-G.png
cp /s

In [134]:
from PIL import Image
from matplotlib import pyplot as plt
from tqdm import tqdm

sns.set(style = "whitegrid")

def create_axis(df, axis_hps, idx=0):
    if idx == len(axis_hps):
        return [{}]
    hp_name = axis_hps[idx]
    hp_vals = df[hp_name].drop_duplicates().to_numpy()
    vals = []
    hp_vals = sorted(sorted(hp_vals), key=lambda x: x != 'none')
    for hpv in hp_vals:
        df_hp = df[df[hp_name] == hpv]
        val = create_axis(df_hp, axis_hps, idx+1)
        if len(hp_vals) > 1:
            for v in val:
                v[hp_name] = hpv
        vals.extend(val)
    return vals


def get_item_path(basedir, item):
    return os.path.join(
        basedir, '{s}_{m}_{f}'.format(
            s=item['sbj'].tolist()[0],
            m=item['mapper'].tolist()[0],
            f=item['file'].tolist()[0]))


def plot_image(img_path, ax):
    im = Image.open(img_path)
    img = np.array(im)
    ax.imshow(img)
    del img
    del im
    return ax

def plot_grid(df, sbj, hps_x, hps_y, hps_fixed_vals, analysis_path, regen=False, text_f=None, show_plot=False):
    ax_x = create_axis(df, hps_x)
    ax_y = create_axis(df, hps_y)
    
    x_keys = sorted(list(set([k for x in ax_x for k,v in x.items()])))
    y_keys = sorted(list(set([k for y in ax_y for k,v in y.items()])))
    fixed_toks = sorted(['{}_{}'.format(k,v) for k,v in hps_fixed_vals.items()])
    savefig_path = os.path.join(
        os.path.join(analysis_path, 'mappers_combined'),
        'comb{}_{}.png'.format(
            sbj,
            '_'.join(['x'] + x_keys + ['y'] + y_keys + fixed_toks)))
    
    if os.path.isfile(savefig_path) and not regen:
        return

    ncols = len(ax_x)
    nrows = len(ax_y)
    fsize = 20
    fig, axr = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fsize * (ncols/nrows)*1.1,fsize))

    for r,(y,axc) in enumerate(zip(ax_y,axr)):
        for c,(x,ax) in enumerate(zip(ax_x,axc)):
            K = {'sbj':sbj, **x, **y, **hps_fixed_vals}
            df_filtered = df
            for k,v in K.items():
                df_filtered = df_filtered[df_filtered[k] == v]
            if len(df_filtered) == 0:
                ax.axis('off')
                continue
            assert len(df_filtered) == 1, 'Too many ({}) items found for {}, {}, {}'.format(
                len(df_filtered), x, y, hps_fixed_vals)
            item = df_filtered
            sbj = item['sbj'].to_numpy()[0]

            ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['bottom'].set_visible(False)
            ax.spines['left'].set_visible(False)
            ax.grid(False)
            if r == len(ax_y)-1:
                ax.set_xlabel(' '.join([str(x[k]) for k in hps_x if k in x]))
                ax.xaxis.set_label_position("bottom")
            if c == 0:
                ax.set_ylabel(' '.join([str(y[k]) for k in hps_y if k in y]))

            img_path = get_item_path(os.path.join(analysis_path, 'mappers_plots'), item)
            if not os.path.isfile(img_path):
    #             ax.axis('off')
                print('Skipping as plot not found at {}'.format(img_path))
                continue
        
#             modd = df_filtered[['preembed_mod_full', 'preembed_mod_wout_ac', 'embed_modd_bin', 'embed_modd_wout_ac']].to_numpy()[0]
#             modd = [str(round(m,3)) for m in modd]
#             title = ' -> '.join([', '.join(modd[0:2]), ', '.join(modd[2:4])])
#             ax.set_title(title)
            if text_f:
                ax = text_f(ax, item)
            ax = plot_image(img_path, ax)
    plt.tight_layout()
    
    plt.savefig(savefig_path, dpi=150)
    if show_plot:
        plt.show()
    plt.close()
    
# plot_grid(df, 'SBJ01', hps_x, hps_y, {'knnparam_gains': 1.0})


In [138]:
df = pd.DataFrame(data=table)

for sbj in tqdm(sbjs):
    plot_grid(df, sbj, ['R'], ['G'], {}, local_dataset_path, regen=True)


100%|█████████████████████████████████████████████████████████████████████████████████| 4/4 [00:11<00:00,  3.00s/it]


In [139]:
local_dataset_path

'/Users/dh/workspace/BDL/demapper/results/w3c_ss/analysis/mappers_w3cv3.json/'

### Other plots (deprecated)

In [None]:


def create_plot(df, ax, hparam, target, title=None):
    labels = sorted(list(df[hparam].unique()))
    data = [df[df[hparam] == label][target] for label in labels]
    ax = sns.boxplot(data=data, ax=ax)
    ax = sns.swarmplot(data=data, color=".25", ax=ax, size=1.5)
    ax.set_xticklabels(labels, rotation=10)
    ax.set_xlabel(hparam)
    ax.set_ylabel(target)
    ax.set_title('Distribution of {} over {}'.format(target,hparam) if not title else title)
    ax.grid(alpha=0.4)
    return ax


In [None]:

target = 'ChangePointsIndicesError'

plt.figure()
ax = create_plot(df, None, 'K', target)
plt.show()

plt.figure()
ax = create_plot(df, None, 'R', target)
plt.show()

plt.figure()
ax = create_plot(df, None, 'G', target)
plt.show()

In [None]:

target = 'ChangePointsResiduals'
plt.figure()
ax = create_plot(df, None, 'K', target)
plt.show()

plt.figure()
ax = create_plot(df, None, 'R', target)
plt.show()

plt.figure()
ax = create_plot(df, None, 'G', target)
plt.show()

In [None]:
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.plot_trisurf(df['R'], df['G'], df['K'], cmap=plt.cm.jet, linewidth=0.01)
plt.show()

# Plot of indices error

In [None]:
df

In [None]:
sorted(df[df['R'] == 300]['hrfdur_stat-mean'].tolist())