(c) 2024 Manuel Razo. This work is licensed under a [Creative Commons
Attribution License CC-BY 4.0](https://creativecommons.org/licenses/by/4.0/).
All code contained herein is licensed under an [MIT
license](https://opensource.org/licenses/MIT).

In [None]:
# Import project package
import Antibiotic

# Import ML libraries
import Flux
import AutoEncoderToolkit as AET
import AutoEncoderToolkit.diffgeo.NeuralGeodesics as NG
import Zygote

# Import libraries to handel data
import CSV
import DataFrames as DF
import Glob


# Import library to save models
import JLD2

# Import basic math
import LinearAlgebra
import StatsBase
import Random
Random.seed!(42)

# Import Plotting libraries
using CairoMakie
import ColorSchemes

# Activate backend
CairoMakie.activate!()

# Set Plotting style
Antibiotic.viz.theme_makie!()

# Gradient ascent on latent space

In this notebook we will explore the idea of performing a gradient ascent on the
fitness landscape built in the RHVAE latent space.

## Fitness landscape in latent space

 
When we trained the RHVAE, we learned a map from data to latent variables back
to the data, where the data consisted of fitness values in different
environments for a set of genotypes. This learned map included the Riemannian
metric on the latent space, giving us information about distances between points
in the latent space and the data space.

In this low-dimensional latent space, we can map every point to the 
corresponding fitness for a particular environment, building like that a fitness
landscape in latent space.

Let's begin by loading one of the trained RHVAE models.

In [None]:
# Define directory where models are saved
model_dir = "$(git_root())/output/beta-rhvae_jointlogencoder_simpledecoder_iwasawa_mcmc/v05/"

# List model file
model_file = first(Glob.glob("$(model_dir)/model.jld2"[2:end], "/"))

# List last saved state
state_file = last(Glob.glob("$(model_dir)/model_state/*.jld2"[2:end], "/"))

# Load model and state
rhvae = JLD2.load(model_file)["model"]
Flux.loadmodel!(rhvae, JLD2.load(state_file)["model_state"])

# Update RHVAE metric
AET.RHVAEs.update_metric!(rhvae)

typeof(rhvae)

Next, let's load the fitness data and their previously-computed latent
coordinates.

In [None]:
# Define data directory
data_dir = "$(git_root())/output/mcmc_iwasawa_logistic"

# Load data
df_logic50 = CSV.read("$(data_dir)/logic50_ci.csv", DF.DataFrame)
# Extract strain and evolution condition from :env by splitting by _
DF.insertcols!(
    df_logic50,
    :strain => getindex.(split.(df_logic50.env, "_"), 1),
    :evo => getindex.(split.(df_logic50.env, "_"), 3),
)

first(df_logic50, 5)

Next, we will map the fitness data to the latent space using the RHVAE.

In [None]:
println("Map data to latent space...")

# Group dataframe by :day, :strain_num, and :env
df_group = DF.groupby(df_logic50, [:day, :strain_num, :env])
# Initialize empty dataframe to store latent coordinates
df_latent = DF.DataFrame()
# Loop over groups
for data in df_group
    # Sort data by drug
    DF.sort!(data, :drug)
    # Run :logic50_mean_std through encoder
    latent = rhvae.vae.encoder(Float32.(data.logic50_mean_std)).µ
    # Append latent coordinates to dataframe
    DF.append!(
        df_latent,
        DF.DataFrame(
            :day .=> first(data.day),
            :strain_num .=> first(data.strain_num),
            :meta .=> first(data.env),
            :env .=> split(first(data.env), "_")[end],
            :strain .=> split(first(data.env), "_")[1],
            :latent1 => latent[1, :],
            :latent2 => latent[2, :],
        )
    )
end # for 

first(df_latent, 5)

Let's take a quick look at the latent space structure by plotting the latent
coordinates of the genotypes.

In [None]:
# Initialize figure
fig = Figure(size=(300, 300))

# Add axis
ax = Axis(
    fig[1, 1],
    xlabel="latent dimension 1",
    ylabel="latent dimension 2",
    aspect=AxisAspect(1)
)

# Add scatter plot
scatter!(
    ax,
    df_latent.latent1,
    df_latent.latent2,
    markersize=6,
)

fig

This plot doesn't convey any information about the curvature of the latent
space. To get a better sense of the latent space structure, we can plot the
log determinant of the metric tensor as the background color to get a sense of
the degree of deformation in the latent space. In other words, given the local
representation of the Riemannian metric tensor given as a matrix 
$\underline{\underline{G}}(\underline{z})$, we compute
$$
\mathcal{L}(\underline{z}) = \log \sqrt{
    \det \underline{\underline{G}}(\underline{z})
},
\tag{1}
$$
and plot this as the background color of the latent space.

Since the `RHVAE` object comes with a learned metric tensor, we can use this
to compute this quantity.

In [None]:
# Define the number of points to evaluate
n_points = 200

# Extract latent space ranges
latent1_range = Float32.(range(
    minimum(df_latent.latent1) - 1.5,
    maximum(df_latent.latent1) + 1.5,
    length=n_points
))
latent2_range = Float32.(range(
    minimum(df_latent.latent2) - 1.5,
    maximum(df_latent.latent2) + 1.5,
    length=n_points
))
# Define latent points to evaluate
z_mat = reduce(hcat, [[x, y] for x in latent1_range, y in latent2_range])

# Compute inverse metric tensor
Ginv = AET.RHVAEs.G_inv(z_mat, rhvae)

# Compute log determinant of metric tensor
log_metric = [
    latent1_range,
    latent2_range,
    Float32.(reshape(
        -1 / 2 * AET.utils.slogdet(Ginv), n_points, n_points
    ))
]

Now, let's repeat the previous plot but with the log determinant of the metric
as a background color.

In [None]:
# Initialize figure
fig = Figure(size=(400, 300))

# Add axis
ax = Axis(
    fig[1, 1],
    xlabel="latent dimension 1",
    ylabel="latent dimension 2",
    aspect=AxisAspect(1)
)

hm = heatmap!(
    ax,
    log_metric...,
    colormap=ColorSchemes.tokyo,
)

# Add colorbar
Colorbar(fig[1, 2], hm, label="log(√det(G))")

# Add scatter plot
scatter!(
    ax,
    df_latent.latent1,
    df_latent.latent2,
    markersize=4,
    color=(:white, 0.3)
)

fig

In the same way that we mapped latent coordinates to the corresponding log 
metric, we can map the latent coordinates to the corresponding fitness values
by simply running the latent coordinates through the decoder. We'll store this
output as a 3D array where each slice represents the fitness landscape in each
of the environments.

In [None]:
# Map latent coordinates to fitness value with decoder. Note that the transpose
# is required to match the shape of the fitness matrix when reshaped to 3D array
# fitness_mat = rhvae.vae.decoder(z_mat).μ'
# Reshape fitness matrix to 3D array
# fitness_mat = reshape(fitness_mat, n_points, n_points, :);

# Extract drug names
drugs = sort(unique(df_logic50.drug))

# Initialize array to store fitness values
fitness_mat = zeros(Float32, n_points, n_points, length(drugs))

# Loop through each latent point
for (i, x) in enumerate(latent1_range)
    for (j, y) in enumerate(latent2_range)
        fitness_mat[i, j, :] = rhvae.vae.decoder(Float32.([x, y])).μ
    end # for
end # for

Let's plot all 8 environments in a 2x4 grid.

In [None]:
# Initialize figure
fig = Figure(size=(900, 900))
# Add grid layout
gl = fig[1, 1] = GridLayout()

# Define number of rows and columns
rows = 3
cols = 3

# Extract unique environments
drugs = sort(unique(df_logic50.drug))

# Loop through each environment
for (i, drug) in enumerate(drugs)
    # Define row and column index
    row = (i - 1) ÷ cols + 1
    col = (i - 1) % cols + 1
    # Add axis
    ax = Axis(gl[row, col], aspect=AxisAspect(1), title="antibiotic $(drug)")
    # Hide axis
    hidedecorations!(ax)
    # Plot heatmap
    hm = heatmap!(
        ax,
        latent1_range,
        latent2_range,
        fitness_mat[:, :, i],
        colormap=ColorSchemes.viridis,
    )
end # for

# Add axis for log metric
ax = Axis(gl[end, end], aspect=AxisAspect(1), title="log(√det(G))")
# Hide axis
hidedecorations!(ax)
# Plot heatmap
hm = heatmap!(
    ax,
    log_metric...,
    colormap=ColorSchemes.tokyo,
)

# Add global x-axis label
Label(gl[end, :, Bottom()], "latent dimension 1", fontsize=24)
# Add global y-axis label
Label(gl[:, 1, Left()], "latent dimension 2", fontsize=24, rotation=π / 2)

fig

These fitness landscapes go beyond where we saw the data. We can use the log
metric to inform us where are these fitness landscapes meaningful given the
data. For that, we simply define a "mask" that sets to zero all values where the
log metric is greater than a certain threshold.

In [None]:
# Define mask for fitness landscape
mask = (
    maximum(log_metric[end]) * 0.92 .<
    log_metric[end] .≤
    maximum(log_metric[end])
)

# Initialize figure
fig = Figure(size=(1_000, 1_000))
# Add grid layout
gl = fig[1, 1] = GridLayout()

# Define number of rows and columns
rows = 3
cols = 3

# Loop through each drug
for (i, drug) in enumerate(drugs)
    # Define row and column index
    row = (i - 1) ÷ cols + 1
    col = (i - 1) % cols + 1
    # Add axis
    ax = Axis(gl[row, col], aspect=AxisAspect(1), title="antibiotic $(drug)")
    # Hide axis
    hidedecorations!(ax)
    # Mask fitness landscape
    fit_landscape_masked = (mask .* minimum(fitness_mat[:, :, i])) .+
                           (fitness_mat[:, :, i] .* .!mask)
    # Plot heatmap
    hm = heatmap!(
        ax,
        latent1_range,
        latent2_range,
        fit_landscape_masked,
        colormap=ColorSchemes.viridis,
    )
end # for

# Add axis for log metric
ax = Axis(gl[end, end], aspect=AxisAspect(1), title="log(√det(G))")
# Hide axis
hidedecorations!(ax)
# Plot heatmap
hm = heatmap!(
    ax,
    log_metric...,
    colormap=ColorSchemes.tokyo,
)

# Add global x-axis label
Label(gl[end, :, Bottom()], "latent dimension 1", fontsize=24)
# Add global y-axis label
Label(gl[:, 1, Left()], "latent dimension 2", fontsize=24, rotation=π / 2)

fig

## Gradient ascent on the fitness landscape

Under the assumption that these low-dimensional fitness landscapes in latent
space are meaningful, we could simulate evolutionary dynamics under strong
selection by performing a gradient ascent on these landscapes. In other words,
let us define a scalar function $f_i(\underline{z})$ as the $i$-th output of the
decoder defining the fitness value of a genotype with latent coordinates
$\underline{z}$ in environment $i$.. Naively, if the evolutionary trajectory
follows the gradient of the fitness landscape, we could approximate the dynamics
of how the latent coordinate evolves over time as
$$
\underline{z}_{t+1} \approx \underline{z}_t + \eta \nabla f_i(\underline{z}_t),
\tag{2}
$$
where $\underline{z}_t$ is the latent coordinate at time $t$, $\eta$ is a step
size parameter in this latent space, and $\nabla f_i(\underline{z}_t)$ is the
gradient of the fitness landscape at the current latent coordinate. However,
Eq. 2 is not correct because it doesn't take into account the curvature of the
latent space. Instead, the gradient vector should be corrected to become the
so-called Riemannian gradient $\nabla^{\mathcal{M}}$, where $\mathcal{M}$ is the
Riemmanian manifold. In other words, we should use the Riemannian metric tensor
to correct the gradient vector. The corrected gradient vector is of the form
$$
\nabla^{\mathcal{M}} f_i(\underline{z}) = 
\underline{\underline{G}}^{-1}(\underline{z}) \nabla f_i(\underline{z}),
\tag{3}
$$
where $\underline{\underline{G}}^{-1}(\underline{z})$ is the inverse of the 
metric tensor at the latent coordinate $\underline{z}$. Thus, Eq. 2 should be
corrected to
$$
\underline{z}_{t+1} \approx \underline{z}_t + 
\eta \nabla^{\mathcal{M}} f_i(\underline{z}_t).
\tag{4}
$$
Luckily, the `RHVAE` object comes with a method to compute the inverse 
Riemannian metric directly.

Let's begin by defining a function that computes the Riemannian gradient. For
this, we will use `Zygote.jl` autodiff to compute the gradient of the fitness
landscape.

In [None]:
@doc raw"""
    riemannian_gradient(
        z::AbstractArray, f::Function, rhvae::AET.RHVAEs.RHVAE
    )

Compute the Riemannian gradient of a given function `f` with respect to the
latent variable `z` in the context of a Riemannian Hamiltonian Variational
Autoencoder (RHVAE).

# Arguments
- `z::AbstractArray`: The latent variable for which the Riemannian gradient is
  computed. This is typically a vector of latent space coordinates.
- `f::Function`: The function whose gradient is to be computed. This function
  should take `z` as input and return a scalar value. Usually, this is an
  anonymous function indexing one of the outputs of the decoder of the RHVAE
  model.
- `rhvae::AET.RHVAEs.RHVAE`: An instance of the RHVAE model, which contains the
  necessary methods and parameters for computing the inverse metric tensor.

# Returns
- `AbstractArray`: The Riemannian gradient of the function `f` with respect to
  `z`. This is computed by multiplying the inverse metric tensor by the
  Euclidean gradient of `f`.

# Details
The function performs the following steps:
1. Computes the Euclidean gradient of the function `f` with respect to the
   latent variable `z` using automatic differentiation.
2. Computes the inverse of the metric tensor `G` at `z` using the method `G_inv`
   from the RHVAE model.
3. Multiplies the inverse metric tensor by the Euclidean gradient to obtain the
   Riemannian gradient.

The Riemannian gradient takes into account the geometry of the latent space
defined by the RHVAE model, providing a more accurate direction for optimization
in this space.
"""
function riemannian_gradient(
    z::AbstractArray,
    f::Function,
    rhvae::AET.RHVAEs.RHVAE
)
    # Compute gradient of function
    ∇f = Zygote.gradient(f, z)[1]
    # Compute inverse metric tensor
    G_inv = AET.RHVAEs.G_inv(z, rhvae)
    # Return riemmanian gradient
    return G_inv * ∇f
end # riemmanian_gradient

# Define anonymous function to compute fitness
f(z) = rhvae.vae.decoder(z).μ[1]

# Test riemmanian gradient
riemannian_gradient(zeros(Float32, 2), f, rhvae)

Let's now test the Riemannian gradient on the fitness landscape of a particular
environment.

In [None]:
# Define evolutionary stress
evo_stress = findfirst(x -> x == "TET", drugs)
# Define anonymous function to compute fitness
f(z) = rhvae.vae.decoder(z).μ[evo_stress]

# Define the number of points to evaluate
n_points = 20

# Extract latent space ranges
latent1_range_vec = Float32.(range(
    minimum(df_latent.latent1) - 1.5,
    maximum(df_latent.latent1) + 1.5,
    length=n_points
))
latent2_range_vec = Float32.(range(
    minimum(df_latent.latent2) - 1.5,
    maximum(df_latent.latent2) + 1.5,
    length=n_points
))

# Define latent points to evaluate
z_vec = reduce(
    hcat, [[x, y] for x in latent1_range_vec, y in latent2_range_vec]
)

# Compute inverse metric tensor at each of these points. Note: This will be used
# to filter out points out of the data manifold
Ginv = AET.RHVAEs.G_inv(z_vec, rhvae)
# Compute log determinant of metric tensor
log_m = -1 / 2 * AET.utils.slogdet(Ginv)
# Find index of points within the data manifold 
z_idx = findall(
    .!(maximum(log_metric[end]) * 0.92 .< log_m .≤ maximum(log_metric[end]))
)
# Index latent points within the data manifold
z_vec = z_vec[:, z_idx]

# Evaluate gradient at each point
∇f_vec = riemannian_gradient.(eachcol(z_vec), Ref(f), Ref(rhvae))

# Compute the magnitude of each gradient vector
∇f_mag = LinearAlgebra.norm.(∇f_vec);

# Compute normalized gradient vectors
∇f_vec_norm = ∇f_vec ./ ∇f_mag;

Now, let's plot the vector field of the Riemannian gradient on the fitness
landscape.

In [None]:
# Initialize figure
fig = Figure(size=(800, 500))

# Add grid layout
gl = fig[1, 1] = GridLayout()

# Add axis for fitness landscape
ax1 = Axis(
    gl[1, 1],
    aspect=AxisAspect(1),
    title="Fitness landscape"
)
# hide decorations
hidedecorations!(ax1)
# Add axis for log metric
ax2 = Axis(gl[1, 2], aspect=AxisAspect(1), title="log(√det(G))")
# hide decorations
hidedecorations!(ax2)

# Mask fitness landscape
fit_landscape_masked = (mask .* minimum(fitness_mat[:, :, evo_stress])) .+
                       (fitness_mat[:, :, evo_stress] .* .!mask)


# Plot heatmap for fitness landscape
hm1 = heatmap!(
    ax1,
    latent1_range,
    latent2_range,
    fit_landscape_masked,
    colormap=ColorSchemes.viridis,
)

# Plot heatmap for log metric
hm2 = heatmap!(
    ax2,
    log_metric...,
    colormap=ColorSchemes.tokyo,
)
# Add colorbars
Colorbar(gl[2, 1], hm1, label="log₁₀(IC₅₀)", vertical=false)
Colorbar(gl[2, 2], hm2, label="log₁₀(√det(G))", vertical=false)


# Plot gradient vectors
arrows!.(
    [ax1; ax2],
    Ref(z_vec[1, :]),
    Ref(z_vec[2, :]),
    Ref(getindex.(∇f_vec_norm, 1)),
    Ref(getindex.(∇f_vec_norm, 2)),
    arrowsize=7.5,
    lengthscale=0.2,
    arrowcolor=∇f_mag,
    linecolor=∇f_mag,
    colormap=ColorSchemes.magma,
)

# Add global title
Label(
    gl[1, :, Top()],
    "Riemannian Gradient Vectors | Antibiotic $(drugs[evo_stress])",
    fontsize=24,
    padding=(0, 0, 30, 0)
)

fig

These plots seem perfectly reasonable. We can see that the gradient arrows
are indeed pointing towards regions of higher fitness.

Next, we can define a function that performs the gradient ascent on the fitness
landscape in latent space. This function will take as input the initial latent
coordinates, the fitness function, an `RHVAE` object, the step size and the
number of steps to take.

In [None]:
@doc raw"""
    riemannian_gradient_ascent(
        z::AbstractArray,
        f::Function,
        rhvae::AET.RHVAEs.RHVAE;
        n_steps::Int=100,
        step_size::Float32=Float32(1E-3)
    ) -> Vector{AbstractArray}

Perform Riemannian gradient ascent on a given function `f` with respect to the
latent variable `z` in the context of a Riemannian Hamiltonian Variational
Autoencoder (RHVAE).

# Arguments
- `z::AbstractArray`: The initial latent variable for which the gradient ascent
  is performed. This is typically a vector of latent space coordinates.
- `f::Function`: The function to be maximized. This function should take `z` as
  input and return a scalar value. Usually, this is an anonymous function that
  indexes one of the outputs of the decoder of the RHVAE model.
- `rhvae::AET.RHVAEs.RHVAE`: An instance of the RHVAE model, which contains the
  necessary methods and parameters for computing the inverse metric tensor and
  the Riemannian gradient.
- `n_steps::Int=100`: The number of gradient ascent steps to perform. Default is
  100.
- `step_size::Float32=Float32(1E-3)`: The step size for the gradient ascent
  updates. Default is 0.001.

# Returns
- `Vector{AbstractArray}`: A vector containing the trajectory of the latent
  variable `z` over the course of the gradient ascent. The first element is the
  initial `z`, and each subsequent element is the updated `z` after each
  gradient ascent step.

# Details
The function performs Riemannian gradient ascent to maximize the function `f`
with respect to the latent variable `z`. The steps are as follows:
1. Initialize a trajectory vector `z_traj` to store the latent variable at each
   step.
2. Set the first element of `z_traj` to the initial `z`.
3. For each step from 1 to `n_steps`:
    - Compute the Riemannian gradient of `f` at the current `z` using the
      `riemmanian_gradient` function.
    - Update the latent variable `z` by adding the product of the step size and
      the Riemannian gradient to the current `z`.
4. Return the trajectory vector `z_traj`, which contains the sequence of `z`
   values over the gradient ascent steps.

This method leverages the geometry of the latent space defined by the RHVAE
model to perform more effective optimization in this space.
"""
function riemannian_gradient_ascent(
    z::AbstractVector,
    f::Function,
    rhvae::AET.RHVAEs.RHVAE;
    n_steps::Int=100,
    step_size::Float32=Float32(1E-3)
)
    # Initialize trajectory
    z_traj = Matrix{eltype(z)}(undef, length(z), n_steps + 1)
    # Add initial point
    z_traj[:, 1] = z
    # Perform gradient ascent
    for i in 1:n_steps
        # Compute riemannian gradient ∇ᴹf 
        ∇ᴹf = riemannian_gradient(z_traj[:, i], f, rhvae)
        # Update latent coordinates
        z_traj[:, i+1] = z_traj[:, i] + step_size * ∇ᴹf
    end # for
    return z_traj
end # riemanniann_gradient_ascent

Let's test this function by performing a gradient ascent on the fitness
landscape of a particular environment, starting from the initial position of
one of the genotypes that evolved in that environment.

First let's extract the lineage information for the most fit individual in a
particular environment.

In [None]:
# Define environment to focus on
evo_stress = findfirst(x -> x == "TET", drugs)
# Filter data from this stress
df_drug = df_logic50[(df_logic50.evo.==drugs[evo_stress]), :]
# Find strain number of best individual
max_id = findmax(df_drug[!, "logic50_mean_std"])[2]
# Extract latent coordinates of best individual
df_lin = df_latent[df_latent.strain_num.==8, :]#df_drug.strain_num[max_id], :]
# Sort by day
DF.sort!(df_lin, :day)

first(df_lin, 5)

Next, starting from the initial generation for this lineage, let's perform the
gradient ascent on the fitness landscape of this environment.

In [None]:
# Define anonymous function to compute fitness
f_env(z) = rhvae.vae.decoder(z).μ[evo_stress]
# Define step-sizes to evaluate
step_size = Float32.([1E-3, 1E-2])
# Locate latent coordinates of initial individual
zₒ = Float32.(vec(Matrix(df_lin[1:1, [:latent1, :latent2]])))
# Perform gradient ascent
z_traj = [
    riemannian_gradient_ascent(
        zₒ, f_env, rhvae; n_steps=4_000, step_size=s
    ) for s in step_size
]

In [None]:
Random.seed!(42)
# Define the step size for plotting the gradient ascent trajectory
step = 1

# Initialize figure
fig = Figure(size=(800, 500))
# Add grid layout
gl = fig[1, 1] = GridLayout()

# Add axis for fitness landscape
ax1 = Axis(gl[1, 1], aspect=AxisAspect(1), title="Fitness landscape")
# hide decorations
hidedecorations!(ax1)
# Add axis for log metric
ax2 = Axis(gl[1, 2], aspect=AxisAspect(1), title="log(√det(G))")
# hide decorations
hidedecorations!(ax2)

# Mask fitness landscape
fit_landscape_masked = (mask .* minimum(fitness_mat[:, :, evo_stress])) .+
                       (fitness_mat[:, :, evo_stress] .* .!mask)

# Plot heatmap for fitness landscape
hm1 = heatmap!(
    ax1,
    latent1_range,
    latent2_range,
    fit_landscape_masked,
    colormap=ColorSchemes.viridis,
)
# Plot heatmap for log metric
hm2 = heatmap!(
    ax2,
    log_metric...,
    colormap=ColorSchemes.tokyo,
)

# Add colorbars
Colorbar(gl[2, 1], hm1, label="Fitness", vertical=false)
Colorbar(gl[2, 2], hm2, label="Curvature", vertical=false)

# Define colors for each step size
colors = get(ColorSchemes.inferno, range(0.5, 1, length=length(z_traj)))

# Loop through each step size
for (i, z) in enumerate(z_traj)
    lines!.(
        [ax1; ax2],
        Ref(z[1, 1:step:end]),
        Ref(z[2, 1:step:end]),
        linewidth=2.5,
        color=colors[i],
        # linestyle=:dash,
    )
end # for

# Plot evolutionary trajectory in latent space
scatter!.(
    [ax1; ax2],
    Ref(df_lin.latent1),
    Ref(df_lin.latent2),
    markersize=11,
    color=:black
)
scatterlines!.(
    [ax1; ax2],
    Ref(df_lin.latent1),
    Ref(df_lin.latent2),
    markersize=6,
    color=ColorSchemes.seaborn_colorblind[1],
)

# Add first point 
scatter!(
    ax1,
    [df_lin.latent1[1]],
    [df_lin.latent2[1]],
    color=:white,
    markersize=18,
    marker=:xcross
)
scatter!(
    ax1,
    [df_lin.latent1[1]],
    [df_lin.latent2[1]],
    color=:black,
    markersize=10,
    marker=:xcross
)

# Add last point
scatter!(
    ax1,
    [df_lin.latent1[end]],
    [df_lin.latent2[end]],
    color=:white,
    markersize=18,
    marker=:utriangle
)
scatter!(
    ax1,
    [df_lin.latent1[end]],
    [df_lin.latent2[end]],
    color=:black,
    markersize=10,
    marker=:utriangle
)

# Set x and y limits
xlims!.(
    [ax1; ax2],
    minimum(latent1_range), maximum(latent1_range)
)
ylims!.(
    [ax1; ax2],
    minimum(latent2_range), maximum(latent2_range)
)

# Add global title
Label(
    gl[1, :, Top()],
    "gradient ascent vs. evolutionary trajectory\nenvironment $(drugs[evo_stress])",
    fontsize=24,
    padding=(0, 0, 30, 0)
)

fig

We can see that the gradient ascent trajectory does not match the evolutionary
trajectory in the latent space.


## Mutational landscape

Here's a crazy idea: What if the dynamics of the evolutionary trajectory in
latent space are not solely a function of the fitness landscape, but also of
what we will call the mutational landscape for lack of a better term? What this
second landscape represents is the likelihood that a genotype can encode a 
phenotype that corresponds to some coordinates in the latent space. In this way,
we include the idea that the dynamics of the evolutionary trajectory are also
influenced by how likely it is to find the encoding that corresponds to a
particular phenotype.

We can model these dynamics as follows:
$$
\frac{d\underline{z}}{dt} = \alpha \nabla F(\underline{z}) +
\beta \nabla M(\underline{z}),
\tag{5}
$$
where $F(\underline{z})$ is the fitness landscape and $M(\underline{z})$ is
the mutational landscape. The coefficients $\alpha$ and $\beta$ determine the
relative influence of the landscapes.

We can obtain the gradient for the fitness landscape directly from the output of
the RHVAE model, as we have seen previously. But, how do we obtain the gradient
for a mutational landscape if we don't even have a model for this landscape?

### The metric as a mutational landscape

From the construction of the Riemannian metric tensor, we know that the
curvature of a region in latent space is inversely related to the density of
points in that region. Therefore, let's try using the inverse of the metric as
a proxy for the mutational landscape. This is not a perfect proxy, but it is a
start with what we have.

Let's define a function that computes the gradient of the inverse of the log
metric.

In [None]:
@doc raw"""
    mut_gradient(z::AbstractVector, rhvae::AET.RHVAEs.RHVAE, log::Bool=true)

Compute the gradient of the (log) determinant of the inverse metric tensor with
respect to the latent variable `z` in the context of a Riemannian Hamiltonian
Variational Autoencoder (RHVAE).

# Arguments
- `z::AbstractVector`: The latent variable for which the gradient is computed.
  This is typically a vector of latent space coordinates.
- `rhvae::AET.RHVAEs.RHVAE`: An instance of the RHVAE model, which contains the
  necessary methods and parameters for computing the inverse metric tensor.

# Optional Keyword Arguments
- `log::Bool`: If true (default), compute the gradient of the log determinant.
  If false, compute the gradient of the determinant without taking the log.

# Returns
- `AbstractVector`: The gradient of the (log) determinant of the inverse metric
  tensor with respect to `z`.

# Details
The function performs the following steps:
1. Computes the inverse of the metric tensor `G` at `z` using the method `G_inv`
   from the RHVAE model.
2. Computes the gradient of the (log) determinant of the inverse metric tensor
   with respect to `z` using automatic differentiation.

This gradient can be interpreted as a proxy for the "mutational landscape" in
the latent space. It provides information about how the local geometry of the
latent space changes with respect to `z`, which may influence the evolutionary
dynamics in this space.

# Note
The use of the inverse metric tensor as a proxy for the mutational landscape is
an experimental approach and may not perfectly capture all aspects of mutational
dynamics.
"""
function mut_gradient(
    z::AbstractArray,
    rhvae::AET.RHVAEs.RHVAE;
    log::Bool=true
)
    # Compute gradient of (log) determinant of inverse metric tensor
    Zygote.gradient(
        x -> log ? LinearAlgebra.logdet(AET.RHVAEs.G_inv(x, rhvae)) :
             LinearAlgebra.det(AET.RHVAEs.G_inv(x, rhvae)),
        z
    )[1]
end # mut_gradient

# Test mut_gradient
mut_gradient(zeros(Float32, 2), rhvae)

Let's test this function by computing the gradient on a grid of points in the
latent space.

In [None]:
# Define the number of points to evaluate
n_points = 20

# Extract latent space ranges
latent1_range_vec = Float32.(range(
    minimum(df_latent.latent1) - 1.5,
    maximum(df_latent.latent1) + 1.5,
    length=n_points
))
latent2_range_vec = Float32.(range(
    minimum(df_latent.latent2) - 1.5,
    maximum(df_latent.latent2) + 1.5,
    length=n_points
))

# Define latent points to evaluate
z_vec = reduce(
    hcat, [[x, y] for x in latent1_range_vec, y in latent2_range_vec]
)
# Find index of points within the data manifold 
z_idx = findall(
    .!(maximum(log_metric[end]) * 0.92 .< log_m .≤ maximum(log_metric[end]))
)
# Index latent points within the data manifold
z_vec = z_vec[:, z_idx]

# Compute gradient on grid of points
∇mut_vec = mut_gradient.(eachcol(z_vec), Ref(rhvae); log=true)

# Compute the magnitude of the gradient
∇mut_mag = LinearAlgebra.norm.(∇mut_vec)

# Compute the normalized gradient vectors
∇mut_norm = ∇mut_vec ./ ∇mut_mag;

Now, let's plot the gradient vectors on the latent space.

In [None]:
# Initialize figure
fig = Figure(size=(800, 500))

# Add grid layout
gl = fig[1, 1] = GridLayout()

# Add axis for fitness landscape
ax1 = Axis(
    gl[1, 1],
    aspect=AxisAspect(1),
    title="Fitness landscape"
)
# hide decorations
hidedecorations!(ax1)
# Add axis for log metric
ax2 = Axis(gl[1, 2], aspect=AxisAspect(1), title="Mutational landscape")
# hide decorations
hidedecorations!(ax2)

# Mask fitness landscape
fit_landscape_masked = (mask .* minimum(fitness_mat[:, :, evo_stress])) .+
                       (fitness_mat[:, :, evo_stress] .* .!mask)


# Plot heatmap for fitness landscape
hm1 = heatmap!(
    ax1,
    latent1_range,
    latent2_range,
    fit_landscape_masked,
    colormap=ColorSchemes.viridis,
)

# Plot heatmap for log metric
hm2 = heatmap!(
    ax2,
    log_metric...,
    colormap=ColorSchemes.tokyo,
)
# Add colorbars
Colorbar(gl[2, 1], hm1, label="log₁₀(IC₅₀)", vertical=false)
Colorbar(gl[2, 2], hm2, label="log₁₀(√det(G))", vertical=false)


# Plot gradient vectors for fitness landscape
arrows!(
    ax1,
    z_vec[1, :],
    z_vec[2, :],
    getindex.(∇f_vec_norm, 1),
    getindex.(∇f_vec_norm, 2),
    arrowsize=7.5,
    lengthscale=0.2,
    arrowcolor=∇f_mag,
    linecolor=∇f_mag,
    colormap=ColorSchemes.magma,
)

# Plot gradient vectors for mutational landscape
arrows!(
    ax2,
    z_vec[1, :],
    z_vec[2, :],
    getindex.(∇mut_vec, 1),
    getindex.(∇mut_vec, 2),
    arrowsize=7.5,
    lengthscale=0.05,
    arrowcolor=∇mut_mag,
    linecolor=∇mut_mag,
    colormap=ColorSchemes.magma,
)

# Add global title
Label(
    gl[1, :, Top()],
    "Gradient Vectors | Antibiotic $(drugs[evo_stress])",
    fontsize=24,
    padding=(0, 0, 30, 0)
)

fig

With these gradients in hand, we can now approximate the dynamics with a 
discrete approximation of Eq. 5. This is
$$
\underline{z}_{t+1} \approx \underline{z}_t + \eta \left[
    \alpha \nabla F(\underline{z}_t) + \beta \nabla M(\underline{z}_t)
\right],
\tag{6}
$$
where $\eta$ is the step size.

Let's write a function that performs this discrete approximation.

In [None]:
"""
    mut_dynamics(z::AbstractVector, f::Function, rhvae::AET.RHVAEs.RHVAE; 
                 α::Float32=one(Float32), β::Float32=one(Float32), 
                 n_steps::Int=100, step_size::Float32=Float32(1E-3), log::Bool=true)

Simulate mutational dynamics in the latent space of a Riemannian Hamiltonian
Variational Autoencoder (RHVAE).

This function performs a discrete approximation of the mutational dynamics,
combining the gradients of a fitness function and a mutational landscape.

# Arguments
- `z::AbstractVector`: Initial point in the latent space.
- `f::Function`: Fitness function that takes a point in the latent space and
  returns a scalar fitness value.
- `rhvae::AET.RHVAEs.RHVAE`: The RHVAE model used for computing the mutational
  gradient.

# Optional Keyword Arguments
- `α::Float32=one(Float32)`: Weight for the fitness gradient.
- `β::Float32=one(Float32)`: Weight for the mutational gradient.
- `n_steps::Int=100`: Number of steps to simulate.
- `step_size::Float32=Float32(1E-3)`: Step size for each iteration.
- `log::Bool=true`: Whether to use logarithmic scaling for the mutational
  gradient.

# Returns
- `z_traj::Matrix`: A matrix where each column represents the position in latent
  space at each step.

# Notes
- The fitness gradient is computed using automatic differentiation (Zygote).
- The mutational gradient is computed using the `mut_gradient` function.
- The trajectory is updated using a weighted sum of both gradients.
"""
function mut_dynamics(
    z::AbstractVector,
    f::Function,
    rhvae::AET.RHVAEs.RHVAE;
    α::Float32=one(Float32),
    β::Float32=one(Float32),
    n_steps::Int=100,
    step_size::Float32=Float32(1E-3),
    log::Bool=true
)
    # Initialize trajectory
    z_traj = Matrix{eltype(z)}(undef, length(z), n_steps + 1)
    # Add initial point
    z_traj[:, 1] = z
    # Perform dynamics
    for i in 1:n_steps
        # Compute gradients
        # NOTE: The gradient for the fitness landscape is not the Riemannian
        # gradient.
        ∇f = Zygote.gradient(f, z_traj[:, i])[1]
        ∇mut = mut_gradient(z_traj[:, i], rhvae; log=log)
        # Update position
        z_traj[:, i+1] = z_traj[:, i] + step_size * (α * ∇f + β * ∇mut)
    end # for
    return z_traj
end # mut_dynamics


Let's again test this function by performing a gradient ascent on the fitness
landscape.


In [None]:
# Define environment to focus on
evo_stress = findfirst(x -> x == "TET", drugs)
# Filter data from this stress
df_drug = df_logic50[(df_logic50.evo.==drugs[evo_stress]), :]
# Find strain number of best individual
max_id = findmax(df_drug[!, "logic50_mean_std"])[2]
# Extract latent coordinates of best individual
df_lin = df_latent[df_latent.strain_num.==5, :]#df_drug.strain_num[max_id], :]
# Sort by day
DF.sort!(df_lin, :day)

first(df_lin, 5)

In [None]:
# Define anonymous function to compute fitness
f_env(z) = rhvae.vae.decoder(z).μ[evo_stress]
# Define step-sizes to evaluate
step_size = Float32.([1E-3, 1E-2])
# Locate latent coordinates of initial individual
zₒ = Float32.(vec(Matrix(df_lin[1:1, [:latent1, :latent2]])))
# Perform dynamics
z_traj = [
    mut_dynamics(
        zₒ, f_env, rhvae; n_steps=1_000, step_size=s, α=1.0f0, β=0.1f0, log=false
    ) for s in step_size
]

In [None]:
Random.seed!(42)
# Define the step size for plotting the gradient ascent trajectory
step = 1

# Initialize figure
fig = Figure(size=(800, 500))
# Add grid layout
gl = fig[1, 1] = GridLayout()

# Add axis for fitness landscape
ax1 = Axis(gl[1, 1], aspect=AxisAspect(1), title="Fitness landscape")
# hide decorations
hidedecorations!(ax1)
# Add axis for log metric
ax2 = Axis(gl[1, 2], aspect=AxisAspect(1), title="log(√det(G))")
# hide decorations
hidedecorations!(ax2)

# Mask fitness landscape
fit_landscape_masked = (mask .* minimum(fitness_mat[:, :, evo_stress])) .+
                       (fitness_mat[:, :, evo_stress] .* .!mask)

# Plot heatmap for fitness landscape
hm1 = heatmap!(
    ax1,
    latent1_range,
    latent2_range,
    fit_landscape_masked,
    colormap=ColorSchemes.viridis,
)
# Plot heatmap for log metric
hm2 = heatmap!(
    ax2,
    log_metric...,
    colormap=ColorSchemes.tokyo,
)

# Add colorbars
Colorbar(gl[2, 1], hm1, label="Fitness", vertical=false)
Colorbar(gl[2, 2], hm2, label="Curvature", vertical=false)

# Define colors for each step size
colors = get(ColorSchemes.inferno, range(0.5, 1, length=length(z_traj)))

# Loop through each step size
for (i, z) in enumerate(z_traj)
    lines!.(
        [ax1; ax2],
        Ref(z[1, 1:step:end]),
        Ref(z[2, 1:step:end]),
        linewidth=2.5,
        color=colors[i],
        # linestyle=:dash,
    )
end # for

# Plot evolutionary trajectory in latent space
scatter!.(
    [ax1; ax2],
    Ref(df_lin.latent1),
    Ref(df_lin.latent2),
    markersize=11,
    color=:black
)
scatterlines!.(
    [ax1; ax2],
    Ref(df_lin.latent1),
    Ref(df_lin.latent2),
    markersize=6,
    color=ColorSchemes.seaborn_colorblind[1],
)

# Add first point 
scatter!(
    ax1,
    [df_lin.latent1[1]],
    [df_lin.latent2[1]],
    color=:white,
    markersize=18,
    marker=:xcross
)
scatter!(
    ax1,
    [df_lin.latent1[1]],
    [df_lin.latent2[1]],
    color=:black,
    markersize=10,
    marker=:xcross
)

# Add last point
scatter!(
    ax1,
    [df_lin.latent1[end]],
    [df_lin.latent2[end]],
    color=:white,
    markersize=18,
    marker=:utriangle
)
scatter!(
    ax1,
    [df_lin.latent1[end]],
    [df_lin.latent2[end]],
    color=:black,
    markersize=10,
    marker=:utriangle
)

# Set x and y limits
xlims!.(
    [ax1; ax2],
    minimum(latent1_range), maximum(latent1_range)
)
ylims!.(
    [ax1; ax2],
    minimum(latent2_range), maximum(latent2_range)
)

# Add global title
Label(
    gl[1, :, Top()],
    "gradient ascent vs. evolutionary trajectory\nenvironment $(drugs[evo_stress])",
    fontsize=24,
    padding=(0, 0, 30, 0)
)

fig