(c) 2024 Manuel Razo. This work is licensed under a [Creative Commons
Attribution License CC-BY 4.0](https://creativecommons.org/licenses/by/4.0/).
All code contained herein is licensed under an [MIT
license](https://opensource.org/licenses/MIT).

In [1]:
# Import project package
import Antibiotic

# Import project package
import AutoEncode
import AutoEncode.diffgeo.NeuralGeodesics as NG

# Import libraries to handel data
import CSV
import DataFrames as DF
import Glob

# Import ML libraries
import Flux

# Import library to save models
import JLD2

# Import basic math
import StatsBase
import Random
Random.seed!(42)

# Import Plotting libraries
using CairoMakie
import ColorSchemes

# Activate backend
CairoMakie.activate!()

# Set Plotting style
Antibiotic.viz.theme_makie!()

# Exploratory data analysis of the $\beta$-RHVAE results

`[explanation here]`

Let's begin by listing the files with the saved model.

In [None]:
# List files in output directory
files = Glob.glob("output/model_state/*.jld2")

first(files, 5)

Now, let's loop through each file, and load the values of the loss function and
the mean squared error for each model.

In [None]:
# Initialize empty dataframe
df_train = DF.DataFrame()
# Define fields to keep
fields = ["mse_train", "mse_val", "loss_train", "loss_val"]

# Loop through files
for f in files
    # spit file name to extract beta value and epoch number
    pars = split(f, "_")[end-1:end]
    # Parse epoch number
    epoch = parse(Int, replace(pars[2], ".jld2" => "", "epoch" => ""))
    # Load file and keep relevant fields
    file_load = JLD2.load(f)
    # Convert to dataframe
    df_tmp = DF.DataFrame(Dict(zip(fields, [file_load[x] for x in fields])))
    # Add beta and epoch to dataframe
    df_tmp[!, :epoch] .= epoch

    # Append to main dataframe
    DF.append!(df_train, df_tmp)
end # for

# Sort by beta and epoch
DF.sort!(df_train, :epoch)

first(df_train, 5)

With this information we can then plot the loss function and the mean squared
error for each model.

In [None]:
# Initialize figure
fig = Figure(size=(550, 550))

# Add axes to figure
axes = [
    Axis(
        fig[i, j],
        aspect=AxisAspect(1),
        xlabel="epoch #",
    ) for i in 1:2, j in 1:2
]

# Loop through fields
for (i, field) in enumerate(fields)
    # Get axis
    ax = axes[i]
    # Get x and y values
    x = df_train.epoch
    y = df_train[:, field]
    # Plot
    scatterlines!(
        ax, x, y, markersize=5,
    )
    # Set subplot title
    ax.title = replace(field, "_" => " ")
end # for

fig

## Mapping points to latent space

In [None]:
# Load model
infomaxvae = JLD2.load("./output/model.jld2")["model"]

# Search for model files
model_files = Glob.glob("./output/model_state/*.jld2")
println("$(model_files[end])")
# Load parameters
model_state = JLD2.load(model_files[end])["model_state"]
# Set model parameters
Flux.loadmodel!(infomaxvae, model_state)

Now, we can load the data.

In [9]:
# Define number of inputs
n_input = 3
# Define number of synthetic data points
n_data = 1_000

# Define function
f(x₁, x₂) = 10.0f0 * exp(-(x₁^2 + x₂^2))

# Defien radius
radius = 3

# Sample random radius
r_rand = radius .* sqrt.(Random.rand(n_data))

# Sample random angles
θ_rand = 2π .* Random.rand(n_data)

# Convert form polar to cartesian coordinates
x_rand = Float32.(r_rand .* cos.(θ_rand))
y_rand = Float32.(r_rand .* sin.(θ_rand))
# Feed numbers to function
z_rand = f.(x_rand, y_rand)

# Compile data into matrix
data = Matrix(hcat(x_rand, y_rand, z_rand)')

# Fit model to standardize data to mean zero and standard deviation 1 on each
# environment
dt = StatsBase.fit(StatsBase.ZScoreTransform, data, dims=2)

# Center data to have mean zero and standard deviation one
data_std = StatsBase.transform(dt, data);

Let's look at the data

In [None]:
# Initialize figure
fig = Figure(size=(2 * 300, 300))

# Add GridLayout
gl = GridLayout(fig[1, 1])

# Add axis for input space
ax_input = Axis(
    gl[1, 1],
    xlabel="z₁",
    ylabel="z₂",
    title="input space"
)

# Plot input space
scatter!(
    ax_input,
    data_std[1, :],
    data_std[2, :],
    markersize=5,
    color=data_std[3, :],
    colormap=:viridis,
)

# Add axis for output space
ax_output = Axis3(
    gl[1, 2],
    xlabel="x₁(z₁,z₂)",
    ylabel="x₂(z₁,z₂)",
    zlabel="x₃(z₁,z₂)",
    title="output space",
    xypanelcolor="#E6E6EF",
    xzpanelcolor="#E6E6EF",
    yzpanelcolor="#E6E6EF",
    xgridcolor=:white,
    ygridcolor=:white,
    zgridcolor=:white,
)

# Plot output space
scatter!(
    ax_output,
    eachrow(data_std)...,
    markersize=5,
    color=data_std[3, :],
    colormap=:viridis,
)

# Separate axis
colgap!(gl, 50)

fig

Let's now encode the data using the model encoder to obtain the latent space
representation of the data.

In [None]:
# Encode data to latent space
latent = infomaxvae.vae.encoder(data_std).µ

Let's plot the latent space representation of the data.

In [None]:
# Initialize plot
fig = Figure(size=(300, 300))

# Add axis to plot
ax = Axis(
    fig[1, 1],
    xlabel="latent dimension 1",
    ylabel="latent dimension 2",
)
# Add scatter plot to axis
scatter!(
    ax,
    latent[1, :],
    latent[2, :],
    markersize=5,
    color=data_std[3, :],
    colormap=:viridis,
)

fig

Let's look at the resemblance between the input and the output of this
autoencoder. For this, we will plot the true value of the function and the
autoencoder-reconstructed value side-to-side. To make sure that the
reconstruction keeps the structure of the data, we will color the points by
their true $x_3$-value dictated by our function as well as by the angle on the
$x-y$-plane to make sure the autoencoder learned to distinguish the points
despite the circular symmetry.

In [None]:
Random.seed!(42)

# Initialize figure
fig = Figure(size=(2 * 300, 2 * 300))

# Add GridLayout
gl = GridLayout(fig[1, 1])

ax = [
    Axis3(
        gl[i, j],
        xlabel="x₁(z₁,z₂)",
        ylabel="x₂(z₁,z₂)",
        zlabel="x₃(z₁,z₂)",
        xypanelcolor="#E6E6EF",
        xzpanelcolor="#E6E6EF",
        yzpanelcolor="#E6E6EF",
        xgridcolor=:white,
        ygridcolor=:white,
        zgridcolor=:white,
    ) for i = 1:2 for j = 1:2
]

# Reshape axis
ax = permutedims(reshape(ax, 2, 2), (2, 1))

# Plot output space (color by z-value)
scatter!(
    ax[1, 1],
    data_std[1, :],
    data_std[2, :],
    data_std[3, :],
    markersize=5,
    color=data_std[3, :],
    colormap=:viridis,
)

# Add title
ax[1, 1].title = "enconder input space (color by value)"

# Plot output space (color by angle)
scatter!(
    ax[2, 1],
    data_std[1, :],
    data_std[2, :],
    data_std[3, :],
    markersize=5,
    color=atan.(data_std[1, :] .^ 2, data_std[2, :] .^ 2),
    colormap=:inferno,
)

# Add title
ax[2, 1].title = "enconder input space (color by angle)"

# Pass data through autoencoder
data_vae = infomaxvae(data_std).µ

# Plot output space (color by z-value)
scatter!(
    ax[1, 2],
    data_vae[1, :],
    data_vae[2, :],
    data_vae[3, :],
    markersize=5,
    color=data_std[3, :],
    colormap=:viridis,
)

# Add title
ax[1, 2].title = "decoder output space (color by value)"

# Plot output space (color by angle)
scatter!(
    ax[2, 2],
    data_vae[1, :],
    data_vae[2, :],
    data_vae[3, :],
    markersize=5,
    color=atan.(data_std[1, :] .^ 2, data_std[2, :] .^ 2),
    colormap=:inferno,
)

# Add title
ax[2, 2].title = "decoder output space (color by angle)"

fig

Next, let's map the input data (3D points) into the latent space (2D points) via
the encoder. Again, we will plot the value, assigning color based on the input
$z$-value determined by our function.

In [None]:
# Map input data to latent space
data_latent = infomaxvae.vae.encoder(data_std).µ

# Initialize figure
fig = Figure(size=(3 * 300, 300))

# Add GridLayout
gl = GridLayout(fig[1, 1])

# Add axis for output space
ax_input = Axis3(
    gl[1, 1],
    xlabel="f₁(x₁,x₂)",
    ylabel="f₂(x₁,x₂)",
    zlabel="f₃(x₁,x₂)",
    title="encoder input space",
    xypanelcolor="#E6E6EF",
    xzpanelcolor="#E6E6EF",
    yzpanelcolor="#E6E6EF",
    xgridcolor=:white,
    ygridcolor=:white,
    zgridcolor=:white,
    aspect=(1, 1, 1)
)

# Plot output space
scatter!(
    ax_input,
    data_std[1, :],
    data_std[2, :],
    data_std[3, :],
    markersize=5,
    color=data_std[3, :],
    colormap=:viridis,
)

# Add axis for latent space
ax_latent = Axis(
    gl[1, 2],
    xlabel="z₁",
    ylabel="z₂",
    title="latent space",
    aspect=AxisAspect(1)
)

# Plot latent space
scatter!(
    ax_latent,
    data_latent[1, :],
    data_latent[2, :],
    markersize=5,
    color=data_std[3, :],
    colormap=:viridis
)

# Add axis for output space
ax_output = Axis3(
    gl[1, 3],
    xlabel="f₁(x₁,x₂)",
    ylabel="f₂(x₁,x₂)",
    zlabel="f₃(x₁,x₂)",
    title="decoder output space",
    xypanelcolor="#E6E6EF",
    xzpanelcolor="#E6E6EF",
    yzpanelcolor="#E6E6EF",
    xgridcolor=:white,
    ygridcolor=:white,
    zgridcolor=:white,
    aspect=(1, 1, 1)
)


# Plot output space (color by z-value)
scatter!(
    ax_output,
    data_vae[1, :],
    data_vae[2, :],
    data_vae[3, :],
    markersize=5,
    color=data_std[3, :],
    colormap=:viridis,
)

# Separate axis
colgap!(gl, 20)

fig