# View Variational Autoencoder and Latent Diffusion Images

In [1]:
import json
import os
import numpy as np
import tensorflow as tf
from train import load_dataset
import matplotlib.pyplot as plt
import glob
import variational_autoencoder as vae
import latent_diffusion
from datetime import datetime
from pathlib import Path

# Variational Autoencoder 

In [9]:
# Variational Autoencoder
with open("/Users/lucky/GitHub/PokeGenerator/model/training_config.json", "r") as f:
    config = json.load(f)
images = load_dataset(config)

latest_model_file = "/Users/lucky/GitHub/PokeGenerator/model/checkpoints/model_2024-03-18-22-44-30.keras"
variational_autoencoder = tf.keras.models.load_model(latest_model_file)
variational_autoencoder.summary()

encoder = variational_autoencoder.get_layer('encoder')
decoder = variational_autoencoder.get_layer('decoder')

reconstructions = variational_autoencoder.predict(images)

Dataset cache exists, loading from cache








Model: "vae_mlp"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 autoencoder_input (InputLayer)  [(None, 128, 128, 3  0          []                               
                                )]                                                                
                                                                                                  
 encoder (Functional)           [(None, 1024),       46304576    ['autoencoder_input[0][0]']      
                                 (None, 1024)]                                                    
                                                                                                  
 z (Lambda)                     (None, 1024)         0           ['encoder[0][0]',                
                                                                  'encoder[0][1]']          

## Interpolation Visualizations

In [10]:
def interpolate_vectors(v1, v2, num_steps=10):
    """Interpolates between two vectors with a specified number of steps."""
    ratios = np.linspace(0, 1, num_steps)
    interpolated_vectors = [(1 - ratio) * v1 + ratio * v2 for ratio in ratios]
    return np.array(interpolated_vectors)

def get_interpolated_images(decoder, latent_vectors, index1, index2, num_steps=10):
    """Gets a series of images showing the transition between two points."""
    latent_vector_1 = latent_vectors[index1]  # Use specific indices
    latent_vector_2 = latent_vectors[index2]

    # Interpolate between the two latent vectors
    interpolated_vectors = interpolate_vectors(latent_vector_1, latent_vector_2, num_steps=num_steps)

    # Decode the interpolated latent vectors into images
    decoded_images = decoder.predict(interpolated_vectors)

    return decoded_images


In [14]:
from ipywidgets import interact, IntSlider
import matplotlib.pyplot as plt
import numpy as np

def interactive_image_gallery(images, images_per_page=10, figsize_per_image=(1, 1)):
    """Creates an interactive gallery for navigating through images."""
    total_images = len(images)
    max_pages = (total_images + images_per_page - 1) // images_per_page

    def show_images(page=1):
        start = (page - 1) * images_per_page
        end = start + images_per_page
        page_images = images[start:end]

        # Determine the number of columns and rows to display
        cols = images_per_page  # Display all images on one row
        rows = 1
        fig_width = figsize_per_image[0] * cols  # Total width of figure
        fig_height = figsize_per_image[1] * rows  # Total height of figure
        plt.figure(figsize=(fig_width, fig_height))

        for i, image in enumerate(page_images):
            plt.subplot(rows, cols, i + 1)
            plt.imshow(image)
            plt.axis('off')

        plt.tight_layout()
        plt.show()

    interact(show_images, page=IntSlider(min=1, max=max_pages, step=1, value=1, description='Page'))

# Example usage:
# Get the interpolated images
z_mean, z_log_var = encoder.predict(images)
interpolated_images = get_interpolated_images(decoder, z_mean, index1=6, index2=20, num_steps=10)

# Display the images in an interactive gallery
interactive_image_gallery(interpolated_images, images_per_page=10, figsize_per_image=(1, 1))



interactive(children=(IntSlider(value=1, description='Page', max=1, min=1), Output()), _dom_classes=('widget-i…

## Visualizations of Originals and Reconstructions VAE

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider

def display_image_pairs(originals, reconstructions, page=1, pairs_per_page=4):
    """Displays original and reconstructed image pairs in a gallery format."""
    # Calculate which images to show on this page
    start = (page - 1) * pairs_per_page
    end = start + pairs_per_page
    page_originals = originals[start:end]
    page_reconstructions = reconstructions[start:end]
    
    # Setup the figure based on the number of image pairs to display
    cols = 4 # We need 2 columns for each pair
    rows = pairs_per_page  # The number of rows is the same as pairs per page
    plt.figure(figsize=(2.5 * cols, 2.5 * rows))
    
    # Display each pair of original and reconstructed images
    for i in range(pairs_per_page):
        if i < len(page_originals):
            # Display original image
            plt.subplot(rows, cols, 2*i + 1)
            plt.imshow(page_originals[i], cmap='gray')
            plt.title('Original')
            plt.axis('off')
            
            # Display reconstructed image
            plt.subplot(rows, cols, 2*i + 2)
            plt.imshow(page_reconstructions[i], cmap='gray')
            plt.title('Reconstructed')
            plt.axis('off')

    plt.tight_layout()
    plt.show()

def interactive_gallery(originals, reconstructions, pairs_per_page=4):
    """Creates an interactive gallery for navigating through image pairs."""
    total_pairs = len(originals)
    max_pages = (total_pairs + pairs_per_page - 1) // pairs_per_page
    
    interact(lambda page: display_image_pairs(originals, reconstructions, page, pairs_per_page),
             page=IntSlider(min=1, max=max_pages, step=1, value=1, description='Page'))

interactive_gallery(images, reconstructions)


# Latent Diffusion 

In [3]:
# Latent Diffusion
latent_vectors = encoder.predict(images)
z_mean, z_log_var = latent_vectors

sampled_latent_vectors = vae.sampling(latent_vectors)

# Ensure the shape is what your model expects
print("Sampled latent vectors shape:", sampled_latent_vectors.shape) 
# Demonison of Latent Space
print("Latent Space Dimension:", sampled_latent_vectors.shape[1])

# # Training
T = 1000
betas = np.linspace(1e-4, .02, T)
sigmas = np.sqrt(betas)
alphas = 1 - betas
alphas_cumprod = np.cumprod(alphas, axis=-1)

ld_model = tf.keras.models.load_model("/Users/lucky/GitHub/PokeGenerator/model/checkpoints/latent_model_2024-03-19-17:41:00.keras")
ld_model.summary()

# Generate sampled latent vectors
sampled_latent_vectors = latent_diffusion.sample(
    model=ld_model, 
    num_samples=sampled_latent_vectors.shape[0], 
    latent_dim=sampled_latent_vectors.shape[1], 
    T=T, 
    sigmas=sigmas, 
    alphas=alphas, 
    alphas_cumprod=alphas_cumprod
)

# Decode the sampled latent vectors to images
decoded_images = decoder.predict(sampled_latent_vectors)



Sampled latent vectors shape: (13334, 512)
Latent Space Dimension: 512




Model: "reverse_process_mlp_model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 1)]          0           []                               
                                                                                                  
 embedding (Embedding)          (None, 1, 512)       512000      ['input_2[0][0]']                
                                                                                                  
 input_1 (InputLayer)           [(None, 512)]        0           []                               
                                                                                                  
 flatten (Flatten)              (None, 512)          0           ['embedding[0][0]']              
                                                                          

  0%|          | 0/1000 [00:00<?, ?it/s]



## Latent Diffusion Visualizations

In [4]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider

def display_images(images, page=1, images_per_page=8):
    """Displays a page of images in a grid format."""
    start = (page - 1) * images_per_page
    end = start + images_per_page
    page_images = images[start:end]

    cols = 4  # You can change this to how many columns you want to display
    rows = (len(page_images) + cols - 1) // cols
    fig_width = cols * 3  # 3 inches per image column
    fig_height = rows * 3  # 3 inches per image row
    plt.figure(figsize=(fig_width, fig_height))

    for i, image in enumerate(page_images):
        plt.subplot(rows, cols, i + 1)
        plt.imshow(image)
        plt.axis('off')

    plt.tight_layout()
    plt.show()

def interactive_image_gallery(images, images_per_page=8):
    """Creates an interactive gallery for navigating through images."""
    total_images = len(images)
    max_pages = (total_images + images_per_page - 1) // images_per_page

    interact(lambda page: display_images(images, page, images_per_page),
             page=IntSlider(min=1, max=max_pages, step=1, value=1, description='Page'))

# Call the interactive gallery with your decoded images
interactive_image_gallery(decoded_images)

interactive(children=(IntSlider(value=1, description='Page', max=1667, min=1), Output()), _dom_classes=('widge…