Note: all code is written in julia (https://julialang.org/) and should be compatible with julia v1.9 and later versions.

Code author: Janko Tackmann (jtackm@github)

In [None]:
###########################################################################
# Functions to aggregate abundances and compute habitat generalism scores #
###########################################################################

using Statistics, StatsBase, SparseArrays, AxisArrays, ProgressMeter

function make_sgabund_matrix(sg_oabund_dict, sg_header, oid_header)
    n_sg = length(sg_header)
    n_oid = length(oid_header)
    A = zeros(Float32, n_sg, n_oid)
    oid_header_rev = Dict(oid=>i for (i, oid) in enumerate(oid_header))
    @showprogress for (i, sg) in enumerate(sg_header)
        for (oid, abund) in sg_oabund_dict[sg]
            if haskey(oid_header_rev, oid)
                oid_j = oid_header_rev[oid]
                A[i, oid_j] = abund
            end
        end
    end
    AxisArray(A, sg_header, oid_header)
end

function sum_abunds_per_otu_and_sgroup(otu_matn_T::AxisArray{T}, sgroup_map::Dict) where T<:Real
    n_samps = size(otu_matn_T, 2)
    rev_sid_map = Dict(s=>i for (i, s) in enumerate(axisvalues(otu_matn_T)[2]))
    rev_sg_map = Dict(s=>i for (i, s) in enumerate(keys(sgroup_map)))
    sg_idx = zeros(Int, n_samps)
    
    for (sg, sids) in sgroup_map
        sg_i = rev_sg_map[sg]
        for sid in sids
            if haskey(rev_sid_map, sid)
                sg_idx[rev_sid_map[sid]] = sg_i
            end
        end
    end
    
    @assert !any(iszero.(sg_idx))
    
    accum_dict = Dict{Int,Dict{Int,T}}()
    A = otu_matn_T.data
    nzv = nonzeros(A)
    rvs = rowvals(A)
    
    @showprogress for s_i in 1:n_samps
        s_itr = nzrange(A, s_i)
        
        sg_i = sg_idx[s_i]
        if !haskey(accum_dict, sg_i)
            sub_acc_dict = Dict(rvs[j] => nzv[j] for j in s_itr)
            accum_dict[sg_i] = sub_acc_dict
        else
            sub_acc_dict = accum_dict[sg_i]
            for j in s_itr
                oid = rvs[j]
                abund = nzv[j]
                if haskey(sub_acc_dict, oid)
                    sub_acc_dict[oid] += abund
                else
                    sub_acc_dict[oid] = abund
                end
            end
        end
    end
    sg_map = Dict(v=>k for (k,v) in rev_sg_map)
    oid_map = axisvalues(otu_matn_T)[1]
    @time trans_accum_dict = Dict(sg_map[sg_i]=>Dict(oid_map[oid_i]=>abund for (oid_i, abund) in sub_d)
                            for (sg_i, sub_d) in accum_dict)
    trans_accum_dict
end

function convert_to_mean_abunds(accum_dict, sgroup_map; nz=false, prev_accum_dict=nothing)
    @assert !nz || prev_accum_dict != nothing "provide a prevalence count dict when choosing 'nz=true'"
    mean_accum_dict = Dict{String,Dict{String,Float64}}()
    @showprogress for (sg, sub_acc_dict) in accum_dict
        if !nz
            sg_size = length(sgroup_map[sg])
            mean_accum_dict[sg] = Dict(oid=>abund/sg_size for (oid, abund) in sub_acc_dict)
        else
            sub_prev_dict = prev_accum_dict[sg]
            mean_accum_dict[sg] = Dict(oid=>abund/sub_prev_dict[oid] for (oid, abund) in sub_acc_dict)
        end
    end
    mean_accum_dict
end

function mean_abunds_per_habitat(A_sg_mean_abunds::AxisArray, env_map::AbstractVector)
    @assert size(A_sg_mean_abunds, 1) == length(env_map)
    rows = []
    envs = unique(env_map)
    for env in envs
        env_mask = env_map .== env
        A_env = A_sg_mean_abunds.data[env_mask, :]
        push!(rows, mean(A_env, dims=1))
    end
    AxisArray(vcat(rows...), envs, axisvalues(A_sg_mean_abunds)[2])
end

function env_entropy_generalism(mean_abunds::AbstractVector)
    mean_abunds_norm = mean_abunds ./ sum(mean_abunds)
    entropy(mean_abunds_norm)
end

function env_entropy_generalism(A_sg_mean_abunds::AxisArray, env_map::AbstractVector)
    @assert size(A_sg_mean_abunds, 1) == length(env_map)
    A_env_abund = mean_abunds_per_habitat(A_sg_mean_abunds, env_map)
    gen_scores = mapslices(env_entropy_generalism, A_env_abund.data, dims=1)
    (gen_scores=AxisArray(vec(gen_scores), axisvalues(A_env_abund)[2]), env_abunds=A_env_abund)
end