In [None]:
import os

import numpy as np
import matplotlib.pyplot as plt
from skimage.transform import resize
from tensorflow.keras.callbacks import (EarlyStopping, ModelCheckpoint, 
                                       TensorBoard, Callback)
import datetime
from time import time
import json

import tensorflow as tf

physical_devices = tf.config.experimental.list_physical_devices('GPU')
physical_devices

tf.config.experimental.set_memory_growth(physical_devices[0], True)
"""
The above line comes from here: 
https://github.com/tensorflow/tensorflow/blob/6e559b96c8146ce15c7c03f66e515e31a6b0aa00/tensorflow/python/framework/config.py#L443
"""

import tensorflow.keras.backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (Conv2D, Input, Reshape, 
                                     Lambda, Dense, Conv2DTranspose)

IMG_HEIGHT = 64
IMG_WIDTH  = 64

LATENT_SIZE = 32
BATCH_SIZE  = 128
KL_TOLERANCE = 0.5
LEARNING_RATE = 1e-4

In [None]:
def sampling(args):
    """Reparameterization trick by sampling from an isotropic unit Gaussian.

    # Arguments
        args (tensor): mean and log of variance of Q(z|X)

    # Returns
        z (tensor): sampled latent vector
    """
    
    # import pdb; pdb.set_trace()
    z_mean, z_log_var = args
    #_batch = z_mean.shape[0]
    #_dim = z_mean.shape[1]
    batch = BATCH_SIZE# if _batch is None else _batch
    dim = LATENT_SIZE #if _dim is None else _dim
    # by default, random_normal has mean = 0 and std = 1.0
    epsilon = K.random_normal(shape=(batch, dim))
    foo = z_mean + K.exp(0.5 * z_log_var)# * epsilon
    # print(foo.shape)
    bar = foo * epsilon
    # print(bar.shape)
    return bar


## ENCODER

inputs = Input(shape=(64, 64, 3), name='encoder_input')
h = Conv2D(32, 4, strides=2, activation="relu", name="enc_conv1")(inputs)
h = Conv2D(64, 4, strides=2, activation="relu", name="enc_conv2")(h)
h = Conv2D(128, 4, strides=2, activation="relu", name="enc_conv3")(h)
h = Conv2D(256, 4, strides=2, activation="relu", name="enc_conv4")(h)
h = Reshape([2*2*256])(h)
z_mean = Dense(LATENT_SIZE, name='z_mean')(h)
z_log_var = Dense(LATENT_SIZE, name='z_log_var')(h)
z = Lambda(sampling, output_shape=(LATENT_SIZE,), name='z')([z_mean, z_log_var])
encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')


encoder.summary()

## DECODER

latent_inputs = Input(shape=(LATENT_SIZE,), name='decoder_input')
h = Dense(4*256, name="dec_fc")(latent_inputs)
h = Reshape([1, 1, 4*256])(h)
h = Conv2DTranspose(128, 5, strides=2, activation="relu", name="dec_deconv1")(h)
h = Conv2DTranspose(64, 5, strides=2, activation="relu", name="dec_deconv2")(h)
h = Conv2DTranspose(32, 6, strides=2, activation="relu", name="dec_deconv3")(h)
outputs = Conv2DTranspose(3, 6, strides=2, activation='sigmoid', name="dec_deconv4")(h)

decoder = Model(latent_inputs, outputs, name='decoder')
decoder.summary()

In [None]:
def sample(decoder, n=128):
    z = tf.random.normal(shape=(n, LATENT_SIZE))
    return decoder.predict(z)

In [None]:
outputs = decoder(encoder(inputs)[2])
vae = Model(inputs, outputs, name='SpatialWorldModel')

## Loss stuff

# loss
eps = 1e-6 # avoid taking log of zero

# reconstruction loss
r_loss = tf.reduce_sum(
  tf.square(inputs - outputs),
  axis = [1,2,3]
)
r_loss = tf.reduce_mean(r_loss)

# augmented kl loss per dim (axis may need to change)
kl_loss = - 0.5 * tf.reduce_sum(
  (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)),
  axis = 1
)
# todo: look this up. why did Ha do it this way?
kl_loss = tf.maximum(kl_loss, KL_TOLERANCE * LATENT_SIZE)
kl_loss = tf.reduce_mean(kl_loss)

loss = r_loss + kl_loss

vae.add_loss(loss)
vae.compile(optimizer='adam')

In [None]:
## Small is data fit: can we do well with 5 gigs train, 1 gig val, 1 gig test?

# This way we safely avoid data loaders

data_dir = "./data/kaggle/"

#def get_data_dir_size(data_dir):
"""Assuming fixed element size"""
alldata = os.listdir(data_dir)
elem = np.load(os.path.join(data_dir, alldata[0]))
elem_size = elem.__sizeof__()
n_elems = len(alldata)
bytes_to_gigs = 1e-9

total_data_gigs = elem_size * n_elems * bytes_to_gigs
    #return total_data_gigs

total_data_gigs

# !du -h /kaggle/input/screenshots/kaggle 

#n_elems_for_m_gigs = lambda m_gigs: m_gigs // (elem_size * bytes_to_gigs)
#n_elems_for_7_gigs = n_elems_for_m_gigs(7)

alldata = os.listdir(data_dir)
np.random.shuffle(alldata)
test = 1/8

train_ids, test_ids = alldata[int(len(alldata)*test):], alldata[:int(len(alldata)*test)]

def read_data(data_dir, file_IDs):
    t = np.zeros((len(file_IDs), 64, 64, 3), dtype=np.float32)
    for i, file in enumerate(file_IDs):
        t[i] = np.load(os.path.join(data_dir, file))
    return t

In [None]:
!free -m 

In [None]:
train_images = read_data(data_dir, train_ids)

In [None]:
test_images = read_data(data_dir, test_ids)

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
train, val = train_test_split(train_images, test_size=256*20)

In [None]:
## Real boii fitting

class TrainTimeCallback(Callback): 
    def __init__(self): 
        super(TrainTimeCallback, self).__init__()
        
    def on_train_begin(self, logs=None): 
        self._start_time = time()
        
    def on_train_end(self, logs=None):
        self._train_time = time() - self._start_time
        
    @property
    def train_time(self):
        s = self._train_time
        hours, remainder = divmod(s, 3600)
        minutes, seconds = divmod(remainder, 60)
        
        return f'{hours:02.0f}:{minutes:02.0f}:{seconds:02.0f}'
    
    def print_train_time(self):
        print(f'Train time for model: {self.train_time}')

log_dir="logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

monitor = 'loss' # change this
callbacks = [
    TrainTimeCallback(),
    ModelCheckpoint('./best_VAE.h5', save_best_only=True, 
                    monitor=monitor),
    EarlyStopping(monitor, patience=10,
                  mode='min',
                  restore_best_weights=True),
    TensorBoard(log_dir=log_dir, histogram_freq=1)
]

In [None]:
!free -m 

In [None]:
history = vae.fit(train_images,
                  epochs=100, 
                  batch_size=BATCH_SIZE,
                  callbacks=callbacks, 
                  validation_data=(val, None))
    
callbacks[0].print_train_time()

with open('history.json', 'w+') as f:
    f.write(json.dumps(history.history))

In [None]:
plt.plot(history.history['loss'])

In [None]:
plt.plot(history.history['val_loss'])

In [None]:
history.history.keys()

In [None]:
im = sample(vae.layers[2])

In [None]:
im.shape

In [None]:
plt.imshow(im[8])