In [1]:
import itertools
import logging
import multiprocessing
import os
import pickle
import sys 

from matplotlib.lines import Line2D
from matplotlib.patches import Patch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tqdm
from scipy.stats import ttest_ind_from_stats as ttest

from shared.utils import config_hasher, tried_config
from cmnist import configurator

In [2]:
exp_name = "correlation"
BASE_DIR = "/data/ddmg/slabs/cmnist"
MODELS = ['slabs', 'opslabs', 'simple_baseline', 'oracle_aug_0.1', 'oracle_aug_0.5']
NUM_WORKERS = 10
X_AXIS_VAR = 'py1_y0_s'

MODEL_TO_PLOT_SPECS = {
	'slabs': {'color': '#ff7f0e', 'label': 'SLABS (ours)'},
	'opslabs': {'color': '#d62728', 'label': 'OP-SLABS (ours)'},
	'weighted_opslabs': {'color': 'black', 'label': 'W-OP-SLABS (ours)'},
	'simple_baseline': {'color': '#2ca02c', 'label': 'Simple baseline'},
	'oracle_aug_0.1': {'color': '#9467bd', 'label': 'Oracle aug (10%)'},
	'oracle_aug_0.5': {'color': '#e377c2', 'label': 'Oracle aug (50%)'},
}

In [3]:

def import_helper(args):
	"""Imports the dictionary with the results of an experiment.

	Args:
		args: tuple with model, config where
			model: str, name of the model we're importing the performance of
			config: dictionary, expected to have the following: exp_dir, the experiment
				directory random_seed,  random seed for the experiment py1_y0_s,
				probability of y1=1| y0=1 in the shifted test distribution alpha,
				MMD/cross prediction penalty sigma,  kernel bandwidth for the MMD penalty
				l2_penalty,  regularization parameter dropout_rate,  drop out rate
				embedding_dim,  dimension of the final representation/embedding
				unused_kwargs, other key word args passed to xmanager but not needed here

	Returns:
		pandas dataframe of results if the file was found, none otherwise
	"""
	model, config = args
	hash_string = config_hasher(config)
	hash_dir = os.path.join(BASE_DIR, 'tuning', hash_string)
	performance_file = os.path.join(hash_dir, 'performance.pkl')

	if not os.path.exists(performance_file):
		logging.error('Couldnt find %s', performance_file)
		return None

	results_dict = pickle.load(open(performance_file, 'rb'))
	results_dict.update(config)
	results_dict['model'] = model
	return pd.DataFrame(results_dict, index=[0])

In [4]:
def plot_errorbars_same_and_shifted(axis,
                                    cmap,
                                    results,
                                    group,
                                    metric):
	"""Plots results for same and shifted test distributions.

	Args:
		axis: matplotlib plot axis
		results: pandas dataframe with all models' results
		group: group to plot
		metric: metric to plot, one of loss or acc

	Returns:
		None. Just adds the errorbars to an existing plot.
	"""
	# TODO x-axis variable
	model_results = results[(results.group == group)]
	axis.errorbar(
		model_results.alpha_id,
		model_results[f'shift_distribution_{metric}_mean'],
		yerr=model_results[f'shift_distribution_{metric}_std'],
        label=np.unique(model_results.sigma),
#         color = cmap(group),
		)

	axis.errorbar(
		model_results.alpha_id,
		model_results[f'same_distribution_{metric}_mean'],
		yerr=model_results[f'same_distribution_{metric}_std'],
#         color = cmap(group),
		linestyle='--')
	axis.legend()

In [5]:
all_config = []
model = "weighted_opslabs"
model_configs = configurator.get_sweep(exp_name, model)
available_configs = [tried_config(config, base_dir=BASE_DIR) for config
                                                in model_configs]
model_configs = list(itertools.compress(model_configs, available_configs))
all_config.extend([(model, config) for config in model_configs])

pool = multiprocessing.Pool(10)
res = []
for config_res in tqdm.tqdm(pool.imap_unordered(import_helper, all_config),
    total=len(all_config)):
    res.append(config_res)

res_or = pd.concat(res, axis=0, ignore_index=True, sort=False)


100%|██████████| 194/194 [00:00<00:00, 1590.29it/s]


In [6]:
res_or.columns

Index(['validation_accuracy', 'validation_loss', 'validation_mmd',
       'validation_global_step', 'shift_0.1_accuracy', 'shift_0.1_loss',
       'shift_0.1_mmd', 'shift_0.1_global_step', 'shift_0.2_accuracy',
       'shift_0.2_loss', 'shift_0.2_mmd', 'shift_0.2_global_step',
       'shift_0.3_accuracy', 'shift_0.3_loss', 'shift_0.3_mmd',
       'shift_0.3_global_step', 'shift_0.4_accuracy', 'shift_0.4_loss',
       'shift_0.4_mmd', 'shift_0.4_global_step', 'shift_0.5_accuracy',
       'shift_0.5_loss', 'shift_0.5_mmd', 'shift_0.5_global_step',
       'shift_0.6_accuracy', 'shift_0.6_loss', 'shift_0.6_mmd',
       'shift_0.6_global_step', 'shift_0.7_accuracy', 'shift_0.7_loss',
       'shift_0.7_mmd', 'shift_0.7_global_step', 'shift_0.8_accuracy',
       'shift_0.8_loss', 'shift_0.8_mmd', 'shift_0.8_global_step',
       'shift_0.9_accuracy', 'shift_0.9_loss', 'shift_0.9_mmd',
       'shift_0.9_global_step', 'shift_1.0_accuracy', 'shift_1.0_loss',
       'shift_1.0_mmd', 'shift_1.0_glo

In [25]:
res = res_or[(res_or.sigma == 1.0) & (res_or.alpha == 1e10)]
print(res.random_seed.value_counts())

5    1
2    1
1    1
8    1
Name: random_seed, dtype: int64


In [26]:
shift_columns = [col for col in res_or.columns if col.startswith('shift')]
shift_columns = [col for col in shift_columns if ('loss' in col) or ('accuracy' in col) ]

In [27]:
shift_columns

['shift_0.1_accuracy',
 'shift_0.1_loss',
 'shift_0.2_accuracy',
 'shift_0.2_loss',
 'shift_0.3_accuracy',
 'shift_0.3_loss',
 'shift_0.4_accuracy',
 'shift_0.4_loss',
 'shift_0.5_accuracy',
 'shift_0.5_loss',
 'shift_0.6_accuracy',
 'shift_0.6_loss',
 'shift_0.7_accuracy',
 'shift_0.7_loss',
 'shift_0.8_accuracy',
 'shift_0.8_loss',
 'shift_0.9_accuracy',
 'shift_0.9_loss',
 'shift_1.0_accuracy',
 'shift_1.0_loss']

In [28]:
res = res[shift_columns].agg({col : ['mean', 'std'] for col in shift_columns})

In [29]:
res = res.transpose()
print(res)

                        mean       std
shift_0.1_accuracy  0.923569  0.017474
shift_0.1_loss      0.282125  0.044269
shift_0.2_accuracy  0.927585  0.017613
shift_0.2_loss      0.261977  0.040947
shift_0.3_accuracy  0.932605  0.013920
shift_0.3_loss      0.258012  0.041048
shift_0.4_accuracy  0.930848  0.016253
shift_0.4_loss      0.268335  0.033820
shift_0.5_accuracy  0.936872  0.011804
shift_0.5_loss      0.247986  0.022353
shift_0.6_accuracy  0.934864  0.011093
shift_0.6_loss      0.259563  0.024655
shift_0.7_accuracy  0.939383  0.006473
shift_0.7_loss      0.245899  0.024192
shift_0.8_accuracy  0.942018  0.005862
shift_0.8_loss      0.253594  0.018711
shift_0.9_accuracy  0.945532  0.003585
shift_0.9_loss      0.242543  0.014888
shift_1.0_accuracy  0.945783  0.002750
shift_1.0_loss      0.237176  0.013678


In [36]:
res['py1_y0_s'] = res.index.str[6:9].astype(float)

In [37]:
res

Unnamed: 0,mean,std,py1_y0_s
shift_0.1_accuracy,0.923569,0.017474,0.1
shift_0.1_loss,0.282125,0.044269,0.1
shift_0.2_accuracy,0.927585,0.017613,0.2
shift_0.2_loss,0.261977,0.040947,0.2
shift_0.3_accuracy,0.932605,0.01392,0.3
shift_0.3_loss,0.258012,0.041048,0.3
shift_0.4_accuracy,0.930848,0.016253,0.4
shift_0.4_loss,0.268335,0.03382,0.4
shift_0.5_accuracy,0.936872,0.011804,0.5
shift_0.5_loss,0.247986,0.022353,0.5


In [44]:
res_accuracy = res[(res.index.str.contains('accuracy'))]
res_accuracy = res_accuracy.rename(columns=
                    {col: f'accuracy_{col}' 
                     for col in res_accuracy.columns if col !='py1_y0_s'})
res_accuracy.head()


Unnamed: 0,accuracy_mean,accuracy_std,py1_y0_s
shift_0.1_accuracy,0.923569,0.017474,0.1
shift_0.2_accuracy,0.927585,0.017613,0.2
shift_0.3_accuracy,0.932605,0.01392,0.3
shift_0.4_accuracy,0.930848,0.016253,0.4
shift_0.5_accuracy,0.936872,0.011804,0.5


In [40]:
res[(res.index.str.contains('accuracy'))].merge(res[(res.index.str.contains('loss'))], on = [
    'py1_y0_s'
])

Unnamed: 0,mean_x,std_x,py1_y0_s,mean_y,std_y
0,0.923569,0.017474,0.1,0.282125,0.044269
1,0.927585,0.017613,0.2,0.261977,0.040947
2,0.932605,0.01392,0.3,0.258012,0.041048
3,0.930848,0.016253,0.4,0.268335,0.03382
4,0.936872,0.011804,0.5,0.247986,0.022353
5,0.934864,0.011093,0.6,0.259563,0.024655
6,0.939383,0.006473,0.7,0.245899,0.024192
7,0.942018,0.005862,0.8,0.253594,0.018711
8,0.945532,0.003585,0.9,0.242543,0.014888
9,0.945783,0.00275,1.0,0.237176,0.013678


In [8]:
validation_columns = [col for col in res_or.columns if 'validation' in col]
print(validation_columns)

['validation_accuracy', 'validation_loss', 'validation_mmd', 'validation_global_step']


In [14]:
shift_columns = [col for col in res_or.columns if col.startswith('shift')]
res = res_or.drop(shift_columns, axis = 1)

In [15]:
res.head()

Unnamed: 0,validation_accuracy,validation_loss,validation_mmd,validation_global_step,alpha,dropout_rate,embedding_dim,l2_penalty,pflip0,pflip1,random_seed,sigma,weighted_mmd,model
0,0.922327,0.637696,0.016331,2621,0.0,0.0,1000,0.0,0.05,0.05,0,9.0,True,weighted_opslabs
1,0.938753,0.65019,0.014722,2621,0.0,0.0,1000,0.0,0.05,0.05,0,10.0,True,weighted_opslabs
2,0.935134,0.581684,0.020338,2621,0.0,0.0,1000,0.0,0.05,0.05,0,7.0,True,weighted_opslabs
3,0.928452,0.6497,0.017398,2621,0.0,0.0,1000,0.0,0.05,0.05,0,8.0,True,weighted_opslabs
4,0.937361,0.64337,0.026933,2621,0.0,0.0,1000,0.0,0.05,0.05,0,5.0,True,weighted_opslabs


In [None]:
# res = pd.read_csv(f'{BASE_DIR}/results/correlation_xval_results.csv')

In [None]:
res_model = res_or.drop(['l2_penalty', 'embedding_dim','dropout_rate', 'validation_global_step', 
         'same_distribution_global_step', 'shift_distribution_global_step', 
         'pflip0', 'pflip1'], axis=1)

In [None]:
res_model.isnull().sum()

In [None]:
# res[(res.validation_mmd.isnull())][['alpha', 'sigma', 'random_seed']]
# res['null_mmd'] = np.where(res.validation_mmd.isnull(), 1, 0)
# pd.crosstab(res.null_mmd, res.random_seed)

In [None]:
res_model = res_model.groupby(
    ['model', 'py1_y0_s', 'sigma', 'alpha']).agg({
        'validation_accuracy': ['mean', 'std'],
        'validation_mmd': ['mean', 'std'],
        'same_distribution_accuracy': ['mean', 'std'],
        'same_distribution_mmd': ['mean', 'std'],
        'shift_distribution_accuracy': ['mean', 'std'],
        'shift_distribution_mmd': ['mean', 'std'],
        'validation_loss': ['mean', 'std'],
        'same_distribution_loss': ['mean', 'std'],
        'shift_distribution_loss': ['mean', 'std']
    }).reset_index()
res_model.columns = ['_'.join(col).strip() for col in res_model.columns.values]
res_model.rename(
    {
        'model_': 'model',
        'py1_y0_s_': 'py1_y0_s',
        'sigma_': 'sigma',
        'alpha_': 'alpha'
    },
    axis=1,
    inplace=True)

In [None]:
# for py_ind, py1_y0_s in enumerate(res_model.py1_y0_s.unique()):
py1_y0_s = 0.3
res_py = res_model[((res_model.py1_y0_s==py1_y0_s))]
print(res_py.sigma.value_counts())

res_sigma = res_py[((res_py.alpha == res_py.alpha.min()
                       ) | (res_py.alpha == res_py.alpha.max()))]
res_sigma = res_sigma.groupby('sigma')['validation_mmd_mean'].agg(np.ptp).reset_index()
res_sigma = res_sigma[(res_sigma.validation_mmd_mean > 0)]
optimal_sigma = np.min(res_sigma.sigma)

# TODO: HERE NEED TO RUN STEP 2 USING OPTIMAL SIGMA 
res_py = res_py[(res_py.sigma == optimal_sigma)].reset_index(drop=True)
print(res_py[['sigma', 'alpha', 'validation_loss_mean', 'same_distribution_accuracy_mean', 
             'shift_distribution_accuracy_mean']])
res_py = res_py.sort_values('validation_loss_mean').reset_index()

min_val_loss = res_py.validation_loss_mean[0]
min_val_loss_ste = res_py.validation_loss_std[0]/np.sqrt(10)

res_py['pvals'] = np.where(res_py.validation_loss_mean <= min_val_loss + min_val_loss_ste, 1, 0)
# pvals = [ttest(mean1 = min_val_loss, std1 = min_val_loss_std, nobs1 = 10, 
#         mean2 = res_py.validation_loss_mean[i], std2 =  res_py.validation_loss_std[i], 
#        nobs2 = 10).pvalue for i in range(res_py.shape[0])]
# res_py['pvals'] = pvals
print(res_py[['sigma', 'alpha', 'validation_loss_mean', 'same_distribution_accuracy_mean', 
             'shift_distribution_accuracy_mean', 'pvals']])

res_py = res_py[(res_py.pvals > 0.05)]
res_py = res_py[(res_py.alpha == res_py.alpha.max())]

print(res_py[['sigma', 'alpha', 'validation_loss_mean', 'same_distribution_accuracy_mean', 
             'shift_distribution_accuracy_mean']])

In [None]:
print(res_model.shape)
res_model = res_model.merge(optimal_sigmas, on=['py1_y0_s', 'sigma'])
print(res_model.shape)

In [None]:
print(res_model)

In [None]:
res_s = res[((res.py1_y0_s==0.3))]
# res_s.drop(['l2_penalty', 'dropout_rate', 'embedding_dim'], axis=1, inplace=True)
res_s['group'] = res_s.groupby(['sigma']).ngroup()
res_s.drop(['py1_y0_s'], inplace = True, axis = 1)
all_alphas = np.unique(res_s.alpha)
vals_to_ind = dict(zip(all_alphas, list(range(len(all_alphas)))))
res_s['alpha_id'] = res_s['alpha'].map(lambda x: vals_to_ind[x])

temp_res = res_s[['alpha', 'sigma', 'validation_loss_mean',  'validation_loss_std', 
                  'validation_mmd_mean', 'validation_mmd_std', 'same_distribution_accuracy_mean', 
                  'shift_distribution_accuracy_mean']]
temp_res.columns = [ 'alpha', 'sigma', 'val_loss', 'val_loss_std', 'val_mmd', 'val_std',  'same_acc', 'shift_acc']
temp_res = temp_res.sort_values('val_loss').reset_index()

min_val_loss = temp_res.val_loss[0]
min_val_loss_std = temp_res.val_loss_std[0]

pvals = [ttest(mean1 = min_val_loss, std1 = min_val_loss_std, nobs1 = 2, 
            mean2 = temp_res.val_loss[i], std2 =  temp_res.val_loss_std[i], 
           nobs2 = 2).pvalue for i in range(temp_res.shape[0])]
temp_res.drop('val_loss_std', axis = 1, inplace = True)
temp_res['pvals'] = pvals
print(temp_res[(temp_res.sigma==2.0)])
# temp_res.fillna(0, inplace=True)

In [None]:
all_groups = np.unique(res_s.group)
_, axes = plt.subplots(len(all_groups),1, figsize=(14, len(all_groups)*5))
cmap = plt.cm.get_cmap('gist_rainbow', len(all_groups))

for gid, group in enumerate(all_groups):
    plot_errorbars_same_and_shifted(axes[gid], cmap, res_s,
        group, 'mmd')
plt.show()
plt.clf()
plt.close()

In [None]:
res_op['within_eps'] = np.where(res_op.validation_loss_mean <= 
                                np.min(res_op.validation_loss_mean + 0.1), 1, 0)
print(res_op.within_eps.value_counts())

In [None]:
res_op[(res_op.validation_loss_mean == np.min(res_op.validation_loss_mean))][[
    'validation_loss_mean', 
   'same_distribution_accuracy_mean', 
   'shift_distribution_accuracy_mean']]

In [None]:
max_alpha_within_eps = np.max(res_op.alpha[(res_op.within_eps==1)])
res_op[(res_op.within_eps == 1) & (res_op.alpha == max_alpha_within_eps) ][[
    'validation_loss_mean', 
   'same_distribution_accuracy_mean', 
   'shift_distribution_accuracy_mean']]


In [None]:
all_groups = np.unique(res_op.group)
_, axes = plt.subplots(len(all_groups),1, figsize=(14, len(all_groups)*5))
cmap = plt.cm.get_cmap('gist_rainbow', len(all_groups))

for gid, group in enumerate(all_groups):
    plot_errorbars_same_and_shifted(axes[gid], cmap, res_op,
        group, 'loss')
plt.show()
plt.clf()
plt.close()

In [None]:
res_s[(res_s.validation_loss_mean == np.min(res_s.validation_loss_mean))][[
    'group', 
    'validation_loss_mean', 
    'validation_accuracy_mean', 
   'same_distribution_accuracy_mean', 
   'shift_distribution_accuracy_mean']]

In [None]:
res_s[(res_s.group==18)]

In [None]:
res_s[((res_s.group==47) & (res_s.alpha_id ==5))][['validation_accuracy_mean', 
                                                   'same_distribution_accuracy_mean', 
                                                   'shift_distribution_accuracy_mean']]

In [None]:
res_s[((res_s.group.isin([78,79,89])) & (res_s.alpha_id==3))][['validation_loss_mean', 'group']]

In [None]:
res_s[(res_s.group==79)]

In [None]:
res_op = res_s[(res_s.embedding_dim==1000) & (res_s.l2_penalty==0.0) & (res_s.dropout_rate ==0) ]
res_op[(res_op.validation_loss_mean==np.min(res_op.validation_loss_mean))]

In [None]:
res_op[((res_op.group == 47) & (res_op.alpha_id ==5))]