Note: all code is written in julia (https://julialang.org/) and should be compatible with julia v1.9 and later versions.

Code author: Janko Tackmann (jtackm@github)

In [None]:
##########################################################
# Functions to compute and quality-score UMAP embeddings #
##########################################################

using Distances, Statistics, StatsBase, Seaborn, PyCall, AxisKeys
import Pandas
sns = Seaborn.seaborn
umap_py = pyimport("umap")

function embedding_quality(embd, dist_orig; metrics=[cor, corspearman], 
        threads=false, detailed=false, verbose=true)
    """Compute the quality of embedding 'embd', defined as the correlation of euclidean
    distances in 'embd' with the original distances 'dist_orig', averaged for each sample in 'embd'.

    @embd: embedding matrix, with samples as rows, UMAP axes as columns.
    @dist_orig: matrix with the original distances that 'embd' was computed on.
    @metrics: correlation functions (e.g. cor, corspearman) to be used for quality score computation.
    @detailed: provide detailed pairwise correlation information (in addition to averages)
    """
    
    dist_embd = pairwise(Euclidean(), embd; dims=1)

    s1, s2 = size(dist_embd), size(dist_orig)
    @assert s1 == s2 "dimensions of dist_embd don't match dist_orig ($(s1) != $(s2))"
    n = size(dist_embd, 2)
    
    col_names = [Symbol("$(f)_qual") for f in metrics]
    scores_single_pairs = if threads
        pairs = []
        for (cn, f) in zip(col_names, metrics)
            tasks = Vector{Task}(undef, n)
            for i in 1:n
                tasks[i] = Threads.@spawn f(view(dist_embd, :, i), view(dist_orig, :, i))
            end
            cor_vs = fetch.(tasks)
            push!(pairs, cn => cor_vs)
        end
        pairs
    else
        [cn => broadcast(f, eachcol(dist_embd), eachcol(dist_orig)) for (cn, f) in zip(col_names, metrics)]
    end
    
    score_df = DataFrame(scores_single_pairs...)

    if verbose
        for col in names(score_df)
            println("Quality stats ($col)")
            describe(score_df[!, col])
        end
    end
    
    if detailed
        dist_vecs = [[A[ind] for ind in CartesianIndices(A) if ind[1] < ind[2]] for A in (dist_embd, dist_orig)]
        scores = Float64[f(dist_vecs...) for f in metrics]
        detailed_scores = (;zip(col_names, scores)...)
    else
        detailed_scores = nothing
    end
    
    (;score_df, detailed_scores)
end

function compute_full_umap(;data::Union{KeyedArray,Nothing}=nothing, metric=nothing, A_dist::Union{KeyedArray,Nothing}=nothing, 
        embd_mat::Union{KeyedArray,Nothing}=nothing, compute_quality=true, qual_cor_funs=(cor,), umap_kws...)
    """Computes UMAP results on OTU table 'data'. Pairwise distances either are computed from scratch (using metric 'metric')
    or are precomputed (if distance matrix 'A_dist' was provided).

    @data: OTU table, with samples as rows and OTUs as columns.
    @metric: Distances.jl metric to compute pairwise distances with.
    @A_dist: alternative to 'metric', precomputed pairwise distances for all samples in 'data'
    @embd_mat: precomputed embedding (samples as rows, UMAP axes as columns), provide if only quality scores should be computed
    @compute_quality: if true, compute embedding quality scores (see 'embedding_quality()')
    @qual_cor_funs: correlation functions (e.g. cor, corspearman) to be used for quality score computation."""
    
    @assert xor((!isnothing(data) && !isnothing(metric)), (!isnothing(A_dist) || !isnothing(embd_mat))) "provide either data+metric, or A_dist and/or embd_mat"
    @assert !(!isnothing(embd_mat) && compute_quality && isnothing(A_dist)) "'compute_quality' requires A_dist to be provided"
    if !isnothing(data) && !isnothing(A_dist)
        @assert axiskeys(data, 1) == axiskeys(A_dist, 1)
    end
    
    if !isnothing(embd_mat) && !isnothing(A_dist)
        @assert axiskeys(embd_mat, 1) == axiskeys(A_dist, 1)
    end
    
    ref_mat = !isnothing(data) ? data : (!isnothing(A_dist) ? A_dist : embd_mat)
    sids = axiskeys(ref_mat, 1)
    index_col = dimnames(ref_mat)[1]
    
    if isnothing(A_dist) && isnothing(embd_mat)
        @time A_dist_raw = pairwise(metric, data, dims=1)
    elseif !isnothing(A_dist)
        A_dist_raw = parent(A_dist).data
    end

    if isnothing(embd_mat)
        umap_obj = umap_py.UMAP(n_components=2, metric="precomputed"; umap_kws...)
        @time embd_mat = umap_obj.fit_transform(A_dist_raw)
    else
        umap_obj = nothing
    end         
    
    embd_df = DataFrame(;index_col=>sids, UMAP1=embd_mat[:, 1], UMAP2=embd_mat[:, 2])
    for c in (:UMAP1, :UMAP2)
        embd_df[!, "$(c)_std"] = standardize_to_interval(embd_df[!, c], -1.0, 1.0)
    end
    
    if compute_quality
        @time qual_df = embedding_quality(embd_mat, A_dist_raw; metrics=qual_cor_funs)[1]
        qual_df[!, index_col] = sids
        qual_df = select(qual_df, index_col, :) 
    else
        qual_df = nothing
    end
    
    (;embd_df, qual_df, umap_obj, A_dist)
end

"""Convenience function to compute a full UMAP from a OTU table ('data') using a distance metric 'metric'."""
compute_full_umap(data::KeyedArray, metric; kws...) = compute_full_umap(;data, metric, kws...)

function plot_umap(umap_res::NamedTuple, groups=nothing; plot_quality=false, qual_palette="icefire",
    index_col=:sid, group_label="group", plot_kws...)
    """Plot previously computed UMAP results.
    
    @umap_res: result object of 'compute_full_umap().
    @groups: group of each sample in 'umap_res'. If provided, color the data points by group.
    @plot_quality: if true, quality scores from 'umap_res' are also plotted.
    @index_col: column in 'umap_res' that specifies the sample index (for joining embedding and quality results).
    @group_label: legend label for the group variable."""
    
    plt.figure()
    if !isnothing(groups)
        @assert length(groups) == nrow(umap_res.embd_df)
        plot_df = copy(umap_res.embd_df)
        plot_df[!, group_label] = groups
        hue_var = group_label
    else
        group_label = nothing
    end
    
    sns.relplot(x=:UMAP1, y=:UMAP2, data=Pandas.DataFrame(umap_res.embd_df), s=1, hue=group_label; plot_kws...)

    if plot_quality
        if isnothing(umap_res.qual_df)
            @warn "plot_quality set to true, but quality scores are missing from input. Skipping."
        else
            plt.figure()
            plot_df = coalesce.(leftjoin(umap_res.embd_df, umap_res.qual_df, on=index_col=>index_col))
            sns.relplot(x=:UMAP1, y=:UMAP2, data=Pandas.DataFrame(plot_df), s=1, hue=:cor_qual, palette=qual_palette)
        end
    end;
end

function plot_umap(umap_args...; plot_quality=false, qual_palette="icefire", umap_kws...)
    """Perform a full UMAP projection and plot the results.
    
    @umap_args: see input for 'plot_umap()'
    @plot_quality: if true, quality scores from 'umap_res' are also plotted."""
    
    umap_res = compute_full_umap(umap_args...; umap_kws...)
    plot_umap(umap_res; plot_quality, qual_palette)
    return umap_res
end