In [None]:
import cv2
import os
import numpy as np
import tensorflow as tf
import torch
import numpy as np
import os
from tensorflow.keras.layers import Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, MaxPooling2D, UpSampling2D
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
from sklearn.model_selection import train_test_split

In [None]:

def load_and_resize_images(folder, target_size=(64, 64)):
    images = []
    for filename in os.listdir(folder):
        filepath = os.path.join(folder, filename)
        if os.path.isfile(filepath) and filename.endswith('.JPEG'):
            img = cv2.imread(filepath)
            if img is not None:
                img = cv2.resize(img, target_size) 
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                images.append(img)
    return images

main_folder = 'path to dir with imagenet images for pretraining'

X = []

for foldername in os.listdir(main_folder):
    folderpath = os.path.join(main_folder, foldername)
    if os.path.isdir(folderpath):
        resized_images_in_folder = load_and_resize_images(folderpath)
        X.extend(resized_images_in_folder)
        

X_train, X_val = train_test_split(X, test_size=0.05)

X_train = np.array(X_train, dtype = 'float32')
X_val = np.array(X_val, dtype = 'float32')
X_train = X_train/255
X_val = X_val/255

In [None]:
#define vae
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()
with strategy.scope():
    input_shape = (64, 64, 3) 
    encoder_input = Input(shape=input_shape, name='encoder_input')
    x = Conv2D(64, 3, strides=1, activation = 'relu', padding = 'same')(encoder_input)
    x = MaxPooling2D()(x)
    x = Conv2D(64, 5, strides=2, activation = 'relu', padding = 'same')(x)
    x = Conv2D(128, 3, strides=1, activation = 'relu')(x)
    x = Conv2D(128, 5, strides=1, activation = 'relu')(x)
    x = Conv2D(256, 3, strides=1, activation = 'relu')(x)
    x = Conv2D(512, 3, strides=2, activation = 'relu')(x)
    x = Conv2D(64, 3, strides=2, activation = 'relu', padding = 'same')(x)
    x = Flatten()(x)

    latent_dim = 2*2*64
    z_mean = Dense(latent_dim, name='z_mean')(x)
    z_log_var = Dense(latent_dim, name='z_log_var')(x)

    def sampling(args):
        z_mean, z_log_var = args
        epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0., stddev=1.)
        return z_mean + K.exp(0.5 * z_log_var) * epsilon

    z = tf.keras.layers.Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])

    encoder = Model(encoder_input, [z_mean, z_log_var, z], name='encoder')
    encoder.summary()

    latent_inputs = Input(shape=(latent_dim,), name='latent_inputs')
    x = Reshape((2, 2, 64))(latent_inputs)
    x = Conv2DTranspose(512, 3, strides=2, activation='relu')(x)
    x = Conv2DTranspose(256, 5, strides=1, activation='relu')(x)
    x = Conv2DTranspose(256, 3, strides=1, activation='relu')(x)
    x = Conv2DTranspose(256, 5, strides=1, activation='relu', padding = 'same')(x)
    x = Conv2DTranspose(128, 3, strides=1, activation='relu')(x)
    x = Conv2DTranspose(128, 4, strides=1, activation='relu')(x)
    x = UpSampling2D()(x)
    x = Conv2DTranspose(64, 3, strides=2, padding = 'same', activation='relu')(x)
    decoder_output = Conv2DTranspose(3, 3, strides=1,padding = 'same', activation='sigmoid')(x)

    decoder = Model(latent_inputs, decoder_output, name='decoder')
    decoder.summary()

    vae_outputs = decoder(encoder(encoder_input)[2])
    vae = Model(encoder_input, vae_outputs, name='vae')
    #custom kl divergence + mse loss
    def custom_loss(y_true, y_pred):
        loss1 = K.mean(K.square(y_true - y_pred))
        kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
        kl_loss = K.sum(kl_loss, axis=-1)
        kl_loss *= -0.5
        kl_loss = K.mean(kl_loss)
        return (loss1) + (kl_loss)

    vae.compile(optimizer='adam', loss = custom_loss)
    vae.summary()



In [None]:
#training loop
epochs = 125
batch_size = 512
for epoch in range(epochs):
    z_train = encoder(X_train)  
    mu, logvar, _ = z_train
    z_mean = mu
    z_log_var = logvar
    print("epoch: " + str(epoch+1)+"/"+str(epochs))
    vae.fit(X_train, X_train, validation_data = (X_val, X_val), batch_size=batch_size, epochs=1)

In [None]:
#save weights
with strategy.scope():    
    vae.save_weights('path to save vae weights.h5')
