In [None]:
import Pkg
# Pkg.activate(".")
# Pkg.add("Revise")
# import Revise

# Pkg.develop(path="../../..")
# import Mycelia

pkgs = [
"DataFrames",
"StatsBase",
"StatsPlots",
"uCSV",
"ProgressMeter",
"Distances",
"Clustering",
"Colors",
"MultivariateStats"
]
Pkg.add(pkgs)
for pkg in pkgs
    eval(Meta.parse("import $pkg"))
end

In [None]:
data_dir = joinpath(dirname(pwd()), "data")

In [None]:
results_dir = joinpath(data_dir, "results")

In [None]:
kraken_db = "k2_pluspfp_20221209"

In [None]:
viral_taxids = Set(Mycelia.list_subtaxa(10239))
taxon_levels = Mycelia.list_ranks()

In [None]:
# i = 1
# i = 2
# i = 3
# i = 4
# i = 5
# i = 6
# i = 7
i = 8

#### i = 9

(taxon_index, taxon_level) = collect(enumerate(taxon_levels))[i]
println("$(taxon_index) - $(taxon_level)")
rank_table = Mycelia.list_rank(taxon_level)

In [None]:
if taxon_index > 2
    taxid_is_viral = map(taxid -> taxid in viral_taxids, rank_table[!, "taxid"])
    rank_table = rank_table[taxid_is_viral, :]
end
rank_table

In [None]:
# println("$(taxon_index) - $(taxon_level)")
# rank_table = Mycelia.list_rank(taxon_level)
rank_taxids = Set(rank_table[!, "taxid"])

cross_sample_taxon_report = joinpath(results_dir, "$(kraken_db).$(taxon_level).tsv")
cross_sample_taxon_figure_png = joinpath(results_dir, "$(kraken_db).$(taxon_level).png")
# cross_sample_taxon_figure_svg = joinpath(results_dir, "$(kraken_db).$(taxon_level).svg")
cross_sample_taxon_figure_skip_unclassified_png = joinpath(results_dir, "$(kraken_db).$(taxon_level).skip_unclassified.png")
if !isfile(cross_sample_taxon_report)
    cross_sample_taxon_report_table = DataFrames.DataFrame()
    ProgressMeter.@showprogress for SRR_path in SRR_paths
        SRR = basename(SRR_path)
        kraken_dir = mkpath(joinpath(SRR_path, "kraken"))
        output_file = joinpath(kraken_dir, "$(SRR).$(kraken_db).kraken-output.tsv")
        report_file = joinpath(kraken_dir, "$(SRR).$(kraken_db).kraken-report.tsv")
        report_table = Mycelia.read_kraken_report(report_file)
        taxon_level_report = report_table[map(x -> x in rank_taxids, report_table[!, "ncbi_taxonid"]), :]
        taxon_level_report[!, "sample_identifier"] .= SRR
        append!(cross_sample_taxon_report_table, taxon_level_report)
    end
    uCSV.write(cross_sample_taxon_report, cross_sample_taxon_report_table, delim='\t')
else
    cross_sample_taxon_report_table = DataFrames.DataFrame(uCSV.read(cross_sample_taxon_report, delim='\t', header=1))
end

In [None]:
# this could be a function here 

In [None]:
cross_sample_taxon_report_summary = cross_sample_taxon_report_table[!, 
    DataFrames.Not([
            "percentage_of_fragments_at_or_below_taxon",
            "number_of_fragments_assigned_directly_to_taxon",
            "rank"
        ])]
cross_sample_taxon_report_summary[!, "taxon"] = map(row -> string(row["ncbi_taxonid"]) * "_" * row["scientific_name"], DataFrames.eachrow(cross_sample_taxon_report_summary))
cross_sample_taxon_report_summary = cross_sample_taxon_report_summary[!, DataFrames.Not([
            "ncbi_taxonid",
            "scientific_name"
        ])]

# assert sortedness & uniqueness (should be a no-op)
unique!(DataFrames.sort!(cross_sample_taxon_report_summary, ["sample_identifier", "taxon"]))

In [None]:
# taxa = String[]
# samples = String[]
# n_samples = length(unique(cross_sample_taxon_report_summary[!, "sample_identifier"]))
# n_taxa = length(unique(cross_sample_taxon_report_summary[!, "taxon"]))
# values = zeros(n_samples, n_taxa)

taxa = sort(unique(cross_sample_taxon_report_summary[!, "taxon"]))
samples = sort(unique(cross_sample_taxon_report_summary[!, "sample_identifier"]))
values = zeros(length(samples), length(taxa))
ProgressMeter.@showprogress for (column_index, taxon_table) in enumerate(DataFrames.groupby(cross_sample_taxon_report_summary, "taxon"))
    taxon = taxon_table[1, "taxon"]
    # push!(taxa, taxon)
    # @show taxon
    @assert taxa[column_index] == taxon
    for (row_index, sample_table) in enumerate(DataFrames.groupby(taxon_table, "sample_identifier"))
        @assert DataFrames.nrow(sample_table) == 1
        row = sample_table[1, :]
        sample = row["sample_identifier"]
        # if column_index == 1
        #     push!(samples, sample)
        # else
        @assert samples[row_index] == sample
        # end
        values[row_index, column_index] = row["number_of_fragments_at_or_below_taxon"]
    end
end
values

In [None]:
filter(x -> occursin(r"Enterovirus"i, x), taxa)

In [None]:
# sort taxa so largest single sample taxa is first
taxa_frequency_ordering = sortperm(maximum.(eachcol(values)))
values = values[:, taxa_frequency_ordering]
taxa = taxa[taxa_frequency_ordering]
# find taxa that have no representation, and filter them out
taxa_is_detected = [sum(col) > 0 for col in eachcol(values)]
values = values[:, taxa_is_detected]
taxa = taxa[taxa_is_detected]

In [None]:
colorscheme = Colors.distinguishable_colors(length(taxa), [Colors.RGB(1,1,1), Colors.RGB(0,0,0)], dropseed=true)

In [None]:
normalized_values = values ./ sum(values, dims=2)

In [None]:
distance_matrix = Distances.pairwise(Distances.Euclidean(), normalized_values, dims=1)
clustering = Clustering.hclust(distance_matrix, branchorder=:optimal)

In [None]:
# fit_pca = MultivariateStats.fit(MultivariateStats.PCA, normalized_values')
# transformed_observations = MultivariateStats.transform(fit_pca, normalized_values')

# # x = transformed_observations[1, :]
# # y = transformed_observations[2, :]

# # pc_plot = 
# # StatsPlots.scatter(
# #     x,
# #     y,
# #     # [z[control_indices], z[case_indices]],
# #     xlabel = "PC1",
# #     ylabel = "PC2",
# #     # zlabel = "PC3",
# #     # labels = hcat(["sample ID goes here"]...),
# #     # title = "Case vs. Control",
# #     legend = :outertopright,
# #     size = (1000, 500),
# #     # margins = 5StatsPlots.mm,
# # )

# # display(pc_plot)

In [None]:
# top level classification rank to show absolute reads per sample
# not very helpful at lower ranks since low read depth samples are too hard to see breakdowns
if taxon_index in [1, 2]
    aspect_ratio = [9, 2]
    scale = 500
    plot = StatsPlots.groupedbar(
        values,
        title = "read-classification - $(taxon_level)",
        xticks = (1:length(samples), samples),
        xlims = (0, length(samples)+1),
        xtickfontsize = 6,
        size= aspect_ratio .* scale,
        xrotation=90,
        ylabel = "number of reads",
        labels = hcat(taxa...),
        margins = 100StatsPlots.Plots.PlotMeasures.px,
        legendmargins = 0,
        legend = :outertopright,
        legendfontsize = 6,
        bar_position = :stack,
        bar_width=0.7,
        seriescolor = hcat(reverse(colorscheme)...)
    )
    # display(plot)
    for extension in [".png", ".svg"]
        file = joinpath(results_dir, "taxonomic-breakdowns.kraken.$(taxon_index).$(taxon_level).total-reads") * extension
        StatsPlots.savefig(plot, file)
    end
end

In [None]:
# need to tune this to be proportional to the # of taxa in the list
aspect_ratio = [9, 40]
scale = 500
plot = StatsPlots.groupedbar(
    normalized_values,
    title = "read-classification - $(taxon_level)",
    xticks = (1:length(samples), samples),
    xlims = (0, length(samples)+1),
    xtickfontsize = 6,
    size= aspect_ratio .* scale,
    xrotation=90,
    ylabel = "proportion of reads",
    labels = hcat(taxa...),
    margins = 100StatsPlots.Plots.PlotMeasures.px,
    legend = :outertopright,
    bar_position = :stack,
    bar_width=0.7,
    seriescolor = hcat(reverse(colorscheme)...),
    legendfontsize = 6,
)

for extension in [".png", ".svg"]
    file = joinpath(results_dir, "taxonomic-breakdowns.kraken.$(taxon_index).$(taxon_level).normalized-reads") * extension
    StatsPlots.savefig(plot, file)
end

In [None]:
aspect_ratio = [2, 1]
scale = 1000

cluster_plot = StatsPlots.plot(
    clustering,
    xlims = (1, length(samples)),
    title = "read-classification - $(taxon_level)",
    xticks = false,
    yticks = false,
    yaxis = false
)
taxonomy_plot = StatsPlots.groupedbar(
    normalized_values[clustering.order, :],
    bar_position = :stack,
    bottommargin = 50StatsPlots.Plots.PlotMeasures.px,
    leftmargin = 50StatsPlots.Plots.PlotMeasures.px,
    xticks = (1:length(samples), samples[clustering.order]),
    xtickfontsize = 4,
    xrotation=90,
    xlims = (0, length(samples)+1),
    size= aspect_ratio .* scale,
    ylabel = "proportion of reads",
    seriescolor = hcat(reverse(colorscheme)...),
    # bar_width=0.7,
    legend = false,
    labels = hcat(taxa...),
    legendfontsize = 6,
)

# Create a separate subplot as a legend
# legend_plot = StatsPlots.groupedbar(normalized_values[clustering.order, :], legend=true)
# plot!(legend_subplot, label="Series 1", legend=:best, linecolor=1)
# plot!(legend_subplot, label="Series 2", legend=:best, linecolor=2)

# Layout the main plot and the legend subplot
# l = @layout [a{0.7w}; b{0.3w}]
# plot(p, legend_subplot, layout=l)

plot = StatsPlots.plot(
    cluster_plot, 
    taxonomy_plot,
    layout=StatsPlots.grid(2,1, heights=[0.2,0.8])
)
# display(plot)
for extension in [".png", ".svg"]
    file = joinpath(results_dir, "taxonomic-breakdowns.kraken.$(taxon_index).$(taxon_level).normalized-reads.clustered") * extension
    StatsPlots.savefig(plot, file)
end
# end

In [None]:
# for g in DataFrames.groupby(cross_sample_taxon_report_table, "sample_identifier")
#     # @show sum(g[!, "percentage_of_fragments_at_or_below_taxon"])
#     if sum(g[!, "percentage_of_fragments_at_or_below_taxon"]) == 100
#     else
#         identifier = g[1, "sample_identifier"]
#         println("\"$(identifier)\",")
#     end
# end

In [None]:
# [number_of_fragments_at_or_below_taxon	ncbi_taxonid	scientific_name	sample_identifier

In [None]:

# Mycelia.list_rank("kingdom")
# Mycelia.list_rank("phylum")
# Mycelia.list_rank("class")
# Mycelia.list_rank("order")
# Mycelia.list_rank("family")
# Mycelia.list_rank("genus")
# Mycelia.list_rank("species")

- loop through the above
- make a matrix where each row is a taxon rank and each column is a dataset
- make stacked barplots for each sample
- run PCA and k-means cluster to find meaningful groups
- repeat for mmseqs protein and blast nt