In [None]:
import os
import sys
module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler

from lmmvae.dim_reduction_images import run_dim_reduction_images
from lmmvae.simulation import Count
from lmmvae.utils_images import get_generators, sample_split

In [None]:
images_dir = '../../data/img_align_celeba_png/'
images_df = pd.read_csv('../../data/celeba_small.csv')

images_df.head()

In [None]:
n_cats_celebs = len(images_df['celeb'].unique())
print(f'no. of sources: {n_cats_celebs}')

In [None]:
# params for LMMVAE and other methods, some unnecessary for current use-case therefore are none
img_height, img_width, channels = (72, 60, 3)
img_file_col = 'img_file'
RE_col = 'celeb'
RE_inputs = [images_df[RE_col].values]
mode = 'categorical'
n_sig2bs = 1
n_sig2bs_spatial = 0
n_neurons = [32, 16]
dropout = None
activation = 'relu'
RE_cols_prefix = 'z'
thresh = None
epochs = 200
qs = [n_cats_celebs]
q_spatial = None
batch_size = 1000
patience = None
kernel_root = None
U = None
B_list = None
est_cors = []
n_neurons_re = n_neurons
pred_unknown_clusters = False
max_spatial_locs = 100
time2measure_dict = None

if pred_unknown_clusters:
    filter_col = RE_col
else:
    filter_col = 'filter_col'
    images_df[filter_col] = np.arange(images_df.shape[0])

In [None]:
res = pd.DataFrame(columns=['d', 'beta', 're_prior', 'experiment', 'exp_type', 'mse_X', 'sigma_b0_est', 'n_epoch', 'time',
    'total_loss_tr', 'recon_loss_tr', 'kl_loss_tr', 're_kl_loss_tr', 'total_loss_te', 'recon_loss_te', 'kl_loss_te', 're_kl_loss_te'
    ])
kf = KFold(n_splits=5, shuffle=True, random_state=40)
counter = Count().gen()

In [None]:
def iterate_reg_types(train_generator, valid_generator, test_generator, train_RE_inputs, counter, d, beta, re_prior, i, verbose):
    mse_lmmvae, sigmas, _, n_epochs_lmmvae, time_lmmvae, losses_lmmvae = run_dim_reduction_images(None, None, None, None,
            img_height, img_width, channels, d, 'lmmvae-cnn-gen',
            thresh, epochs, qs, q_spatial, n_sig2bs, n_sig2bs_spatial, est_cors, batch_size, patience, n_neurons, n_neurons_re, dropout,
            activation, mode, beta, re_prior, kernel_root, pred_unknown_clusters, max_spatial_locs, time2measure_dict, verbose, U, B_list,
            train_generator=train_generator, valid_generator=valid_generator, test_generator=test_generator, train_RE_inputs=train_RE_inputs)
    print('   finished lmmvae-cnn, mse: %.3f' % mse_lmmvae)
    # mse_ig, _, _, n_epochs_ig, time_ig, losses_ig = run_dim_reduction_images(X_train, X_test, Z_train, Z_test,
    #         img_height, img_width, channels, d, 'pca-ignore',
    #         thresh, epochs, qs, q_spatial, n_sig2bs, n_sig2bs_spatial, est_cors, batch_size, patience, n_neurons, n_neurons_re, dropout,
    #         activation, mode, beta, re_prior, kernel_root, pred_unknown_clusters, max_spatial_locs, time2measure_dict, verbose, U, B_list)
    # print('   finished pca-ignore, mse: %.3f' % mse_ig)
#     mse_ohe, _, _, n_epochs_ohe, time_ohe, losses_ohe = run_dim_reduction_images(X_train, X_test, Z_train, Z_test,
#             img_height, img_width, channels, d, 'pca-ohe',
#             thresh, epochs, qs, q_spatial, n_sig2bs, n_sig2bs_spatial, est_cors, batch_size, patience, n_neurons, n_neurons_re, dropout,
#             activation, mode, beta, re_prior, kernel_root, pred_unknown_clusters, max_spatial_locs, time2measure_dict, verbose, U, B_list)
#     print('   finished pca-ohe, mse: %.3f' % mse_ohe)
    mse_vaeig, _, _, n_epochs_vaeig, time_vaeig, losses_vaeig = run_dim_reduction_images(None, None, None, None,
            img_height, img_width, channels, d, 'vae-ignore-cnn-gen',
            thresh, epochs, qs, q_spatial, n_sig2bs, n_sig2bs_spatial, est_cors, batch_size, patience, n_neurons, n_neurons_re, dropout,
            activation, mode, beta, re_prior, kernel_root, pred_unknown_clusters, max_spatial_locs, time2measure_dict, verbose, U, B_list,
            train_generator=train_generator, valid_generator=valid_generator, test_generator=test_generator)
    print('   finished vae-ignore, mse: %.3f' % mse_vaeig)
    mse_vaeem, _, _, n_epochs_vaeem, time_vaeem, losses_vaeem = run_dim_reduction_images(None, None, None, None,
            img_height, img_width, channels, d, 'vae-embed-cnn-gen',
            thresh, epochs, qs, q_spatial, n_sig2bs, n_sig2bs_spatial, est_cors, batch_size, patience, n_neurons, n_neurons_re, dropout,
            activation, mode, beta, re_prior, kernel_root, pred_unknown_clusters, max_spatial_locs, time2measure_dict, verbose, U, B_list,
            train_generator=train_generator, valid_generator=valid_generator, test_generator=test_generator)
    print('   finished vae-embed, mse: %.3f' % mse_vaeem)
    res.loc[next(counter)] = [d, beta, re_prior, i, 'lmmvae', mse_lmmvae, sigmas[1][0], n_epochs_lmmvae, time_lmmvae] + losses_lmmvae
    # res.loc[next(counter)] = [d, beta, re_prior, i, 'pca-ignore', mse_ig, np.nan, n_epochs_ig, time_ig] + losses_ig
#     res.loc[next(counter)] = [d, beta, re_prior, i, 'pca-ohe', mse_ohe, np.nan, n_epochs_ohe, time_ohe] + losses_ohe
    res.loc[next(counter)] = [d, beta, re_prior, i, 'vae-ignore', mse_vaeig, np.nan, n_epochs_vaeig, time_vaeig] + losses_vaeig
    res.loc[next(counter)] = [d, beta, re_prior, i, 'vae-embed', mse_vaeem, np.nan, n_epochs_vaeem, time_vaeem] + losses_vaeem
    res.to_csv('res_celeba.csv')

In [None]:
betas = [0.01]
ds = [100, 200, 500]
re_priors = [0.001]

In [None]:
for beta in betas:
  for d in ds:
    for re_prior in re_priors:
      print(f'beta: {beta}, d: {d}, re_prior: {re_prior}:')
      if pred_unknown_clusters:
        for i, (train_samp_subj, test_samp_subj) in enumerate(kf.split(range(n_cats_celebs))):
          print('  iteration %d' % i)
          train_samp_subj, valid_samp_subj = sample_split(i, train_samp_subj)
          train_generator, valid_generator, test_generator = get_generators(
            images_df, images_dir, train_samp_subj, valid_samp_subj, test_samp_subj, batch_size,
            img_file_col, RE_col, RE_col, img_height, img_width)
          train_RE_table = [RE_input[images_df[RE_col].isin(train_samp_subj)] for RE_input in RE_inputs]
          iterate_reg_types(train_generator, valid_generator, test_generator, train_RE_table,
                            counter, d, beta, re_prior, i, verbose=True)
      else:
        for i, (train_index, test_index) in enumerate(kf.split(range(images_df.shape[0]))):
          print('  iteration %d' % i)
          train_index, valid_index = sample_split(i, train_index)
          train_generator, valid_generator, test_generator = get_generators(
            images_df, images_dir, train_index, valid_index, test_index, batch_size,
            img_file_col, RE_col, filter_col, img_height, img_width)
          train_RE_inputs = [RE_input[images_df.index.isin(train_index)] for RE_input in RE_inputs]
          iterate_reg_types(train_generator, valid_generator, test_generator, train_RE_inputs,
                            counter, d, beta, re_prior, i, verbose=True)