# For CME, plot an image for each mapper where the grid is the subjects

In [11]:
import os
import matplotlib.pyplot as plt
import numpy as np

mapper_config = 'mappers_cmev4_euc.json'
main_path = '/scratch/groups/saggar/demapper-cme/{}'.format(mapper_config)
res_path = '/scratch/groups/saggar/demapper-cme/analysis/ch8_{}/plot_task-grids'.format(mapper_config)
os.makedirs(res_path, exist_ok=True)


fname = 'plot_task-CME.png'
sbjs = sorted([s for s in os.listdir(main_path) if s.startswith('SBJ')])
mappers = sorted([m for m in os.listdir(os.path.join(main_path, sbjs[0])) if 'Mapper' in m])



In [12]:
import math
from PIL import Image


def plot_image(img_path, ax):
    im = Image.open(img_path)
    img = np.array(im)
    ax.imshow(img)
    del img
    del im
    return ax


def process_mapper(mapper):
    savepath = os.path.join(res_path, '{}-{}'.format(mapper, fname))
    if os.path.isfile(savepath):
        return
    
    ncols = 5
    nrows = math.ceil(len(sbjs) / ncols)
    fsize = 20 

    fig, axr = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fsize * (ncols/nrows)*1.1,fsize))

    index = 0
    for r,axc in enumerate(axr):
        for c,ax in enumerate(axc):
            # Disable ax defaults
            ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['bottom'].set_visible(False)
            ax.spines['left'].set_visible(False)
            ax.grid(False)

            # get sbj and img_path and plot
            if index >= len(sbjs):
                continue
            sbj = sbjs[index]
            img_path = os.path.join(main_path, sbj, mapper, fname)

            plot_image(img_path, ax)
            ax.set_title(sbj)
            index += 1

    plt.suptitle(mapper)
    # plt.tight_layout()
    fig.tight_layout(rect=[0, 0.03, 1, 0.95])
    # plt.show()
    plt.savefig(savepath, dpi=100)
    plt.close(fig)


In [13]:
mapper = mappers[0]
process_mapper(mapper)

In [14]:
from tqdm import tqdm

for mapper in tqdm(mappers):
    process_mapper(mapper)

100%|██████████| 539/539 [24:57<00:00,  2.78s/it]  


# For CME, for selected mappers and subjects, plot a grid

In [3]:
import os
import matplotlib.pyplot as plt
import numpy as np

mapper_config = 'mappers_cmev3_disp.json'
main_path = '/scratch/groups/saggar/demapper-cme/{}'.format(mapper_config)
res_path = '/scratch/groups/saggar/demapper-cme/analysis/{}/plot_task-grids'.format(mapper_config)
os.makedirs(res_path, exist_ok=True)

fname = 'plot_task-CME.png'
sbjs = sorted([s for s in os.listdir(main_path) if s.startswith('SBJ')])
mappers = sorted([m for m in os.listdir(os.path.join(main_path, sbjs[0])) if 'Mapper' in m])


In [46]:
import math
from PIL import Image


def plot_image(img_path, ax):
    im = Image.open(img_path)
    img = np.array(im)
    ax.imshow(img)
    del img
    del im
    return ax


def process_mapper(sel_sbjs, sel_mappers, index=0, show_grid=False):
    savepath = os.path.join(res_path, '{}-{}Mappers_v{}-{}'.format(
        '_'.join(sel_sbjs), len(sel_mappers), index, fname))
#     if os.path.isfile(savepath):
#         return
    
    ncols = len(sel_mappers)
    nrows = len(sel_sbjs)
    fsize = 20

    fig, axr = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fsize * (ncols/nrows)*1.35,fsize))

    linestyle = (0, (1, 10)) # dashed 
    index = 0
    for r,(sbj, axc) in enumerate(zip(sel_sbjs, axr)):
        for c,(mapper, ax) in enumerate(zip(sel_mappers, axc)):
            # Disable ax defaults
            ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
            if show_grid:
                ax.spines['top'].set_linestyle(linestyle)# .set_visible(False)
                ax.spines['right'].set_linestyle(linestyle)# .set_visible(False)
                ax.spines['bottom'].set_linestyle(linestyle)# .set_visible(False)
                ax.spines['left'].set_linestyle(linestyle)# .set_visible(False)
            else:
                ax.spines['top'].set_visible(False)
                ax.spines['right'].set_visible(False)
                ax.spines['bottom'].set_visible(False)
                ax.spines['left'].set_visible(False)
            
            ax.grid(False)

            # get sbj and img_path and plot
            if index >= len(sbjs):
                continue
            img_path = os.path.join(main_path, sbj, mapper, fname)

            plot_image(img_path, ax)
            index += 1

    plt.tight_layout()
#     plt.show()
    plt.savefig(savepath, dpi=150)
    plt.close(fig)


In [47]:
K = '12'

sel_mappers = ['BDLMapper_{}_20_40', 'BDLMapper_{}_30_40', 'BDLMapper_{}_30_60',
               'BDLMapper_{}_10_60', 'BDLMapper_{}_10_90']
sel_mappers = [m.format(K) for m in sel_mappers]
sel_sbjs = ['SBJ01', 'SBJ07', 'SBJ14']

process_mapper(sel_sbjs, sel_mappers)

### For W3C

In [1]:
import os
import matplotlib.pyplot as plt
import numpy as np

# mapper_config = 'mappers_w3cv4_reg.json'
# main_path = '/scratch/groups/saggar/demapper-w3c/{}'.format(mapper_config)
# res_path = '/scratch/groups/saggar/demapper-w3c/analysis/{}/plot_task-grids'.format(mapper_config)

mapper_config = 'mappers_w3cv8embed_disp'
main_path = '/Users/dh/workspace/BDL/demapper/results/w3c_hightr3/{}/'.format(mapper_config)
res_path = '/Users/dh/workspace/BDL/demapper/results/w3c_hightr3/analysis/{}/'.format(mapper_config)
os.makedirs(res_path, exist_ok=True)

fname = 'plot_task-G.png'
sbjs = sorted([s for s in os.listdir(main_path) if s.startswith('SBJ')])
mappers = sorted([[m] + [int(k) for k in m.split('_')[1:]]
                  for m in os.listdir(os.path.join(main_path, sbjs[0]))
                  if 'Mapper' in m])



In [2]:
import pandas as pd

df = pd.DataFrame(data=mappers, columns=['mapper', 'K', 'R', 'G'])

In [9]:
import math
from PIL import Image
import seaborn as sns

sns.set(style = "whitegrid")

def plot_image(img_path, ax):
    im = Image.open(img_path)
    img = np.array(im)
    ax.imshow(img)
    del img
    del im
    return ax

def filter_dataframe(df, filters):
    df_filter = None
    for fi, (key, vals) in enumerate(filters.items()):
        df_f = df[key] == vals[0]
        for i in range(1, len(vals)):
            df_f = (df_f | (df[key] == vals[i]))
        if fi == 0:
            df_filter = df_f
        else:
            df_filter = (df_f & df_filter)
        
    df_filtered = df[df_filter]
    return df_filtered

def mappers_grid(df, sbj, fname, x_field, x_vals, y_field, y_vals, prefix=''):
    savepath = os.path.join(res_path, '{}-{}{}'.format(sbj, prefix, fname))
#     if os.path.isfile(savepath):
#         return
    
    ncols = len(x_vals)
    nrows = len(y_vals)
    fsize = 20

    fig, axr = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fsize * (ncols/nrows)*1.1,fsize))

    for r,(yval, axc) in enumerate(zip(y_vals, axr)):
        for c,(xval, ax) in enumerate(zip(x_vals, axc)):
            # Disable ax defaults
            ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['bottom'].set_visible(False)
            ax.spines['left'].set_visible(False)
            ax.grid(False)
            
            fdf = filter_dataframe(df, {x_field: [xval], y_field: [yval]})
#             print(fdf, {x_field: [xval], y_field: [yval]})
            assert len(fdf) == 1
            mapper = fdf['mapper'].to_numpy()[0]
#             print(mapper)
            
            img_path = os.path.join(main_path, sbj, mapper, fname)
            plot_image(img_path, ax)
            if c == 0:
                ax.set_ylabel(str(yval))
            if r == len(y_vals)-1:
                ax.set_xlabel(str(xval))
#             ax.set_title(sbj)

    plt.suptitle(mapper)
    # plt.tight_layout()
    fig.tight_layout(rect=[0, 0.03, 1, 0.95])
#     plt.show()
    plt.savefig(savepath, dpi=100)
    plt.close(fig)


In [13]:

for KVal in [4,6]:
    df1 = df[df['K'] == KVal]
    RVals = df1['R'].drop_duplicates()[::-1]
    GVals = df1['G'].drop_duplicates()

    mappers_grid(df1, sbjs[0], fname, 'G', GVals, 'R', RVals, prefix='K{}-'.format(KVal))

### Old code

In [4]:
import math
from PIL import Image


def plot_image(img_path, ax):
    im = Image.open(img_path)
    img = np.array(im)
    ax.imshow(img)
    del img
    del im
    return ax


def process_mappers(picked_mappers):
    savepath = os.path.join(res_path, '{}-{}'.format(mapper, fname))
    if os.path.isfile(savepath):
        return
    
    ncols = 5
    nrows = math.ceil(len(sbjs) / ncols)
    fsize = 20 

    fig, axr = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fsize * (ncols/nrows)*1.1,fsize))

    index = 0
    for r,axc in enumerate(axr):
        for c,ax in enumerate(axc):
            # Disable ax defaults
            ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['bottom'].set_visible(False)
            ax.spines['left'].set_visible(False)
            ax.grid(False)

            # get sbj and img_path and plot
            if index >= len(sbjs):
                continue
            sbj = sbjs[index]
            img_path = os.path.join(main_path, sbj, mapper, fname)

            plot_image(img_path, ax)
            ax.set_title(sbj)
            index += 1

    plt.suptitle(mapper)
    # plt.tight_layout()
    fig.tight_layout(rect=[0, 0.03, 1, 0.95])
    # plt.show()
    plt.savefig(savepath, dpi=100)
    plt.close(fig)


In [18]:
mapper = mappers[0]
process_mapper(mapper)

In [19]:
from tqdm import tqdm

for mapper in tqdm(mappers):
    process_mapper(mapper)