In [1]:
import sys, os, time, shutil
from IPython.core.display import display, HTML
# display(HTML("<style>.container { width:98% !important; }</style>"))
os.environ['FOUNDATION_RUN_MODE'] = 'jupyter'
# %load_ext autoreload
# %autoreload 2
# %pdb
from tqdm import tqdm_notebook as tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as distrib
import torch.multiprocessing as mp
import torchvision.models
from torch.utils.data import Dataset, DataLoader, TensorDataset

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.preprocessing import normalize

import gym
import inspect
import yaml
import numpy as np
import pickle

import humpack as hp

#%matplotlib tk
import matplotlib.pyplot as plt
#plt.switch_backend('Qt5Agg') #('Qt5Agg')
import foundation as fd
from foundation import models
from foundation import util
from foundation import train as trn
#from foundation.util import replicate, Cloner

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import matplotlib.cm
from matplotlib import animation
import matplotlib as mpl
from matplotlib import cm
# mpl.rc('image', cmap='gray')

# import gpumap
# import umap, shap
# import umap.plot
from adjustText import adjust_text

from matplotlib.patches import Patch
from matplotlib.lines import Line2D

from sklearn.decomposition import PCA
import sklearn.datasets

# import gpumap


import visualizations as viz_util
import project as proj
# import pointnets as ptn

%matplotlib notebook

np.set_printoptions(linewidth=120)



In [2]:
%matplotlib notebook

In [3]:
root = os.path.join(os.environ['FOUNDATION_SAVE_DIR'], '..', 'results', 'metrics')
fnames = set(os.listdir(root))
len(fnames)

86

In [4]:
def organize(raw, runs=None, **common):
    raw = raw.split('\n')
    if runs is None:
        runs = util.Table()
    for line in raw:
        n,t,r = line.split('\t')
        runs.new(name=n, arch=t, rname=r, **common)
    return runs

def save_fig(name):
    plt.savefig(f'figures/{name}.png')
    plt.savefig(f'figures/{name}.pdf')
    print(name, 'saved')
    
def preset_figure_info(info):
    plt.rc('font', size=info.small_font_size)          # controls default text sizes
    plt.rc('axes', titlesize=info.small_font_size)     # fontsize of the axes title
    plt.rc('axes', labelsize=info.medium_font_size)    # fontsize of the x and y labels
    plt.rc('xtick', labelsize=info.small_font_size)    # fontsize of the tick labels
    plt.rc('ytick', labelsize=info.small_font_size)    # fontsize of the tick labels
    plt.rc('legend', fontsize=info.small_font_size)    # legend fontsize
    plt.rc('figure', titlesize=info.large_font_size)  # fontsize of the figure title
    
fixed_name_codes = {
    '12 Branch': 'K12',
    '6 Branch': 'K6',
    '4 Branch': 'K4',
    
    '12 Head': 'H12',
    '6 Head': 'H6',
    '4 Head': 'H4',
    
    '12 LVAE': 'L12',
    '6 LVAE': 'L6',
    '4 LVAE': 'L4',
}
fixed_names = {
    '12 Branch': '12-SAE',
    '6 Branch': '6-SAE',
    '4 Branch': '4-SAE',
    
    '12 Head': '12-ASAE',
    '6 Head': '6-ASAE',
    '4 Head': '4-ASAE',
    
    '12 LVAE': '12-VLAE',
    '6 LVAE': '6-VLAE',
    '4 LVAE': '4-VLAE',
}

def fix_names(run):
    
    if run.name in fixed_name_codes:
        run.code = fixed_name_codes[run.name]
    else:
        run.code = run.name
        
    if run.name in fixed_names:
        run.name = fixed_names[run.name]

def encode(model, X):
    with torch.no_grad():
        q = model.encode(X.cuda())
    if isinstance(q, distrib.Distribution):
        q = q.mean
    return q.cpu()

In [5]:
info = util.tdict()

In [6]:
info.results_name = 'final'
# info.results_name = 'results'

info.small_font_size = 14
info.medium_font_size = 16
info.large_font_size = 18

info.colors = {
    'Branch':'C0', 
          'Baseline':'C1', 
          'Conv':'C2', 
          'Attention':'C3',
        'Ladder': 'C4',
         }
info.tnames = {'Branch':'Structural', 'Baseline':'Baseline', 'Conv':'Simple', 'Attention':'Attention', 'Ladder': 'Ladder'}

preset_figure_info(info)
print(info)

t{results_name, small_font_size, medium_font_size, large_font_size, colors, tnames}


In [7]:
model_names = [
    ('AE', 'Conv'),
    ('VAE', 'Conv'),
    ('4VAE', 'Conv'),
    ('16VAE', 'Conv'),
    ('WAE', 'Conv'),
    ('12 Branch', 'Branch'),
    ('6 Branch', 'Branch'),
    ('4 Branch', 'Branch'),
    ('12 Head', 'Attention'),
    ('6 Head', 'Attention'),
    ('4 Head', 'Attention'),
    ('4 LVAE', 'Ladder'),
    ('6 LVAE', 'Ladder'),
    ('12 LVAE', 'Ladder'),
    ('AE', 'Baseline'),
    ('VAE', 'Baseline'),
    ('4VAE', 'Baseline'),
    ('16VAE', 'Baseline'),
    ('WAE', 'Baseline'),
]
len(model_names)

19

In [67]:
raw = '''AE	Conv	3ds-ae-conv_0002-6337051-00_200508-232722
AE-L2	Conv	3ds-ae-b1-conv_0002-6337051-01_200508-232718
VAE	Conv	3ds-vae-b1-conv_0002-6337051-03_200508-232715
2VAE	Conv	3ds-vae-b2-conv_0002-6337051-04_200508-232711
4VAE	Conv	3ds-vae-b4-conv_0002-6337051-05_200508-232713
8VAE	Conv	3ds-vae-b8-conv_0002-6337051-06_200508-232748
16VAE	Conv	3ds-vae-b16-conv_0002-6337051-07_200508-232723
32VAE	Conv	3ds-vae-b32-conv_0002-6337051-08_200508-232717
WAE	Conv	3ds-wae-conv_0002-6337051-02_200508-232718
12 Branch	Branch	3ds-ae-12b1-conv_0005-6337809-01_200509-120419
6 Branch	Branch	3ds-ae-6b2-conv_0034-6373616-01_200518-163704
4 Branch	Branch	3ds-ae-4b3-conv_0034-6373616-00_200518-163709
12 Head	Attention	3ds-ae-12b1-12h1k32v32_0001-6337050-04_200508-232612
6 Head	Attention	3ds-ae-12b1-6h1k32v32_0001-6337050-08_200508-232652
4 Head	Attention	3ds-ae-12b1-4h1k32v32_0001-6337050-07_200508-232614
12 LVAE	Ladder	3ds-vae-lddrev12_0042-6464006-07_200522-152910
12 Ladder-AE	Ladder	3ds-ae-lddrev12_0042-6464006-04_200522-152948
6 LVAE	Ladder	3ds-vae-lddrev6_0042-6464006-06_200522-152917
6 Ladder-AE	Ladder	3ds-ae-lddrev6_0042-6464006-03_200522-152912
4 LVAE	Ladder	3ds-vae-lddrev4_0042-6464006-05_200522-152911
4 Ladder-AE	Ladder	3ds-ae-lddrev4_0042-6464006-02_200522-152915
AE	Baseline	3ds-ae-dislib_0003-6337807-00_200509-120112
AE-L2	Baseline	3ds-ae-b1-dislib_0003-6337807-01_200509-120113
VAE	Baseline	3ds-vae-b1-dislib_0003-6337807-03_200509-120115
2VAE	Baseline	3ds-vae-b2-dislib_0003-6337807-04_200509-120115
4VAE	Baseline	3ds-vae-b4-dislib_0003-6337807-05_200509-120125
8VAE	Baseline	3ds-vae-b8-dislib_0003-6337807-06_200509-120109
16VAE	Baseline	3ds-vae-b16-dislib_0003-6337807-07_200509-120113
32VAE	Baseline	3ds-vae-b32-dislib_0003-6337807-08_200509-120119
WAE	Baseline	3ds-wae-dislib_0003-6337807-02_200509-120136'''



info.ds = '3ds'
info.fsize = (6,4)
info.title = '3D-Shapes'

info.bidx = 9
info.bidx = 2
# info.highlight = [-1,0,-1,1,2,-1,3,3,4,4,5,5,]

In [18]:
raw = '''AE	Conv	mpi3d-ae-conv_0011-6341579-00_200511-035009
AE-L2	Conv	mpi3d-ae-b1-conv_0011-6341579-01_200511-035004
VAE	Conv	mpi3d-vae-b1-conv_0011-6341579-03_200511-035007
2VAE	Conv	mpi3d-vae-b2-conv_0011-6341579-04_200511-035008
4VAE	Conv	mpi3d-vae-b4-conv_0011-6341579-05_200511-035016
8VAE	Conv	mpi3d-vae-b8-conv_0011-6341579-06_200511-035007
16VAE	Conv	mpi3d-vae-b16-conv_0011-6341579-07_200511-035003
32VAE	Conv	
WAE	Conv	mpi3d-wae-conv_0011-6341579-02_200511-035010
12 Branch	Branch	mpi3d-ae-12b1-conv_0011-6341579-09_200511-035023
6 Branch	Branch	mpi3d-ae-toy-6b2_0056-6468034-01_200525-033418
4 Branch	Branch	mpi3d-ae-toy-4b3-conv_0071-6472326-00_200526-191939
12 Head	Attention	mpi3d-ae-12b1-12h1k32v32_0013-6343740-04_200512-014614
6 Head	Attention	mpi3d-ae-12b1-6h1k32v32_0013-6343740-08_200512-014603
4 Head	Attention	mpi3d-ae-12b1-4h1k32v32_0013-6343740-07_200512-014605
12 LVAE	Ladder	mpi3d-vae-lddrev12_0042-6464006-13_200522-152922
12 Ladder-AE	Ladder	mpi3d-ae-lddrev12_0042-6464006-10_200522-152920
6 LVAE	Ladder	mpi3d-vae-toy-lddrev6_0071-6472326-01_200526-191942
6 Ladder-AE	Ladder	mpi3d-ae-lddrev6_0042-6464006-09_200522-152920
4 LVAE	Ladder	mpi3d-vae-lddrev4_0042-6464006-11_200522-152927
4 Ladder-AE	Ladder	mpi3d-ae-lddrev4_0042-6464006-08_200522-152926'''

if ('16VAE', 'Baseline') in model_names:
    model_names.remove(('16VAE', 'Baseline'))

info.ds = 'toy'
info.fsize = (6,4)
info.title= 'MPI3D Toy'
info.bidx = 13 #24
# info.highlight = [0,0,0,1,2,3,4,4,5,-1,6,5,] # toy
info.ridx = 2
info.baselines = {}


In [27]:
raw = '''AE	Conv	mpi3d-ae-conv_0012-6341580-00_200511-035101
AE-L2	Conv	mpi3d-ae-b1-conv_0012-6341580-01_200511-035039
VAE	Conv	mpi3d-vae-b1-conv_0012-6341580-03_200511-035054
2VAE	Conv	mpi3d-vae-b2-conv_0012-6341580-04_200511-035041
4VAE	Conv	mpi3d-vae-b4-conv_0012-6341580-05_200511-035040
8VAE	Conv	mpi3d-vae-b8-conv_0012-6341580-06_200511-035042
16VAE	Conv	mpi3d-vae-b16-conv_0012-6341580-07_200511-035036
32VAE	Conv	
WAE	Conv	mpi3d-wae-conv_0012-6341580-02_200511-035101
12 Branch	Branch	mpi3d-ae-12b1-conv_0012-6341580-09_200511-035033
6 Branch	Branch	mpi3d-ae-real-6b2_0056-6468034-03_200525-033420
4 Branch	Branch	mpi3d-ae-real-4b3_0056-6468034-02_200525-033422
12 Head	Attention	mpi3d-ae-12b1-12h1k32v32_0015-6346974-04_200513-020758
6 Head	Attention	mpi3d-ae-12b1-6h1k32v32_0015-6346974-08_200513-020756
4 Head	Attention	mpi3d-ae-12b1-4h1k32v32_0015-6346974-07_200513-020753
12 LVAE	Ladder	mpi3d-vae-real-lddrev12_0048-6464516-02_200522-212011
12 Ladder-AE	Ladder	mpi3d-ae-real-lddrev12-s2_0071-6472326-07_200526-191933
6 LVAE	Ladder	mpi3d-vae-real-lddrev6_0048-6464516-01_200522-212003
6 Ladder-AE	Ladder	mpi3d-ae-real-lddrev6_0071-6472326-03_200526-191932
4 LVAE	Ladder	mpi3d-vae-real-lddrev4_0048-6464516-00_200522-212001
4 Ladder-AE	Ladder	mpi3d-ae-real-lddrev4_0071-6472326-04_200526-191934'''

if ('16VAE', 'Conv') in model_names:
    model_names.remove(('16VAE', 'Conv'))

info.ds = 'real'
info.fsize = (6,4)
info.title= 'MPI3D Real'
info.bidx = 24 #48
info.bidx = 13
# info.highlight = [0,1,0,1,0,2,-1,3,4,5,5,-1]
info.ridx = 2
info.baselines = {}


In [54]:
from collections import OrderedDict

In [56]:
names = OrderedDict({
    'AE': ['3ds-ae-conv_0002-6337051-00_200508-232722', 'mpi3d-ae-conv_0011-6341579-00_200511-035009', 'mpi3d-ae-conv_0012-6341580-00_200511-035101'],
    'VAE': ['3ds-vae-b1-conv_0002-6337051-03_200508-232715', 'mpi3d-vae-b1-conv_0011-6341579-03_200511-035007', 'mpi3d-vae-b1-conv_0012-6341580-03_200511-035054'],
    r'$\beta$VAE': ['3ds-vae-b16-conv_0002-6337051-07_200508-232723', 'mpi3d-vae-b4-conv_0011-6341579-05_200511-035016', 'mpi3d-vae-b4-conv_0012-6341580-05_200511-035040'],
    
    '12-SAE': ['3ds-ae-12b1-conv_0005-6337809-01_200509-120419', 'mpi3d-ae-12b1-conv_0011-6341579-09_200511-035023', 'mpi3d-ae-12b1-conv_0012-6341580-09_200511-035033'],
    '6-SAE': ['3ds-ae-6b2-conv_0034-6373616-01_200518-163704', 'mpi3d-ae-toy-6b2_0056-6468034-01_200525-033418', 'mpi3d-ae-real-6b2_0056-6468034-03_200525-033420'],
    '4-SAE': ['3ds-ae-4b3-conv_0034-6373616-00_200518-163709', 'mpi3d-ae-toy-4b3-conv_0071-6472326-00_200526-191939', 'mpi3d-ae-real-4b3_0056-6468034-02_200525-033422'],
    
    '12-ASAE': ['3ds-ae-12b1-12h1k32v32_0001-6337050-04_200508-232612', 'mpi3d-ae-12b1-12h1k32v32_0013-6343740-04_200512-014614', 'mpi3d-ae-12b1-12h1k32v32_0015-6346974-04_200513-020758'],
    '6-ASAE': ['3ds-ae-12b1-6h1k32v32_0001-6337050-08_200508-232652', 'mpi3d-ae-12b1-6h1k32v32_0013-6343740-08_200512-014603', 'mpi3d-ae-12b1-6h1k32v32_0015-6346974-08_200513-020756'],
    '4-ASAE': ['3ds-ae-12b1-4h1k32v32_0001-6337050-07_200508-232614', 'mpi3d-ae-12b1-4h1k32v32_0013-6343740-07_200512-014605', 'mpi3d-ae-12b1-4h1k32v32_0015-6346974-07_200513-020753'],
    
    '12-VLAE': ['3ds-vae-lddrev12_0042-6464006-07_200522-152910', 'mpi3d-vae-lddrev12_0042-6464006-13_200522-152922', 'mpi3d-vae-real-lddrev12_0048-6464516-02_200522-212011'],
    '6-VLAE': ['3ds-vae-lddrev6_0042-6464006-06_200522-152917', 'mpi3d-vae-toy-lddrev6_0071-6472326-01_200526-191942', 'mpi3d-vae-real-lddrev6_0048-6464516-01_200522-212003'],
    '4-VLAE': ['3ds-vae-lddrev4_0042-6464006-05_200522-152911', 'mpi3d-vae-lddrev4_0042-6464006-11_200522-152927', 'mpi3d-vae-real-lddrev4_0048-6464516-00_200522-212001'],
})
names.keys()

odict_keys(['AE', 'VAE', '$\\beta$VAE', '12-SAE', '6-SAE', '4-SAE', '12-ASAE', '6-ASAE', '4-ASAE', '12-VLAE', '6-VLAE', '4-VLAE'])

In [57]:
runs = util.Table()
for n, rs in names.items():
    runs.new(name=n, rnames=rs)
len(runs)

12

In [60]:
def _update_data(run):
    run.raws = []
    for rname in run.rnames:
        rname = f'{rname}.txt'
        run.raws.append(hp.containerify(yaml.load(open(os.path.join(root, rname), 'r').read().replace('\t', '  '))))
runs.map(_update_data)
pass

calling yaml.load() without Loader=... is deprecated, as the default Loader is unsafe. Please read https://msg.pyyaml.org/load for full details.


In [63]:
def extract_metrics(run):
    line = []
    for raw in run.raws:
        line.append(raw['DCI']['disentanglement'])
        line.append(raw['IRS']['IRS'])
        line.append(-np.log(raw['Fairness']['mean_fairness:mean_pred:mean_sens']))
    run.line = line
runs.map(extract_metrics)
pass

In [68]:
runs = organize(raw, dataset=info.ds)
runs.filter_(lambda r: (r.name, r.arch) in model_names)
# runs.filter_(lambda r: r.arch == 'Attention')

runs.map(fix_names)
len(runs)

17

In [69]:
def load_data(run):
    rname = f'{run.rname}.txt'
    if rname in fnames:
        run.data = hp.containerify(yaml.load(open(os.path.join(root, rname), 'r').read().replace('\t', '  ')))
    else:
        print(f'Missing: {run.rname}')
        run.data = None
#     if info.results_name not in run:
#         full = torch.load(os.path.join(root, rname, f'{info.results_name}.pth.tar'))
#         run[info.results_name] = full

In [70]:
runs.map(load_data, pbar=tqdm)
pass

HBox(children=(FloatProgress(value=0.0, max=17.0), HTML(value='')))




calling yaml.load() without Loader=... is deprecated, as the default Loader is unsafe. Please read https://msg.pyyaml.org/load for full details.


In [71]:
run = runs[0]
run.rname

'3ds-ae-conv_0002-6337051-00_200508-232722'

In [79]:
print(run.data.keys())
print(run.data['Factor VAE Score'])

odict_keys(['Beta VAE Score', 'Factor VAE Score', 'MIG', 'DCI', 'IRS', 'SAP', 'Modularity Explicitness', 'Unsupervised Metrics', 'Fairness'])
t{train_accuracy, eval_accuracy, num_active_dims}


In [91]:
run.data

t{'Beta VAE Score':t{'train_accuracy':0.8654, 'eval_accuracy':0.8598}, 'Factor VAE Score':t{'train_accuracy':0.6748, 'eval_accuracy':0.6696, 'num_active_dims':12}, 'MIG':t{'discrete_mig':0.0392523447710372}, 'DCI':t{'informativeness_train':0.8895000000000001, 'informativeness_test':0.7043333333333334, 'disentanglement':0.19285710199444991, 'completeness':0.14664690605207303}, 'IRS':t{'IRS':0.5062780736485408, 'num_active_dims':219.39516}, 'SAP':t{'SAP_score':0.016533333333333337}, 'Modularity Explicitness':t{'modularity_score':0.8768743865137764, 'explicitness_score_train':0.9122557483371718, 'explicitness_score_test':0.9099392116334312}, 'Unsupervised Metrics':t{'gaussian_total_correlation':1.1783814439264315, 'gaussian_wasserstein_correlation':12.25666447281759, 'gaussian_wasserstein_correlation_norm':0.09062213432414572, 'mutual_info_score':0.043318961004037364}, 'Fairness':t{'mean_fairness:mean_pred:mean_sens':0.0729711582325399}}

In [94]:
def extract_metrics(run):
    if run.data is not None:
        run.fairness = -np.log(run.data['Fairness']['mean_fairness:mean_pred:mean_sens'])
        run.dci = run.data['DCI']['disentanglement']
        run.irs = run.data['IRS']['IRS']
        run.mig = run.data['MIG']['discrete_mig']
        run.fvae = -np.log(1+1e-5-run.data['Factor VAE Score']['eval_accuracy'])
        run.mutinf = run.data['Unsupervised Metrics']['mutual_info_score']
runs.map(extract_metrics)
pass

In [95]:
list(runs.selects('name', 'arch'))

In [96]:
fig, ax = plt.subplots(figsize=(8,5))
# texts = []

prop = 'dci'
prop = 'mig'
prop = 'fvae'
prop = 'mutinf'
# prop = 'irs'
# prop = 'fairness'

tps = set()
names = []
x = np.arange(len(names))
for i,(n,t,v) in enumerate(runs.selects('name', 'arch', prop)):
    names.append(n)
    tps.add(t)
    plt.bar(i, v, color=info.colors[t], edgecolor='k')#linewidth=2)
plt.xticks(np.arange(len(names)), names, rotation='vertical')
#     plt.scatter(r.loss, r.fid, color=colors[r.type])
#     texts.append(plt.text(r.loss, r.fid, r.name))
#     tps.add(r.type)
#     plt.scatter(r.fid, r.loss, color=colors[r.type])
#     texts.append(plt.text(r.fid-(fidrg/10 if r.type == 'attn' else 0), r.loss, r.name))
legend_elements = [Patch(facecolor=info.colors[name], label=info.tnames[name]) for name in set(tps)]
# legend_elements1 = [Patch(facecolor='w', hatch='xx', label='Hybrid'), Patch(facecolor='w', hatch='oo', label='Prior')]
legend = plt.legend(handles=legend_elements, )
# plt.xlabel('Reconstruction Loss')
# plt.ylabel('FID Score')
# plt.title(f'{info.title} Generated Samples')
plt.title(prop)
# plt.yscale('log')
# s = [20,30,40,50,60,70, 80,90,100]
# plt.ylim(0,90)
# plt.yticks(s,map(str,s))
# plt.grid(linestyle='--',axis='y')

# legend1 = plt.legend(legend_elements1, ['Hybrid', 'Prior'], loc=4, title='Sampling Method')
# ax.add_artist(legend1, )
# ax.add_artist(legend)
plt.tight_layout()
pass

<IPython.core.display.Javascript object>

In [50]:
from tabulate import tabulate

In [64]:
lines = [[r.name, *r.line] for r in runs]
len(lines)

12

In [66]:
print(tabulate(lines, floatfmt='.2f', tablefmt='latex'))

\begin{tabular}{lrrrrrrrrr}
\hline
 AE         & 0.19 & 0.51 & 2.62 & 0.07 & 0.63 & 2.35 & 0.07 & 0.61 & 2.19 \\
 VAE        & 0.89 & 0.74 & 5.19 & 0.34 & 0.62 & 2.50 & 0.22 & 0.64 & 2.33 \\
 \$\textbackslash{}beta\$VAE & 0.54 & 0.64 & 4.11 & 0.25 & 0.66 & 2.35 & 0.27 & 0.73 & 2.29 \\
 12-SAE     & 0.97 & 0.89 & 7.72 & 0.57 & 0.63 & 2.81 & 0.30 & 0.59 & 2.32 \\
 6-SAE      & 0.80 & 0.81 & 6.62 & 0.50 & 0.63 & 2.80 & 0.33 & 0.63 & 2.37 \\
 4-SAE      & 0.70 & 0.68 & 5.38 & 0.39 & 0.63 & 2.59 & 0.16 & 0.60 & 2.28 \\
 12-ASAE    & 0.91 & 0.82 & 6.57 & 0.46 & 0.70 & 2.59 & 0.38 & 0.79 & 2.48 \\
 6-ASAE     & 0.59 & 0.72 & 4.34 & 0.50 & 0.71 & 2.58 & 0.27 & 0.71 & 2.35 \\
 4-ASAE     & 0.66 & 0.76 & 4.40 & 0.46 & 0.72 & 2.42 & 0.29 & 0.75 & 2.45 \\
 12-VLAE    & 0.85 & 0.83 & 7.90 & 0.44 & 0.60 & 2.62 & 0.20 & 0.58 & 2.23 \\
 6-VLAE     & 0.70 & 0.64 & 7.10 & 0.36 & 0.61 & 2.52 & 0.34 & 0.60 & 2.32 \\
 4-VLAE     & 0.53 & 0.73 & 3.58 & 0.40 & 0.69 & 2.53 & 0.40 & 0.66 & 2.33 \\
\hline
\end{

In [40]:
tabulate?

In [None]:
cols = ['dci', 'irs', 'fairness']

In [None]:
'''\begin{center}
\begin{tabular}{ |c|c|c|c| } 
\hline
col1 & col2 & col3 \\
\hline
\multirow{3}{4em}{Multiple row} & cell2 & cell3 \\ 
& cell5 & cell6 \\ 
& cell8 & cell9 \\ 
\hline
\end{tabular}
\end{center}'''