In [None]:
import os

import numpy as np
import tensorflow as tf

from batcher import Batcher
from fc import FullyConnected
from ae import AutoEncoder, VAE
from conv import Convolutional

In [None]:
models = [(FullyConnected, {'batch_size':64, 'learning_rate':1e-3, 'num_hid':0}),
          (FullyConnected, {'batch_size':64, 'learning_rate':1e-3, 'num_hid':2}),
          (Convolutional, {'batch_size':64, 'learning_rate':1e-3, 'num_conv':0}),
          (Convolutional, {'batch_size':64, 'learning_rate':1e-3, 'num_conv':3}),
          (AutoEncoder, {'batch_size':64, 'learning_rate':1e-3, 'num_hid':2}),
          (VAE, {'batch_size':64, 'learning_rate':1e-3, 'num_hid':2})]

filters = [('one', [1]), ('two', [2]), ('three', [3]), ('four', [4]), ('five', [5]), ('six', [6]), ('seven', [7]),
           ('eight', [8]), ('nine', [9]), ('evens', [0, 2, 4, 6, 8]), ('odds', [1, 3, 5, 7, 9])]

save_dir = 'logit_output'

In [None]:
batcher = Batcher('MNIST_data')

In [None]:
def run_single_experiment(model_class, batcher, params, epochs, to_filter):
    tf.reset_default_graph()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        model = model_class(sess, batcher, params)
        model.initialize_variables()
        model.train(epochs, to_filter)
        logits = model.predict()
    return logits

In [None]:
for experiment in range(5):
    for filter_name, filter_list in filters:
        for model_class, params in models:
            logits = run_single_experiment(model_class, batcher, params, 10, filter_list)
            model_name = model_class.__name__
            if model_name == 'FullyConnected':
                model_name = model_name + str(params['num_hid'])
            elif model_name == 'Convolutional':
                model_name = model_name + str(params['num_conv'])
            save_name = '_'.join([model_name, filter_name, str(experiment)])
            np.save(os.path.join(save_dir, save_name), logits)

In [None]:
np.save(os.path.join(save_dir, 'test_images'), batcher.test_img)
np.save(os.path.join(save_dir, 'test_labels'), batcher.test_lbl)