In [None]:
import os
# choose a particular GPU
#os.environ["CUDA_VISIBLE_DEVICES"] = "2"

import tensorflow as tf

from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Model
from tensorflow.keras.utils import plot_model

import tensorflow_probability as tfp

import numpy as np

%matplotlib inline
import matplotlib.pyplot as plt

#physical_devices = tf.config.list_physical_devices('GPU') 
#tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [None]:
class IMLESubsetkLayer(tf.keras.layers.Layer):
  
    def __init__(self, _k=10, _tau=30.0, _lambda=10.0):
        super(IMLESubsetkLayer, self).__init__()
        
        self.k = _k
        self._tau = _tau
        self._lambda = _lambda
        self.samples = None
        self.gumbel_dist = tfp.distributions.Gumbel(loc=0.0, scale=1.0)
        
    @tf.function
    def sample_gumbel(self, shape, eps=1e-20):
        return self.gumbel_dist.sample(shape)
    
    @tf.function
    def sample_gumbel_k(self, shape):
        
        s = tf.map_fn(fn=lambda t: tf.random.gamma(shape, 1.0/self.k,  self.k/t), 
                  elems=tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]))
        # now add the samples
        s = tf.reduce_sum(s, 0)
        # the log(m) term
        s = s - tf.math.log(10.0)
        # divide by k --> each s[c] has k samples whose sum is distributed as Gumbel(0, 1)
        s = self._tau * (s / self.k)

        return s
       
    @tf.custom_gradient
    def imle_layer(self, logits, hard=False):

        # toggle between Gumbel(0, 1) and sum-of-Gamma
        #forward_sample = self.sample_gumbel(tf.shape(logits))
        forward_sample = self.sample_gumbel_k(tf.shape(logits))
        gumbel_softmax_sample = logits + forward_sample
        threshold = tf.expand_dims(tf.nn.top_k(gumbel_softmax_sample, self.k, sorted=True)[0][:,-1], -1)
        y_map = tf.cast(tf.greater_equal(gumbel_softmax_sample, threshold), tf.float32)

        def custom_grad(dy):
            
            # perturb and MAP on the target distribution parameters
            gumbel_softmax_sample = logits - (self._lambda*dy) + forward_sample
            threshold = tf.expand_dims(tf.nn.top_k(gumbel_softmax_sample, self.k, sorted=True)[0][:,-1], -1)
            map_y = tf.cast(tf.greater_equal(gumbel_softmax_sample, threshold), tf.float32)    
            # now compute the gradient of the conditional log-likelihood
            grad = (1.0 / self._lambda) * tf.math.subtract(y_map, map_y)

            # return the gradient function
            return grad, hard

        return y_map, custom_grad

    def call(self, logits, hard=False):
        return self.imle_layer(logits, hard)

In [None]:
PARAMS = {
    "batch_size": 100,
    "data_dim": 784,
    "M": 20,
    "N": 20,
    "nb_epoch": 100, 
    "epsilon_std": 0.01,
    "anneal_rate": 0.0003,
    "init_temperature": 1.0,
    "min_temperature": 0.5,
    "learning_rate": 1e-3,
    "hard": False,
}

class DiscreteVAE(tf.keras.Model):
    
    def __init__(self, params):
        super(DiscreteVAE, self).__init__()
        
        self.params = params
                
        # encoder
        self.enc_dense1 = tf.keras.layers.Dense(512, activation='relu')
        self.enc_dense2 = tf.keras.layers.Dense(256, activation='relu')
        self.enc_dense3 = tf.keras.layers.Dense(params["N"]*params["M"])
        
        # this is our new Gumbel layer
        self.imleLayer = IMLESubsetkLayer(_k=10, _tau=10.0, _lambda=10.0)

        # decoder
        self.flatten = Flatten()
        self.dec_dense1 = tf.keras.layers.Dense(256, activation='relu')
        self.dec_dense2 = tf.keras.layers.Dense(512, activation='relu')
        self.dec_dense3 = tf.keras.layers.Dense(params["data_dim"])


    def sample_gumbel(self, shape, eps=1e-20): 
        """Sample from Gumbel(0, 1)""" 
        # test
        #U = gumbel_dist.sample(shape)
        U = tf.random.uniform(shape, minval=0, maxval=1)
        return -tf.math.log(-tf.math.log(U + eps) + eps)
    
    def gumbel_softmax_sample(self, logits, temperature): 
        """ Draw a sample from the Gumbel-Softmax distribution"""
        # logits: [batch_size, n_class] unnormalized log-probs
        y = logits + self.sample_gumbel(tf.shape(logits), temperature)
        return tf.nn.softmax(y / temperature)  

    def gumbel_softmax(self, logits, temperature, hard=True):
        """
        logits: [batch_size, n_class] unnormalized log-probs
        temperature: non-negative scalar
        hard: if True, take argmax, but differentiate w.r.t. soft sample y
        """
        y = self.gumbel_softmax_sample(logits, temperature)
        if hard: 
            # 
            y_hard = tf.cast(tf.equal(y, tf.reduce_max(y, 1, keepdims=True)),y.dtype)
            y = tf.stop_gradient(y_hard - y) + y
        return y
    
    def decoder(self, x):
        # decoder
        h = self.flatten(x)
        h = self.dec_dense1(h)
        h = self.dec_dense2(h)
        h = self.dec_dense3(h)
        return h

    def call(self, x, tau, hard=False):
        N = self.params["N"]
        M = self.params["M"]

        # encoder
        x = self.enc_dense1(x)
        x = self.enc_dense2(x)
        x = self.enc_dense3(x)   # (batch, N*M)
        logits_y = tf.reshape(x, [-1, M])   # (batch*N, M)

        ###################################################################
        ## here we toggle between methods #################################
        # here we can switch between traditional and our method
        # "traditional" Gumbel Softmax trick
        #y = self.gumbel_softmax(logits=logits_y, temperature=tau, hard=False)
        # IMLE approach -- note: we don't anneal so set temperature once at init
        y = self.imleLayer(logits=logits_y, hard=True)
        ###################################################################
        
        assert y.shape == (self.params["batch_size"]*N, M)
        y = tf.reshape(y, [-1, N, M])
        self.sample_y = y

        # decoder
        logits_x = self.decoder(y)
        return logits_y, logits_x


def gumbel_loss(model, x, tau, hard=True):
    M = 20
    N = 20
    data_dim = PARAMS['data_dim']
    logits_y, logits_x = model(x, tau, hard)
    
    # cross-entropy
    cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=logits_x)
    cross_ent = tf.math.reduce_sum(cross_ent, 1)
    cross_ent = tf.math.reduce_mean(cross_ent, 0)
    
    # KL loss
    q_y = tf.nn.softmax(logits_y)   # (batshsize*N, M)  softmax
    log_q_y = tf.math.log(q_y + 1e-20)   # (batshsize*N, M)  
    kl_tmp = tf.reshape(q_y*(log_q_y-tf.math.log(1.0/M)), [-1,N,M])  # (batch_size,N,K)
    KL = tf.math.reduce_sum(kl_tmp, [1, 2])    # shape=(batch_size, 1)
        
    KL_mean = tf.math.reduce_mean(KL)
    #print("**", cross_ent.numpy(), KL_mean.numpy())
    return cross_ent + KL_mean


def compute_gradients(model, x, tau, hard):
    with tf.GradientTape() as tape:
        loss = gumbel_loss(model, x, tau, hard)
    return tape.gradient(loss, model.trainable_variables), loss


def apply_gradients(optimizer, gradients, variables):
    optimizer.apply_gradients(zip(gradients, variables))


def get_learning_rate(step, init=PARAMS["learning_rate"]):
    return tf.convert_to_tensor(init * pow(0.95, (step / 1000.)), dtype=tf.float32)

In [None]:
%%time

tf.random.set_seed(1234)

model = DiscreteVAE(PARAMS)
plot_model(model, to_file='model_plot.pdf', show_shapes=True, show_layer_names=True)
learning_rate = tf.Variable(PARAMS["learning_rate"], trainable=False, name="LR")

optimizer = tf.keras.optimizers.Adam()

# data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))

TRAIN_BUF = 60000
BATCH_SIZE = 100
TEST_BUF = 10000

train_dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(TRAIN_BUF).batch(BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices(x_test).shuffle(TEST_BUF).batch(BATCH_SIZE)

# temperature
tau = PARAMS["init_temperature"]
anneal_rate = PARAMS["anneal_rate"]
min_temperature = PARAMS["min_temperature"]

results = []

# Train
for epoch in range(1, PARAMS["nb_epoch"] + 1):
    
    # this is only needed for the standard Gumbel softmax trick
    tau = np.maximum(tau * np.exp(-anneal_rate*epoch), min_temperature)

    for train_x in train_dataset:
        gradients, loss = compute_gradients(model, train_x, tau, hard=PARAMS["hard"])
        apply_gradients(optimizer, gradients, model.trainable_variables)

    print("Epoch:", epoch, ", TRAIN loss:", loss.numpy(), ", Temperature:", tau)

    if epoch % 1 == 0:
        losses = []
        for test_x in test_dataset:
            losses.append(gumbel_loss(model, test_x, tau, hard=True))
        eval_loss = np.mean(losses)
        results.append(eval_loss)
        print("Eval Loss:", eval_loss, "\n")

    if PARAMS['hard'] == True:
        model.save_weights("model.h5")
    else:
        model.save_weights("model_hard.h5")

In [None]:
results

In [None]:
(x_train, _), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.astype('float32') / 255.
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))

# Model Forward
logits_y, logits_x1 = model(x_test[10:110], tau=1.0, hard=True)
logits_y, logits_x2 = model(x_test[10:110], tau=1.0, hard=True)
logits_y, logits_x3 = model(x_test[10:110], tau=1.0, hard=True)


sample_y = model.sample_y.numpy()

logits_x1 = tf.sigmoid(logits_x1).numpy()
logits_x2 = tf.sigmoid(logits_x2).numpy()
logits_x3 = tf.sigmoid(logits_x3).numpy()

print("sample_y: ", sample_y.shape, ", logits_x.shape:", logits_x1.shape)
code = model.sample_y


def save_plt(original_img, construct_img1, construct_img2, construct_img3, code):
    plt.figure(figsize=(10, 10))
    for i in range(0, 25, 5):
        # input img
        plt.subplot(5, 5, i+1)
        plt.imshow(original_img[i, :].reshape(28, 28), cmap='gray')
        plt.axis('off')

        # code
        plt.subplot(5, 5, i+2)
        plt.imshow(code[i, :, :], cmap='gray')
        plt.axis('off')

        # output im
        plt.subplot(5, 5, i+3)
        plt.imshow(construct_img1[i, :,].reshape((28, 28)), cmap='gray')
        plt.axis('off')
        
        # output im
        plt.subplot(5, 5, i+4)
        plt.imshow(construct_img2[i, :,].reshape((28, 28)), cmap='gray')
        plt.axis('off')
        
        # output im
        plt.subplot(5, 5, i+5)
        plt.imshow(construct_img3[i, :,].reshape((28, 28)), cmap='gray')
        plt.axis('off')

    #plt.savefig('vae-pic/vae_rebuilt.png')

save_plt(x_test[10:110], logits_x1,  logits_x2,  logits_x3, code)

In [None]:
def make_squares(images, nr_images_per_side):
    images_to_plot = np.concatenate(
        [np.concatenate([images[j*nr_images_per_side+i].reshape((28,28)) for i in range(0,nr_images_per_side)],
                        axis=1)
         for j in range(0,nr_images_per_side)],
        axis=0)
    return images_to_plot

def plot_squares(originals, reconstructs, nr_images_per_side):
    originals_square = make_squares(originals, nr_images_per_side)
    plt.imsave('original.pdf', originals_square, cmap='viridis', format='pdf')
    reconstructs_square = make_squares(reconstructs, nr_images_per_side)
    plt.imsave('recons.pdf', reconstructs_square, cmap='viridis', format='pdf')
    combined = np.concatenate([originals_square, reconstructs_square], axis=1)
    plt.imsave('combined.pdf', combined, cmap='viridis', format='pdf')

In [None]:
plot_squares(x_test[10:110], logits_x1, 8)