(c) 2024 Manuel Razo. This work is licensed under a [Creative Commons
Attribution License CC-BY 4.0](https://creativecommons.org/licenses/by/4.0/).
All code contained herein is licensed under an [MIT
license](https://opensource.org/licenses/MIT).

In [1]:
# Import project package
import Antibiotic

# Import project package
import AutoEncode
import AutoEncode.diffgeo.NeuralGeodesics as NG

# Import libraries to handel data
import CSV
import DataFrames as DF
import Glob

# Import ML libraries
import Flux

# Import library to save models
import JLD2

# Import basic math
import LinearAlgebra
import StatsBase
import Random
Random.seed!(42)

# Import Plotting libraries
using CairoMakie
import ColorSchemes

# Activate backend
CairoMakie.activate!()

# Set Plotting style
Antibiotic.viz.theme_makie!()

`[explanation here]`

Now, let's load the data.

In [None]:
# Define data directory
data_dir = "$(git_root())/data/Iwasawa_2022"

# Load file into memory
df_ic50 = CSV.read("$(data_dir)/iwasawa_ic50_tidy.csv", DF.DataFrame)

# Locate strains with missing values
missing_strains = unique(df_ic50[ismissing.(df_ic50.log2ic50), :strain])

# Remove data
df_ic50 = df_ic50[[x ∉ missing_strains for x in df_ic50.strain], :]

# Group data by strain and day
df_group = DF.groupby(df_ic50, [:strain, :day])

# Extract unique drugs to make sure the matrix is built correctly
drug = sort(unique(df_ic50.drug))

# Initialize matrix to save ic50 values
ic50_mat = Matrix{Float32}(undef, length(drug), length(df_group))

# Loop through groups
for (i, data) in enumerate(df_group)
    # Sort data by stress
    DF.sort!(data, :drug)
    # Check that the stress are in the correct order
    if all(data.drug .== drug)
        # Add data to matrix
        ic50_mat[:, i] = Float32.(data.log2ic50)
    else
        println("group $i stress does not match")
    end # if
end # for

# Define number of environments
n_env = size(ic50_mat, 1)
# Define number of samples
n_samples = size(ic50_mat, 2)

# Fit model to standardize data to mean zero and standard deviation 1 on each
# environment
dt = StatsBase.fit(StatsBase.ZScoreTransform, ic50_mat, dims=2)

# Center data to have mean zero and standard deviation one
ic50_std = StatsBase.transform(dt, ic50_mat)

## PCA on the data  

Let's begin by performing PCA on the data, we will do this via SVD.

In [3]:
# Perform SVD on the data
ic50_svd = LinearAlgebra.svd(ic50_std)

# Extract principal components
pcs = ic50_svd.U

# Convert singular values to explained variance
ic50_var = ic50_svd.S .^ 2 / (n_samples - 1)

# Compute explained variance percentage
ic50_var_pct = ic50_var / sum(ic50_var);

Let's look at the percentage of variance explained by each principal component.

In [None]:
# Initialize figure
fig = Figure(size=(600, 250))

# Add axis
ax1 = Axis(
    fig[1, 1], xlabel="component", ylabel="fraction variance explained"
)

# Plot explained variance
scatterlines!(ax1, 1:n_env, ic50_var_pct)

# Add axis
ax2 = Axis(
    fig[1, 2],
    xlabel="component",
    ylabel="cumulative fraction \nvariance explained"
)

# Plot cumulative explained variance
scatterlines!(ax2, 1:n_env, cumsum(ic50_var_pct))

# Set y axis limits
ylims!(ax2, [0, 1.05])

fig

We can see that the first two principal componets explain ≈ 60% of the variance
in the data. Now, let's plot the mean squared error for the reconstruction.


In [None]:
# Compute mean squared error of the data reconstruction
pca_mse = [
    begin
        # Compute the reconstruction of the data
        ic50_recon = pcs[:, 1:i] * (pcs[:, 1:i]' * ic50_std)
        # Compute the mean squared error
        Flux.mse(ic50_std, ic50_recon)
    end for i in 1:n_env
]

# Initialize figure
fig = Figure(size=(300, 250))

# Add axis
ax = Axis(
    fig[1, 1], xlabel="# components included", ylabel="mean squared error"
)

# Plot explained variance
scatterlines!(ax, 1:n_env, pca_mse)

fig

Let's plot the data projected onto the first two principal
components.

In [None]:
# Project data to the first two principal components
data_pca = pcs[:, 1:2]' * ic50_std

# Convert data to DataFrame
df_pca = DF.DataFrame(
    data_pca',
    [:pc1, :pc2],
)

# Change sign for pc1
df_pca.pc1 = -df_pca.pc1

# Extract strains as ordered in ic50 matrix
strains_mat = [x.strain for x in keys(df_group)]
day_mat = [x.day for x in keys(df_group)]

# Add strains and days to DataFrame
DF.insertcols!(
    df_pca,
    :strain => strains_mat,
    :day => day_mat
)

# Add corresponding metadata resistance value
df_pca = DF.leftjoin!(
    df_pca,
    unique(df_ic50[:, [:strain, :day, :parent, :env]]),
    on=[:strain, :day]
)

first(df_pca, 5)

Now, let's look at the distribution of the data in the first two principal
components.

In [None]:
# Initialize figure
fig = Figure(size=(300, 300))

# Add axis
ax = Axis(fig[1, 1], xlabel="PC1", ylabel="PC2")

# Plot data
scatter!(ax, df_pca.pc1, df_pca.pc2, markersize=5)

fig

Note that we invert the PC1 axis. This is for the plot to match what the Iwasawa
et al. paper showed.

Let's repeat this plot, but now, let's color the data points by the environment
in which they evolved.

In [None]:
# Initialize figure
fig = Figure(size=(375, 300))

# Add axis
ax = Axis(fig[1, 1], xlabel="PC1", ylabel="PC2", aspect=AxisAspect(1))

# Convert df_pca.env categorical variables to integers
env_int = [findfirst(x .== unique(df_pca.env)) for x in df_pca.env]

# Group PCA data by :env
df_group_pca = DF.groupby(df_pca, :env)

# Loop through groups
for data in df_group_pca
    # Plot data
    scatter!(
        ax,
        data.pc1,
        data.pc2,
        markersize=5,
        label=first(data.env)
    )
end

# Add legend
Legend(fig[1, 2], ax, labelsize=10)

# Loop through groups
fig

## RHVAE latent space

Let's now contrast this with the resulting latent space from the RHVAE model.
First, we need to load the model.

In [None]:
# Define temp
T = 0.5f0

# Define model directory
model_dir = "$(git_root())/code/processing/" *
            "beta-rhvae_jointlogencoder_simpledecoder_iwasawa_fitness/v02/" *
            "output"

# Load model template
rhvae = JLD2.load("$(model_dir)/model.jld2")["model"]

rhvae = AutoEncode.RHVAEs.RHVAE(
    deepcopy(rhvae.vae),
    deepcopy(rhvae.metric_chain),
    deepcopy(rhvae.centroids_data),
    deepcopy(rhvae.centroids_latent),
    deepcopy(rhvae.L),
    deepcopy(rhvae.M),
    T,
    deepcopy(rhvae.λ)
)

# List files in output directory
files = sort(
    Glob.glob("$(model_dir)/model_state/*$(T)temp*.jld2"[2:end], "/")
)

# Load the MSE for each epoch
mse_train = [JLD2.load(files[i])["mse_train"] for i in 1:length(files)]
mse_val = [JLD2.load(files[i])["mse_val"] for i in 1:length(files)]

# Load parameters
model_state = JLD2.load(files[end])["model_state"]

# Set model parameters
Flux.loadmodel!(rhvae, model_state)
# Update metric
AutoEncode.RHVAEs.update_metric!(rhvae)

typeof(rhvae)

Let's look at the learning curve of the model.

In [None]:
# Initialize figure
fig = Figure(size=(300, 300))

# Add axis
ax = Axis(fig[1, 1], xlabel="epoch", ylabel="MSE")

# Plot MSE for training and validation
scatterlines!(ax, mse_train, markersize=5, label="train")
scatterlines!(ax, mse_val, markersize=5, label="val")

# Add legend
axislegend(ax, position=:rt, labelsize=10)

# Label axis
ax.title = "Training MSE"
ax.xlabel = "epoch"
ax.ylabel = "MSE"

fig

Now, we can project the data into the latent space.

In [None]:
# Project data to RHVAE latent space
data_latent = rhvae.vae.encoder(ic50_std).µ

# Convert data to DataFrame
df_latent = DF.DataFrame(
    data_latent',
    [:z1, :z2],
)

# Extract strains as ordered in ic50 matrix
strains_mat = [x.strain for x in keys(df_group)]
day_mat = [x.day for x in keys(df_group)]

# Add strains and days to DataFrame
DF.insertcols!(
    df_latent,
    :strain => strains_mat,
    :day => day_mat
)

# Add corresponding metadata resistance value
df_latent = DF.leftjoin!(
    df_latent,
    unique(df_ic50[:, [:strain, :day, :parent, :env]]),
    on=[:strain, :day]
)

first(df_latent, 5)

Let's now plot side by side the PCA and RHVAE latent space.

In [None]:
# Merge PCA and latent data
df_red = DF.leftjoin(df_pca, df_latent, on=[:strain, :day, :parent, :env])

# Initialize figure
fig = Figure(size=(700, 300))

# Initialize axis
axes = [
    Axis(
        fig[1, 1],
        xlabel="PC1",
        ylabel="PC2",
        aspect=AxisAspect(1),
        title="PCA"
    ),
    Axis(
        fig[1, 2],
        xlabel="latent dimension 1",
        ylabel="latent dimension 2",
        aspect=AxisAspect(1),
        title="RHVAE"
    ),
]

# Group data by :env
df_red_group = DF.groupby(df_red, :env)

# Loop through groups
for data in df_red_group
    # Plot data
    scatter!(
        axes[1],
        data.pc1,
        data.pc2,
        markersize=5,
        label=first(data.env)
    )
    scatter!(
        axes[2],
        data.z1,
        data.z2,
        markersize=5,
        label=first(data.env)
    )
end # for

# Add legend
leg = Legend(fig[1, 3], axes[2], "evolution\ncondition", labelsize=12, titlesize=12)

fig

Points on the RHVAE latent space seem to be more separated than in the PCA
space.

For the RHVAE plot, we should compute the curvature of the manifold to get a
sense of the local structure of the latent space.

In [44]:
# Define number of points per axis
n_points = 300

# Define range of latent space
latent_range_z1 = Float32.(range(-3.2, 3.2, length=n_points))
latent_range_z2 = Float32.(range(-3.5, 3, length=n_points))

# Define latent points to evaluate
z_mat = reduce(hcat, [[x, y] for x in latent_range_z1, y in latent_range_z2])

# Compute inverse metric tensor
Ginv = AutoEncode.RHVAEs.G_inv(z_mat, rhvae)

# Compute log determinant of metric tensor
logdetG = reshape(-1 / 2 * AutoEncode.utils.slogdet(Ginv), n_points, n_points);

Now, let's repeat the plot, but this time showhing the curvature of the
manifold.

In [None]:
# Merge PCA and latent data
df_red = DF.leftjoin(df_pca, df_latent, on=[:strain, :day, :parent, :env])

# Initialize figure
fig = Figure(size=(700, 300))

# Initialize axis
axes = [
    Axis(
        fig[1, 1],
        xlabel="PC1",
        ylabel="PC2",
        aspect=AxisAspect(1),
        title="PCA"
    ),
    Axis(
        fig[1, 2],
        xlabel="latent dimension 1",
        ylabel="latent dimension 2",
        aspect=AxisAspect(1),
        title="RHVAE"
    ),
]

# Plot heatmap of log determinant of metric tensor
hm = heatmap!(
    axes[2], latent_range_z1, latent_range_z2, logdetG, colormap=:tokyo
)

# Add colorbar
cb = Colorbar(
    fig[1, 3],
    hm,
    size=8,
    label="log√det(G)",
    labelsize=12,
    labelpadding=0.0,
    ticklabelsize=12,
    ticksvisible=false
)

# Group data by :env
df_red_group = DF.groupby(df_red, :env)

# Loop through groups
for data in df_red_group
    # Plot data
    scatter!(
        axes[1],
        data.pc1,
        data.pc2,
        markersize=5,
        label=first(data.env)
    )
    scatter!(
        axes[2],
        data.z1,
        data.z2,
        markersize=5,
        label=first(data.env)
    )
end # for

# Add legend
leg = Legend(
    fig[1, 4], axes[2], "evolution\ncondition", labelsize=12, titlesize=12
)

fig

Let's now look at the trajectories of different strains in the latent space.

In [None]:
# List unique environments
envs = unique(df_red.env)

# Initialize figure
fig = Figure(size=(1_000, 500))

# Add Grid layout for each environment
gls = [
    GridLayout(fig[i, j])
    for i in 1:2 for j in 1:2
]

# Loop through environments
for (i, env) in enumerate(envs)
    # Add axis
    axes = [
        Axis(
            gls[i][1, 1],
            xlabel="PC1",
            ylabel="PC2",
            aspect=AxisAspect(1),
        ),
        Axis(
            gls[i][1, 2],
            xlabel="latent dimension 1",
            ylabel="latent dimension 2",
            aspect=AxisAspect(1),
        ),
    ]

    # Add subplots subtitle
    Label(
        gls[i][1, :, Top()],
        env,
        valign=:bottom,
        font=:bold,
        padding=(0, 0, 5, 0)
    )

    # Plot PCA points as gray background
    scatter!(axes[1], df_red.pc1, df_red.pc2, markersize=5, color=(:gray, 0.5))
    # Plot heatmap of log determinant of metric tensor
    hm = heatmap!(
        axes[2], latent_range_z1, latent_range_z2, logdetG, colormap=:tokyo
    )

    # Add colorbar
    cb = Colorbar(
        gls[i][1, 3],
        hm,
        size=8,
        label="log√det(G)",
        labelsize=12,
        labelpadding=0.0,
        ticklabelsize=12,
        ticksvisible=false
    )

    # Extract data for environment
    data = df_red[df_red.env.==env, :]

    # Group data by strain
    df_env_group = DF.groupby(data, :strain)

    # Loop through groups
    for (j, d) in enumerate(df_env_group)
        # Sort data by day
        DF.sort!(d, :day)

        # Plot trajectories in PCA space
        scatterlines!(
            axes[1],
            d.pc1,
            d.pc2,
            markersize=5,
            color=ColorSchemes.glasbey_bw_minc_20_hue_330_100_n256[j]
        )

        # Add first point 
        scatter!(
            axes[1],
            [d.pc1[1]],
            [d.pc2[1]],
            color=:white,
            markersize=12,
            marker=:xcross
        )
        scatter!(
            axes[1],
            [d.pc1[1]],
            [d.pc2[1]],
            color=:black,
            markersize=8,
            marker=:xcross
        )

        # Add last point
        scatter!(
            axes[1],
            [d.pc1[end]],
            [d.pc2[end]],
            color=:white,
            markersize=12,
            marker=:utriangle
        )
        scatter!(
            axes[1],
            [d.pc1[end]],
            [d.pc2[end]],
            color=:black,
            markersize=8,
            marker=:utriangle
        )

        # Plot trajectories in RHVAE latent space
        scatterlines!(
            axes[2],
            d.z1,
            d.z2,
            markersize=5,
            color=ColorSchemes.glasbey_bw_minc_20_hue_330_100_n256[j]
        )

        # Add first point 
        scatter!(
            axes[2],
            [d.z1[1]],
            [d.z2[1]],
            color=:white,
            markersize=12,
            marker=:xcross
        )
        scatter!(
            axes[2],
            [d.z1[1]],
            [d.z2[1]],
            color=:black,
            markersize=8,
            marker=:xcross
        )

        # Add last point
        scatter!(
            axes[2],
            [d.z1[end]],
            [d.z2[end]],
            color=:white,
            markersize=12,
            marker=:utriangle
        )
        scatter!(
            axes[2],
            [d.z1[end]],
            [d.z2[end]],
            color=:black,
            markersize=8,
            marker=:utriangle
        )
    end # for
end

save("/Users/mrazo/Downloads/fig.pdf", fig)
fig

## Mean squared error for the reconstruction of out-of-sample data

To compare the ability of PCA vs RHVAE to reconstruct the data, we will split
the data into a training and a test set. We will train the PCA on the training
set and then reconstruct the test set. We will then calculate the mean squared
error for the reconstruction.

We will repeat this process multiple times to get a distribution of the mean
squared error on in-sample and out-of-sample data.

In [None]:
# Define splitting fraction
split_frac = 0.85
# Define number of bootstrap samples
n_samples = 1_000

# Initialize dataframe to save bootstrap results
df_boots = DF.DataFrame()

# Set random seed
Random.seed!(42)

# Loop through bootstrap samples
for s in 1:n_samples
    # Split data into training and testing
    train_data, val_data = Flux.splitobs(ic50_std, at=split_frac, shuffle=true)

    # Perform SVD on the data
    train_svd = LinearAlgebra.svd(train_data)
    # Extract principal components
    pcs = train_svd.U

    # Compute mean squared error of the train data reconstruction
    train_pca_mse = [
        begin
            # Compute the reconstruction of the training data
            train_recon = pcs[:, 1:i] * (pcs[:, 1:i]' * train_data)
            # Compute the mean squared error
            Flux.mse(train_data, train_recon)
        end for i in 1:n_env
    ]

    # Compute mean squared error of the validation data reconstruction
    val_pca_mse = [
        begin
            # Compute the reconstruction of the validation data
            val_recon = pcs[:, 1:i] * (pcs[:, 1:i]' * val_data)
            # Compute the mean squared error
            Flux.mse(val_data, val_recon)
        end for i in 1:n_env
    ]

    # Initialize temporary dataframe
    df_tmp = DF.DataFrame(
        mse_train=train_pca_mse,
        mse_val=val_pca_mse,
        pcs=1:n_env,
        sample=s
    )

    # Append temporary dataframe to the main dataframe
    DF.append!(df_boots, df_tmp)
end # for

first(df_boots, 5)

Let's now plot the distribution of the mean squared error for the in-sample and
out-of-sample data.

In [None]:
# Initialize figure
fig = Figure(size=(400, 300))

# Add axis
ax = Axis(
    fig[1, 1],
    xlabel="# principal components included",
    ylabel="reconstruction\nmean squared error",
    xticks=1:n_env,
)

# Boxplot of the MSE in training data
boxplot!(
    ax, df_boots.pcs .- 0.2, df_boots.mse_train, label="PCA train", width=0.5
)
# Boxplot of the MSE in validation data
boxplot!(
    ax, df_boots.pcs .+ 0.2, df_boots.mse_val, label="PCA validation", width=0.5
)

# Add horizontal line for RHVAE MSE on training and validation data
hlines!(
    ax,
    [mse_train[end], mse_train[end]],
    linestyle=:dash,
    linewidth=2.5,
    label="2D RHVAE train",
)
hlines!(
    ax,
    [mse_val[end], mse_val[end]],
    linestyle=:dot,
    linewidth=2.5,
    label="2D RHVAE validation",
)
# Add legend
axislegend(ax, position=:rt, labelsize=11)

fig