(c) 2024 Manuel Razo. This work is licensed under a [Creative Commons
Attribution License CC-BY 4.0](https://creativecommons.org/licenses/by/4.0/).
All code contained herein is licensed under an [MIT
license](https://opensource.org/licenses/MIT).

In [30]:
# Import project package
import Antibiotic

# Import project package
import AutoEncoderToolkit as AET
import AutoEncoderToolkit.diffgeo.NeuralGeodesics as NG

# Import libraries to handel data
import CSV
import DataFrames as DF
import Glob

# Import ML libraries
import Flux

# Import library to save models
import JLD2

# Import basic math
import Distances
import LinearAlgebra
import StatsBase
import Random
Random.seed!(42)

# Import Plotting libraries
using CairoMakie
import ColorSchemes

# Activate backend
CairoMakie.activate!()

# Set Plotting style
Antibiotic.viz.theme_makie!()

# Geodesic curves in latent space

In this notebook, we will explore how well geodesic curves in latent space match
observed evolutionary trajectories.

To begin this exploration, we must list the paths where relevant files are
stored.

In [None]:
# Define version of RHVAE to use
version = "v05"

# Define data directory
data_dir = "$(git_root())/output/mcmc_iwasawa_logistic"

# Find path where output is stored
out_prefix = "beta-rhvae_jointlogencoder_simpledecoder_iwasawa_mcmc/$(version)"

# Define output directory
out_dir = "$(git_root())/output/$(out_prefix)"

# Define model directory
model_dir = "$(git_root())/output/$(out_prefix)/model_state"

# Define directory to store trained geodesic curves
geodesic_dir = "$(git_root())/output/$(out_prefix)/geodesic_state/"

Next, we list all of the files for the trained geodesic curves.

In [None]:
# List all files in the directory
geodesic_files = Glob.glob("$(geodesic_dir)*.jld2"[2:end], "/")

# Initialize dataframe to store files metadata
df_meta = DF.DataFrame()

# Loop over geodesic state files
for gf in geodesic_files
    # Extract initial generation number from file name using regular expression
    day_init = parse(Int, match(r"dayinit(\d+)", gf).captures[1])
    # Extract final generation number from file name using regular expression
    day_final = parse(Int, match(r"dayfinal(\d+)", gf).captures[1])
    # Extract evo stress number from file name using regular expression
    env = match(r"evoenv(\w+)_id", gf).captures[1]
    # Extract GRN id from file name using regular expression
    strain_num = parse(Int, match(r"id(\d+)", gf).captures[1])
    # Extract RHVAE epoch number from file name using regular expression
    rhvae_epoch = parse(Int, match(r"rhvaeepoch(\d+)", gf).captures[1])
    # Extract geodesic epoch number from file name using regular expression
    geo_epoch = parse(Int, match(r"geoepoch(\d+)", gf).captures[1])
    # Append as DataFrame
    DF.append!(
        df_meta,
        DF.DataFrame(
            :day_init => day_init,
            :day_final => day_final,
            :env => env,
            :strain_num => strain_num,
            :rhvae_epoch => rhvae_epoch,
            :geodesic_epoch => geo_epoch,
            :geodesic_state => gf,
        ),
    )
end # for gf in geodesic_files

# Sort dataframe by environment
DF.sort!(df_meta, :env)

first(df_meta, 4)

We also must load the `NeuralGeodesic` model template

In [None]:
println("Loading NeuralGeodesic template...")
nng_template = JLD2.load("$(out_dir)/geodesic.jld2")["model"].mlp

# Define number of points per axis
n_time = 500
# Define time points along curve
t_array = Float32.(collect(range(0, 1, length=n_time)));

Also, the trained `RHVAE` model

In [None]:
println("Loading trained RHVAE model...")

# Load RHVAE model
rhvae = JLD2.load("$(out_dir)/model.jld2")["model"]
# List parameters for epochs
param_files = sort(Glob.glob("$(model_dir)/*.jld2"[2:end], "/"))
# Load last epoch
Flux.loadmodel!(rhvae, JLD2.load(param_files[end])["model_state"])
# Update metric
AET.RHVAEs.update_metric!(rhvae);

Finally, we load the inferred $IC_{50}$ profiles and map them to the `RHVAE`
latent space.

In [None]:
println("Loading IC50 data...")

df_logic50 = CSV.read("$(data_dir)/logic50_ci.csv", DF.DataFrame)

println("Map data to latent space...")

# Group dataframe by :day, :strain_num, and :env
df_group = DF.groupby(df_logic50, [:day, :strain_num, :env])
# Initialize empty dataframe to store latent coordinates
df_latent = DF.DataFrame()
# Loop over groups
for data in df_group
    # Sort data by drug
    DF.sort!(data, :drug)
    # Run :logic50_mean_std through encoder
    latent = rhvae.vae.encoder(Float32.(data.logic50_mean_std)).µ
    # Append latent coordinates to dataframe
    DF.append!(
        df_latent,
        DF.DataFrame(
            :day .=> first(data.day),
            :strain_num .=> first(data.strain_num),
            :meta .=> first(data.env),
            :env .=> split(first(data.env), "_")[end],
            :strain .=> split(first(data.env), "_")[1],
            :latent1 => latent[1, :],
            :latent2 => latent[2, :],
        )
    )
end # for 

first(df_latent, 5)

The `RHVAE` model contains a metric tensor network that captures the geometry of
the learned latent space. We can use this network to compute the metric that we
can display in our plots to convey geometric information.

In [None]:
println("Compute Riemannian metric for latent space...")

# Define number of points per axis
n_points = 200

# Extract latent space ranges
latent1_range = range(
    minimum(df_latent.latent1) - 2,
    maximum(df_latent.latent1) + 2,
    length=n_points
)
latent2_range = range(
    minimum(df_latent.latent2) - 2,
    maximum(df_latent.latent2) + 2,
    length=n_points
)
# Define latent points to evaluate
z_mat = reduce(hcat, [[x, y] for x in latent1_range, y in latent2_range])

# Compute inverse metric tensor
Ginv = AET.RHVAEs.G_inv(z_mat, rhvae)

# Compute metric 
logdetG = reshape(
    -1 / 2 * AET.utils.slogdet(Ginv), n_points, n_points
);

Let's now plot a few example trajectories and the corresponding geodesic curves
to make sure we have everything in place.

In [None]:
println("Plotting specific geodesic trajectories at higher resolution...")

# Initialize figure
fig = Figure(size=(900, 600))

# Add grid layout
gl = fig[1, 1] = GridLayout()

# Define (:env, :strain_num) pairs to plot
pairs = [
    ("KM", 16), ("NFLX", 33), ("TET", 8), ("KM", 28), ("NFLX", 35), ("TET", 3)
]

# Define number of rows and columns
rows = 2
cols = 3

# Loop through pairs
for (i, p) in enumerate(pairs)
    println("env: $(p[1]) | strain: $(p[2])")
    # Extract metadata
    data_meta = df_meta[
        (df_meta.env.==p[1]).&(df_meta.strain_num.==p[2]), :
    ]
    # Define row and column index
    row = (i - 1) ÷ cols + 1
    col = (i - 1) % cols + 1
    # Add axis
    ax = Axis(
        gl[row, col],
        aspect=AxisAspect(1),
        title="evolution antibiotic: $(p[1])",
        xticksvisible=false,
        yticksvisible=false,
    )
    # Hide axis labels
    hidedecorations!(ax)

    # Plot heatmap of log determinant of metric tensor
    hm = heatmap!(
        ax,
        latent1_range,
        latent2_range,
        logdetG,
        colormap=ColorSchemes.tokyo,
    )

    # Plot all points in background
    scatter!(
        ax,
        df_latent.latent1,
        df_latent.latent2,
        markersize=5,
        color=(:gray, 0.25),
        marker=:circle,
    )

    # Extract lineage information
    lineage = df_latent[df_latent.strain_num.==p[2], :]

    # Plot lineage
    scatterlines!(
        ax,
        lineage.latent1,
        lineage.latent2,
        markersize=8,
        linewidth=2,
    )

    # # Load geodesic state
    geo_state = JLD2.load(first(data_meta.geodesic_state))
    # Define NeuralGeodesic model
    nng = NG.NeuralGeodesic(
        nng_template,
        geo_state["latent_init"],
        geo_state["latent_end"],
    )
    # Update model state
    Flux.loadmodel!(nng, geo_state["model_state"])
    # Generate curve
    curve = nng(t_array)
    # Add geodesic line to axis
    lines!(
        ax,
        eachrow(curve)...,
        linewidth=2,
        linestyle=(:dot, :dense),
        color=:white,
    )

    # Add first point 
    scatter!(
        ax,
        [lineage.latent1[1]],
        [lineage.latent2[1]],
        color=:white,
        markersize=18,
        marker=:xcross
    )
    scatter!(
        ax,
        [lineage.latent1[1]],
        [lineage.latent2[1]],
        color=:black,
        markersize=12,
        marker=:xcross
    )

    # Add last point
    scatter!(
        ax,
        [lineage.latent1[end]],
        [lineage.latent2[end]],
        color=:white,
        markersize=18,
        marker=:utriangle
    )
    scatter!(
        ax,
        [lineage.latent1[end]],
        [lineage.latent2[end]],
        color=:black,
        markersize=12,
        marker=:utriangle
    )

    # Check if plot is the last one
    if i == length(pairs)
        # Add grid layout inside grid layout
        gc = gl[1:2, 4] = GridLayout()
        # Add couple of empty grid layouts
        ge_top = gc[1, :] = GridLayout()
        ge_bottom = gc[4, :] = GridLayout()
        # Add colorbar
        Colorbar(gc[2:3, :], hm, label="logdet(G)")
    end
end # for p in pairs

fig

## Quantifying the match between evolutionary trajectories and geodesic curves

To quantify how well the geodesic curves match the observed evolutionary
trajectories we will perform the following steps:
1. Evaluate the geodesic curve at a high density sampling of points.
2. Quantify the pairwise distance between each observed point and each point
   along the geodesic.
3. Take the minimum of these distances as the "true" distance between the
   observed point and the geodesic.
4. Add all "true" distances for each observed point.
5. Normalize by the length of the geodesic curve.

Let's define a function that performs this set of steps.

In [None]:
function geodesic_match(
    data, geodesic;
    metric=Distances.Euclidean(),
    norm::Bool=false,
    norm_fn=x -> sqrt(sum(diff(x, dims=2) .^ 2))
)
    # Compute the pairwise distance between points
    dist = Distances.pairwise(metric, data, geodesic)

    # Compute the minimum distance for each data point
    dist_min = minimum(dist, dims=2)

    if norm
        # Compute length of geodesic curve
        geo_length = norm_fn(geodesic)

        # Return normalized deviations
        return sum(dist_min) / geo_length
    else
        # Return normalized deviations
        return sum(dist_min)
    end # if
end # function

To test the function, let's use an example trajectory and the corresponding
geodesic curve.

In [None]:
# Define example pair
p = ("KM", 16)

# Extract metadata
data_meta = df_meta[
    (df_meta.env.==p[1]).&(df_meta.strain_num.==p[2]), :
]

# Extract lineage information
lineage = df_latent[df_latent.strain_num.==p[2], :]

# Load geodesic state
geo_state = JLD2.load(first(data_meta.geodesic_state))
# Define NeuralGeodesic model
nng = NG.NeuralGeodesic(
    nng_template,
    geo_state["latent_init"],
    geo_state["latent_end"],
)
# Update model state
Flux.loadmodel!(nng, geo_state["model_state"])
# Generate curve
curve = nng(t_array)

# Compute residuals
res = geodesic_match(
    Matrix(lineage[:, [:latent1, :latent2]])', curve,
    norm=false, metric=Distances.SqEuclidean()
)
norm_res = geodesic_match(
    Matrix(lineage[:, [:latent1, :latent2]])', curve, norm=true
)


println("Residuals: $(res)")
println("Normalized residuals: $(norm_res)")

With this definition in hand, we can compute this for all lineages.

In [None]:
# Initialize array to store residuals
norm_res = zeros(Float32, DF.nrow(df_meta))
res = similar(norm_res)
# Loop through each row of df_meta
for (i, data_meta) in enumerate(eachrow(df_meta))
    # Extract lineage information
    lineage = df_latent[df_latent.strain_num.==data_meta.strain_num, :]
    # Load geodesic state
    geo_state = JLD2.load(data_meta.geodesic_state)
    # Define NeuralGeodesic model
    nng = NG.NeuralGeodesic(
        nng_template,
        geo_state["latent_init"],
        geo_state["latent_end"],
    )
    # Update model state
    Flux.loadmodel!(nng, geo_state["model_state"])
    # Generate curve
    geodesic_curve = nng(t_array)
    # Compute residuals
    norm_res[i] = geodesic_match(
        Matrix(lineage[:, [:latent1, :latent2]])', geodesic_curve, norm=true
    )
    res[i] = geodesic_match(
        Matrix(lineage[:, [:latent1, :latent2]])', geodesic_curve,
        norm=false, metric=Distances.SqEuclidean()
    )

end # for

# Add normalized residuals to df_meta
DF.insertcols!(
    df_meta,
    :geo_res => res,
    :geo_norm_res => norm_res
)

first(df_meta[:, [:env, :strain_num, :geo_res, :geo_norm_res]], 5)

To put these results in context, we can compare the normalized residuals with
respect to the geodesic to a naive straight line between initial and final 
points.

In [None]:
# Initialize array to store residuals
norm_res = zeros(Float32, DF.nrow(df_meta))
res = similar(norm_res)
# Loop through each row of df_meta
for (i, data_meta) in enumerate(eachrow(df_meta))
    # Extract lineage information
    lineage = df_latent[df_latent.strain_num.==data_meta.strain_num, :]
    # Load geodesic state
    geo_state = JLD2.load(data_meta.geodesic_state)
    # Extract initial and final positions
    latent_init = geo_state["latent_init"]
    latent_final = geo_state["latent_end"]
    # Generate evenly spaced points for x and y coordinates
    x_coord = Float32.(
        collect(LinRange(latent_init[1], latent_final[1], length(t_array)))
    )
    y_coord = Float32.(
        collect(LinRange(latent_init[2], latent_final[2], length(t_array)))
    )
    # Concatentate arrays into curve 
    affine_curve = hcat([x_coord, y_coord]...)'
    # Compute residuals
    norm_res[i] = geodesic_match(
        Matrix(lineage[:, [:latent1, :latent2]])', affine_curve, norm=true
    )
    # Compute residuals
    res[i] = geodesic_match(
        Matrix(lineage[:, [:latent1, :latent2]])', affine_curve,
        norm=false, metric=Distances.SqEuclidean()
    )
end # for

# Add normalized residuals to df_meta
DF.insertcols!(
    df_meta,
    :lin_res => res,
    :lin_norm_res => norm_res
)

first(df_meta[:, [:env, :strain_num, :geo_res, :lin_res]], 5)

In [None]:
# Initialize figure
fig = Figure(size=(450, 300))

# Add axis
ax = Axis(
    fig[1, 1],
    xlabel="∑residuals²",
    ylabel="ECDF"
)

# Plot ECDF
ecdfplot!(
    ax,
    df_meta.geo_res,
    label="geodesic",
    color=ColorSchemes.seaborn_colorblind[1],
    linewidth=2,
)
ecdfplot!(
    ax,
    df_meta.lin_res,
    label="straight line",
    color=ColorSchemes.seaborn_colorblind[2],
    linewidth=2,
)

# Add legend
Legend(fig[1, 2], ax)

fig

In [None]:
# Initialize figure
fig = Figure(size=(450, 300))

# Add axis
ax = Axis(
    fig[1, 1],
    xlabel="normalized residuals",
    ylabel="ECDF"
)

# Plot ECDF
ecdfplot!(
    ax,
    df_meta.geo_norm_res,
    label="geodesic",
    color=ColorSchemes.seaborn_colorblind[1],
    linewidth=2,
)
ecdfplot!(
    ax,
    df_meta.lin_norm_res,
    label="straight line",
    color=ColorSchemes.seaborn_colorblind[2],
    linewidth=2,
)

# Add legend
Legend(fig[1, 2], ax)

fig

In [None]:
# Initialize figure
fig = Figure(size=(300, 300))

# Add axis
ax = Axis(
    fig[1, 1],
    xlabel="geodesic curve residuals",
    ylabel="affine curve residuals",
    aspect=AxisAspect(1)
)

# Add diagonal line
lines!(ax, [0, 30], [0, 30], color=:black, linestyle=:dash)
#  Plot residuals
scatter!(ax, df_meta.geo_res, df_meta.lin_res)

fig

In [None]:
function geodesic_mse_recon(
    data, geodesic, decoder;
    metric=Distances.SqEuclidean(),
)
    # Compute the pairwise distance between points
    dist = Distances.pairwise(metric, data, geodesic)

    # Compute the minimum distance for each data point
    idx_min = getindex.(argmin(dist, dims=2), 2)

    # Run data through decoder
    data_recon = decoder(data).µ

    # Run minimum distance through decoder
    geo_recon = decoder(geodesic[:, idx_min]).µ

    # Return MSE
    return Flux.mse(data_recon, geo_recon)
end # function

In [None]:
geo_recon = zeros(Float32, DF.nrow(df_meta))
lin_recon = similar(geo_recon)
# Loop through each row of df_meta
for (i, data_meta) in enumerate(eachrow(df_meta))
    # Extract lineage information
    lineage = df_latent[df_latent.strain_num.==data_meta.strain_num, :]
    # Load geodesic state
    geo_state = JLD2.load(data_meta.geodesic_state)
    # Extract initial and final positions
    latent_init = geo_state["latent_init"]
    latent_final = geo_state["latent_end"]
    # Generate evenly spaced points for x and y coordinates
    x_coord = Float32.(
        collect(LinRange(latent_init[1], latent_final[1], length(t_array)))
    )
    y_coord = Float32.(
        collect(LinRange(latent_init[2], latent_final[2], length(t_array)))
    )
    # Concatentate arrays into curve 
    affine_curve = hcat([x_coord, y_coord]...)'
    # Define NeuralGeodesic model
    nng = NG.NeuralGeodesic(
        nng_template,
        geo_state["latent_init"],
        geo_state["latent_end"],
    )
    # Update model state
    Flux.loadmodel!(nng, geo_state["model_state"])
    # Generate curve
    geodesic_curve = nng(t_array)
    # Compute MSE on reconstruction
    geo_recon[i] = geodesic_mse_recon(
        Matrix(lineage[:, [:latent1, :latent2]])',
        geodesic_curve,
        rhvae.vae.decoder
    )
    lin_recon[i] = geodesic_mse_recon(
        Matrix(lineage[:, [:latent1, :latent2]])',
        affine_curve,
        rhvae.vae.decoder
    )
end # for

# Add columns
DF.insertcols!(
    df_meta,
    :geo_mse_recon => geo_recon,
    :lin_mse_recon => lin_recon
)

first(df_meta[:, [:env, :strain_num, :geo_mse_recon, :lin_mse_recon]], 5)

In [None]:
# Initialize figure
fig = Figure(size=(450, 300))

# Add axis
ax = Axis(
    fig[1, 1],
    xlabel="reconstruction MSE",
    ylabel="ECDF"
)

# Plot ECDF
ecdfplot!(
    ax,
    df_meta.geo_mse_recon,
    label="geodesic",
    color=ColorSchemes.seaborn_colorblind[1],
    linewidth=2,
)
ecdfplot!(
    ax,
    df_meta.lin_mse_recon,
    label="straight line",
    color=ColorSchemes.seaborn_colorblind[2],
    linewidth=2,
)

# Add legend
Legend(fig[1, 2], ax)

fig

In [None]:
# Initialize figure
fig = Figure(size=(300, 300))

# Add axis
ax = Axis(
    fig[1, 1],
    xlabel="geodesic curve\nreconstruction MSE",
    ylabel="affine curve\nreconstruction MSE",
    aspect=AxisAspect(1)
)

# Add diagonal line
lines!(ax, [0, 0.4], [0, 0.4], color=:black, linestyle=:dash)
#  Plot residuals
scatter!(ax, df_meta.geo_mse_recon, df_meta.lin_mse_recon)

fig