In [None]:
# if hit plotting library issues, try resetting LD path for julia
# can set in ~/.local/share/jupyter/kernels/
@assert ENV["LD_LIBRARY_PATH"] == ""
import Pkg
# Pkg.activate(;temp=true)
Pkg.activate("20240909.mapping-vs-assembly")
Pkg.add("Revise")
import Revise

# Pkg.develop(path="/global/cfs/projectdirs/m4269/cjprybol/Mycelia")
# Pkg.develop(path="../../..")
Pkg.develop(path="$(homedir())/workspace/Mycelia")
import Mycelia

pkgs = String[
    "DataFrames",
    "uCSV",
    "OrderedCollections",
    "CSV",
    "ProgressMeter",
    "StatsBase"
    # "XAM",
    # "CodecZlib"
]
Pkg.add(pkgs)
for pkg in pkgs
    eval(Meta.parse("import $pkg"))
end

In [None]:
project_dir = dirname(pwd())
data_dir = mkpath(joinpath(project_dir, "data"))

In [None]:
db = "nt"
path_to_db = joinpath(homedir(), "workspace", "blastdb", db)
# path_to_db = Mycelia.download_blast_db(db=db, source="ncbi")
# compressed_fasta_export = Mycelia.export_blast_db(path_to_db = path_to_db)
compressed_fasta_export = path_to_db * ".fna.gz"

In [None]:
sra_dirs = readdir(joinpath(data_dir, "SRA"), join=true)

In [None]:
# sort SRA dirs by size so that smallest jobs will run first

In [None]:
sra_filesizes = []
for sra_dir in sra_dirs
    trim_galore_dir_contents = readdir(joinpath(sra_dir, "trim_galore"), join=true)
    forward = first(filter(f -> occursin(r"_1_val_1\.fq\.gz$", f), trim_galore_dir_contents))
    reverse = first(filter(f -> occursin(r"_2_val_2\.fq\.gz$", f), trim_galore_dir_contents))
    push!(sra_filesizes, sra_dir => filesize(forward) + filesize(reverse))
end
size_sorted_sra_directories = first.(sort(sra_filesizes, by=x->x[2]))

In [None]:
# only the ones that have read mappings complete
sam_extension_regex = r"\.nt\.fna\.gz\.xsr\.I51G\.mmi\.minimap2\.sam\.gz$"
sra_directories_subset = [d for d in size_sorted_sra_directories if !isempty(filter(x -> occursin(sam_extension_regex, x), readdir(joinpath(d, "trim_galore"))))]

In [None]:
# visualize these in graphs
# for friday - compare to assembly outputs

In [None]:
# blastn has the original calls

In [None]:
xams = [first(filter(x -> occursin(sam_extension_regex, x), readdir(joinpath(d, "trim_galore"), join=true))) for d in sra_directories_subset]

In [None]:
@assert all(isfile.(xams))

In [None]:
# map to blast NT
blast_db = "nt"
blast_dbs_dir = joinpath([homedir(), "workspace", "blastdb"])
# path_to_db = joinpath(homedir(), "workspace", "blastdb", blast_db)
blast_db_path = joinpath(blast_dbs_dir, blast_db)

In [None]:
# write blast db taxonomy table to disk

In [None]:
@time blast_db_taxonomy_table = Mycelia.load_blast_db_taxonomy_table(Mycelia.export_blast_db_taxonomy_table(path_to_db = blast_db_path))

In [None]:
# do a disk-based join to get an updated record table with taxids
# write that to disk

In [None]:
# Extract the taxid column from that data and get the unique set
# feed the unique set of taxa ids into getting a summarized lineage table from taxonkit
# write that to disk

# do a disk-based join of the record table (now with taxids) and the taxonkit taxonomy table to get taxa level table
# get a disk-based count of taxa by extracting that column by name and counting the unique values
# read that into a dictionary and continue the function as normal

In [None]:
taxa_level = "superkingdom"
# taxa_level = "family"
# taxa_level = "genus"
# taxa_level = "species"
file_to_taxa_relative_abundances = OrderedCollections.OrderedDict{String, Dict{Union{Missing, String}, Float64}}()
ProgressMeter.@showprogress for xam in xams
    # @time record_table = Mycelia.parse_xam_to_primary_mapping_table(xam)
    # # @time record_table = Mycelia.parse_xam_to_mapped_records_table(xam, primary_only=true)
    # # @time record_table = Mycelia.parse_xam_to_summary_table(xam)
    # # record_table = record_table[record_table[!, "isprimary"], :]
    # record_table = DataFrames.innerjoin(record_table, blast_db_taxonomy_table, on="reference" => "sequence_id")
    # unique_taxids = sort(unique(record_table[!, "taxid"]))
    # record_table = DataFrames.innerjoin(record_table, Mycelia.taxids2taxonkit_summarized_lineage_table(unique_taxids), on="taxid")
    
    # use samtools to write a headerless sam filtered to only primary mappings
    # pipe that to extract the columns of interest
    # write that to table 1
    query_ref_table_file = replace(xam, r"\.gz$" => ".query-ref.tsv.gz")
    if !isfile(query_ref_table_file) || (filesize(query_ref_table_file) == 0)
        p = pipeline(
                `gzip -dc $(xam)`,
                `$(Mycelia.CONDA_RUNNER) run --live-stream -n samtools samtools view --no-header --exclude-flags 2308 -`,
                `awk '{OFS="\t"}{print $1, $3}'`,
                `gzip`)
        # very long
        @time run(pipeline(p, query_ref_table_file))
    end
    # kinda long
    @time query_ref_table = CSV.read(open(pipeline(`gzip -dc $(query_ref_table_file)`)), DataFrames.DataFrame, header = ["query", "reference"], delim='\t')
    # long
    @time query_ref_table = DataFrames.innerjoin(query_ref_table, blast_db_taxonomy_table, on="reference" => "sequence_id")
    @time unique_taxids = sort(unique(query_ref_table[!, "taxid"]))
    # long
    # get just the two columns of interest to save memory
    @time taxa_summary_lineage_table = Mycelia.taxids2taxonkit_summarized_lineage_table(unique_taxids)[!, ["taxid", taxa_level]]
    @time query_ref_table = DataFrames.innerjoin(query_ref_table, taxa_summary_lineage_table, on="taxid")
    @time file_to_taxa_relative_abundances[xam] = Mycelia.normalize_countmap(StatsBase.countmap(query_ref_table[!, taxa_level]))
    display(InteractiveUtils.varinfo(sortby=:size, minsize=Int(1e9)))
end
file_to_taxa_relative_abundances

In [None]:
unique_sorted_taxa = unique(sort(collect(reduce(union, keys.(values(file_to_taxa_relative_abundances))))))

n_samples = length(file_to_taxa_relative_abundances)

abundance_matrix = zeros(length(unique_sorted_taxa), n_samples)
taxa_names_to_indices = Dict(t => i for (i, t) in enumerate(unique_sorted_taxa))
for (column, (file, abundances)) in enumerate(file_to_taxa_relative_abundances)
    # @show column, sample
    for (taxa, relative_abundance) in abundances
        row = taxa_names_to_indices[taxa]
        abundance_matrix[row, column] = relative_abundance
    end
end
abundance_matrix

In [None]:


file_to_identifier = Dict(row["xam"] => row["BioSampleName"] for row in DataFrames.eachrow(sample_to_barcode_table))
file_labels = [file_to_identifier[k] for k in keys(file_to_taxa_relative_abundances)]
abundance_matrix = abundance_matrix[:, sortperm(file_labels)]
file_labels = sort(file_labels)

# drop human and missing
filtered_indices = findall(x -> !(x in Set(["Homo", missing])), vec(unique_sorted_taxa))
unique_sorted_taxa = unique_sorted_taxa[filtered_indices]
abundance_matrix = abundance_matrix[filtered_indices, :]

# Calculate the sum of each col
col_sums = sum(abundance_matrix, dims=1)
# Normalize each element by dividing by the col sum
abundance_matrix = abundance_matrix ./ row_sums

In [None]:
# # vaginal_indices
# indices = [1, 2, 5, 6]
# label = "vaginal"
# top_N = 10

# perianal_indices
indices = [3, 4, 7, 8]
label = "perianal"
top_N = 25

file_labels_subset = file_labels[indices]
abundance_matrix_subset = abundance_matrix[:, indices]

sort_perm = sortperm(vec(Statistics.sum(abundance_matrix_subset, dims=2)))
unique_sorted_taxa_subset = unique_sorted_taxa[sort_perm]
abundance_matrix_subset = abundance_matrix_subset[sort_perm, :]
non_zero_indices = findall(vec(Statistics.sum(abundance_matrix_subset, dims=2)) .> 0.0)
unique_sorted_taxa_subset = unique_sorted_taxa_subset[non_zero_indices]
abundance_matrix_subset = abundance_matrix_subset[non_zero_indices, :]
colorscheme = Colors.distinguishable_colors(length(unique_sorted_taxa_subset), [Colors.RGB(1,1,1), Colors.RGB(0,0,0)], dropseed=true)


StatsPlots.groupedbar(
    abundance_matrix_subset'[:, end-(top_N-1):end],
    bar_position = :stack,
    bar_width=0.7, 
    label = hcat(unique_sorted_taxa_subset...)[:, end-(top_N-1):end], 
    xticks = (1:length(file_labels_subset), sort(file_labels_subset)), 
    xrotation = 45,
    ylabel = "proportion of reads",
    xlabel = "$(label) sample",
    # title = "$(taxa_level) relative abundance (top $(top_N-2))",
    title = "$(taxa_level) relative abundance (top $(top_N) classified and non-human)",
    legend = :outertopright,
    # legend = false,
    size = (1000, 500),
    margins = 15StatsPlots.Plots.PlotMeasures.mm,
    seriescolor = hcat(reverse(colorscheme)...)[:, end-(top_N-1):end]
)