In [None]:
import Pkg
pkgs = [
    "Revise",
    "DataFrames",
    "StatsBase",
    "StatsPlots",
    "uCSV",
    "ProgressMeter",
    "Distances",
    "Clustering",
    "Colors",
    "MultivariateStats"
]
# Pkg.add(pkgs)
for pkg in pkgs
    eval(Meta.parse("import $pkg"))
end
import Mycelia

In [None]:
data_dir = joinpath(dirname(pwd()), "data")

In [None]:
sample_paths = filter(x -> !occursin(".ipynb_checkpoints", x), readdir(joinpath(data_dir, "SRA"), join=true))

In [None]:
results_dir = joinpath(data_dir, "results")

In [None]:
taxon_levels = Mycelia.list_ranks()

In [None]:
i = 9

(taxon_index, taxon_level) = collect(enumerate(taxon_levels))[i]
println("$(taxon_index) - $(taxon_level)")
rank_table = Mycelia.list_rank(taxon_level)

In [None]:
# NCBI host metadata
ncbi_metadata_file = joinpath(dirname(pwd()), "metadata", "NCBI-virus-refseq.transformed.tsv")
ncbi_host_metadata = DataFrames.DataFrame(uCSV.read(ncbi_metadata_file, header=1, delim='\t', encodings=Dict("false" => false, "true" => true)))

# ICTV host metadata
ictv_metadata_file = joinpath(dirname(pwd()), "metadata", "VMR_MSL38_v1 - VMR MSL38 v1.transformed.tsv")
ictv_host_metadata = DataFrames.DataFrame(uCSV.read(ictv_metadata_file, header=1, delim='\t', typedetectrows=100))
ictv_host_metadata = ictv_host_metadata[.!isempty.(ictv_host_metadata[!, "taxid"]), :]
ictv_host_metadata[!, "taxid"] = parse.(Int, ictv_host_metadata[!, "taxid"])

# # VirusHostDB metadata
virushostdb_metadata_file = joinpath(dirname(pwd()), "metadata", "virushostdb.transformed.tsv")
virushostdb_metadata = DataFrames.DataFrame(uCSV.read(virushostdb_metadata_file, header=1, delim='\t', typedetectrows=1086, encodings=Dict("missing" => missing, "false" => false, "true" => true)))

# vertebrate_taxids = Set(union(
#     ictv_host_metadata[ictv_host_metadata[!, "Host source"] .== "vertebrates", "taxid"],
#     ncbi_host_metadata[ncbi_host_metadata[!, "host_is_vertebrate"], "taxid"],
#     virushostdb_metadata[virushostdb_metadata[!, "host_is_vertebrate"], "virus_taxid"]
# ))

# human_taxids = Set(union(
#     ncbi_host_metadata[ncbi_host_metadata[!, "host_is_human"], "taxid"],
#     virushostdb_metadata[virushostdb_metadata[!, "host_is_human"], "virus_taxid"]
# ))

In [None]:
filtered_rank_table = rank_table[map(taxid -> taxid in human_taxids, rank_table[!, "taxid"]), :]

In [None]:
# println("$(taxon_index) - $(taxon_level)")
# rank_table = Mycelia.list_rank(taxon_level)
rank_taxids = Set(filtered_rank_table[!, "taxid"])

In [None]:
kraken_db = "k2_pluspfp"
kraken_db_regex = Regex("$(kraken_db)_\\d{8}")

In [None]:
cross_sample_taxon_report = joinpath(results_dir, "$(kraken_db).$(taxon_level).tsv")

In [None]:
cross_sample_taxon_figure_png = joinpath(results_dir, "$(kraken_db).$(taxon_level).png")

In [None]:
cross_sample_taxon_report_table = DataFrames.DataFrame()
# sample_path = first(sample_paths)
ProgressMeter.@showprogress for sample_path in sample_paths
    sample = basename(sample_path)
    kraken_dir = joinpath(sample_path, "kraken")
    report_file = first(filter(x -> occursin(kraken_db_regex, x) && occursin(r"kraken-report\.tsv$", x), readdir(kraken_dir, join=true)))
    report_table = Mycelia.read_kraken_report(report_file)
    taxon_level_report = report_table[map(x -> x in rank_taxids, report_table[!, "ncbi_taxonid"]), :]
    taxon_level_report[!, "sample_identifier"] .= sample
    append!(cross_sample_taxon_report_table, taxon_level_report)
end
cross_sample_taxon_report_summary = cross_sample_taxon_report_table[!, DataFrames.Not(["percentage_of_fragments_at_or_below_taxon", "number_of_fragments_assigned_directly_to_taxon", "rank"])]
# uCSV.write(cross_sample_taxon_report, cross_sample_taxon_report_table, delim='\t')

In [None]:
cross_sample_taxon_report_summary[!, "taxon"] = map(row -> string(row["ncbi_taxonid"]) * "_" * row["scientific_name"], DataFrames.eachrow(cross_sample_taxon_report_summary))
cross_sample_taxon_report_summary = cross_sample_taxon_report_summary[!, DataFrames.Not([
            "ncbi_taxonid",
            "scientific_name"
        ])]

In [None]:
# assert sortedness & uniqueness (should be a no-op)
unique!(DataFrames.sort!(cross_sample_taxon_report_summary, ["sample_identifier", "taxon"]))

In [None]:
# taxa = String[]
# samples = String[]
# n_samples = length(unique(cross_sample_taxon_report_summary[!, "sample_identifier"]))
# n_taxa = length(unique(cross_sample_taxon_report_summary[!, "taxon"]))
# values = zeros(n_samples, n_taxa)

taxa = sort(unique(cross_sample_taxon_report_summary[!, "taxon"]))
taxa_map = Dict(taxon => i for (i, taxon) in enumerate(taxa))
samples = sort(unique(cross_sample_taxon_report_summary[!, "sample_identifier"]))
samples_map = Dict(sample => i for (i, sample) in enumerate(samples))
values = zeros(length(samples), length(taxa))
ProgressMeter.@showprogress for taxon_table in DataFrames.groupby(cross_sample_taxon_report_summary, "taxon")
    taxon = taxon_table[1, "taxon"]
    column_index = taxa_map[taxon]
    for sample_table in DataFrames.groupby(taxon_table, "sample_identifier")
        @assert DataFrames.nrow(sample_table) == 1
        sample = sample_table[1, "sample_identifier"]
        row_index = samples_map[sample]
        value = sample_table[1, "number_of_fragments_at_or_below_taxon"]
        values[row_index, column_index] = value
    end
end
values

# sort taxa so largest single sample taxa is first
taxa_frequency_ordering = sortperm(maximum.(eachcol(values)))
values = values[:, taxa_frequency_ordering]
taxa = taxa[taxa_frequency_ordering]
# find taxa that have no representation, and filter them out
taxa_is_detected = [sum(col) >= 3 for col in eachcol(values)]
values = values[:, taxa_is_detected]
taxa = taxa[taxa_is_detected]
sample_has_classifications = [sum(row) > 0 for row in eachrow(values)]
values = values[sample_has_classifications, :]
samples = samples[sample_has_classifications]
taxa

In [None]:
# mmseqs_hits =  
# [
# 61673,
#  1862825,
#    45617,
#  1647924,
#    10566,
#   931209,
#   493803,
#    11676,
#    28312,
#    37955,
#   765052,
#   463676,
#  3052230,
#  1891726
#     ]

In [None]:
println("[")
for x in parse.(Int, first.(split.(taxa, '_')))
    println("$(x),")
end
println("]")

In [None]:
colorscheme = Colors.distinguishable_colors(length(taxa), [Colors.RGB(1,1,1), Colors.RGB(0,0,0)], dropseed=true)

In [None]:
normalized_values = values ./ sum(values, dims=2)

In [None]:
distance_matrix = Distances.pairwise(Distances.Euclidean(), normalized_values, dims=1)
clustering = Clustering.hclust(distance_matrix, branchorder=:optimal)

In [None]:
bottommargin = (maximum(length.(samples)) * 5)
leftmargin = 150
rightmargin = 25
topmargin = 25

width = max(1920, (length(samples) * 10) + 300)
height = max(1080, bottommargin + 600)
height = max(height, length(taxa)*11)

plot = StatsPlots.groupedbar(
    log10.(values),
    title = "read-classification - kraken - $(kraken_db) - $(taxon_level)",
    titlefontsize = 12,
    xticks = (1:length(samples), samples),
    xlims = (0, length(samples)+1),
    xtickfontsize = 6,
    size= (width, height),
    xrotation=90,
    ylabel = "log10(number of reads)",
    labels = hcat(taxa...),
    leftmargin = (leftmargin)StatsPlots.Plots.PlotMeasures.px,
    topmargin = (topmargin)StatsPlots.Plots.PlotMeasures.px,
    rightmargin = (rightmargin)StatsPlots.Plots.PlotMeasures.px,
    bottommargin = (bottommargin)StatsPlots.Plots.PlotMeasures.px,
    legendmargins = 0,
    legend = :outertopright,
    legendfontsize = 6,
    bar_position = :stack,
    bar_width=0.7,
    seriescolor = hcat(reverse(colorscheme)...),
)
for extension in [".png"]
    file = joinpath(results_dir, "taxonomic-breakdowns.kraken.$(kraken_db).$(taxon_index).$(taxon_level).total-reads") * extension
    StatsPlots.savefig(plot, file)
end

In [None]:
plot = StatsPlots.groupedbar(
    normalized_values,
    title = "read-classification - kraken - $(kraken_db) - $(taxon_level)",
    titlefontsize = 12,
    xticks = (1:length(samples), samples),
    xlims = (0, length(samples)+1),
    xtickfontsize = 6,
    size= (width, height),
    xrotation=90,
    ylabel = "proportion of reads",
    labels = hcat(taxa...),
    leftmargin = (leftmargin)StatsPlots.Plots.PlotMeasures.px,
    topmargin = (topmargin)StatsPlots.Plots.PlotMeasures.px,
    rightmargin = (rightmargin)StatsPlots.Plots.PlotMeasures.px,
    bottommargin = (bottommargin)StatsPlots.Plots.PlotMeasures.px,
    legendmargins = 0,
    legend = :outertopright,
    bar_position = :stack,
    bar_width=0.7,
    seriescolor = hcat(reverse(colorscheme)...),
    legendfontsize = 6,
)

# for extension in [".png", ".svg"]
for extension in [".png"]
    file = joinpath(results_dir, "taxonomic-breakdowns.kraken.$(kraken_db).$(taxon_index).$(taxon_level).normalized-reads") * extension
    StatsPlots.savefig(plot, file)
end

In [None]:
################################################################################################
distance_matrix = Distances.pairwise(Distances.Euclidean(), normalized_values, dims=1)
clustering = Clustering.hclust(distance_matrix, branchorder=:optimal)
cluster_plot = StatsPlots.plot(
    clustering,
    xlims = (0, length(samples)+1),
    title = "read-classification - kraken - $(kraken_db) - $(taxon_level)",
    titlefontsize = 12,
    bottommargin = 0StatsPlots.Plots.PlotMeasures.px,
    leftmargin = (leftmargin)StatsPlots.Plots.PlotMeasures.px,
    rightmargin = (rightmargin)StatsPlots.Plots.PlotMeasures.px,
    topmargin = (topmargin)StatsPlots.Plots.PlotMeasures.px,
    xticks = false,
    yticks = false,
    yaxis = false,
    xaxis = false
)
taxonomy_plot = StatsPlots.groupedbar(
    normalized_values[clustering.order, :],
    bar_position = :stack,
    leftmargin = (leftmargin)StatsPlots.Plots.PlotMeasures.px,
    topmargin = 0StatsPlots.Plots.PlotMeasures.px,
    rightmargin = (rightmargin)StatsPlots.Plots.PlotMeasures.px,
    bottommargin = (bottommargin)StatsPlots.Plots.PlotMeasures.px,
    legendmargins = 0,
    xticks = (1:length(samples), samples[clustering.order]),
    xtickfontsize = 4,
    xrotation=90,
    xlims = (0, length(samples)+1),
    size= (width, height),
    ylabel = "proportion of reads",
    seriescolor = hcat(reverse(colorscheme)...),
    legend = false,
    labels = hcat(taxa...),
    legendfontsize = 6,
)

plot = StatsPlots.plot(
    cluster_plot,
    taxonomy_plot,
    margins = 0StatsPlots.Plots.PlotMeasures.px,
    layout=StatsPlots.grid(2,1, heights=[0.2,0.8])
)
# # display(plot)
# # for extension in [".png", ".svg"]
for extension in [".png"]
    file = joinpath(results_dir, "taxonomic-breakdowns.kraken.$(kraken_db).$(taxon_index).$(taxon_level).normalized-reads.clustered") * extension
    StatsPlots.savefig(plot, file)
end

In [None]:
# fit_pca = MultivariateStats.fit(MultivariateStats.PCA, normalized_values')
# transformed_observations = MultivariateStats.transform(fit_pca, normalized_values')

# # x = transformed_observations[1, :]
# # y = transformed_observations[2, :]

# # pc_plot = 
# # StatsPlots.scatter(
# #     x,
# #     y,
# #     # [z[control_indices], z[case_indices]],
# #     xlabel = "PC1",
# #     ylabel = "PC2",
# #     # zlabel = "PC3",
# #     # labels = hcat(["sample ID goes here"]...),
# #     # title = "Case vs. Control",
# #     legend = :outertopright,
# #     size = (1000, 500),
# #     # margins = 5StatsPlots.mm,
# # )

# # display(pc_plot)

In [None]:
# # top level classification rank to show absolute reads per sample
# # not very helpful at lower ranks since low read depth samples are too hard to see breakdowns
# if taxon_index in [1, 2]
#     aspect_ratio = [9, 2]
#     scale = 500
#     plot = StatsPlots.groupedbar(
#         values,
#         title = "read-classification - $(taxon_level)",
#         xticks = (1:length(samples), samples),
#         xlims = (0, length(samples)+1),
#         xtickfontsize = 6,
#         size= aspect_ratio .* scale,
#         xrotation=90,
#         ylabel = "number of reads",
#         labels = hcat(taxa...),
#         margins = 100StatsPlots.Plots.PlotMeasures.px,
#         legendmargins = 0,
#         legend = :outertopright,
#         legendfontsize = 6,
#         bar_position = :stack,
#         bar_width=0.7,
#         seriescolor = hcat(reverse(colorscheme)...)
#     )
#     # display(plot)
#     for extension in [".png", ".svg"]
#         file = joinpath(results_dir, "taxonomic-breakdowns.kraken.$(taxon_index).$(taxon_level).total-reads") * extension
#         StatsPlots.savefig(plot, file)
#     end
# end

In [None]:
fit_pca = MultivariateStats.fit(MultivariateStats.PCA, normalized_values')

fit_pca = MultivariateStats.fit(MultivariateStats.PCA, normalized_values')
transformed_observations = MultivariateStats.transform(fit_pca, normalized_values')

top_level_groups = unique(map(x -> join(split(x, '_')[1:2], '_'), samples))
# Colors.RGB(0,0,0), Colors.RGB(0.78, 0.129, 0.867)
colorscheme = Colors.distinguishable_colors(length(top_level_groups), [Colors.RGB(1,1,1), Colors.RGB(0,0,0)], dropseed=true)
xs = [Float64[] for group in top_level_groups]
ys = [Float64[] for group in top_level_groups]
zs = [Float64[] for group in top_level_groups]

raw_xs = transformed_observations[1, :]
if size(transformed_observations, 1) >= 2
    raw_ys = transformed_observations[2, :]
else
    raw_ys = zeros(length(raw_xs))
end

if size(transformed_observations, 1) >= 3
    raw_zs = transformed_observations[3, :]
else
    raw_zs = zeros(length(raw_xs))
end

for (sample, x, y, z) in zip(samples, raw_xs, raw_ys, raw_zs)
    # @show sample, x, y
    sample_groups = findall(x -> occursin(x, sample), top_level_groups)
    @assert length(sample_groups) == 1
    sample_group = first(sample_groups)
    push!(xs[sample_group], x)
    push!(ys[sample_group], y)
    push!(zs[sample_group], z)
end

plot = 
StatsPlots.scatter(
    xs,
    ys,
    # zs,
    xlabel = "PC1",
    ylabel = "PC2",
    # zlabel = "PC3",
    labels = hcat(top_level_groups...),
    title = "read-classification - kraken - $(kraken_db) - $(taxon_level)",
    titlefontsize = 10,
    legend = :outertopright,
    size = (640, 480),
    margins = 20StatsPlots.px,
    seriescolor = hcat(colorscheme...)
)

display(plot)
for extension in [".png"]
    file = joinpath(results_dir, "taxonomic-breakdowns.kraken.$(kraken_db).$(taxon_index).$(taxon_level).pca") * extension
    StatsPlots.savefig(plot, file)
end

In [None]:
# # need to tune this to be proportional to the # of taxa in the list
# aspect_ratio = [9, 40]
# scale = 500
# plot = StatsPlots.groupedbar(
#     normalized_values,
#     title = "read-classification - $(taxon_level)",
#     xticks = (1:length(samples), samples),
#     xlims = (0, length(samples)+1),
#     xtickfontsize = 6,
#     size= aspect_ratio .* scale,
#     xrotation=90,
#     ylabel = "proportion of reads",
#     labels = hcat(taxa...),
#     margins = 100StatsPlots.Plots.PlotMeasures.px,
#     legend = :outertopright,
#     bar_position = :stack,
#     bar_width=0.7,
#     seriescolor = hcat(reverse(colorscheme)...),
#     legendfontsize = 6,
# )

# for extension in [".png", ".svg"]
#     file = joinpath(results_dir, "taxonomic-breakdowns.kraken.$(taxon_index).$(taxon_level).normalized-reads") * extension
#     StatsPlots.savefig(plot, file)
# end

In [None]:
# aspect_ratio = [2, 1]
# scale = 1000

# cluster_plot = StatsPlots.plot(
#     clustering,
#     xlims = (1, length(samples)),
#     title = "read-classification - $(taxon_level)",
#     xticks = false,
#     yticks = false,
#     yaxis = false
# )
# taxonomy_plot = StatsPlots.groupedbar(
#     normalized_values[clustering.order, :],
#     bar_position = :stack,
#     bottommargin = 50StatsPlots.Plots.PlotMeasures.px,
#     leftmargin = 50StatsPlots.Plots.PlotMeasures.px,
#     xticks = (1:length(samples), samples[clustering.order]),
#     xtickfontsize = 4,
#     xrotation=90,
#     xlims = (0, length(samples)+1),
#     size= aspect_ratio .* scale,
#     ylabel = "proportion of reads",
#     seriescolor = hcat(reverse(colorscheme)...),
#     # bar_width=0.7,
#     legend = false,
#     labels = hcat(taxa...),
#     legendfontsize = 6,
# )

# # Create a separate subplot as a legend
# # legend_plot = StatsPlots.groupedbar(normalized_values[clustering.order, :], legend=true)
# # plot!(legend_subplot, label="Series 1", legend=:best, linecolor=1)
# # plot!(legend_subplot, label="Series 2", legend=:best, linecolor=2)

# # Layout the main plot and the legend subplot
# # l = @layout [a{0.7w}; b{0.3w}]
# # plot(p, legend_subplot, layout=l)

# plot = StatsPlots.plot(
#     cluster_plot, 
#     taxonomy_plot,
#     layout=StatsPlots.grid(2,1, heights=[0.2,0.8])
# )
# # display(plot)
# for extension in [".png", ".svg"]
#     file = joinpath(results_dir, "taxonomic-breakdowns.kraken.$(taxon_index).$(taxon_level).normalized-reads.clustered") * extension
#     StatsPlots.savefig(plot, file)
# end
# # end

In [None]:
# for g in DataFrames.groupby(cross_sample_taxon_report_table, "sample_identifier")
#     # @show sum(g[!, "percentage_of_fragments_at_or_below_taxon"])
#     if sum(g[!, "percentage_of_fragments_at_or_below_taxon"]) == 100
#     else
#         identifier = g[1, "sample_identifier"]
#         println("\"$(identifier)\",")
#     end
# end

In [None]:
# [number_of_fragments_at_or_below_taxon	ncbi_taxonid	scientific_name	sample_identifier

In [None]:

# Mycelia.list_rank("kingdom")
# Mycelia.list_rank("phylum")
# Mycelia.list_rank("class")
# Mycelia.list_rank("order")
# Mycelia.list_rank("family")
# Mycelia.list_rank("genus")
# Mycelia.list_rank("species")

- loop through the above
- make a matrix where each row is a taxon rank and each column is a dataset
- make stacked barplots for each sample
- run PCA and k-means cluster to find meaningful groups
- repeat for mmseqs protein and blast nt