<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Loading-Dataset" data-toc-modified-id="Loading-Dataset-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Loading Dataset</a></span></li><li><span><a href="#Reading-TFRecord-Data" data-toc-modified-id="Reading-TFRecord-Data-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Reading TFRecord Data</a></span></li><li><span><a href="#Model" data-toc-modified-id="Model-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Model</a></span></li><li><span><a href="#Model" data-toc-modified-id="Model-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Model</a></span></li></ul></div>

In [None]:
import os
os.chdir("../")

In [None]:
import glob
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import IPython.display as display

from vae.config import *
from vae.data_processing import read_tfrecord

In [None]:
print(f"Tensorflow Version: {tf.__version__}")
print(f"Pandas Version: {pd.__version__}")
print(f"Numpy Version: {np.__version__}")

In [None]:
file_path = MONET_TFREC_PATH + "/*.tfrec"

In [None]:
monet_file_path = glob.glob(MONET_TFREC_PATH + "/*.tfrec")
photo_file_path = glob.glob(PHOTO_TFREC_PATH + "/*.tfrec")
print(len(monet_file_path), len(photo_file_path))

# Loading Dataset

In [None]:
photo_dataset = tf.data.TFRecordDataset(photo_file_path)
monet_dataset = tf.data.TFRecordDataset(monet_file_path)

# Reading TFRecord Data

In [None]:
parsed_monet_dataset = monet_dataset.map(read_tfrecord)
parsed_photo_dataset = photo_dataset.map(read_tfrecord)

In [None]:
def scale_image(data):
    return data/255

scaled_monet_dataset = parsed_monet_dataset.map(scale_image)
scaled_photo_dataset = parsed_photo_dataset.map(scale_image)

# Model

In [None]:
from tensorflow import keras
from tensorflow.keras import layers, Input
from vae.model import  encoder, \
                        decoder, \
                        kl_loss, \
                        mse_loss, \
                        vae_loss

In [None]:
IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256
IMAGE_DEPTH = 3

**Building Model**

In [None]:
vae_inputs= Input(shape=(IMAGE_HEIGHT, IMAGE_WIDTH,IMAGE_DEPTH))
latent_vec, mean, log_var  = encoder(vae_inputs)
recons_image = decoder(latent_vec)

vae_model = keras.Model(inputs=vae_inputs, outputs=[recons_image, mean, log_var])

**Training Model**

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0005)

epoch = 5

for i in range(epoch):
    for train_data in scaled_monet_dataset.batch(32):

        # Forward Pass
        with tf.GradientTape() as tape:

            # Froward Pass
            recons_img, mean, log_var = vae_model(train_data)
            # Loss
            model_loss = vae_loss(train_data, recons_img, mean, log_var)

        gradient = tape.gradient(model_loss, vae_model.trainable_weights)
        optimizer.apply_gradients(zip(gradient, vae_model.trainable_weights))
    print(f"Epoch: {i} --- Loss Value: {tf.reduce_sum(model_loss)}")

In [None]:
# Saving Trained Model
vae_model.save("saved_model")

In [None]:
# Laoding Saved Models
vae_model = keras.models.load_model("saved_model/")

In [None]:
img_batch_list = []
for batch in scaled_monet_dataset.take(100).shuffle(101).batch(25):
    img_batch_list.append(batch)

In [None]:
def plot_img(img_batch):

    figsize = 15
    fig = plt.figure(figsize=(figsize, 10))

    for i in range(25):
        ax = fig.add_subplot(5, 5, i+1)
        ax.axis('off')
        img = img_batch[i]
        ax.imshow(img)

In [None]:
plot_img(img_batch_list[0])

In [None]:
plot_img(vae_model.predict(img_batch_list[0])[0])

# Model

In [None]:
from vae.model import Encoder, Decoder

In [None]:
encoder_model = Encoder()
decoder_model = Decoder()

In [None]:
def train_step(dataset):
    