(c) 2024 Manuel Razo. This work is licensed under a [Creative Commons
Attribution License CC-BY 4.0](https://creativecommons.org/licenses/by/4.0/).
All code contained herein is licensed under an [MIT
license](https://opensource.org/licenses/MIT).

In [None]:
# Import project package
import Antibiotic

# Import package to handle DataFrames
import DataFrames as DF
import CSV

# Import library for Bayesian inference
import Turing

# Import library to list files
import Glob

# Import packages to work with data
import DataFrames as DF

# Load CairoMakie for plotting
using CairoMakie
import PairPlots
import ColorSchemes

# Import basic math libraries
import StatsBase
import LinearAlgebra
import Random

# Activate backend
CairoMakie.activate!()

# Set PBoC Plotting style
Antibiotic.viz.theme_makie!()

# Exploratory data analysis of MCMC samples

In this notebook, we will perform an exploratory data analysis (EDA) of the
Bayesian model fit to the Iwasawa et al. data.

Let's begin by loading the raw data.

In [None]:
# Load data into a DataFrame
df = CSV.read(
    "$(git_root())/data/Iwasawa_2022/iwasawa_tidy.csv", DF.DataFrame
)

# Remove blank measurements
df = df[.!df.blank, :]
# Remove zero concentrations
df = df[df.concentration_ugmL.>0, :]

first(df, 5)

Next, let's define the model fit to the data.

In [None]:
@doc raw"""
    logistic(x, a, b, c, ic50)

Compute the logistic function used to model the relationship between antibiotic
concentration and bacterial growth.

This function implements the following equation:

f(x) = a / (1 + (x / ic50)^b) + c

# Arguments
- `x`: Antibiotic concentration (input variable)
- `a`: Maximum effect parameter (difference between upper and lower asymptotes)
- `b`: Slope parameter (steepness of the curve)
- `c`: Minimum effect parameter (lower asymptote)
- `ic50`: IC₅₀ parameter (concentration at which the effect is halfway between
  the minimum and maximum)

# Returns
The computed effect (e.g., optical density) for the given antibiotic
concentration and parameters.

Note: This function is vectorized and can handle array inputs for `x`.
"""
function logistic(logx, a, b, c, logic50)
    return @. a / (1.0 + exp(b * (logx - logic50))) + c
end

Now, let's list the MCMC chains stored as dataframes.

In [None]:
# Define output directory
out_dir = "$(git_root())/data/Iwasawa_2022/mcmc_nonnegative"

# List all files in the output directory
files = sort(Glob.glob("$(out_dir)/*.csv"[2:end], "/"))

# Initialize empty dataframe to store metadata
df_meta = DF.DataFrame()

# Loop over each file and extract metadata
for file in files
    # Extract antibiotic from filename using regular expressions
    antibiotic = match(r"_(\w+)antibiotic", file).captures[1]

    # Extract day from filename using regular expressions
    day = parse(Int, match(r"_(\d+)day", file).captures[1])

    # Extract strain from filename using regular expressions
    strain_num = parse(Int, match(r"_(\d+)strain", file).captures[1])

    # Extract design from filename using regular expressions
    design = parse(Int, match(r"_(\d+)design", file).captures[1])

    # Extract environment from filename using regular expressions
    env = match(r"design_(\w+)env", file).captures[1]

    # Create a new row with the extracted metadata
    DF.append!(df_meta, DF.DataFrame(
        antibiotic=antibiotic,
        day=day,
        strain_num=strain_num,
        design=design,
        env=env,
        file=file
    ))
end

println("Number of MCMC samples: $(size(df_meta, 1))")
first(df_meta, 5)

Let's load an example file and look at the posterior predictive samples.

In [None]:
Random.seed!(42)

# Initialize figure
fig = Figure(size=(900, 900))

# Define number of rows and columns
rows = 4
cols = 4

# Select random indexes
idxs = Random.randperm(size(df_meta, 1))[1:rows*cols]

# Loop through each plot
for i in 1:rows*cols
    # Locate row and column index
    row_idx = (i - 1) ÷ cols + 1
    col_idx = (i - 1) % cols + 1

    # Extract row to analyze
    row = df_meta[idxs[i], :]

    # Add axis
    ax = Axis(
        fig[row_idx, col_idx],
        xlabel="antibiotic concentration",
        ylabel="optical density",
        title="$(row.env) | day $(row.day)",
        xscale=log10,
        yscale=log10,
    )

    # Load the file
    chain = CSV.read(row.file, DF.DataFrame)

    # Extract corresponding data from the raw data
    data = df[
        (df.antibiotic.==row.antibiotic).&(df.day.==row.day).&(df.env.==row.env).&(df.strain_num.==row.strain_num),
        :]
    # Sort data by concentration
    sort!(data, :concentration_ugmL)

    # Locate unique concentrations
    unique_concentrations = unique(data.concentration_ugmL)

    # Initialize matrix to store samples
    y_samples = Array{Float64}(
        undef, length(unique_concentrations), size(chain, 1)
    )

    # Loop through samples
    for i in 1:size(chain, 1)
        logy_samples = log.(logistic(
            log.(unique_concentrations),
            chain[i, :a],
            chain[i, :b],
            chain[i, :c],
            chain[i, :logic50]
        ))
        # Add noise
        logy_samples .+= randn(length(unique_concentrations)) * √(chain[i, :σ²])
        y_samples[:, i] = exp.(logy_samples)

        # Plot samples
        lines!(
            ax,
            unique_concentrations,
            y_samples[:, i],
            color=(ColorSchemes.Paired_12[1], 0.5)
        )
    end # for

    # Plot data
    scatter!(
        ax, data.concentration_ugmL, data.OD, color=ColorSchemes.Paired_12[2]
    )

end # for

fig

The fit looks decent on these random sample.

## Comparison with point estimates

Let's compare the $\mathrm{IC}_{50}$ parameter with the point estimate reported
in Iwasawa et al. (2022).

To do this, we will compute the mean and 95% credible region for the
$\mathrm{IC}_{50}$ parameter.

In [None]:
# Initialize dataframe to store results
df_ic50 = DF.DataFrame()

# Loop over each file
for row in eachrow(df_meta)
    # Read file
    chain = CSV.read(row.file, DF.DataFrame)

    # Compute mean and 95% credible interval
    logic50_mean = StatsBase.mean(chain[:, :logic50])
    logic50_ci = StatsBase.quantile(chain[:, :logic50], [0.025, 0.975])

    # Add to dataframe
    DF.append!(df_ic50, DF.DataFrame(
        drug=row.antibiotic,
        day=row.day,
        strain_num=row.strain_num,
        design=row.design,
        env=row.env,
        logic50_mean=logic50_mean,
        logic50_ci_lower=logic50_ci[1],
        logic50_ci_upper=logic50_ci[2],
        logic50_ci_width=logic50_ci[2] - logic50_ci[1]
    ))
end # for

first(df_ic50, 5)

Next, let's load the point estimates reported in Iwasawa et al. (2022).

In [None]:
df_point = CSV.read(
    "$(git_root())/data/Iwasawa_2022/iwasawa_ic50_tidy.csv",
    DF.DataFrame
)

println("Number of samples: $(size(df_point, 1))")
first(df_point, 5)

To be able to match the columns of these dataframes, we must concatentate the
string in the `:strain` column of the `df_point` dataframe. The last character
in this column is the replicate number for the evolutionary trajectory. The
`df_ic50` dataframe does not have this column. However, we have the `strain_num`
column that upon grouping the data, we can sort by this column and recover the
replicate number.

In [None]:
# Group df_ic50 by :env 
df_ic50_grouped = DF.groupby(df_ic50, :env)

# Initialize dictionary to store the number of replicates per environment
n_replicates = Dict{Any,Any}()

# Loop through each group
for data in df_ic50_grouped
    # Get unique strain numbers sorted
    strain_nums = sort(unique(data.strain_num))
    # Add to dictionary
    n_replicates[first(data.env)] = Dict(strain_nums .=> 1:length(strain_nums))
end

# Add :strain column to df_ic50 as the :env column plus the replicate number
df_ic50.strain = [
    "$(data.env)_$(n_replicates[data.env][data.strain_num])"
    for data in eachrow(df_ic50)
]

first(df_ic50, 5)

To be able to match the columns of these dataframes, we must concatenate the
string in the `:strain` column without the last element in the `df_point`
dataframe. This will be equivalent to the `:info` column in the `df_ic50`
dataframe.

One thing to note is that the reported point estimates include one estimate per
experimental replicate. However, in my Bayesian analysis, I am jointly
estimating the parameters for all replicates.

In [None]:
# Concatenate string in the `:strain` column with underscore
df_point.strain = [
    join(split(strain, " "), "_") for strain in df_point.strain
]

# Merge dataframes
df_merge = DF.leftjoin(
    df_ic50, df_point[:, DF.Not(:env)], on=[:drug, :day, :strain]
)

first(df_merge, 5)

With this merged dataframe in hand, let's plot the relationship between the
$\mathrm{IC}_{50}$ parameter estimated from the Bayesian model and the point
estimate reported in Iwasawa et al. (2022).

In [None]:
# Plot
fig = Figure(size=(300, 300))

# Add axis
ax = Axis(
    fig[1, 1],
    xlabel="log(IC₅₀) (point estimate)",
    ylabel="log(IC₅₀) (Bayesian model)",
    title="IC₅₀ comparison"
)

# Scatter plot
scatter!(
    ax,
    df_merge.log2ic50 .* log(2),
    df_merge.logic50_mean,
    color=ColorSchemes.Paired_12[2],
    markersize=4
)

# Add 1:1 line 
lines!(ax, [-6, 6], [-6, 6], color="black", linestyle=:dash)

fig

There is a high degree of agreement between the Bayesian model and the point
estimate, except for a few outliers.

Let's now plot the relationship between the $\mathrm{IC}_{50}$ parameter but
this time include the 95% credible interval.

In [None]:
# Plot
fig = Figure(size=(300, 300))

# Add axis
ax = Axis(
    fig[1, 1],
    xlabel="log(IC₅₀) (point estimate)",
    ylabel="log(IC₅₀) (Bayesian model)",
    title="IC₅₀ comparison"
)

# Scatter plot
scatter!(
    ax,
    df_merge.log2ic50 .* log(2),
    df_merge.logic50_mean,
    color=ColorSchemes.Paired_12[2],
    markersize=4
)

# Add error bars
errorbars!(
    ax,
    df_merge.log2ic50 .* log(2),
    df_merge.logic50_mean,
    df_merge.logic50_mean .- df_merge.logic50_ci_lower,
    df_merge.logic50_ci_upper .- df_merge.logic50_mean,
    color=ColorSchemes.Paired_12[2]
)

# Add 1:1 line 
lines!(ax, [-6, 6], [-6, 6], color="black", linestyle=:dash)

fig

Coincidentally, the outliers are those that have the highest uncertainty in
the Bayesian model. Let's look at some of these examples. First, we need to
compute the width of the credible interval.

In [None]:
# Sort by error bars
sort!(df_merge, :logic50_ci_width, rev=true)

first(df_merge, 5)

Next, let's compute the posterior predictive samples for the top widest
intervals.

In [None]:
# Initialize figure
fig = Figure(size=(900, 900))

# Define number of rows and columns
rows = 4
cols = 4

# Loop through each plot
for i in 1:rows*cols
    # Locate row and column index
    row_idx = (i - 1) ÷ cols + 1
    col_idx = (i - 1) % cols + 1

    # Extract row to analyze
    row_ic50 = df_merge[i, :]
    # Extract corresponding row from df_meta
    row_meta = df_meta[
        (df_meta.antibiotic.==row_ic50.drug).&(df_meta.day.==row_ic50.day).&(df_meta.env.==row_ic50.env).&(df_meta.strain_num.==row_ic50.strain_num),
        :
    ]

    # Add axis
    ax = Axis(
        fig[row_idx, col_idx],
        xlabel="antibiotic concentration",
        ylabel="optical density",
        title="$(first(row_meta.env)) | day $(first(row_meta.day))",
        xscale=log10,
        yscale=log10,
    )

    # Load the file
    chain = CSV.read(row_meta.file, DF.DataFrame)

    # Extract corresponding data from the raw data
    data = df[
        (df.antibiotic.==row_meta.antibiotic).&(df.day.==row_meta.day).&(df.env.==row_meta.env).&(df.strain_num.==row_meta.strain_num),
        :]
    # Sort data by concentration
    sort!(data, :concentration_ugmL)

    # Locate unique concentrations
    unique_concentrations = unique(data.concentration_ugmL)

    # Initialize matrix to store samples
    y_samples = Array{Float64}(
        undef, length(unique_concentrations), size(chain, 1)
    )

    # Loop through samples
    for i in 1:size(chain, 1)
        logy_samples = log.(logistic(
            log.(unique_concentrations),
            chain[i, :a],
            chain[i, :b],
            chain[i, :c],
            chain[i, :logic50]
        ))
        # Add noise
        logy_samples .+= randn(length(unique_concentrations)) * √(chain[i, :σ²])
        y_samples[:, i] = exp.(logy_samples)

        # Plot samples
        lines!(
            ax,
            unique_concentrations,
            y_samples[:, i],
            color=(ColorSchemes.Paired_12[1], 0.5)
        )
    end # for

    # Plot data
    scatter!(
        ax, data.concentration_ugmL, data.OD, color=ColorSchemes.Paired_12[2]
    )

end # for

fig

As expected, fitting a Bayesian model can help us find pathological cases like
these ones that we would have otherwise missed with a point estimate.

Let's plot the distribution of the width of the credible interval.

In [None]:
# Initialize figure
fig = Figure(size=(350, 300))

# Add axis
ax = Axis(
    fig[1, 1],
    xlabel="width of the 95% credible interval",
    ylabel="counts",
    title="Distribution of the width of\nthe 95% credible interval",
    yscale=log10
)

# Plot histogram
hist!(
    ax,
    df_merge.logic50_ci_width,
    color=ColorSchemes.Paired_12[2],
    bins=50,
)

fig

This histogram reveals a long tail that represents these pathological cases.
For downstream analysis, we will remove these samples from the dataset.

### Cleaning the dataset

Let's say that we apply an arbitrary cut-off on the width of the credible
interval. Let's look at the resulting extreme cases.

In [None]:
# Define threshold
thresh = 1.0

# Filter dataframe
df_clean = df_merge[df_merge.logic50_ci_width.≤thresh, :]

# Initialize figure
fig = Figure(size=(900, 900))

# Define number of rows and columns
rows = 4
cols = 4

# Loop through each plot
for i in 1:rows*cols
    # Locate row and column index
    row_idx = (i - 1) ÷ cols + 1
    col_idx = (i - 1) % cols + 1

    # Extract row to analyze
    row_ic50 = df_clean[i, :]
    # Extract corresponding row from df_meta
    row_meta = df_meta[
        (df_meta.antibiotic.==row_ic50.drug).&(df_meta.day.==row_ic50.day).&(df_meta.env.==row_ic50.env).&(df_meta.strain_num.==row_ic50.strain_num),
        :
    ]

    # Add axis
    ax = Axis(
        fig[row_idx, col_idx],
        xlabel="antibiotic concentration",
        ylabel="optical density",
        title="$(first(row_meta.env)) | day $(first(row_meta.day))",
        xscale=log10,
        yscale=log10,
    )

    # Load the file
    chain = CSV.read(row_meta.file, DF.DataFrame)

    # Extract corresponding data from the raw data
    data = df[
        (df.antibiotic.==row_meta.antibiotic).&(df.day.==row_meta.day).&(df.env.==row_meta.env).&(df.strain_num.==row_meta.strain_num),
        :]
    # Sort data by concentration
    sort!(data, :concentration_ugmL)

    # Locate unique concentrations
    unique_concentrations = unique(data.concentration_ugmL)

    # Initialize matrix to store samples
    y_samples = Array{Float64}(
        undef, length(unique_concentrations), size(chain, 1)
    )

    # Loop through samples
    for i in 1:size(chain, 1)
        logy_samples = log.(logistic(
            log.(unique_concentrations),
            chain[i, :a],
            chain[i, :b],
            chain[i, :c],
            chain[i, :logic50]
        ))
        # Add noise
        logy_samples .+= randn(length(unique_concentrations)) * √(chain[i, :σ²])
        y_samples[:, i] = exp.(logy_samples)

        # Plot samples
        lines!(
            ax,
            unique_concentrations,
            y_samples[:, i],
            color=(ColorSchemes.Paired_12[1], 0.5)
        )
    end # for

    # Plot data
    scatter!(
        ax, data.concentration_ugmL, data.OD, color=ColorSchemes.Paired_12[2]
    )

end # for

fig

## Locating complete profiles with clean data

Having set this arbitrary cut-off on the width of the credible interval, we can
now proceed to analyze the complete profiles with clean data. By this we mean to
find the strains with complete $\mathrm{IC}_{50}$ measurements across all
antibiotics.

In principle, when grouping the data by `:strain` and `:day`, the complete 
profiles are those that contain all eight antibiotics. Let's look at this count.

In [None]:
# Group by strain and day
df_group = DF.groupby(df_merge[df_merge.logic50_ci_width.≤1.5, :], [:strain, :day])

# Count number of antibiotics per strain and day
n_drugs = StatsBase.countmap([length(unique(group.drug)) for group in df_group])

n_drugs

We see that we have on the order of 1,400 complete profiles. Let's plot
this number as a function of the threshold on the width of the credible
interval.

In [None]:
# Define range of thresholds
thresh = 0.5:0.01:3.0

# Initialize array to store the number of complete profiles
n_complete = []

# Loop through each threshold
for thresh in thresh
    # Filter dataframe
    df_clean = df_merge[df_merge.logic50_ci_width.≤thresh, :]
    # Group by strain and day
    df_group = DF.groupby(df_clean, [:strain, :day])
    # Count number of antibiotics per strain and day
    n_drugs = StatsBase.countmap([length(unique(group.drug)) for group in df_group])
    # Append to array
    push!(n_complete, n_drugs[8])
end

# Initialize figure
fig = Figure(size=(350, 300))

# Add axis
ax = Axis(
    fig[1, 1],
    xlabel="width of the 95% credible interval",
    ylabel="counts",
    title="Number of complete profiles\n with eight antibiotics",
)

scatterlines!(ax, thresh, n_complete)

fig