In [1]:
import os
import numpy as np
import tensorflow as tf
from keras.optimizers import schedules
from dataset import get_data, normalize
from matplotlib import pyplot as plt
from tensorflow.python.keras.layers import Input, Dense, Lambda, Conv2D, Conv2DTranspose, Flatten, Reshape
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.losses import binary_crossentropy
from tensorflow.python.keras.optimizers import adam_v2
from tensorflow.python.keras import backend as K
from skimage.transform import resize
import random

In [2]:
def encoder(input_shape, latent_dim):
    inputs = Input(shape=input_shape)
    x = Conv2D(filters=32, kernel_size=3, strides=2, activation='relu', padding='same')(inputs)
    x = Conv2D(filters=64, kernel_size=3, strides=2, activation='relu', padding='same')(x)
    x = Flatten()(x)
    x = Dense(128, activation='relu')(x)

    z_mean = Dense(latent_dim)(x)
    z_log_var = Dense(latent_dim)(x)

    return Model(inputs, [z_mean, z_log_var], name='encoder')

In [3]:
def decoder(latent_dim):
    inputs = Input(shape=latent_dim)
    x = Dense(8 * 8 * 64, activation='relu')(inputs)
    x = Reshape((8, 8, 64))(x)
    x = Conv2DTranspose(filters=64, kernel_size=3, strides=2, activation='relu', padding='same')(x)
    x = Conv2DTranspose(filters=32, kernel_size=3, strides=2, activation='relu', padding='same')(x)
    outputs = Conv2DTranspose(filters=3, kernel_size=3, strides=1, activation='sigmoid', padding='same')(x)

    return Model(inputs, outputs, name='decoder')

In [4]:
def sampling(args):
    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    epsl = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsl

In [5]:
def vae_loss(x, x_decode_mean, z_mean, z_log_var):
    rec_loss = binary_crossentropy(K.flatten(x), K.flatten(x_decode_mean)) * 32 * 32 * 3
    kl_loss = -0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
    return K.mean(rec_loss + kl_loss)

In [6]:
if __name__ == '__main__':
    ######################## Get train dataset ########################
    X_train = get_data('dataset')

In [7]:
    input_shape = (32, 32, 3)
    latent_dim = 64

In [8]:
    encoder = encoder(input_shape, latent_dim)
    decoder = decoder(latent_dim)

In [9]:
    inputs = Input(shape=input_shape)
    z_mean, z_log_var = encoder(inputs)
    z = Lambda(sampling)([z_mean, z_log_var])
    outputs = decoder(z)

    vae = Model(inputs, outputs, name='vae')

    loss = vae_loss(inputs, outputs, z_mean, z_log_var)
    vae.add_loss(loss)
    vae_optimizer = adam_v2.Adam(learning_rate=0.001)

In [10]:
    X_train = normalize(X_train)
    X_train = np.transpose(X_train, (0, 2, 3, 1)) #使shape匹配
    # print(X_train.shape)

In [11]:
    epochs = 100
    batch_size = 32
    decay_steps = 10000
    decay_rate = 0.9

In [12]:
    for epoch in range(epochs):
        for i in range(0, len(X_train), batch_size):
            x_batch = X_train[i:i + batch_size]
            with tf.GradientTape() as tape:
                z_mean, z_log_var = encoder(x_batch)
                z = sampling([z_mean, z_log_var])
                x_decoded = decoder(z)
                loss_value = vae_loss(x_batch, x_decoded, z_mean, z_log_var)

            gradients = tape.gradient(loss_value, vae.trainable_variables)
            vae_optimizer.apply_gradients(zip(gradients, vae.trainable_variables))

            current_step = epoch * len(X_train) // batch_size + i // batch_size
            current_learning_rate = 0.001 * decay_rate ** (current_step / decay_steps)
            vae_optimizer.learning_rate.assign(current_learning_rate)

        print('Epoch:', epoch + 1, 'Loss:', loss_value.numpy())

Epoch: 1 Loss: 1926.367


KeyboardInterrupt: 

In [None]:
    idx = 999

In [None]:
    path = "./reconstructed/"
    if not os.path.exists(path):
        os.mkdir(path)
    x = np.expand_dims(X_train[idx], axis=0)
    x_reconstructed = vae.predict(x)
    x_reconstructed = np.clip(x_reconstructed, 0.0, 1.0)

    plt.subplot(1, 2, 1)
    plt.imshow(x[0])
    plt.axis('off')
    plt.title('Original Image')

    plt.subplot(1, 2, 2)
    plt.imshow(x_reconstructed[0])
    plt.axis('off')
    plt.title('Reconstructed Image')
    plt.savefig(path + "rec.png")

    plt.show()

In [None]:

    def interpolate_images(image1, image2, num_interpolations):
        z1, _ = encoder.predict(np.expand_dims(image1, axis=0))
        z2, _ = encoder.predict(np.expand_dims(image2, axis=0))

        interpolated_images = []
        for i in range(num_interpolations + 1):
            alpha = i / num_interpolations
            z_interpolated = alpha * z1 + (1 - alpha) * z2
            img_interpolated = decoder.predict(z_interpolated)
            interpolated_images.append(img_interpolated)
            # if i == 0 or i == int(num_interpolations / 2) or i == num_interpolations:
            #     interpolated_images.append(img_interpolated)

        return interpolated_images

In [None]:
    images = []
    path = "./random_6x6/"
    if not os.path.exists(path):
        os.mkdir(path)
    for cnt in range(10):
        current_images = []

        for _ in range(6):
            i1 = random.randint(0, 999)
            i2 = random.randint(0, 999)
            image1 = X_train[i1]
            image2 = X_train[i2]

            interpolated_images = interpolate_images(image1, image2, 5)
            current_images.extend(interpolated_images)

        images.append(current_images)

    for cnt, image_set in enumerate(images):
        fig, axes = plt.subplots(6, 6, figsize=(20, 20))

        for i, img in enumerate(image_set):
            ax = axes[i // 6, i % 6]
            ax.imshow(np.squeeze(img))
            ax.axis('off')

        plt.savefig("./random_6x6/" + str(cnt) + ".png")
        plt.show()