Note: all code is written in julia (https://julialang.org/) and should be compatible with julia v1.9 and later versions.

Code author: Janko Tackmann (jtackm@github)

In [None]:
############################################################
# Functions to compute and filter latitudinal associations #
############################################################

using MultipleTesting, SparseArrays, Statistics, FlashWeave, ProgressMeter

function _populate_caches!(val_vec, group_vec, sum_cache, prev_cache)
    for i in 1:length(val_vec)
        group = group_vec[i]
        sum_cache[group] += val_vec[i]
        prev_cache[group] += 1
    end

    return nothing
end

function _populate_caches_logged!(val_vec::Union{SparseVector,SubArray{<:AbstractFloat, 1, <:SparseMatrixCSC}}, group_vec, sum_cache, 
    prev_cache, group_count_cache)
    """Optimized for (views of) sparse vectors"""
    inds, vals = rowvals(val_vec), nonzeros(val_vec)

    pcount = minimum(vals) / 10
    
    for (i, val) in zip(inds, vals)
        group = group_vec[i]
        sum_cache[group] += log10(val + pcount)
        prev_cache[group] += 1
    end

    # Add pseudo-counts for zero entries by group
    for i in 1:length(sum_cache)
        # We know the number of zeros by subtracting the observed non-zeros (captured in prev_cache)
        # from total group counts
        n_zeros = group_count_cache[i] - prev_cache[i]
        total_zero_pcounts_logged = log10(pcount) * n_zeros
        sum_cache[i] += total_zero_pcounts_logged
    end
    
    return nothing
end

function _populate_caches!(val_vec::Union{SparseVector,SubArray{<:AbstractFloat, 1, <:SparseMatrixCSC}}, group_vec, sum_cache, 
    prev_cache)
    """Optimized for (views of) sparse vectors"""
    inds, vals = rowvals(val_vec), nonzeros(val_vec)
    
    for (i, val) in zip(inds, vals)
        group = group_vec[i]
        sum_cache[group] += val
        prev_cache[group] += 1
    end
    
    return nothing
end

function groupby_mean!(val_vec, group_vec, caches::NamedTuple; log_vals=false)
    """Assumes group_count_cache was pre-populated with counts per group"""
    mean_cache, sum_cache, group_count_cache, prev_cache = caches.mean_cache, caches.sum_cache, caches.group_count_cache, caches.prev_cache
    fill!(mean_cache, 0.0)
    fill!(sum_cache, 0.0)
    fill!(prev_cache, 0)

    if log_vals
        _populate_caches_logged!(val_vec, group_vec, sum_cache, prev_cache, group_count_cache)
    else
        _populate_caches!(val_vec, group_vec, sum_cache, prev_cache)
    end

    mean_cache .= sum_cache ./ group_count_cache
    
    return nothing
end        

function _fast_cor_binned(val_vec, v2_binned::Vector, bins_srt, caches::NamedTuple; log_means=false, log_vals=false)
    """Assumes values of v2_binned are equal to index range in caches. Otherwise, remap bins first to this range!
    val_vec can be either a vector or a pair (col_i, data) in case of sparse matrices"""
    @assert !(log_means && log_vals) "'log_means' and 'log_vals' shouldn't be used together"
    
    if all(iszero, val_vec)
        stat = NaN
        n_bins_obs = 0
        prev = 0
    else
        groupby_mean!(val_vec, v2_binned, caches; log_vals)
        n_bins_obs = sum(!iszero, caches.prev_cache)
    
        if log_means
            pcount = if any(iszero.(caches.mean_cache))
                minimum(caches.mean_cache[.!iszero.(caches.mean_cache)]) / 10
            else
                0.0
            end
            caches.mean_cache .= log10.(caches.mean_cache .+ pcount)
        end
        stat = cor(caches.mean_cache, bins_srt)
        prev = sum(caches.prev_cache)
    end
    
    return (;cor=stat, n_bins_obs, prev)
end

function compute_latitudinal_abundance_cors(otu_matn::KeyedArray, lats_binned::KeyedArray; prev_min=250, bins_obs_min_factor=-1, log_means=false, do_FDR=true, log_vals=false)
    """bins_obs_min_factor: factor defining the number of bins an OTU must have been observed in (e.g. factor 2 = observed in 50% of bins)"""
    @assert axiskeys(otu_matn, 1) == axiskeys(lats_binned, 1)
    @assert all(>=(0), lats_binned) "all latitudes must be positive / absolute"
    n_samp = size(otu_matn, 1)
    
    # Map bins to monotonically increasing index range
    bins_srt = sort(unique(lats_binned))
    n_bins = length(bins_srt)

    if bins_obs_min_factor != -1
        n_bins_obs_min = div(n_bins, bins_obs_min_factor)
    else
        n_bins_obs_min = 0
    end
    
    indmap = Dict(bin=>i for (i, bin) in enumerate(bins_srt))
    lats_binned_indmap = [indmap[x] for x in Vector(lats_binned)]
    caches = (abunds_cache=zeros(Float64, n_samp), mean_cache=zeros(Float64, n_bins), sum_cache=zeros(Float64, n_bins), group_count_cache=zeros(Int, n_bins), prev_cache=zeros(Int, n_bins))

    # Populate per-bin sample counts
    for bin_ind in lats_binned_indmap
        caches.group_count_cache[bin_ind] += 1
    end
    
    otu_mat = AxisKeys.keyless_unname(otu_matn)
    cor_res = @showprogress map(enumerate(axiskeys(otu_matn, 2))) do (i, oid)
        sub_cor_res = _fast_cor_binned(view(otu_mat, :, i), lats_binned_indmap, bins_srt, caches; log_means, log_vals)
        (;oid, sub_cor_res..., n_obs=n_bins)
    end

    cor_df = DataFrame(cor_res)

    # Remove unreliable OTUs (defined by filter criteria)
    filt_mask = (cor_df.prev .< prev_min) .| (cor_df.n_bins_obs .< n_bins_obs_min)
    rm_cor_df = cor_df[filt_mask, :]
    cor_df = cor_df[.!filt_mask, :]

    # Compute p-values for reliable OTUs
    cor_df[!, :pval] = [FlashWeave.fz_pval(x.cor, x.n_obs, 0) for x in eachrow(cor_df)]

    if do_FDR
        cor_df[!, :pval_adj] = MultipleTesting.adjust(cor_df.pval, BenjaminiHochberg())
    end

    return (test_df=select(cor_df, :oid, :cor, r"pval", :n_obs, :prev, :n_bins_obs), unrel_df=rm_cor_df)
end

"""Bin elements in vector using standard Histogram parameters, return
bin for each element (same order, lower bound)
"""
function make_bins(x::AbstractVector; n_bins=nothing, bin_kws...)
    # assure interface compatibility with make_variable_bins
    if !isnothing(n_bins) || haskey(bin_kws, :nbins)
        if isnothing(n_bins)
            n_bins = bin_kws[:nbins]
        end
        h = fit(Histogram, x; nbins=n_bins, bin_kws...)
    else
        h = fit(Histogram, x; bin_kws...)
    end
    hist_edges = h.edges[1]
    return [hist_edges[findfirst(>(xi), hist_edges)-1] for xi in x]
end

function compute_latitudinal_abundance_cors_by_env(otu_matn_mp_lat, env_vec, lats_vec, n_bins, prev_min, bins_obs_min_factor,
    log_vals, log_means)
    dfs = []
    unrel_dfs = []
    
    for env in unique(env_vec)
        env == "unknown" && continue
        println(env)
        env_mask = env_vec .== env
        lats_vec_mirr_env = lats_vec_mirr[env_mask]
        otu_matn_mp_lat_env = otu_matn_mp_lat[findall(env_mask), :]
        lats_binned_env = make_bins(lats_vec_mirr_env, n_bins=n_bins)
        
        # Remove OTUs whos prefered environment is not the current one
        @assert axiskeys(pref_habs, 1) == axiskeys(otu_matn_mp_lat_env, 2)
        oid_mask = findall(pref_habs .== env)
        otu_matn_mp_lat_env = otu_matn_mp_lat_env[:, oid_mask]
    
        # Compute correlations
        curr_cor_df, curr_unrel_df = compute_latitudinal_abundance_cors(otu_matn_mp_lat_env, lats_binned_env; prev_min, bins_obs_min_factor, log_vals, log_means, do_FDR=false)
        curr_cor_df[!, :env] .= env
        curr_unrel_df[!, :env] .= env
        println("reliable / unreliable OTUs: ", nrow.([curr_cor_df, curr_unrel_df]))
        println("significant OTUs (fraction):")
        sig_mask = curr_cor_df.pval .< 0.05
        
        for (desc, mask) in [("total", trues(nrow(curr_cor_df))), ("positive", curr_cor_df.cor .> 0), ("negative", curr_cor_df.cor .< 0)]
            println("\t", desc, ": ", sum(sig_mask[mask]), " ($(mean(sig_mask .& mask)))")
        end
        println()
        
        push!(dfs, curr_cor_df)
        push!(unrel_dfs, curr_unrel_df)
    end
    
    cor_df = vcat(dfs...)
    
    # Do final FDR
    cor_df[!, :pval_adj] = MultipleTesting.adjust(cor_df.pval, BenjaminiHochberg())
    cor_df[!, :is_sig] = cor_df.pval_adj .< 0.05
    
    unrel_df = vcat(unrel_dfs...)
    
    @show nrow.([cor_df, unrel_df])
    
    return cor_df
end