In [1]:
import torch
import numpy as np
import time

from math import pi
from scipy.special import logsumexp

import PCAE_pytorch

In [None]:
def zo_to_pm (data):
    return 2.0*data - 1.0
def binarize_stoch(dataset):
    # Assuming dataset is in [0,1], draws Bernoullis accordingly and outputs a binary version.
    return np.random.binomial(1, dataset, dataset.shape)

"""
Example usage below: autoencoding MNIST.
"""
if __name__ == "__main__":
    import scipy.io
    from keras.datasets import mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    digits_train_all = x_train.reshape((len(x_train), np.prod(x_train.shape[1:]))).astype('float32') / 255.
    digits_test_all = x_test.reshape((len(x_test), np.prod(x_test.shape[1:]))).astype('float32') / 255.
    train_size = digits_train_all.shape[0]
    test_size = digits_test_all.shape[0]
    digits_train = digits_train_all[0:train_size, :]
    digits_test = digits_test_all[0:test_size, :]
    data_train = zo_to_pm(binarize_stoch(digits_train))     # Binarize stochastically
    data_test = zo_to_pm(binarize_stoch(digits_test))
    data_test = np.random.permutation(data_test)    # Shuffle test data
    dim_input = data_train.shape[1]

    # Initialize autoencoder
    dim_hidden = 100
    batch_size = 250
    L1_eps = 0.0
    enc_binary = True
    tf.reset_default_graph()
    pc_ae = PCAE(dim_input, dim_hidden, enc_binary=enc_binary, batch_size=batch_size, L1_eps=L1_eps, 
                 optimizer=tf.train.AdagradOptimizer(0.3))
    num_epochs = 60
    losses_list = []
    slacks_list = []
    max_W_list = []
    inittime = time.time()
    for epoch_ctr in range(num_epochs):
        data_mbatch = data_train[np.random.choice(data_train.shape[0], batch_size, replace=False)]
        print 'Epoch: \t ' + str(epoch_ctr)
        # Train more accurately as we get closer to the optimum and W, B get better. This works fine but other settings work too.
        iters_encode = 35*(epoch_ctr + 1)
        iters_decode = 35*(epoch_ctr + 1)
        encs, corrs, _, _ = pc_ae.encode(data_mbatch, iters_encode=iters_encode, display_step=100)
        losslist, slacklist, maxWlist = pc_ae.decode_fit(
            data_mbatch, encodings=encs.T, iters_decode=iters_decode, display_step=100)
        losses_list.extend(losslist)
        slacks_list.extend(slacklist)
        max_W_list.extend(maxWlist)
    print 'Total time taken: \t' + str(time.time() - inittime)

    # Now compute test set encodings.
    encs_list = []
    corrs_list = []
    celosses = []
    inittime = time.time()
    for step in xrange(test_size / batch_size):
        print '-- Epoch %02d --' % (step + 1)
        offset = step * batch_size
        data_mb = data_test[offset:(offset + batch_size)]
        encs, corrs, losses_list, slacks_list = pc_ae.encode(data_mb, iters_encode=1000, display_step=200)
        encs_list.append(encs)
        corrs_list.append(corrs)
        celosses.append(losses_list[-1])
    test_encs = np.concatenate(tuple(encs_list), axis=1)      # Test set encodings
    test_corrs = np.mean(corrs_list, axis=0)                  # B matrix associated with test encodings
    print 'Mean loss: \t' + str(np.mean(celosses))
    print 'Total time taken: \t' + str(time.time() - inittime)