The objective of this notebook is to:
- join all of the kraken results with the sample metadata
- subset to samples with metadata
- collapse P3+ into Other
- show the time series at each level for each participant (P1, P2, and Other)

In [None]:
import Pkg
pkgs = [
    "Revise",
    "DataFrames",
    "StatsBase",
    "StatsPlots",
    "uCSV",
    "ProgressMeter",
    "Distances",
    "Colors",
    "MultivariateStats",
    "Dates",
    "CategoricalArrays",
    "Statistics"
]
# Pkg.add(pkgs)
for pkg in pkgs
    eval(Meta.parse("import $pkg"))
end
import Mycelia

In [None]:
data_dir = joinpath(dirname(pwd()), "data")

In [None]:
results_dir = joinpath(data_dir, "results")

In [None]:
# load in metadata
metadata_dir = joinpath(dirname(pwd()), "metadata")

exposome_environmental_data = DataFrames.DataFrame(uCSV.read(
    joinpath(metadata_dir, "metadata_exposome.rds.tsv"),
    delim='\t',
    header=1,
    typedetectrows=300
))

joint_sample_metadata = DataFrames.DataFrame(uCSV.read(
    joinpath(metadata_dir, "exposome/joint_sample_metadata.tsv"),
    delim='\t',
    header=1,
    typedetectrows=300
))

@assert joint_sample_metadata[!, "Library Name"] == joint_sample_metadata[!, "LibraryName"]
joint_metadata = DataFrames.innerjoin(
    joint_sample_metadata,
    exposome_environmental_data,
    on="Library Name" => "samplenames");


# recode P3 and beyond to Other, since they don't have enough samples to do much analysis on
joint_metadata[!, "aownership"] = map(x -> x in Set(["P1", "P2"]) ? x : "Others", joint_metadata[!, "aownership"])

# aownership
metadata_by_owner = DataFrames.groupby(joint_metadata, "aownership");

In [None]:
sample_paths = sort(joinpath.(data_dir, "SRA", joint_metadata[!, "Run"]))
kraken_db = "k2_pluspfp"
kraken_db_regex = Regex("$(kraken_db)_\\d{8}")
kraken_reports = map(path ->
    first(filter(x -> occursin(kraken_db_regex, x) && occursin(r"kraken-report\.tsv$", x), readdir(joinpath(path, "kraken"), join=true))),
    sample_paths)

# create a full joint table so that we can subset dynamically down below without needing to re-read all of them over and over again
joint_report_table = DataFrames.DataFrame()
ProgressMeter.@showprogress for kraken_report in kraken_reports
    report_table = Mycelia.read_kraken_report(kraken_report)
    report_table[!, "report"] .= basename(kraken_report)
    append!(joint_report_table, report_table)
end
joint_report_table[!, "taxon"] = map(row -> string(row["ncbi_taxonid"]) * "_" * row["scientific_name"], DataFrames.eachrow(joint_report_table))
joint_report_table[!, "sample_identifier"] = string.(first.(split.(joint_report_table[!, "report"], '.')))
joint_report_table

In [None]:
# taxon_levels = Mycelia.list_ranks()
viral_tax_ids = Mycelia.list_subtaxa(10239)

# rank_level = 1
# rank_level = 2
# rank_level = 3
# rank_level = 4
# rank_level = 5
# rank_level = 6
# rank_level = 7
rank_level = 8

(taxon_index, taxon_level) = collect(enumerate(taxon_levels))[rank_level]
println("$(taxon_index) - $(taxon_level)")
rank_table = Mycelia.list_rank(taxon_level)

In [None]:
# get a list of all viral taxids across the databases
if rank_level <= 2
    filtered_tax_ids = Set(rank_table[!, "taxid"])
elseif rank_level > 2
    filtered_tax_ids = Set(viral_tax_ids)
    filtered_rank_table = rank_table[map(taxid -> taxid in filtered_tax_ids, rank_table[!, "taxid"]), :]
    filtered_tax_ids = Set(filtered_rank_table[!, "taxid"])
end

In [None]:
# cross_sample_taxon_report = joinpath(results_dir, "$(kraken_db).$(taxon_level).ictv.tsv")
# cross_sample_taxon_figure_png = joinpath(results_dir, "$(kraken_db).$(taxon_level).ictv.png")

In [None]:
cross_sample_taxon_report_summary = joint_report_table[map(x -> x in filtered_tax_ids, joint_report_table[!, "ncbi_taxonid"]), DataFrames.Not(["percentage_of_fragments_at_or_below_taxon", "number_of_fragments_assigned_directly_to_taxon", "rank"])]
# assert sortedness & uniqueness (should be a no-op)
unique!(DataFrames.sort!(cross_sample_taxon_report_summary, ["sample_identifier", "taxon"]))
# filter out zero hits
cross_sample_taxon_report_summary = cross_sample_taxon_report_summary[cross_sample_taxon_report_summary[!, "number_of_fragments_at_or_below_taxon"] .> 0, :]
StatsPlots.histogram(log2.(cross_sample_taxon_report_summary[!, "number_of_fragments_at_or_below_taxon"]))

In [None]:
cross_sample_taxon_report_summary = cross_sample_taxon_report_summary[cross_sample_taxon_report_summary[!, "number_of_fragments_at_or_below_taxon"] .>= 3, :]
cross_sample_taxon_report_summary[!, "sample_identifier"] = string.(first.(split.(cross_sample_taxon_report_summary[!, "sample_identifier"], '.')))

sorted_taxa_counts_table = sort(DataFrames.combine(
    DataFrames.groupby(
        cross_sample_taxon_report_summary[!, 
            ["number_of_fragments_at_or_below_taxon", "taxon"]], "taxon"),
    "number_of_fragments_at_or_below_taxon" => Statistics.mean), "number_of_fragments_at_or_below_taxon_mean", rev=true)

In [None]:
unique_taxa = sorted_taxa_counts_table[!, "taxon"]
colorscheme = Colors.distinguishable_colors(length(unique_taxa), [Colors.RGB(1,1,1), Colors.RGB(0,0,0)], dropseed=true)
taxa_to_color = Dict(t => c for (t, c) in zip(unique_taxa, colorscheme))

In [None]:
top_n = 60

for participant in 1:3

    participant_table = DataFrames.innerjoin(
        metadata_by_owner[participant],
        cross_sample_taxon_report_summary,
        on="Run" => "sample_identifier"
    )

    participant_table = participant_table[!, [
        "aownership",
        "season",
        "geo_loc_name",
        "weekend",
        "temperature",
        "humid",
        "particle",
        "Run",
        "date.start",
        "date.end",
        "ncbi_taxonid",
        "scientific_name",
        "taxon",
        "number_of_fragments_at_or_below_taxon",
        ]]

    participant_table[!, "date.start"] = Dates.Date.(participant_table[!, "date.start"], "yyyy-mm-dd")
    participant_table[!, "date.end"] = Dates.Date.(participant_table[!, "date.end"], "yyyy-mm-dd")

    sort!(participant_table, "date.start")

    participant_table[!, "date.start_relative"] = participant_table[!, "date.start"] .- first(participant_table[!, "date.start"])

    participant_table[!, "date.end_relative"] = participant_table[!, "date.end"] .- first(participant_table[!, "date.start"])

    participant_table[!, "duration"] = participant_table[!, "date.end"] .- participant_table[!, "date.start"]

    participant = participant_table[1, "aownership"]

    samples = sort(unique(participant_table[!, "Run"]))
    taxon = sort(unique(participant_table[!, "taxon"]))
    samples_map = Dict(sample => i for (i, sample) in enumerate(samples))
    counts = zeros(length(samples), length(taxon))
    for (column_index, taxon_table) in enumerate(DataFrames.groupby(sort(participant_table, "taxon"), "taxon"))
        for sample_table in DataFrames.groupby(taxon_table, "Run")
            sample = sample_table[1, "Run"]
            row_index = samples_map[sample]
            @assert DataFrames.nrow(sample_table) == 1
            counts[row_index, column_index] = sum(sample_table[!, "number_of_fragments_at_or_below_taxon"])
        end
    end

    # sort taxa so largest single group is at the bottom
    frequency_ordering = sortperm(maximum.(eachcol(counts)))
    counts = counts[:, frequency_ordering]
    taxon = taxon[frequency_ordering]
    # find taxa that have no representation, and filter them out
    is_detected = [sum(col) >= 1 for col in eachcol(counts)]
    counts = counts[:, is_detected]
    taxon = taxon[is_detected]

    # # drop samples that have no data, not sure this is relevant now that we dropped negative control samples
    # sample_has_classifications = [sum(row) > 0 for row in eachrow(counts)]
    # counts = counts[sample_has_classifications, :]
    # samples = samples[sample_has_classifications]
    
    if size(counts, 2) > top_n
        counts = counts[:, (end-top_n+1):end]
    end
    
    # unique_taxa = sort(unique(participant_table[!, "ncbi_taxonid"]))
    # colorscheme = Colors.distinguishable_colors(length(unique_taxa), [Colors.RGB(1,1,1), Colors.RGB(0,0,0)], dropseed=true)

    normalized_counts = counts ./ sum(counts, dims=2)

    bottommargin = (maximum(length.(samples)) * 5)
    leftmargin = 150
    rightmargin = 25
    topmargin = 25

    width = max(1920, (size(counts, 1) * 12) + 300)
    height = max(1080, bottommargin + 600)
    height = max(height, size(counts, 2)*11)
    
    legendfontsize=12

    xtickdates = sort(unique(participant_table[!, ["date.end", "Run"]]))[!, "date.end"]

    plot = StatsPlots.groupedbar(
        log10.(counts .+ 1),
        title = "read-classification - $(participant) - kraken - $(kraken_db) - $(taxon_level)",
        xticks = (1:length(samples), xtickdates),
        xlims = (0, length(samples)+1),
        size=(width, height),
        xrotation=90,
        ylabel = "log10(number of reads)",
        labels = hcat(taxon...),
        leftmargin = (leftmargin)StatsPlots.Plots.PlotMeasures.px,
        topmargin = (topmargin)StatsPlots.Plots.PlotMeasures.px,
        rightmargin = (rightmargin)StatsPlots.Plots.PlotMeasures.px,
        bottommargin = (bottommargin)StatsPlots.Plots.PlotMeasures.px,
        legendmargins = 0,
        legend = :outertopright,
        bar_position = :stack,
        bar_width=0.7,
        seriescolor = hcat([taxa_to_color[t] for t in taxon]...),
    )
    display(plot)
    for extension in [".png"]
        file = joinpath(results_dir, "taxonomic-breakdowns.kraken.$(kraken_db).$(taxon_index).$(taxon_level).by-participant.$(participant).total-reads") * extension
        StatsPlots.savefig(plot, file)
    end

    plot = StatsPlots.groupedbar(
        normalized_counts,
        title = "read-classification - $(participant) - kraken - $(kraken_db) - $(taxon_level)",
        xticks = (1:length(samples), xtickdates),
        xlims = (0, length(samples)+1),
        size= (width, height),
        xrotation=90,
        ylabel = "proportion of reads",
        labels = hcat(taxon...),
        leftmargin = (leftmargin)StatsPlots.Plots.PlotMeasures.px,
        topmargin = (topmargin)StatsPlots.Plots.PlotMeasures.px,
        rightmargin = (rightmargin)StatsPlots.Plots.PlotMeasures.px,
        bottommargin = (bottommargin)StatsPlots.Plots.PlotMeasures.px,
        legendmargins = 0,
        legend = :outertopright,
        bar_position = :stack,
        bar_width=0.7,
        seriescolor = hcat([taxa_to_color[t] for t in taxon]...),
    )
    display(plot)
    for extension in [".png"]
        file = joinpath(results_dir, "taxonomic-breakdowns.kraken.$(kraken_db).$(taxon_index).$(taxon_level).by-participant.$(participant).normalized-reads") * extension
        StatsPlots.savefig(plot, file)
    end
end