Note: all code is written in julia (https://julialang.org/) and should be compatible with julia v1.9 and later versions.

Code author: Janko Tackmann (jtackm@github)

# Alpha diversity

In [None]:
###################################################################
# Functions to estimate species coverage and unobserved diversity #
###################################################################

# Formulas are taken from: Chao, A. & Jost, L. Coverage‐based rarefaction and
# extrapolation: standardizing samples by completeness rather than size. Ecology 93, 2533–2547 (2012).

using AxisArrays, SparseArrays

function estimate_unobserved_species(abunds_nz)
    """Estimate the number of unobserved species, based on
    abundances in sparse vector 'abunds_nz'."""
    n = sum(abunds_nz)
    f1 = sum(x == 1 for x in abunds_nz)
    f2 = sum(x == 2 for x in abunds_nz)
    
    if f2 > 0
        ((n - 1) * f1^2) / (n * 2*f2)
    else
        ((n - 1) * f1*(f1 - 1)) / (n * 2*(f2 + 1))
    end
end


function estimate_coverage(abunds_nz)
    """Estimate total species coverage, based on
    abundances in sparse vector 'abunds_nz'."""
    n = sum(abunds_nz)
    f1 = sum(x == 1 for x in abunds_nz)
    f2 = sum(x == 2 for x in abunds_nz)
    1 - (f1/n) * ( ((n - 1)f1) / (((n - 1)f1) + 2f2))
end


function interpolate_coverage(abunds_nz, m::Integer)
    """Interpolate the species coverage at sampling effort 'm',
    based on abundances in sparse vector 'abunds_nz'."""
    function compute_term(Xi, n, m)
        inner_term
        Xi/n * ( binomial(n-Xi, m) / binomial(n-1, m) )
    end
    n = sum(abunds_nz)
    1 - sum(compute_term(Xi, n, m) for Xi in abunds_nz if Xi > 0)
end

            
function extrapolate_coverage(abunds_nz, m_star::Integer)
    """Extrapolate the species coverage at sampling effort 'm_star',
    based on abundances in sparse vector 'abunds_nz'."""
    n = sum(abunds_nz)
    f1 = sum(x == 1 for x in abunds_nz)
    f2 = sum(x == 2 for x in abunds_nz)
    
    1 - f1/n * ( ((n - 1) * f1) / ((n - 1) * f1 + 2 * f2))^(m_star + 1)
end
            
function extrapolate_unobserved_species(abunds_nz, m_star::Integer)
    """Extrapolate the number of unobserved species at sampling effort 'm_star',
    based on abundances in sparse vector 'abunds_nz'."""
    Sobs = length(abunds_nz)
    f1 = sum(x == 1 for x in abunds_nz)
    f0_est = estimate_unobserved_species(abunds_nz)
    Sobs + f0_est * ( 1 - ( 1 - ( f1 / (n * f0_est + f1) ))^m_star )
end      

function get_abundance_itr(otu_mat_T, sample_i)
    vals = nonzeros(otu_mat_T)
    return (vals[i] for i in nzrange(otu_mat_T, sample_i))
end

function estimate_diversity_fun_threads(otu_mat_T::AbstractMatrix, div_fun::Function)
    """Generic, optimized function to estimate diverse coverage and diversity
    statistics for each sample (column) in OTU table 'otu_mat_T'.
    
    @div_fun: function that computes a coverage or diversity characteristic for a single
    sparse vector (i.e. a single sample/column)"""
    
    res = zeros(size(otu_mat_T, 2))
    Threads.@threads for i in 1:size(otu_mat_T, 2)
        res[i] = div_fun(get_abundance_itr(otu_mat_T, i))
    end
    res
end

function estimate_unobserved_diversity(otu_mat_T::AxisArray; verbose=true)
    """Estimate species coverage and observed / unobserved OTUs for each sample (column)
    in OTU table 'otu_mat_T'."""
    obs_otus = [length(nzrange(otu_mat_T.data, i)) for i in 1:size(otu_mat_T, 2)]
    
    verbose && println("Estimating coverage")
    @time samp_covs = estimate_diversity_fun_threads(otu_mat_T.data, estimate_coverage)
    verbose && println("Estimating unobserved OTUs")
    @time unobs_otus = estimate_diversity_fun_threads(otu_mat_T.data, estimate_unobserved_species)
    DataFrame(sample_id=axisvalues(otu_mat_T)[2], coverage=samp_covs, unobs_OTUs=unobs_otus, obs_OTUs=obs_otus)
end

# Beta diversity

In [None]:
##############################################################
# Functions to compute beta diversity indices and statistics #
##############################################################

using HDF5, SparseArrays, AxisKeys, DataFrames, PyCall, Distances
import Distances:pairwise

## Unifrac (weighted & unweighted) ##

unifrac_py = pyimport("unifrac")

function unifrac_biocore(data::KeyedArray, tree_path; weighted=true, normalized=true, unifrac_mode::Union{Nothing,Symbol}=nothing,
    generalized=false, alpha=1.0, variance_adjusted=false, strided=true, prec=32, kws...)
    """Core backend function to run biocore-unifrac from within julia. Computes pairwise unifrac distances
    for OTU table 'data' and phylogenetic tree (newick format) at location 'tree_path'.
    
    @alpha: parameter for generalized unifrac, ignored if generalized=false
    @strided: use faster, strided unifrac (biocore function: ssu_fast)"""
    
    biom_path, _ = mktemp()
    biom_path = biom_path * ".biom"
    save_biom(biom_path, data)
    
    if isnothing(unifrac_mode)
        if generalized
            unifrac_mode = Symbol("generalized_fp$(prec)")
        else
            w_str = weighted ? "weighted" : "unweighted"
            norm_str = weighted ? (normalized ? "normalized_" : "unnormalized_") : ""
            unifrac_mode = Symbol("$(w_str)_$(norm_str)fp$(prec)")
        end
    end
    
    unifrac_fun = if strided
        getproperty(unifrac_py, :ssu_fast)
    else
        getproperty(unifrac_py, unifrac_mode)
    end
    
    @show unifrac_mode unifrac_fun

    unifrac_res = if strided
        ids = h5read(biom_path, "sample/ids")
        unifrac_fun(biom_path, tree_path, ids, unifrac_mode, variance_adjusted, alpha, false, 1)
    else
        if generalized
            unifrac_fun(biom_path, tree_path; alpha, variance_adjusted, kws...)
        else
            unifrac_fun(biom_path, tree_path; variance_adjusted, kws...)
        end
    end
    
    unifrac_df = unifrac_res.to_data_frame()
    idx = String.(unifrac_df.columns)
    A_dist = KeyedArray(Matrix(unifrac_df.values), ax1=idx, ax2=idx)
    
    return A_dist
end

struct Unifrac{T<:AbstractString} <: PreMetric
    weighted::Bool
    normalized::Bool
    tree_file::T
end

# wrap biocore-unifrac with Distances.jl
function Distances.pairwise(d::Unifrac, data::KeyedArray; dims=1, kws...)
    """Makes biocore-unifrac available via Distances.jl."""
    @assert dims == 1 "can only compute Unifrac between rows"
    return Matrix(unifrac_biocore(data, d.tree_file; weighted=d.weighted, normalized=d.normalized, kws...))
end

function save_biom(out_path, otu_table::KeyedArray)
    """Saves an OTU table (columns = OTUs) to BIOM v2.1"""
    @assert endswith(out_path, ".biom") "output path must end with '.biom'"
    sids, oids = axiskeys(otu_table)
    otu_table_csc = sparse(otu_table)
    otu_table_csr = sparse(transpose(otu_table_csc))
    h5open(out_path, "w") do f
        attrs = HDF5.attributes(f)
        attrs["id"] = ""
        attrs["type"] = "OTU table"
        attrs["format-url"] = "http://biom-format.org"
        attrs["format-version"] = [2, 1]
        attrs["generated-by"] = ""
        attrs["creation-date"] = ""
        attrs["shape"] = collect(size(otu_table_csr)) # rows are OTUs, columns are samples
        attrs["nnz"] = nnz(otu_table_csr)

        for (main_group, A, idx) in [("observation", otu_table_csc, oids), ("sample", otu_table_csr, sids)]
            f["$(main_group)/ids"] = idx

            for (prop, key) in [(:nzval, "data"), (:rowval, "indices"), (:colptr, "indptr")]
                vec = getproperty(A, prop)

                if prop != :nzval
                    vec .-= 1 # account for index offset in python
                end

                if prop == :nzval
                    vec = Float64.(vec)
                else+
                    vec = Int32.(vec)
                end

                f["$(main_group)/matrix/$key"] = vec
            end

            create_group(f, "$(main_group)/metadata")
            create_group(f, "$(main_group)/group-metadata")
        end     
    end
    
    return out_path
end


## Bray-Curtis ##

using SparseArrays
import Base:iterate

struct NonzeroPairIteratorCSC{T1<:Real, T2<:Integer, S<:Integer}
    A::SparseMatrixCSC{T1,T2}
    i_itr::UnitRange{S}
    j_itr::UnitRange{S}
end

NonzeroPairIteratorCSC(A::SparseMatrixCSC, i::Integer, j::Integer) =
    NonzeroPairIteratorCSC(A, nzrange(A, i), nzrange(A, j))

struct NonzeroIteratorCSCState{T1 <: Real, T2 <: Integer, S <: Integer}
    val::T1
    row::T2
    state::S
    done::Bool
end

struct NonzeroPairIteratorCSCState{T1 <: Real, T2 <: Integer, S <: Integer}
    i_state::NonzeroIteratorCSCState{T1,T2,S}
    j_state::NonzeroIteratorCSCState{T1,T2,S}
end

Base.IteratorSize(::NonzeroPairIteratorCSC) = Base.SizeUnknown()

function _determine_steps_inner(itr::NonzeroPairIteratorCSC, state::NonzeroPairIteratorCSCState)
    ri, rj = state.i_state.row, state.j_state.row
    if ri < rj
        return true, false
    elseif ri > rj
        return false, true
    else
        return true, true
    end
end

function _determine_steps(itr::NonzeroPairIteratorCSC, state::NonzeroPairIteratorCSCState)
    if !state.i_state.done && !state.j_state.done
        step_i, step_j = _determine_steps_inner(itr, state)
    else
        step_i = !state.i_state.done
        step_j = !state.j_state.done
    end
    return step_i, step_j
end
    
function _iterate_inner_single(A::SparseMatrixCSC{T1,T2}, sub_itr::UnitRange{S},
        state=-1) where {T1<:Real, T2<:Integer, S<:Integer}
    itr_next = state == -1 ? iterate(sub_itr) : iterate(sub_itr, state)
    if itr_next === nothing
        vi = zero(T1)
        ri = T2(-1)
        state_next = -1
        done = true
    else
        i = itr_next[1]
        vi = @inbounds nonzeros(A)[i]
        ri = @inbounds rowvals(A)[i]
        state_next = @inbounds itr_next[2]
        done = false
    end
    return NonzeroIteratorCSCState(vi, ri, state_next, done)
end

function iterate(itr::NonzeroPairIteratorCSC{T1,T2,S}, 
        state::NonzeroPairIteratorCSCState{T1,T2,S}) where {T1<:Real, T2<:Integer, S<:Integer}
    step_i, step_j = _determine_steps(itr, state)
    
    if !step_i && !step_j
        return nothing
    else
        i_state = state.i_state
        if step_i
            vi = i_state.val
            i_state = _iterate_inner_single(itr.A, itr.i_itr, state.i_state.state)
        else
            vi = zero(T1)
        end
        j_state = state.j_state
        if step_j
            vj = j_state.val
            j_state = _iterate_inner_single(itr.A, itr.j_itr, state.j_state.state)
        else
            vj = zero(T1)
        end
    end
    
    state_next = NonzeroPairIteratorCSCState(i_state, j_state)
    return (vi, vj), state_next
end

function iterate(itr::NonzeroPairIteratorCSC)
    i_state = _iterate_inner_single(itr.A, itr.i_itr)
    j_state = _iterate_inner_single(itr.A, itr.j_itr)
    state_init = NonzeroPairIteratorCSCState(i_state, j_state)
    return iterate(itr, state_init)
end


function braycurtis_sparse_iter(A, i, j)
    """Optimized, sparse kernel to compute Bray-Curtis distances between columns i and j
    in OTU table A. Utilizes generic, sparse iterator objects (defined above)."""
    s_enum = 0.0
    s_denom = 0.0
    for (vi, vj) in NonzeroPairIteratorCSC(A, i, j)
        s_enum += abs(vi - vj)
        s_denom += abs(vi + vj)
    end
    s_enum / s_denom
end

order_by_nnz(X::SparseMatrixCSC) = mapslices(nnz, X, dims=1) |> vec |> sortperm

function _pairwise_threads_inner!(f, i, j_s, X, A_out)
    for j in j_s
        d = f(X, i, j)
        @inbounds A_out[i, j] = d
        @inbounds A_out[j, i] = d
    end
    nothing
end

function pairwise_threads!(f, X::AbstractMatrix{T}, A_out; work_order_fun=order_by_nnz) where T<:Real
    """Generic, optimized pairwise distance calculation routine.
    
    @f: distance function kernel, can be optimized for sparse data
    @X: OTU table, should be sparse if 'f' is defined on sparse data"""
    
    ncol = size(X, 2)
    @assert all(ncol == size(A_out, 2)) "number of columns in input and output matrix must match"
    i_s = work_order_fun(X)
    tasks = Vector{Task}(undef, ncol)
    for i in i_s
        tasks[i] = Threads.@spawn _pairwise_threads_inner!(f, i, i+1:ncol, X, A_out)
    end
    wait.(tasks)
    nothing
end

## PERMANOVA ##

skb = pyimport("skbio")

function compute_r2(pseudo_f, n_groups, n_samples)
    """Compute R2 for a PERMANOVA result. Results checked to match adonis2 (R package: vegan)."""
    return (pseudo_f * (n_groups - 1)) / ( (pseudo_f * (n_groups - 1)) + n_samples - n_groups )
end

function permanova_skbio(A_dist::PyObject, groups::AbstractVector; permutations=999)
    """Core permanova function that accepts a scikit-bio DistanceMatrix (wrapped by PyCall.jl-wrapped)
    and a group vector."""
    test_res = skb.stats.distance.permanova(A_dist, groups).to_dict()
    r2 = compute_r2(test_res["test statistic"], test_res["number of groups"], test_res["sample size"])
    test_res["R2"] = r2

    return test_res
end

function permanova_skbio(A_dist::AbstractMatrix, args...; kws...)
    """Makes scikit-bio permanova available from within julia."""
    A_dist_skb = skbio.DistanceMatrix(A_dist)
    return permanova_skbio(A_dist_skb, args...; kws...)
end

# Rarefaction

In [None]:
############################################################################
# Functions to compute community cluster- and OTU-based rarefaction curves #
############################################################################

using DataFrames, StatsBase, ProgressMeter, SparseArrays

function cclust_rarefaction_samplebased(sid_cclust_dict::Dict{Int,Int}, m)
    """Compute a single community cluster rarefaction step by randomly choosing 'm' samples from
    an OTU table (represented by 'sid_cclust_dict') and counting the number of unique community 
    clusters observed within this sample subset.
    
    @sid_cclust_dict: dictionary with sample IDs as keys and community cluster IDs as values."""
    
    length(sid_cclust_dict) < m && return zeros(eltype(abunds), length(abunds))
    rf_sids = sample(collect(keys(sid_cclust_dict)), m, replace=true)
    
    rf_ccs = Set{Int}()
    for sid in rf_sids
        cc = sid_cclust_dict[sid]
        push!(rf_ccs, cc)
    end

    return length(rf_ccs)
end

function run_cclust_rarefaction_samplebased(sid_cclust_dict::Dict{Int,Int}; nrep=10, step=100_000)
    """Compute rarefaction curves of observed community clusters, i.e. the number of community clusters observed
    in each sample subset, for increasing subset sizes.
    
    @sid_cclust_dict: dictionary with sample IDs as keys and community cluster IDs as values."""
    
    rf_rows = []
    @showprogress for m in 1:step:length(sid_cclust_dict)
        for i in 1:nrep
            n_rf_cclusts = cclust_rarefaction_samplebased(sid_cclust_dict, m)
            push!(rf_rows, (n_samples=m, n_ccs=n_rf_cclusts, rep=i))
        end
    end
    
    return DataFrame(rf_rows)
end

function total_otus_in_subsample(otu_matn, m; min_abund=0)
    """Compute a single total OTU rarefaction step by randomly choosing 'm' samples from
    OTU table 'otu_matn' and counting the number of unique OTUs observed within this sample subset."""
    A = otu_matn.data
    rvs = rowvals(A)
    
    m = Set(sample(1:size(A, 1), m, replace=false))
    
    n_otus_total = 0
    prev_col_ind = 0
    for (j, col_ind) in enumerate(A.colptr[1:end-1])
        if col_ind == prev_col_ind
            continue
        else
            nzr = nzrange(A, j)

            for i in nzr
                row = rvs[i]

                if row in m
                    n_otus_total += 1
                    break
                end
            end
            prev_col_ind = col_ind
        end         
    end

    return n_otus_total    
end

function run_total_otus_rarefaction(otu_matn; nrep=10, step=100_000, max_samples=size(otu_matn, 1), 
        kws...)
    """Compute rarefaction curves of total observed OTUs, i.e. the total number of unique OTUs observed
    in each sample subset, for increasing subset sizes."""
    rf_rows = []
        
    @showprogress for m in 1:step:max_samples
        for i in 1:nrep
            rf_total_otus = total_otus_in_subsample(otu_matn, m; kws...)
            push!(rf_rows, (n_samples=m, n_otus_total=rf_total_otus, rep=i))
        end
    end
    
    return DataFrame(rf_rows)
end

#########################################
# Functions to rarefy sparse OTU tables #
#########################################

function rarefaction(abunds::SparseVector, m)
    """Rarefy sparse abundance vector 'abunds' to sequencing depth 'm'"""
    sum(abunds) < m && return zeros(eltype(abunds), length(abunds))
    freqs = Vector(abunds) ./ sum(abunds)
    rf_samp = sample(1:length(abunds), Weights(freqs), m)
    cmap = countmap(rf_samp)
    abunds_rf = sparse([haskey(cmap, x) ? cmap[x] : 0 for x in 1:length(abunds)])
    
    return abunds_rf
end

function rarefaction(data::SparseMatrixCSC{Tv,Ti}, m; verbose=true) where {Tv,Ti<:Integer}
    """Specialized rarefaction function for sparse input data. Rarefies columns (samples)
    to sequencing depth 'm'."""
    samps_rf = SparseVector{Tv,Ti}[]
    @showprogress for i in 1:size(data, 2)
        samp_rf = rarefaction(data[:, i], m)
        push!(samps_rf, samp_rf)
    end
    
    data_rf = build_csc(samps_rf, size(data, 1))
    
    filter_mask_otus, filter_mask_obs = [(sum(data_rf, dims=x) .> 0)[:] for x in (2, 1)]
    
    if verbose
        println("retained non-zero fraction: ", nnz(data_rf) / nnz(data))
    end
    (;data=data_rf[filter_mask_otus, filter_mask_obs], filter_mask_obs, filter_mask_otus)
end

function rarefaction(datan::KeyedArray, m; kws...)
    """Generic function that rarefies samples (columns) in OTU table 'datan' to sequencing 'm'."""
    
    sids, oids = axiskeys(datan)
    data_rf, filter_mask_obs, filter_mask_otus = rarefaction(aku(datan), m; kws...)
    
    return (;data=KeyedArray(data_rf, sid=sids[filter_mask_obs], oid=oids[filter_mask_otus]), filter_mask_obs, filter_mask_otus)
end

function build_csc(vecs::AbstractVector{SparseVector{Tv,Ti}}, m::Int) where {Tv,Ti<:Integer}
    """Helper function to efficiently convert a vector of sparse vectors into
    a SparseMatrixCSC."""
    n = length(vecs)

    nnzs = Vector{Ti}(undef, n)
    total_nnz = zero(Ti)
    for j in 1:n
        nnzs[j] = nnz(vecs[j])
        total_nnz += nnzs[j]
    end

    colptr = Vector{Ti}(undef, n+1)
    colptr[1] = one(Ti)
    for j in 1:n
        colptr[j+1] = colptr[j] + nnzs[j]
    end

    rowval = Vector{Ti}(undef, total_nnz)
    nzval  = Vector{Tv}(undef, total_nnz)

    pos = one(Ti)
    for j in 1:n
        sv = vecs[j]
        len = nnzs[j]
        rowval[pos:pos+len-1] .= sv.nzind
        nzval[pos:pos+len-1] .= sv.nzval
        pos += len
    end

    return SparseMatrixCSC{Tv,Ti}(m, n, colptr, rowval, nzval)
end