In [None]:
# if hit plotting library issues, try resetting LD path for julia
# can set in ~/.local/share/jupyter/kernels/
haskey(ENV, "LD_LIBRARY_PATH") && @assert ENV["LD_LIBRARY_PATH"] == ""
import Pkg
Pkg.activate(;temp=true)
Pkg.add("Revise")
import Revise

# Pkg.develop(path="/global/cfs/projectdirs/m4269/cjprybol/Mycelia")
# Pkg.develop(path="../../..")
Pkg.develop(path="$(homedir())/workspace/Mycelia")
import Mycelia

pkgs = String[
    "DataFrames",
    "FASTX",
    "XAM",
    "uCSV",
    "CodecZlib",
    "ProgressMeter",
    "StatsBase",
    "Statistics",
    "CSV",
    "Random",
    "Distributions",
    "Plots",
    "OrderedCollections",
    "StatsPlots",
    "Colors",
    "Clustering"
]
Pkg.add(pkgs)
for pkg in pkgs
    eval(Meta.parse("import $pkg"))
end

In [None]:
basedir = dirname(pwd())
data_dir = joinpath(basedir, "data")

In [None]:
locus_c_strain_directory = mkpath(joinpath(data_dir, "locus-c-strains"))

In [None]:
in_fastas = filter(x -> occursin(Mycelia.FASTA_REGEX, x), readdir(locus_c_strain_directory, join=true))
# locus_c_strain_fasta = joinpath(data_dir, "locus-c-strains.fna")
outfile = joinpath(basedir, "results", "20240702.c-strain-ani-analysis.txt")
fasta_list_file = joinpath(data_dir, "locus-c-strain-file-list.txt")
# open(fasta_list_file, "w") do io
#     for f in in_fastas
#         println(io, f)
#     end
# end
# readlines(fasta_list_file)

Here we generate clusters so that we can summarize the Urine mapping more cleanly

In [None]:
# defaults to using all cores in the system
# Mycelia.fastani_list(query_list = fasta_list_file, reference_list = fasta_list_file, threads=8, outfile = outfile)

fastani_results = Mycelia.read_fastani(outfile)
fastani_results[!, "query_strain"] = map(x -> lowercase(match(r"(c\d{6})"i, x).captures[1]), basename.(fastani_results[!, "query"]))
fastani_results[!, "reference_strain"] = map(x -> lowercase(match(r"(c\d{6})"i, x).captures[1]), basename.(fastani_results[!, "reference"]))
unique_strains = sort(collect(union(fastani_results[!, "query_strain"], fastani_results[!, "reference_strain"])))

strain_to_index_map = Dict(s => i for (i, s) in enumerate(unique_strains))

ani_distance_matrix = Array{Float64}(undef, length(unique_strains), length(unique_strains))
ani_distance_matrix .= Inf
for group in DataFrames.groupby(fastani_results, ["query_strain", "reference_strain"])
    row_index = strain_to_index_map[group[1, "query_strain"]]
    column_index = strain_to_index_map[group[1, "reference_strain"]]
    average_percent_identity = Statistics.mean(group[!, "%_identity"])
    ani_distance_matrix[row_index, column_index] = average_percent_identity
end

# average across the diagonals to ensure they are symmetric
for i in 1:size(ani_distance_matrix, 1)
    for j in i+1:size(ani_distance_matrix, 1)
        ani_distance_matrix[i, j] = ani_distance_matrix[j, i] = Statistics.mean([ani_distance_matrix[i, j], ani_distance_matrix[j, i]])
    end
end
ani_distance_matrix

# convert % ani into a distance
for i in eachindex(ani_distance_matrix)
    ani_distance_matrix[i] = 1 - (ani_distance_matrix[i] / 100)
    if ani_distance_matrix[i] == -Inf
        ani_distance_matrix[i] = 1
    end
end
ani_distance_matrix

# ?Clustering.hclust
hclust_result = Clustering.hclust(ani_distance_matrix, linkage=:average)

clusters_995 = Clustering.cutree(hclust_result, h=0.005)

# Create a dictionary to store clusters
cluster_dict = Dict{Int, Vector{String}}()

for (idx, cluster_id) in enumerate(clusters_995)
    if haskey(cluster_dict, cluster_id)
        push!(cluster_dict[cluster_id], unique_strains[idx])
    else
        cluster_dict[cluster_id] = [unique_strains[idx]]
    end
end

cluster_dict = sort(cluster_dict)

cstrain_to_cluster_map = Dict{String, Int}()
for (cluster_id, cluster_members) in cluster_dict
    for cluster_member in cluster_members
        cstrain_to_cluster_map[cluster_member] = cluster_id
    end
end
cstrain_to_cluster_map

Now we load in the Urine data and reanalyze it, reporting cluster identities rather than individual strains

In [None]:
RUN_ID = "r64342e_20240621_140056"

In [None]:
xml = first(filter(x -> occursin(r"\.run\.metadata\.xml", x), readdir(joinpath(data_dir, RUN_ID, "1_A01"), join=true)))

In [None]:
sample_to_barcode_table = Mycelia.extract_pacbiosample_information(xml)
sample_to_barcode_table = sample_to_barcode_table[map(x -> occursin(r"urine"i, x), sample_to_barcode_table[!, "BioSampleName"]), :]

In [None]:
barcode_directories = filter(x -> occursin(r"^bc\d+", basename(x)) && (basename(x) in Set(sample_to_barcode_table[!, "BarcodeName"])), readdir(joinpath(data_dir, RUN_ID, "1_A01"), join=true))

In [None]:
# map to blast NT
blast_db = "nt"
blast_dbs_dir = joinpath([homedir(), "workspace", "blastdb"])
blast_db_path = joinpath(blast_dbs_dir, blast_db)

In [None]:
locus_c_strain_fasta = joinpath(data_dir, "locus-c-strains.fna")

In [None]:
barcode_to_xam = Dict()
for barcode_directory in barcode_directories
    barcode = basename(barcode_directory)
    xams = filter(x -> occursin(Mycelia.XAM_REGEX, x) && occursin("locus-c-strains.fna", x), readdir(barcode_directory, join=true))
    xam = first(xams)
    # println("$(barcode)\t$(xam)")
    barcode_to_xam[barcode] = xam
end
barcode_to_xam
sample_to_barcode_table[!, "xam"] = [barcode_to_xam[barcode] for barcode in sample_to_barcode_table[!, "BarcodeName"]]
sample_to_barcode_table

In [None]:
xams = sample_to_barcode_table[!, "xam"]

In [None]:
# taxa_level = "species"
# when using cnumbers
# file_to_taxa_relative_abundances = OrderedCollections.OrderedDict{String, Dict{Union{Missing, String}, Float64}}()
# when using cluster IDs
file_to_taxa_relative_abundances = OrderedCollections.OrderedDict{String, Dict{Union{Missing, Int}, Float64}}()
ProgressMeter.@showprogress for xam in xams
    @time record_table = Mycelia.parse_xam_to_mapped_records_table(xam)
    record_table = record_table[record_table[!, "isprimary"], :]
    record_table[!, "top_hit_strain"] = map(x -> lowercase(match(r"(c\d{6})"i, x).captures[1]), record_table[!, "reference"])
    # NEW! HERE WE CONVERT FROM STRAIN TO CLUSTER
    record_table[!, "strain_cluster"] = map(x -> cstrain_to_cluster_map[x], record_table[!, "top_hit_strain"])
    file_to_taxa_relative_abundances[xam] =  Mycelia.normalize_countmap(StatsBase.countmap(record_table[!, "strain_cluster"]))
end
file_to_taxa_relative_abundances

In [None]:
unique_sorted_taxa = unique(sort(collect(reduce(union, keys.(values(file_to_taxa_relative_abundances))))))

In [None]:
n_samples = length(file_to_taxa_relative_abundances)

In [None]:
abundance_matrix = zeros(length(unique_sorted_taxa), n_samples)
taxa_names_to_indices = Dict(t => i for (i, t) in enumerate(unique_sorted_taxa))
for (column, (file, abundances)) in enumerate(file_to_taxa_relative_abundances)
    # @show column, sample
    for (taxa, relative_abundance) in abundances
        row = taxa_names_to_indices[taxa]
        abundance_matrix[row, column] = relative_abundance
    end
end
abundance_matrix

In [None]:
abundance_sort_perm = sortperm(abundance_matrix, dims=1)

In [None]:
abundance_matrix[abundance_sort_perm]

In [None]:
file_to_identifier = Dict(row["xam"] => row["BioSampleName"] for row in DataFrames.eachrow(sample_to_barcode_table))

In [None]:
colorscheme = Colors.distinguishable_colors(length(unique_sorted_taxa), [Colors.RGB(1,1,1), Colors.RGB(0,0,0)], dropseed=true)

In [None]:
# ## BASE - INCLUDES EVERYTHING
# # Find the sort permutation of the row means vector
# sort_perm = sortperm(vec(Statistics.mean(abundance_matrix, dims=2)))
# file_labels = [file_to_identifier[k] for k in keys(file_to_taxa_relative_abundances)]
# StatsPlots.groupedbar(
#     abundance_matrix[sort_perm, :]',
#     bar_position = :stack,
#     bar_width=0.7, 
#     # label = permutedims(unique_sorted_taxa[sort_perm]),
#     label = false,
#     xticks = (1:size(abundance_matrix, 2), file_labels), 
#     xrotation = 45,
#     ylabel = "proportion of reads", 
#     xlabel = "Subsampling proportion",
#     title = "Species relative abundance",
#     legend = :outertopright,
#     size = (1000, 500),
#     margins = 10StatsPlots.Plots.PlotMeasures.mm,
#     seriescolor = hcat(reverse(colorscheme)...)
# )

In [None]:
top_N = 10
# Find the sort permutation of the row means vector
sort_perm = sortperm(vec(Statistics.mean(abundance_matrix, dims=2)))
file_labels = [file_to_identifier[k] for k in keys(file_to_taxa_relative_abundances)]
StatsPlots.groupedbar(
    abundance_matrix[sort_perm, :]'[:, end-(top_N-1):end],
    bar_position = :stack,
    bar_width=0.7,
    label = permutedims(unique_sorted_taxa[sort_perm])[:, end-(top_N-1):end], 
    xticks = (1:size(abundance_matrix, 2), file_labels), 
    xrotation = 45,
    ylabel = "proportion of reads", 
    xlabel = "Subsampling proportion",
    title = "Strain relative abundance",
    legend = :outertopright,
    legendtitle = "cluster ID",
    size = (1000, 500),
    margins = 10StatsPlots.Plots.PlotMeasures.mm,
    seriescolor = hcat(reverse(colorscheme)...)[:, end-(top_N-1):end]
)

In [None]:
uCSV.write(joinpath(data_dir, "c-strain-clusterings-99.5ANI.tsv"), DataFrames.DataFrame(strain = unique_strains, cluster_995 = clusters_995), delim='\t')