In [None]:
import Pkg

# clean environment on each run to get around unregistered package issues
isfile("Project.toml") && rm("Project.toml")
isfile("Manifest.toml") && rm("Manifest.toml")
Pkg.activate(".")
Pkg.add("Revise")
import Revise

Pkg.develop(url="../../..")
import Mycelia

pkgs = [
"DataFrames",
"StatsBase",
"StatsPlots",
"uCSV",
"ProgressMeter",
"Distances",
"Clustering",
"Colors",
"MultivariateStats"
]
Pkg.add(pkgs)
for pkg in pkgs
    eval(Meta.parse("import $pkg"))
end

In [None]:
kraken_db = "k2_pluspfp_20221209"
# kraken_db = "k2_pluspfp_08gb_20231009"
# kraken_db = "k2_pluspfp_16gb_20231009"

In [None]:
data_dir = joinpath(dirname(pwd()), "data")

In [None]:
results_dir = joinpath(data_dir, "results")

In [None]:
kraken_reports = readdir(joinpath(results_dir, "kraken"), join=true)

In [None]:
taxon_levels = Mycelia.list_ranks()

In [None]:
viral_taxids = Set(Mycelia.list_subtaxa(10239))

In [None]:
# i = 1
# i = 2
# i = 3
# i = 4
# i = 5
# i = 6
# i = 7
# i = 8
i = 9

(taxon_index, taxon_level) = collect(enumerate(taxon_levels))[i]
println("$(taxon_index) - $(taxon_level)")
rank_table = Mycelia.list_rank(taxon_level)

In [None]:
# HERE IS WHERE WE APPLY VIRAL FILTERING
if taxon_index > 2
    taxid_is_viral = map(taxid -> taxid in viral_taxids, rank_table[!, "taxid"])
    rank_table = rank_table[taxid_is_viral, :]
end
rank_table

In [None]:
# turn me into a function that accepts a rank table and a list of kraken reports as well as an output
rank_taxids = Set(rank_table[!, "taxid"])
cross_sample_taxon_report = joinpath(results_dir, "$(kraken_db).$(taxon_level).tsv")
# if !isfile(cross_sample_taxon_report)
cross_sample_taxon_report_table = DataFrames.DataFrame()
ProgressMeter.@showprogress for kraken_report in kraken_reports
    report_table = Mycelia.read_kraken_report(kraken_report)
    taxon_level_report = report_table[map(x -> x in rank_taxids, report_table[!, "ncbi_taxonid"]), :]
    taxon_level_report[!, "sample_identifier"] .= basename(kraken_report)
    append!(cross_sample_taxon_report_table, taxon_level_report)
end
cross_sample_taxon_report_table = cross_sample_taxon_report_table[map(x -> !isnan(x), cross_sample_taxon_report_table[!, "percentage_of_fragments_at_or_below_taxon"]), :]
uCSV.write(cross_sample_taxon_report, cross_sample_taxon_report_table, delim='\t')
# else
# cross_sample_taxon_report_table = DataFrames.DataFrame(uCSV.read(cross_sample_taxon_report, delim='\t', header=1))
# end

In [None]:
# cross_sample_taxon_report_table = cross_sample_taxon_report_table[map(x -> !isnan(x), cross_sample_taxon_report_table[!, "percentage_of_fragments_at_or_below_taxon"]), :]
# uCSV.write(cross_sample_taxon_report, cross_sample_taxon_report_table, delim='\t')

In [None]:
# this could be a function here

In [None]:
cross_sample_taxon_report_summary = cross_sample_taxon_report_table[!, 
    DataFrames.Not([
            "percentage_of_fragments_at_or_below_taxon",
            "number_of_fragments_assigned_directly_to_taxon",
            "rank"
        ])]
cross_sample_taxon_report_summary[!, "taxon"] = map(row -> string(row["ncbi_taxonid"]) * "_" * row["scientific_name"], DataFrames.eachrow(cross_sample_taxon_report_summary))
cross_sample_taxon_report_summary = cross_sample_taxon_report_summary[!, DataFrames.Not([
            "ncbi_taxonid",
            "scientific_name"
        ])]

# assert sortedness & uniqueness (should be a no-op)
unique!(DataFrames.sort!(cross_sample_taxon_report_summary, ["sample_identifier", "taxon"]))

In [None]:
taxa = sort(unique(cross_sample_taxon_report_summary[!, "taxon"]))
samples = sort(unique(cross_sample_taxon_report_summary[!, "sample_identifier"]))
values = zeros(length(samples), length(taxa))
ProgressMeter.@showprogress for (column_index, taxon_table) in enumerate(DataFrames.groupby(cross_sample_taxon_report_summary, "taxon"))
    taxon = taxon_table[1, "taxon"]
    @assert taxa[column_index] == taxon
    for (row_index, sample_table) in enumerate(DataFrames.groupby(taxon_table, "sample_identifier"))
        @assert DataFrames.nrow(sample_table) == 1
        row = sample_table[1, :]
        sample = row["sample_identifier"]
        @assert samples[row_index] == sample
        values[row_index, column_index] = row["number_of_fragments_at_or_below_taxon"]
    end
end
values

In [None]:
# sort taxa so largest single sample taxa is first
taxa_frequency_ordering = sortperm(maximum.(eachcol(values)))
values = values[:, taxa_frequency_ordering]
taxa = taxa[taxa_frequency_ordering]
# find taxa that have no representation, and filter them out
taxa_is_detected = [sum(col) > 0 for col in eachcol(values)]
values = values[:, taxa_is_detected]
taxa = taxa[taxa_is_detected]

In [None]:
normalized_values = values ./ sum(values, dims=2)
samples = string.(first.(split.(samples, '.')))

In [None]:
colorscheme = Colors.distinguishable_colors(length(taxa), [Colors.RGB(1,1,1), Colors.RGB(0,0,0)], dropseed=true)
# top level classification rank to show absolute reads per sample
# not very helpful at lower ranks since low read depth samples are too hard to see breakdowns

bottommargin = max(100, maximum(length.(samples)) * 3)
leftmargin = 150
rightmargin = 25
topmargin = 25

width = max(1920, (length(samples) * 10) + 300)
height = max(1080, bottommargin + 600)

if taxon_index in [1, 2]
    plot = StatsPlots.groupedbar(
        values,
        title = "read-classification - kraken - $(kraken_db) - $(taxon_level)",
        titlefontsize = 12,
        xticks = (1:length(samples), samples),
        xlims = (0, length(samples)+1),
        xtickfontsize = 6,
        size= (width, height),
        xrotation=90,
        ylabel = "number of reads",
        labels = hcat(taxa...),
        leftmargin = (leftmargin)StatsPlots.Plots.PlotMeasures.px,
        topmargin = (topmargin)StatsPlots.Plots.PlotMeasures.px,
        rightmargin = (rightmargin)StatsPlots.Plots.PlotMeasures.px,
        bottommargin = (bottommargin)StatsPlots.Plots.PlotMeasures.px,
        legendmargins = 0,
        legend = :outertopright,
        legendfontsize = 6,
        bar_position = :stack,
        bar_width=0.7,
        seriescolor = hcat(reverse(colorscheme)...)
    )
    # display(plot)
    # for extension in [".png", ".svg"]
    for extension in [".png"]
        file = joinpath(results_dir, "taxonomic-breakdowns.kraken.$(kraken_db).$(taxon_index).$(taxon_level).total-reads") * extension
        StatsPlots.savefig(plot, file)
    end
end

In [None]:
################################################################################################

# re-tune height to be proportional to size of taxa list (tall legends)
height = max(height, length(taxa)*11)
plot = StatsPlots.groupedbar(
    normalized_values,
    title = "read-classification - kraken - $(kraken_db) - $(taxon_level)",
    titlefontsize = 12,
    xticks = (1:length(samples), samples),
    xlims = (0, length(samples)+1),
    xtickfontsize = 6,
    size= (width, height),
    xrotation=90,
    ylabel = "proportion of reads",
    labels = hcat(taxa...),
    leftmargin = (leftmargin)StatsPlots.Plots.PlotMeasures.px,
    topmargin = (topmargin)StatsPlots.Plots.PlotMeasures.px,
    rightmargin = (rightmargin)StatsPlots.Plots.PlotMeasures.px,
    bottommargin = (bottommargin)StatsPlots.Plots.PlotMeasures.px,
    legendmargins = 0,
    legend = :outertopright,
    bar_position = :stack,
    bar_width=0.7,
    seriescolor = hcat(reverse(colorscheme)...),
    legendfontsize = 6,
)

# for extension in [".png", ".svg"]
for extension in [".png"]
    file = joinpath(results_dir, "taxonomic-breakdowns.kraken.$(kraken_db).$(taxon_index).$(taxon_level).normalized-reads") * extension
    StatsPlots.savefig(plot, file)
end

In [None]:
################################################################################################
distance_matrix = Distances.pairwise(Distances.Euclidean(), normalized_values, dims=1)
clustering = Clustering.hclust(distance_matrix, branchorder=:optimal)
cluster_plot = StatsPlots.plot(
    clustering,
    xlims = (0, length(samples)+1),
    title = "read-classification - kraken - $(kraken_db) - $(taxon_level)",
    titlefontsize = 12,
    bottommargin = 0StatsPlots.Plots.PlotMeasures.px,
    leftmargin = (leftmargin)StatsPlots.Plots.PlotMeasures.px,
    rightmargin = (rightmargin)StatsPlots.Plots.PlotMeasures.px,
    topmargin = (topmargin)StatsPlots.Plots.PlotMeasures.px,
    xticks = false,
    yticks = false,
    yaxis = false,
    xaxis = false
)
taxonomy_plot = StatsPlots.groupedbar(
    normalized_values[clustering.order, :],
    bar_position = :stack,
    leftmargin = (leftmargin)StatsPlots.Plots.PlotMeasures.px,
    topmargin = 0StatsPlots.Plots.PlotMeasures.px,
    rightmargin = (rightmargin)StatsPlots.Plots.PlotMeasures.px,
    bottommargin = (bottommargin)StatsPlots.Plots.PlotMeasures.px,
    legendmargins = 0,
    xticks = (1:length(samples), samples[clustering.order]),
    xtickfontsize = 4,
    xrotation=90,
    xlims = (0, length(samples)+1),
    size= (width, height),
    ylabel = "proportion of reads",
    seriescolor = hcat(reverse(colorscheme)...),
    legend = false,
    labels = hcat(taxa...),
    legendfontsize = 6,
)

plot = StatsPlots.plot(
    cluster_plot,
    taxonomy_plot,
    margins = 0StatsPlots.Plots.PlotMeasures.px,
    layout=StatsPlots.grid(2,1, heights=[0.2,0.8])
)
# display(plot)
# for extension in [".png", ".svg"]
for extension in [".png"]
    file = joinpath(results_dir, "taxonomic-breakdowns.kraken.$(kraken_db).$(taxon_index).$(taxon_level).normalized-reads.clustered") * extension
    StatsPlots.savefig(plot, file)
end

In [None]:
fit_pca = MultivariateStats.fit(MultivariateStats.PCA, normalized_values')
transformed_observations = MultivariateStats.transform(fit_pca, normalized_values')

xs = transformed_observations[1, :]
if size(transformed_observations, 1) >= 2
    ys = transformed_observations[2, :]
else
    ys = zeros(length(xs))
end
if size(transformed_observations, 1) >= 3
    zs = transformed_observations[3, :]
else
    zs = zeros(length(xs))
end

plot = 
StatsPlots.scatter(
    xs,
    ys,
    zs,
    xlabel = "PC1",
    ylabel = "PC2",
    zlabel = "PC3",
    legend=false,
    title = "read-classification - kraken - $(kraken_db) - $(taxon_level)",
    titlefontsize = 10,
    size = (640, 480),
    margins = 20StatsPlots.px,
)

display(plot)
for extension in [".png"]
    file = joinpath(results_dir, "taxonomic-breakdowns.kraken.$(kraken_db).$(taxon_index).$(taxon_level).pca") * extension
    StatsPlots.savefig(plot, file)
end

In [None]:
# - loop through the above
# - make a matrix where each row is a taxon rank and each column is a dataset
# - make stacked barplots for each sample

In [None]:
# colorscheme = Colors.distinguishable_colors(length(taxa), [Colors.RGB(1,1,1), Colors.RGB(0,0,0)], dropseed=true)
# # top level classification rank to show absolute reads per sample
# # not very helpful at lower ranks since low read depth samples are too hard to see breakdowns

# bottommargin = (maximum(length.(samples)) * 3)
# leftmargin = 150
# rightmargin = 25
# topmargin = 25

# width = max(1920, (length(samples) * 10) + 300)
# height = max(1080, bottommargin + 600)

# if taxon_index in [1, 2]
#     plot = StatsPlots.groupedbar(
#         values,
#         title = "read-classification - kraken - $(kraken_db) - $(taxon_level)",
#         titlefontsize = 12,
#         xticks = (1:length(samples), samples),
#         xlims = (0, length(samples)+1),
#         xtickfontsize = 6,
#         size= (width, height),
#         xrotation=90,
#         ylabel = "number of reads",
#         labels = hcat(taxa...),
#         leftmargin = (leftmargin)StatsPlots.Plots.PlotMeasures.px,
#         topmargin = (topmargin)StatsPlots.Plots.PlotMeasures.px,
#         rightmargin = (rightmargin)StatsPlots.Plots.PlotMeasures.px,
#         bottommargin = (bottommargin)StatsPlots.Plots.PlotMeasures.px,
#         legendmargins = 0,
#         legend = :outertopright,
#         legendfontsize = 6,
#         bar_position = :stack,
#         bar_width=0.7,
#         seriescolor = hcat(reverse(colorscheme)...)
#     )
#     # display(plot)
#     # for extension in [".png", ".svg"]
#     for extension in [".png"]
#         file = joinpath(results_dir, "taxonomic-breakdowns.kraken.$(kraken_db).$(taxon_index).$(taxon_level).total-reads") * extension
#         StatsPlots.savefig(plot, file)
#     end
# end

# ################################################################################################

# # re-tune height to be proportional to size of taxa list (tall legends)
# height = max(height, length(taxa)*11)
# plot = StatsPlots.groupedbar(
#     normalized_values,
#     title = "read-classification - kraken - $(kraken_db) - $(taxon_level)",
#     titlefontsize = 12,
#     xticks = (1:length(samples), samples),
#     xlims = (0, length(samples)+1),
#     xtickfontsize = 6,
#     size= (width, height),
#     xrotation=90,
#     ylabel = "proportion of reads",
#     labels = hcat(taxa...),
#     leftmargin = (leftmargin)StatsPlots.Plots.PlotMeasures.px,
#     topmargin = (topmargin)StatsPlots.Plots.PlotMeasures.px,
#     rightmargin = (rightmargin)StatsPlots.Plots.PlotMeasures.px,
#     bottommargin = (bottommargin)StatsPlots.Plots.PlotMeasures.px,
#     legendmargins = 0,
#     legend = :outertopright,
#     bar_position = :stack,
#     bar_width=0.7,
#     seriescolor = hcat(reverse(colorscheme)...),
#     legendfontsize = 6,
# )

# # for extension in [".png", ".svg"]
# for extension in [".png"]
#     file = joinpath(results_dir, "taxonomic-breakdowns.kraken.$(kraken_db).$(taxon_index).$(taxon_level).normalized-reads") * extension
#     StatsPlots.savefig(plot, file)
# end

# ################################################################################################
# distance_matrix = Distances.pairwise(Distances.Euclidean(), normalized_values, dims=1)
# clustering = Clustering.hclust(distance_matrix, branchorder=:optimal)
# cluster_plot = StatsPlots.plot(
#     clustering,
#     xlims = (0, length(samples)+1),
#     title = "read-classification - kraken - $(kraken_db) - $(taxon_level)",
#     titlefontsize = 12,
#     bottommargin = 0StatsPlots.Plots.PlotMeasures.px,
#     leftmargin = (leftmargin)StatsPlots.Plots.PlotMeasures.px,
#     rightmargin = (rightmargin)StatsPlots.Plots.PlotMeasures.px,
#     topmargin = (topmargin)StatsPlots.Plots.PlotMeasures.px,
#     xticks = false,
#     yticks = false,
#     yaxis = false,
#     xaxis = false
# )
# taxonomy_plot = StatsPlots.groupedbar(
#     normalized_values[clustering.order, :],
#     bar_position = :stack,
#     leftmargin = (leftmargin)StatsPlots.Plots.PlotMeasures.px,
#     topmargin = 0StatsPlots.Plots.PlotMeasures.px,
#     rightmargin = (rightmargin)StatsPlots.Plots.PlotMeasures.px,
#     bottommargin = (bottommargin)StatsPlots.Plots.PlotMeasures.px,
#     legendmargins = 0,
#     xticks = (1:length(samples), samples[clustering.order]),
#     xtickfontsize = 4,
#     xrotation=90,
#     xlims = (0, length(samples)+1),
#     size= (width, height),
#     ylabel = "proportion of reads",
#     seriescolor = hcat(reverse(colorscheme)...),
#     legend = false,
#     labels = hcat(taxa...),
#     legendfontsize = 6,
# )

# plot = StatsPlots.plot(
#     cluster_plot,
#     taxonomy_plot,
#     margins = 0StatsPlots.Plots.PlotMeasures.px,
#     layout=StatsPlots.grid(2,1, heights=[0.2,0.8])
# )
# # display(plot)
# # for extension in [".png", ".svg"]
# for extension in [".png"]
#     file = joinpath(results_dir, "taxonomic-breakdowns.kraken.$(kraken_db).$(taxon_index).$(taxon_level).normalized-reads.clustered") * extension
#     StatsPlots.savefig(plot, file)
# end


- run PCA and k-means cluster to find meaningful groups
- repeat for mmseqs protein and blast nt