In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sn

# One-dimensional vs. Multidimensional

## Held-out log-likelihoods

In [19]:
def cross_val_log_lik(fname):
    df = pd.read_csv(fname, index_col=0)
    return np.mean(np.array(df),axis=(0, 1))

In [196]:
for i in ['1d','md']:
    for j in ['a', 'h']:
        print(i, j)
        fname = "modeling_results/{}_other_{}_fb_CV.csv".format(i,j)
        print(cross_val_log_lik(fname))

## WAIC and LOO scores

In [14]:
def compute_avg_score(fname):
    with open(fname, "r") as f:
        waic, loo = f.readlines()
    waics = [float(i) for i in waic.strip()[1:-1].split(",")]
    loos = [float(i) for i in loo.strip()[1:-1].split(",")]
    return np.mean(waics), np.mean(loos)

In [196]:
for i in ['1d','md']:
    for j in ['a', 'h']:
        print(i, j)
        fname = "modeling_results/{}_other_{}_fb_CV_metrics.csv".format(i,j)
        print(compute_avg_score(fname))

## Latent correlations

In [25]:
def get_corrs(df_corr):
    omegas = [['Omega.'+str(i)+"."+str(j) for j in range(1,5)] for i in range(1, 5)]
    corr = []
    for row in omegas:
        corr_row = []
        for col in row:
            corr_row.append(np.mean(df_corr[col]))
        corr.append(corr_row)
    return corr

In [26]:
def plot_corrs(corrs, labsx, labsy):
    _min = 1
    _max = 0
    for corr in corrs:
        if np.min(corr) < _min:
            _min = np.min(corr)
        if np.max(corr) >_max:
            _max = np.max(corr)
    for corr in corrs:
        fig, ax = plt.subplots(figsize=(7,7))
        sn.heatmap(corr, annot=True,xticklabels=labsx, yticklabels=labsy,
                   cmap='crest', cbar=False, vmin=_min, vmax=_max, ax=ax)
        plt.show()

In [28]:
labsy = ['History of Art','Video Games','Cities','Math']
labsx = ['HA', 'VG', 'C', 'M']

In [34]:
sn.set(font_scale=1.9)

In [27]:
human_md = pd.read_csv('modeling_results/md_other_h.csv', index_col=0)
human_corrs = get_corrs(human_md)

In [32]:
ai_md = pd.read_csv('modeling_results/md_other_a.csv', index_col=0)
ai_corrs = get_corrs(ai_md)

In [None]:
plot_corrs([ai_corrs,human_corrs], labsx, labsy)