In [1]:
import numpy as np
import tensorflow as tf
import matplotlib
#matplotlib.use('Agg')
import matplotlib.pyplot as plt
from tensorflow.contrib import slim
from tensorflow.contrib import distributions

from tqdm import tqdm, tqdm_notebook
import os

from sgp_utils import nlog_fitc, nlog_vfe, fitc_pred, vfe_pred

class FLAGS:
    def __getattr__(self, name):
        self.__dict__[name] = FLAGS()
        return self.__dict__[name]

FLAGS = FLAGS()

FLAGS.img_path = './img'
FLAGS.niter = 10000
FLAGS.burn_in = 100
FLAGS.print_freq = 500
FLAGS.g_lr = 2e-5
FLAGS.d_lr = 1e-4
FLAGS.batch_size = 200
FLAGS.gen_type = 1
FLAGS.dataset = 'kin40k'
FLAGS.val_prc = .2
FLAGS.n_clusters = 100
FLAGS.n_z = 300
FLAGS.num_refs = 5
FLAGS.num_sgp_samples = 20
FLAGS.sgp_approx = 'vfe'

In [2]:
def load_kin40k(val_prc=0):
    
    import pandas as pd

    kin40k_dir = '/Users/mme/Dropbox/_projects/gp_gan/data/kin40k/'

    # Read KIN40K Data
    train_x = pd.read_csv(kin40k_dir+'kin40k_train_data.asc',sep=' ',header=None,skipinitialspace=True).values
    train_y = pd.read_csv(kin40k_dir+'kin40k_train_labels.asc',sep=' ',header=None,skipinitialspace=True).values
    test_x = pd.read_csv(kin40k_dir+'kin40k_test_data.asc',sep=' ',header=None,skipinitialspace=True).values
    test_y = pd.read_csv(kin40k_dir+'kin40k_test_labels.asc',sep=' ',header=None,skipinitialspace=True).values
    train_y = np.squeeze(train_y)
    test_y = np.squeeze(test_y)

    # Normalize KIN40K Data

    x_var = np.var(train_x,axis=0)
    x_mean = np.mean(train_x,axis=0)
    y_var = np.var(train_y)
    y_mean = np.mean(train_y)

    train_x = (train_x-x_mean)/x_var
    test_x = (test_x-x_mean)/x_var
    train_y = (train_y-y_mean)/y_var
    test_y = (test_y-y_mean)/y_var

    # Create a validation set (20% of training data)

    val_indices = np.random.permutation(np.arange(len(train_x)))<int(val_prc*len(train_x))
    val_x = train_x[val_indices,:]
    val_y = train_y[val_indices]
    train_x = train_x[~val_indices,:]
    train_y = train_y[~val_indices]

    return train_x, train_y, val_x, val_y, test_x, test_y

def load_otherdata():
    return None

In [3]:
def kmeans_mixture_model(d,n_clusters=100,random_state=0):
    """
    Cluster x and return cluster centers and cluster widths (variances)
    """
    from sklearn.cluster import MiniBatchKMeans
    km = MiniBatchKMeans(n_clusters=n_clusters,random_state=random_state).fit(d)

    lbls = km.labels_
    cc = km.cluster_centers_

    d_centered = np.array([x - cc[y] for x,y in zip(d,lbls)])
    c_widths = np.array([np.sum(d_centered[lbls==c,:]**2,axis=0)/np.sum(lbls==c) 
        for c in range(n_clusters)])

    weights = np.array([np.sum(lbls==c) for c in range(n_clusters)])
    weights = weights/np.sum(weights)

    mixture_model = {
        'n_components':n_clusters,
        'n_dims':np.shape(d)[1],
        'weights':weights,
        'means':cc,
        'sds':np.sqrt(c_widths)
        }

    return mixture_model

In [4]:
def logdet(matrix):
    chol = tf.cholesky(matrix)
    return 2.0 * tf.reduce_sum(tf.log(tf.real(tf.matrix_diag_part(chol))),reduction_indices=[-1])

def pairwise_distance(x1,x2):
    r1 = tf.reduce_sum(x1**2,axis=2)[:,:,tf.newaxis]
    r2 = tf.reduce_sum(x2**2,axis=2)[:,tf.newaxis,:]
    r12 = tf.matmul(x1,tf.matrix_transpose(x2))
    return r1+r2-2*r12

def nlog_vfe(x,y,m,sls,sfs,noise):

    P = tf.shape(m)[0] # number of pseudo-input samples
    y = tf.tile(tf.expand_dims(tf.cast(y,dtype=tf.float32),0),(P,1))
    x = tf.tile(tf.expand_dims(tf.cast(x,dtype=tf.float32),0),(P,1,1))

    D = tf.shape(x)[2]
    X = tf.shape(x)[1]

    logtp = tf.constant(np.log(2.*np.pi),dtype=tf.float32)

    #m = m + 1e-6*tf.random_normal(tf.shape(m),dtype=tf.float32)

    M = tf.shape(m)[1]

    jitter = 1e-6*tf.eye(M,dtype=tf.float32)[tf.newaxis,:,:]

    xm = pairwise_distance(x,m)
    mm = pairwise_distance(m,m)

    kxm = sfs[:,tf.newaxis,tf.newaxis]*tf.exp(-.5*xm/sls[:,tf.newaxis,tf.newaxis])
    kmm = sfs[:,tf.newaxis,tf.newaxis]*tf.exp(-.5*mm/sls[:,tf.newaxis,tf.newaxis])
    
    kmx = tf.matrix_transpose(kxm)
    kmmi = tf.matrix_inverse(kmm+jitter)

    qmm_diag = tf.reduce_sum(tf.matmul(kxm,kmmi)*kxm,axis=2)
    
    gd = noise[:,tf.newaxis]
    gid = 1/gd

    tr = sfs*tf.cast(X,tf.float32)-tf.reduce_sum(qmm_diag,axis=1)
    
    kmx_gi_kxm = tf.matmul(kmx,kxm)/noise[:,tf.newaxis,tf.newaxis]

    giy = gid*y
    kgiy = tf.reduce_sum(kmx*giy[:,tf.newaxis,:],axis=2)

    inner = kmmi+kmx_gi_kxm+jitter

    covd = logdet(inner)+logdet(kmm+jitter)+tf.log(noise)*tf.cast(X,dtype=tf.float32)
    
    t1 = .5*tf.cast(X,tf.float32)*logtp
    t2 = .5*tf.reduce_sum(y*y*gid,axis=1) - .5*tf.reduce_sum(
        tf.reduce_sum(kgiy[:,:,tf.newaxis]*tf.matrix_inverse(inner),axis=1)*kgiy,axis=1)
    t3 = .5*covd
    t4 = .5*tf.div(tr,noise)
    
    return t1+t2+t3+t4

In [5]:
# Code to make and evaluate predictions

#def get_mse(pred,real):
#    se = (pred-real)**2
#    ms = np.mean(se)
#    return ms, ms/np.var(real)

def get_mse(pred,real): # check axes
    se = (pred-real)**2
    ms = tf.reduce_mean(se)
    _, var = tf.nn.moments(real, axes=[0])
    return ms, ms/var

def vfe_pred(x,y,t,m,sls,sfs,noise):
    
    P = tf.shape(m)[0] # number of pseudo-input samples
    y = tf.tile(tf.expand_dims(tf.cast(y,dtype=tf.float32),0),(P,1))
    x = tf.tile(tf.expand_dims(tf.cast(x,dtype=tf.float32),0),(P,1,1))
    t = tf.tile(tf.expand_dims(tf.cast(t,dtype=tf.float32),0),(P,1,1))

    D = tf.shape(x)[2]
    X = tf.shape(x)[1]

    logtp = tf.constant(np.log(2.*np.pi),dtype=tf.float32)

    M = tf.shape(m)[1]

    jitter = 1e-6*tf.eye(M,dtype=tf.float32)[tf.newaxis,:,:]

    xm = pairwise_distance(x,m)
    mm = pairwise_distance(m,m)

    kxm = sfs[:,tf.newaxis,tf.newaxis]*tf.exp(-.5*xm/sls[:,tf.newaxis,tf.newaxis])
    kmm = sfs[:,tf.newaxis,tf.newaxis]*tf.exp(-.5*mm/sls[:,tf.newaxis,tf.newaxis])
    kmx = tf.matrix_transpose(kxm)

    tm = pairwise_distance(t,m)

    ktm = sfs[:,tf.newaxis,tf.newaxis]*tf.exp(-.5*tm/sls[:,tf.newaxis,tf.newaxis])

    giy = y/noise[:,tf.newaxis]
    kgiy = tf.reduce_sum(kmx*giy[:,tf.newaxis,:],axis=2)
    kmx_gi_kxm = tf.matmul(kmx,kxm)/noise[:,tf.newaxis,tf.newaxis]
    qm = kmm+kmx_gi_kxm+jitter

    ph = tf.reduce_sum(tf.matrix_inverse(qm)*kgiy[:,tf.newaxis,:],axis=2)
    mean = tf.reduce_sum(ktm*ph[:,tf.newaxis,:],axis=2)

    return mean

In [16]:
### MAIN ###

# Load data

if FLAGS.dataset == 'kin40k':
    train_x, train_y, val_x, val_y, test_x, test_y = load_kin40k(val_prc=FLAGS.val_prc)

# PSEUDOCODE HERE FOR THE MOMENT

initial_model = kmeans_mixture_model(train_x,n_clusters=FLAGS.n_clusters)

sls = tf.expand_dims(tf.Variable(1.,name='sls',dtype=tf.float32),axis=0)
sfs = tf.expand_dims(tf.Variable(1.,name='sfs',dtype=tf.float32),axis=0)
noise = tf.expand_dims(tf.Variable(1e-3,name='noise',dtype=tf.float32),axis=0)

z = tf.expand_dims(tf.Variable(initial_model['means'],name='z',dtype=tf.float32),axis=0)

# Get SGP stuff
sgp_nlogprob = nlog_vfe(train_x,train_y,z,sls,sfs,noise)
   
train_step = tf.train.AdamOptimizer(1e-1).minimize(sgp_nlogprob)

pred = vfe_pred(train_x,train_y,test_x,z,sls,sfs,noise)
mse, nmse = get_mse(tf.squeeze(pred),tf.constant(test_y,dtype=tf.float32))

# Get MSE / normalized MSE

outdir = FLAGS.img_path
if not os.path.exists(outdir):
    os.makedirs(outdir)

In [17]:
def run_training(sess, niter=FLAGS.niter, printfreq=100):

    progress = tqdm_notebook(range(niter))
    error = []

    for i in progress:

        sgp_nlogprob_,sls_,sfs_,noise_,_,nmse_ = sess.run([sgp_nlogprob,sls,sfs,noise,train_step,nmse])

        #print(np.shape(y_),np.shape(refsample_))

        progress.set_description("nlp=%.0f,l=%.3f,s=%.3f,n=%.3f"  % (sgp_nlogprob_,sls_,sfs_,noise_))
        
        error.append(nmse_)
        
        if i%printfreq == (printfreq-1):
            print('normalized MSE is %.3f' % nmse_)
        
    # if it's time, check the error

In [18]:
try:
    s.close()
except NameError:
    pass
s = tf.InteractiveSession()
s.run(tf.global_variables_initializer())
run_training(s,niter=FLAGS.niter)

HBox(children=(IntProgress(value=0, max=10000), HTML(value='')))

normalized MSE is 0.323
normalized MSE is 0.301
normalized MSE is 0.288
normalized MSE is 0.282
normalized MSE is 0.274
normalized MSE is 0.264
normalized MSE is 0.257
normalized MSE is 0.253
normalized MSE is 0.248
normalized MSE is 0.242
normalized MSE is 0.239


KeyboardInterrupt: 