In [1]:
%matplotlib inline
# %matplotlib widget

In [2]:
import csv
import os
import numpy as np

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

sns.set(rc={'figure.facecolor':'white'})

### Helper Functions

In [40]:
# Given a dataset stats and results path, load in the DataFrame, merge and filter

def extract_dataset(stats_path, results_path, dataset_name):
    filter_by = 'BDLMapper' if 'w3cv1' in dataset_name else 'NeuMapper'
    stats = []
    with open(stats_path) as f:
        for row in csv.DictReader(f):
            obj = {}
            for k,v in row.items():
                if k == 'mapper':
                    obj[k] = v
                elif k.startswith('id'):
                    pass
                elif v == '':
                    obj[k] = 0
                else:
                    obj[k] = float(v)
            obj['SBJ'] = row['id0'] # TODO! This only works for this type of data (fix!)
            stats.append(obj)
    print('len(stats): ', len(stats))


    MAX_INT = 100000
    results = []
    with open(results_path) as f:
        for row in csv.DictReader(f):
            obj = {}
            for k,v in row.items():
                if k == 'Mapper' or k == 'subject':
                    obj[k] = v
                else:
                    if v == 'Inf' or v == 'NaN':
                        obj[k] = MAX_INT
                    else:
                        obj[k] = float(v)
            results.append(obj)
    print('len(results): ', len(results))

    assert len(stats) == len(results)

    dfs = pd.DataFrame(data=stats)
    dfr = pd.DataFrame(data=results)

    df = pd.merge(dfr, dfs,  how='left', left_on=['Mapper','subject'], right_on = ['mapper','SBJ'])
    # df = dfr.join(dfs.set_index('Mapper'), on='Mapper')

    df['K'] = df.apply(lambda x: int(x['Mapper'].split('_')[1]), axis=1)
    df['R'] = df.apply(lambda x: int(x['Mapper'].split('_')[2]), axis=1)
    df['G'] = df.apply(lambda x: int(x['Mapper'].split('_')[3]), axis=1)

    # fix CircleLoss and TransitionBetweeness
    max_values = {}
    for colname in ["CircleLoss", "TransitionBetweeness"]:
        if len(df[df[colname] < MAX_INT][colname]) == 0:
            raise Exception('There are no valid values for {}'.format(colname))
        new_max_loss = max(df[df[colname] < MAX_INT][colname]) * 1.5
        max_values[colname] = new_max_loss
        df[colname] = df.apply(lambda x: x[colname] if x[colname] != MAX_INT else new_max_loss, axis=1)

    main_cols = ['Mapper', 'SBJ', 'K', 'R', 'G']
    other_cols = [c for c in df.columns.tolist() if c not in main_cols and c != 'Mapper' and c != 'subject']
    df = df[main_cols + other_cols]

    if filter_by:
        df = df[df['Mapper'].str.startswith(filter_by)]
        
    return df, max_values

In [12]:
# Extract the subjects for each datset type and the combination that we should compute.
# for example: all SBJ2* would be SBJ20 and SBJ21
# For new datasets, this has to be changed

def get_all_parameters(df, dataset_name, silent=False):
    all_sbjs = df['SBJ'].unique().tolist()
    if not silent:
        print('Total {} subjects:'.format(len(all_sbjs)))
        for sbj in all_sbjs:
            print(sbj)


    sbjs_map = {}

    if dataset_name.startswith('ss_'):
        # This is for subsampled data
        for sbj in all_sbjs:
            sbjs_map[sbj] = [sbj]

        sbjs_map['SBJ2x'] = ['SBJ20', 'SBJ21']
        sbjs_map['SBJ4x'] = ['SBJ40', 'SBJ41', 'SBJ42', 'SBJ43']
        sbjs_map['SBJxx-50'] = [sbj for sbj in all_sbjs if sbj.endswith('-50.0')]
        sbjs_map['SBJxx-75'] = [sbj for sbj in all_sbjs if sbj.endswith('-75.0')]
        sbjs_map['SBJxx-83'] = [sbj for sbj in all_sbjs if sbj.endswith('-83.0')]
        sbjs_map['SBJxx-99'] = ['SBJ20', 'SBJ40', 'SBJ99']

    elif dataset_name.startswith('wnoise_'):
        # This is for wnoise data
        for sbj in all_sbjs:
            sbjs_map[sbj] = [sbj]

    elif dataset_name.startswith('hightr_'):
        # This is for subsampled data hightr
        for sbj in all_sbjs:
            sbjs_map[sbj] = [sbj]

        for i in [2,3,4]:
            sbjs_map['SBJe{}'.format(i)] = [s for s in all_sbjs if 'e{}v'.format(i) in s]

    if not silent:
        print('Extra combinations:')
        for sbjname, sbjs_list in sbjs_map.items():
            if len(sbjs_list) > 1:
                print(sbjname, ':', sbjs_list)
            
    return all_sbjs, sbjs_map

In [23]:
from matplotlib.colors import LogNorm, Normalize

# For a DataFrame, compute a big figure with multiple subplots
# Each row would be a different metric (some metrics are in log scale `log_metrics`)
# Each column is a different value of the fixedV column (usually `R`)
# For each subplot, x-axis is colV column (usually `G`) and y-axis is indexV column (usually `K`)
# The `sbj_group_name` is the name of the group of subjects
def plot_results(df, sbj_group_name, sbj_group, fixedV, indexV, colV, target_metrics, log_metrics):
    df_filter = df['SBJ'] == sbj_group[0]
    for idx in range(1,len(sbj_group)):
        df_filter = df_filter | (df['SBJ'] == sbj_group[idx])
    
    dff = df[df_filter]
    dff = dff.groupby('Mapper').mean()
    # Don't recompute CircleLossRev as next line, average over the CircleLossRev!
    # dff['CircleLossRev'] = dff.apply(lambda x: 1.0 / x['CircleLoss'] if x['CircleLoss'] > 0 else 100, axis=1)
    dff = dff.astype({'K':'int', 'G':'int', 'R':'int'})

    fixed_vals = sorted(list(set(df[fixedV].to_list())))
    f, axr = plt.subplots(len(target_metrics), len(fixed_vals), figsize=(4 * len(fixed_vals), 4 * len(target_metrics)))

    for axc, target in zip(axr, target_metrics):
        vmin, vmax = min(df[target]), max(df[target]) # get vmin and vmax based on all results not only for the sbj group
        for col_idx,(K,ax) in enumerate(zip(fixed_vals,axc)):
            df_p = dff[dff[fixedV] == K].pivot(index=indexV, columns=colV, values=target)
            
            last_col = col_idx == len(axc) - 1
            if target in log_metrics:
                ax = sns.heatmap(df_p, norm=LogNorm(vmin=vmin, vmax=vmax), ax=ax, cbar=not last_col)
            else:
                ax = sns.heatmap(df_p, vmin=vmin, vmax=vmax, ax=ax, cbar=not last_col)
            ax.set_title('{} == {}'.format(fixedV, K))

            if last_col:
                ax1 = ax.twinx()
                ax1.set_ylabel(target)
                ax1.set_yticks([])

    plt.tight_layout()
    plt.savefig(os.path.join(datadir,'plot_results_{}.png'.format(sbj_group_name)))
    plt.close()
    
    
# Similar to `plot_results`, this function has a map of target_metrics to an interval.
# If the picked metric inside the interval, then the value is 1, otherwise its 0
# This figure also has a row of TOTAL where all metrics are combined to yield the combination of "AND" on all metrics
def plot_limits(df, sbj_group_name, sbj_group, fixedV, indexV, colV, target_metrics):
    df_filter = df['SBJ'] == sbj_group[0]
    for idx in range(1,len(sbj_group)):
        df_filter = df_filter | (df['SBJ'] == sbj_group[idx])
    
    dff = df[df_filter]
    dff = dff.groupby('Mapper').mean()
    dff = dff.astype({'K':'int', 'G':'int', 'R':'int'})

    fixed_vals = sorted(list(set(df[fixedV].to_list())))
    f, axr = plt.subplots(len(target_metrics)+1, len(fixed_vals), figsize=(4 * len(fixed_vals), 4 * len(target_metrics) + 4))

    for axc, (target, lims) in zip(axr, target_metrics.items()):
        vmin, vmax = min(df[target]), max(df[target]) # get vmin and vmax based on all results not only for the sbj group
        for col_idx,(K,ax) in enumerate(zip(fixed_vals,axc)):
            df_p = dff[dff[fixedV] == K].pivot(index=indexV, columns=colV, values=target)
            df_wl = (df_p >= lims[0]) & (df_p <= lims[1]) # within limits
            
            last_col = col_idx == len(axc)-1
            ax = sns.heatmap(df_wl, vmin=0.0, vmax=1.0, ax=ax, cbar=not last_col)
            ax.set_title('{} == {}'.format(fixedV, K))
            if last_col:
                ax1 = ax.twinx()
                ax1.set_ylabel(target)
                ax1.set_yticks([])
                
    
    # plot the combined plot
    axc = axr[len(target_metrics)]

    for col_idx,(K,ax) in enumerate(zip(fixed_vals,axc)):
        comb_isset = False
        df_comb = None
        for target, lims in target_metrics.items():
            df_p = dff[dff[fixedV] == K].pivot(index=indexV, columns=colV, values=target)
            df_wl = (df_p >= lims[0]) & (df_p <= lims[1]) # within limits
            if not comb_isset:
                df_comb = df_wl
                comb_isset = True
            else:
                df_comb = df_comb & df_wl

        last_col = col_idx == len(axc)-1
        ax = sns.heatmap(df_comb, vmin=0.0, vmax=1.0, ax=ax, cbar=not last_col)
        ax.set_title('{} == {}'.format(fixedV, K))

        if last_col:
            ax1 = ax.twinx()
            ax1.set_ylabel('TOTAL')
            ax1.set_yticks([])


    plt.tight_layout()
    plt.savefig(os.path.join(datadir,'plot_limits_{}.png'.format(sbj_group_name)))
    plt.close()


# Analysis of subjects

In [7]:
import os
import csv
from tqdm import tqdm

sns.set(rc={'figure.facecolor':'white'})

DATASETS = {
    'ss_w3cv1': '/Users/dh/workspace/BDL/demapper/results/w3c_ss/analysis/mappers_w3cv1.json/',
    'wnoise_w3cv1': '/Users/dh/workspace/BDL/demapper/results/w3c_wnoise/analysis/mappers_w3cv1.json/',
    'hightr_w3cv1': '/Users/dh/workspace/BDL/demapper/results/w3c_hightr/analysis/mappers_w3cv1.json/',
    'ss_w3cv2': '/Users/dh/workspace/BDL/demapper/results/w3c_ss/analysis/mappers_w3cv2.json/',
    'wnoise_w3cv2': '/Users/dh/workspace/BDL/demapper/results/w3c_wnoise/analysis/mappers_w3cv2.json/',
    'hightr_w3cv2': '/Users/dh/workspace/BDL/demapper/results/w3c_hightr/analysis/mappers_w3cv2.json/',
}

circle_loss_threshold = 10.0

### Run for one dataset

In [35]:
PICKED_DATASET = 'hightr_w3cv1'

datadir = DATASETS[PICKED_DATASET]
stats_path = os.path.join(datadir, 'compute_stats-combined.csv')
results_path = os.path.join(datadir, 'scores-all.csv')

df, max_values = extract_dataset(stats_path, results_path, PICKED_DATASET)


all_sbjs, sbjs_map = get_all_parameters(df, PICKED_DATASET)

print(max_values)
df.head()

len(stats):  2772
len(results):  2772
Total 6 subjects:
SBJe2v0
SBJe2v1
SBJe3v0
SBJe3v1
SBJe4v0
SBJe4v1
Extra combinations:
SBJe2 : ['SBJe2v0', 'SBJe2v1']
SBJe3 : ['SBJe3v0', 'SBJe3v1']
SBJe4 : ['SBJe4v0', 'SBJe4v1']
{'CircleLoss': 686.7, 'TransitionBetweeness': 4.5}


Unnamed: 0,Mapper,SBJ,K,R,G,CircleLoss,TransitionBetweeness,mapper,coverage_nodes,coverage_TRs,hrfdur_stat,distances_max,distances_entropy,assortativity,degree_TRs_avg,degree_TRs_entropy
0,BDLMapper_12_10_50,SBJe2v0,12,10,50,40.367647,4.5,BDLMapper_12_10_50,1.0,1.0,0.75,10.0,3.14311,0.398689,25.4424,2.94179
1,BDLMapper_12_10_50,SBJe2v1,12,10,50,40.0,4.5,BDLMapper_12_10_50,1.0,0.9988,0.716667,10.0,3.14311,0.297631,23.7551,3.4306
2,BDLMapper_12_10_50,SBJe3v0,12,10,50,686.7,4.5,BDLMapper_12_10_50,0.84375,0.595324,0.59375,8.0,2.92892,0.118397,7.95144,2.70063
3,BDLMapper_12_10_50,SBJe3v1,12,10,50,686.7,4.5,BDLMapper_12_10_50,0.8,0.591727,0.68,14.0,3.66479,0.385375,14.1295,2.90305
4,BDLMapper_12_10_50,SBJe4v0,12,10,50,686.7,4.5,BDLMapper_12_10_50,0.609756,0.551559,0.609756,7.0,2.52094,0.075265,7.03357,2.63697


In [None]:

target_metrics = ['CircleLossRev', 'TransitionBetweenessRev', 'coverage_nodes', 'hrfdur_stat', 'distances_entropy']
log_metrics = ['CircleLossRev', 'TransitionBetweenessRev']

df['CircleLossRev'] = df.apply(lambda x: 1.0 / x['CircleLoss'] if x['CircleLoss'] > 0 else 100, axis=1)
df['TransitionBetweenessRev'] = df.apply(lambda x: 1.0 / (x['TransitionBetweeness'] + 1), axis=1)

fixedV, indexV, colV = 'R', 'K', 'G' # Most informative

for sbj_group_name, sbj_group in tqdm(sbjs_map.items()):
    plot_results(df, sbj_group_name, sbj_group, fixedV, indexV, colV, target_metrics, log_metrics)
    

In [None]:
# Plot limits

target_metrics = {
    'CircleLoss': [0, circle_loss_threshold],
    'TransitionBetweeness': [0, max_values['TransitionBetweeness'] * 0.99],
    'coverage_nodes': [0.7, 1.0],
#     'hrfdur_stat': [0.15, 1.0],
    'distances_entropy': [2.0, 10000.0]
}

fixedV, indexV, colV = 'R', 'K', 'G' # Most informative

for sbj_group_name, sbj_group in tqdm(sbjs_map.items()):
    plot_limits(df, sbj_group_name, sbj_group, fixedV, indexV, colV, target_metrics)
    


## Recompute for all analysis that we have

In [46]:

target_metrics = ['CircleLossRev', 'TransitionBetweenessRev', 'coverage_nodes', 'hrfdur_stat', 'distances_entropy']
log_metrics = ['CircleLossRev', 'TransitionBetweenessRev']

circle_loss_threshold = 2.0

target_metrics_limits = {
    'CircleLoss': [0, circle_loss_threshold],
#     'TransitionBetweeness': [0, max_values['TransitionBetweeness'] * 0.99], # Need the data first, add later
    'coverage_nodes': [0.7, 1.0],
    'distances_entropy': [2.0, 10000.0]
}

#     'hrfdur_stat': [0.15, 1.0], # SKip using hrfdurstat since it doesn't have an impact at >= 15%

fixedV, indexV, colV = 'R', 'K', 'G' # Most informative


for dataset_name in DATASETS.keys():
    print('======= Processing', dataset_name)
    datadir = DATASETS[dataset_name]
    stats_path = os.path.join(datadir, 'compute_stats-combined.csv')
    results_path = os.path.join(datadir, 'scores-all.csv')

    try:
        df, max_values = extract_dataset(stats_path, results_path, dataset_name)
        all_sbjs, sbjs_map = get_all_parameters(df, dataset_name, silent=True)

        df['CircleLossRev'] = df.apply(lambda x: 1.0 / x['CircleLoss'] if x['CircleLoss'] > 0 else 100, axis=1)
        df['TransitionBetweenessRev'] = df.apply(lambda x: 1.0 / (x['TransitionBetweeness'] + 1), axis=1)    
        for sbj_group_name, sbj_group in tqdm(sbjs_map.items(), desc='plot_results'):
            plot_results(df, sbj_group_name, sbj_group, fixedV, indexV, colV, target_metrics, log_metrics)

        target_metrics_limits['TransitionBetweeness'] = [0.0, max_values['TransitionBetweeness'] * 0.99]
        for sbj_group_name, sbj_group in tqdm(sbjs_map.items(), desc='plot_limits'):
            plot_limits(df, sbj_group_name, sbj_group, fixedV, indexV, colV, target_metrics_limits)
    except Exception as err:
        print("Warning! Didn't process '{}' because:".format(dataset_name))
        print(err)

len(stats):  7392
len(results):  7392


plot_results: 100%|███████████████████████████████████████████████████████████████████| 22/22 [06:05<00:00, 16.59s/it]
plot_limits: 100%|████████████████████████████████████████████████████████████████████| 22/22 [05:13<00:00, 14.24s/it]


len(stats):  2772
len(results):  2772


plot_results: 100%|█████████████████████████████████████████████████████████████████████| 6/6 [01:53<00:00, 18.89s/it]
plot_limits: 100%|██████████████████████████████████████████████████████████████████████| 6/6 [01:19<00:00, 13.26s/it]


len(stats):  2772
len(results):  2772


plot_results: 100%|█████████████████████████████████████████████████████████████████████| 9/9 [02:46<00:00, 18.51s/it]
plot_limits: 100%|██████████████████████████████████████████████████████████████████████| 9/9 [02:29<00:00, 16.58s/it]


len(stats):  14784
len(results):  14784


plot_results: 100%|███████████████████████████████████████████████████████████████████| 22/22 [08:06<00:00, 22.12s/it]
plot_limits: 100%|████████████████████████████████████████████████████████████████████| 22/22 [07:23<00:00, 20.16s/it]


len(stats):  5544
len(results):  5544


plot_results: 100%|█████████████████████████████████████████████████████████████████████| 6/6 [02:05<00:00, 20.84s/it]
plot_limits: 100%|██████████████████████████████████████████████████████████████████████| 6/6 [01:47<00:00, 17.90s/it]

len(stats):  5544
len(results):  5544
There are no valid values for TransitionBetweeness





### Other plots (deprecated)

In [None]:


def create_plot(df, ax, hparam, target, title=None):
    labels = sorted(list(df[hparam].unique()))
    data = [df[df[hparam] == label][target] for label in labels]
    ax = sns.boxplot(data=data, ax=ax)
    ax = sns.swarmplot(data=data, color=".25", ax=ax, size=1.5)
    ax.set_xticklabels(labels, rotation=10)
    ax.set_xlabel(hparam)
    ax.set_ylabel(target)
    ax.set_title('Distribution of {} over {}'.format(target,hparam) if not title else title)
    ax.grid(alpha=0.4)
    return ax


In [None]:

target = 'ChangePointsIndicesError'

plt.figure()
ax = create_plot(df, None, 'K', target)
plt.show()

plt.figure()
ax = create_plot(df, None, 'R', target)
plt.show()

plt.figure()
ax = create_plot(df, None, 'G', target)
plt.show()

In [None]:

target = 'ChangePointsResiduals'
plt.figure()
ax = create_plot(df, None, 'K', target)
plt.show()

plt.figure()
ax = create_plot(df, None, 'R', target)
plt.show()

plt.figure()
ax = create_plot(df, None, 'G', target)
plt.show()

In [None]:
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.plot_trisurf(df['R'], df['G'], df['K'], cmap=plt.cm.jet, linewidth=0.01)
plt.show()

# Plot of indices error

In [None]:
df

In [None]:
sorted(df[df['R'] == 300]['hrfdur_stat-mean'].tolist())