In [None]:
# run(`conda create --channel conda-forge --channel bioconda --channel defaults --strict-channel-priority --name blast blast`)
# run(`conda create --channel conda-forge --channel bioconda --channel defaults --strict-channel-priority --name taxonkit taxonkit`)

In [None]:
import Pkg
pkgs = [
    "Revise",
    "DataFrames",
    "uCSV",
    "StatsPlots",
    "StatsBase",
    "FreqTables",
    "Conda",
    "ProgressMeter",
    "PrettyTables",
    "Distances",
    "Statistics",
    "Kmers"
]
# Pkg.add(pkgs)
for pkg in pkgs
    eval(Meta.parse("import $pkg"))
end
import Mycelia

In [None]:
base_dir = dirname(pwd())
data_dir = joinpath(base_dir, "data")
results_dir = joinpath(data_dir, "results")

# load in metadata
metadata_dir = joinpath(dirname(pwd()), "metadata")

exposome_environmental_data = DataFrames.DataFrame(uCSV.read(
    joinpath(metadata_dir, "metadata_exposome.rds.tsv"),
    delim='\t',
    header=1,
    typedetectrows=300
))

joint_sample_metadata = DataFrames.DataFrame(uCSV.read(
    joinpath(metadata_dir, "exposome/joint_sample_metadata.tsv"),
    delim='\t',
    header=1,
    typedetectrows=300
))

@assert joint_sample_metadata[!, "Library Name"] == joint_sample_metadata[!, "LibraryName"]

joint_metadata = DataFrames.innerjoin(
    joint_sample_metadata,
    exposome_environmental_data,
    on="Library Name" => "samplenames")

sample_directories = joinpath.(data_dir, "SRA", joint_metadata[!, "Run"])

In [None]:
viral_contigs_by_tool = Dict{String, Dict{String, Set{String}}}()

In [None]:
joint_genomad_results = DataFrames.DataFrame()
# sample_directory = first(sample_directories)
ProgressMeter.@showprogress for sample_directory in sample_directories
    genomad_virus_summary = joinpath(sample_directory, "genomad", "final.contigs.fastg.gfa_summary", "final.contigs.fastg.gfa_virus_summary.tsv")
    genomad_results = DataFrames.DataFrame(uCSV.read(genomad_virus_summary, delim='\t', header=1, typedetectrows=100)...)
    genomad_results[!, "sample_id"] .= basename(sample_directory)
    append!(joint_genomad_results, genomad_results, promote=true)
end
joint_genomad_results[!, "seq_name"] = string.(joint_genomad_results[!, "seq_name"])
joint_genomad_results

In [None]:
viral_contigs_by_tool["genomad"] = Dict{String, Set{String}}()
for gdf in DataFrames.groupby(joint_genomad_results, "sample_id")
    sample_id = gdf[1, "sample_id"]
    viral_contigs_by_tool["genomad"][sample_id] = Set()
    for row in DataFrames.eachrow(gdf)
        push!(viral_contigs_by_tool["genomad"][sample_id], row["seq_name"])
    end
end
viral_contigs_by_tool["genomad"]

In [None]:
blast_task = "megablast"

In [None]:
db = "nt_viruses"

In [None]:
# load in metadata
metadata_dir = joinpath(dirname(pwd()), "metadata")

exposome_environmental_data = DataFrames.DataFrame(uCSV.read(
    joinpath(metadata_dir, "metadata_exposome.rds.tsv"),
    delim='\t',
    header=1,
    typedetectrows=300
))

joint_sample_metadata = DataFrames.DataFrame(uCSV.read(
    joinpath(metadata_dir, "exposome/joint_sample_metadata.tsv"),
    delim='\t',
    header=1,
    typedetectrows=300
))

@assert joint_sample_metadata[!, "Library Name"] == joint_sample_metadata[!, "LibraryName"]

joint_metadata = DataFrames.innerjoin(
    joint_sample_metadata,
    exposome_environmental_data,
    on="Library Name" => "samplenames")

run_ids = sort(joint_metadata[!, "Run"])

sample_paths = joinpath.(data_dir, "SRA", run_ids)

In [None]:
# NCBI host metadata
ncbi_metadata_file = joinpath(dirname(pwd()), "metadata", "NCBI-virus-refseq.transformed.tsv")
ncbi_host_metadata = DataFrames.DataFrame(uCSV.read(ncbi_metadata_file, header=1, delim='\t', encodings=Dict("false" => false, "true" => true)))

# ICTV host metadata
ictv_metadata_file = joinpath(dirname(pwd()), "metadata", "VMR_MSL38_v1 - VMR MSL38 v1.transformed.tsv")
ictv_host_metadata = DataFrames.DataFrame(uCSV.read(ictv_metadata_file, header=1, delim='\t', typedetectrows=100))
ictv_host_metadata = ictv_host_metadata[.!isempty.(ictv_host_metadata[!, "taxid"]), :]
ictv_host_metadata[!, "taxid"] = parse.(Int, ictv_host_metadata[!, "taxid"])

viral_tax_ids = Set(Mycelia.list_subtaxa(10239))

In [None]:
joint_top_hits = DataFrames.DataFrame()
ProgressMeter.@showprogress for sample_path in sample_paths[1:end]
    sample = basename(sample_path)
    blastn_directory = mkpath(joinpath(sample_path, "blastn"))
    assembled_fasta = joinpath(sample_path, "megahit", "final.contigs.fastg.gfa.fna")
    blast_file = joinpath(blastn_directory, basename(assembled_fasta) * ".blastn.$(db).$(blast_task).txt")
    this_blast_table = Mycelia.parse_blast_report(blast_file)
    if isempty(this_blast_table)
        continue
    else
        this_blast_table[!, "sample_id"] .= sample
        # bonferonni correction on raw tests
        this_blast_table[!, "evalue"] = this_blast_table[!, "evalue"] .* DataFrames.nrow(this_blast_table)
        
        # filter to top hits to avoid ballooning memory just to throw it away later
        this_top_hits = DataFrames.DataFrame()
        for gdf in DataFrames.groupby(this_blast_table, "query id")
            push!(this_top_hits, first(sort(gdf, "bit score", rev=true)))
        end
        append!(joint_top_hits, this_top_hits)
    end
end

In [None]:
taxids = unique(joint_top_hits[!, "subject tax id"])
taxid2name_map = Dict(row["taxid"] => row["tax_name"] for row in DataFrames.eachrow(Mycelia.taxids2lineage_name_and_rank(taxids)))
joint_top_hits[!, "subject tax name"] = map(taxid -> taxid2name_map[taxid], joint_top_hits[!, "subject tax id"])

# filter to good hits even after bonferroni correction
joint_top_hits = joint_top_hits[joint_top_hits[!, "evalue"] .<= 0.001, :]

# filter to viral only
viral_hits_df = joint_top_hits[map(x -> x in viral_tax_ids, joint_top_hits[!, "subject tax id"]), :]

# # current_host = "host_is_vertebrate"
# # current_host = "host_is_mammal"
# # current_host = "host_is_primate"
# current_host = "host_is_human"
# host_viral_tax_ids = Set(ncbi_host_metadata[ncbi_host_metadata[!, current_host] .== true, "taxid"])

# # host_viral_taxids = ictv_host_metadata[map(x -> x in ["vertebrates", "invertebrates, vertebrates"], ictv_host_metadata[!, "Host source"]), "taxid"]
# host_hits_df = viral_hits_df[map(x -> x in host_viral_tax_ids, viral_hits_df[!, "subject tax id"]), :]

# blast_viral_contigs = Dict{String, Set{String}}()
# for gdf in DataFrames.groupby(viral_hits_df, "sample_id")
#     sample_id = gdf[1, "sample_id"]
#     blast_viral_contigs[sample_id] = Set()
#     for row in DataFrames.eachrow(gdf)
#         push!(blast_viral_contigs[sample_id], row["query id"])
#     end
# end
# blast_viral_contigs

viral_contigs_by_tool["blast"] = Dict{String, Set{String}}()
for gdf in DataFrames.groupby(viral_hits_df, "sample_id")
    sample_id = gdf[1, "sample_id"]
    viral_contigs_by_tool["blast"][sample_id] = Set()
    for row in DataFrames.eachrow(gdf)
        push!(viral_contigs_by_tool["blast"][sample_id], row["query id"])
    end
end
viral_contigs_by_tool["blast"]

In [None]:
uniref50_viral_df

In [None]:
db = "UniRef50"
uniref50_df = DataFrames.DataFrame()
ProgressMeter.@showprogress for sample_path in sample_paths
    sample_id = basename(sample_path)
    mmseqs_lca_file = joinpath(sample_path, "mmseqs_easy_taxonomy", "final.contigs.fastg.gfa.fna.mmseqs_easy_taxonomy.$(db)_lca.tsv")
    mmseqs_lca_table = Mycelia.parse_mmseqs_easy_taxonomy_lca_tsv(mmseqs_lca_file)
    mmseqs_lca_table[!, "sample_id"] .= sample_id
    append!(uniref50_df, mmseqs_lca_table)
end
uniref50_df
uniref50_viral_df = uniref50_df[map(x -> x in viral_tax_ids, uniref50_df[!, "taxon_id"]), :]

viral_contigs_by_tool[db] = Dict{String, Set{String}}()
for gdf in DataFrames.groupby(uniref50_viral_df, "sample_id")
    sample_id = gdf[1, "sample_id"]
    viral_contigs_by_tool[db][sample_id] = Set()
    for row in DataFrames.eachrow(gdf)
        push!(viral_contigs_by_tool[db][sample_id], string(row["contig_id"]))
    end
end
viral_contigs_by_tool[db]

In [None]:
db = "UniRef90"
uniref90_df = DataFrames.DataFrame()

ProgressMeter.@showprogress for sample_path in sample_paths
    sample_id = basename(sample_path)
    mmseqs_lca_file = joinpath(sample_path, "mmseqs_easy_taxonomy", "final.contigs.fastg.gfa.fna.mmseqs_easy_taxonomy.$(db)_lca.tsv")
    mmseqs_lca_table = Mycelia.parse_mmseqs_easy_taxonomy_lca_tsv(mmseqs_lca_file)
    mmseqs_lca_table[!, "sample_id"] .= sample_id
    append!(uniref90_df, mmseqs_lca_table)
end
uniref90_df
uniref90_viral_df = uniref90_df[map(x -> x in viral_tax_ids, uniref90_df[!, "taxon_id"]), :]

viral_contigs_by_tool[db] = Dict{String, Set{String}}()
for gdf in DataFrames.groupby(uniref90_viral_df, "sample_id")
    sample_id = gdf[1, "sample_id"]
    viral_contigs_by_tool[db][sample_id] = Set()
    for row in DataFrames.eachrow(gdf)
        push!(viral_contigs_by_tool[db][sample_id], string(row["contig_id"]))
    end
end
viral_contigs_by_tool[db]

In [None]:
db = "UniRef100"
uniref100_df = DataFrames.DataFrame()

ProgressMeter.@showprogress for sample_path in sample_paths
    sample_id = basename(sample_path)
    mmseqs_lca_file = joinpath(sample_path, "mmseqs_easy_taxonomy", "final.contigs.fastg.gfa.fna.mmseqs_easy_taxonomy.$(db)_lca.tsv")
    mmseqs_lca_table = Mycelia.parse_mmseqs_easy_taxonomy_lca_tsv(mmseqs_lca_file)
    mmseqs_lca_table[!, "sample_id"] .= sample_id
    append!(uniref100_df, mmseqs_lca_table)
end
uniref100_df
uniref100_viral_df = uniref100_df[map(x -> x in viral_tax_ids, uniref100_df[!, "taxon_id"]), :]

viral_contigs_by_tool[db] = Dict{String, Set{String}}()
for gdf in DataFrames.groupby(uniref100_viral_df, "sample_id")
    sample_id = gdf[1, "sample_id"]
    viral_contigs_by_tool[db][sample_id] = Set()
    for row in DataFrames.eachrow(gdf)
        push!(viral_contigs_by_tool[db][sample_id], string(row["contig_id"]))
    end
end
viral_contigs_by_tool[db]

In [None]:
viral_contigs_by_tool

In [None]:
Distances.jaccard(Set([1, 2, 3]), Set([1, 3, 4]))

In [None]:
ordered_tools = 
["genomad",
"blast",
"UniRef50",
"UniRef90",
"UniRef100"]

ordered_samples = basename.(sample_paths)

In [None]:
jaccard_similarity(a, b) = length(intersect(a, b)) / length(union(a, b))

In [None]:
jaccard_similarity_matrix = Array{Union{Missing, Float64}}(missing, length(ordered_tools), length(ordered_tools), length(ordered_samples))
for (x, tool1) in enumerate(ordered_tools)
    for (y, tool2) in enumerate(ordered_tools)
        for (z, sample_id) in enumerate(ordered_samples)
            try
                a = viral_contigs_by_tool[tool1][sample_id]
                b = viral_contigs_by_tool[tool2][sample_id]
                jaccard_similarity_matrix[x, y, z] = jaccard_similarity(a, b)
            catch
                # @show tool1, tool2, sample_id
                continue
            end
        end
    end
end
jaccard_similarity_matrix

In [None]:
jaccard_similarity_matrix_means = zeros(length(ordered_tools), length(ordered_tools))
jaccard_similarity_matrix_medians = zeros(length(ordered_tools), length(ordered_tools))
jaccard_similarity_matrix_stddevs = zeros(length(ordered_tools), length(ordered_tools))

for x in 1:size(jaccard_similarity_matrix, 1)
    for y in 1:size(jaccard_similarity_matrix, 1)
        non_missing_values = filter(!ismissing, jaccard_similarity_matrix[x, y, :])
        jaccard_similarity_matrix_means[x, y] = Statistics.mean(non_missing_values)
        jaccard_similarity_matrix_medians[x, y] = Statistics.median(non_missing_values)
        jaccard_similarity_matrix_stddevs[x, y] = Statistics.std(non_missing_values)
    end
end

In [None]:
hits_by_tool_per_sample = Dict{String, Vector{Int}}()
for tool in ordered_tools
    hits_by_tool_per_sample[tool] = Int[length(viral_contigs_by_tool[tool][sample]) for sample in keys(viral_contigs_by_tool[tool])]
end     

In [None]:
hits_by_tool_per_sample

In [None]:
1920/2

In [None]:
StatsPlots.plot(
    ordered_tools,
    [Statistics.mean(hits_by_tool_per_sample[tool]) for tool in ordered_tools],
    yerror = [Statistics.std(hits_by_tool_per_sample[tool]) for tool in ordered_tools],
    legend=false,
    title = "mean +/- stddev # viral contigs classified by tool\nacross samples",
    ylabel = "value",
    xlabel = "tool",
    size = (960, 540),
    margins = 5StatsPlots.Plots.PlotMeasures.mm
)

In [None]:
hits_by_tool_per_sample

StatsPlots.heatmap(
    jaccard_similarity_matrix_means,
    yticks = (1:length(ordered_tools), ordered_tools),
    xticks = (1:length(ordered_tools), ordered_tools),
    title = "Mean Jaccard Similarity",
    xlabel = "tool/database",
    ylabel = "tool/database",
    size = (960, 540),
    margins = 5StatsPlots.Plots.PlotMeasures.mm
)


In [None]:
StatsPlots.heatmap(
    jaccard_similarity_matrix_means,
    yticks = (1:length(ordered_tools), ordered_tools),
    xticks = (1:length(ordered_tools), ordered_tools),
    title = "Mean Jaccard Similarity",
    xlabel = "tool/database",
    ylabel = "tool/database",
    size = (960, 540),
    margins = 5StatsPlots.Plots.PlotMeasures.mm
)

In [None]:
StatsPlots.heatmap(
    jaccard_similarity_matrix_stddevs,
    yticks = (1:length(ordered_tools), ordered_tools),
    xticks = (1:length(ordered_tools), ordered_tools),
    title = "Standard Deviation of Jaccard Similarity",
    xlabel = "tool/database",
    ylabel = "tool/database",
    size = (960, 540),
    margins = 5StatsPlots.Plots.PlotMeasures.mm
)

In [None]:
unanimous_hits = Dict{String, Set{String}}()
for sample in ordered_samples
    unanimous_hits[sample] = Set(viral_contigs_by_tool["genomad"][sample])
    for other_tool in setdiff(ordered_tools, "genomad")
        # @show other_tool
        unanimous_hits[sample] = intersect(unanimous_hits[sample], viral_contigs_by_tool["genomad"][sample])
    end
end
unanimous_hits

In [None]:
unanimous_blast_hits = viral_hits_df[[row["query id"] in unanimous_hits[row["sample_id"]] for row in DataFrames.eachrow(viral_hits_df)], :]
current_host = "host_is_human"
host_viral_tax_ids = Set(ncbi_host_metadata[ncbi_host_metadata[!, current_host] .== true, "taxid"])
unanimous_blast_host_hits = unanimous_blast_hits[map(x -> x in host_viral_tax_ids, unanimous_blast_hits[!, "subject tax id"]), :]

In [None]:
ft = FreqTables.freqtable(unanimous_blast_host_hits, "sample_id", "subject tax name")
data = [collect(keys(ft.dicts[1])), [col for col in eachcol(ft.array)]...]
header = ["$(ft.dimnames[1]) \\ $(ft.dimnames[2])", collect(keys(ft.dicts[2]))...]
summary_table = DataFrames.DataFrame(data, header)
summary_table = DataFrames.innerjoin(summary_table, joint_metadata[!, ["Run", "aownership", "geo_loc_name", "date.end"]], on="sample_id \\ subject tax name" => "Run")
summary_table = summary_table[!, [
    ["sample_id \\ subject tax name", "aownership", "geo_loc_name", "date.end"]...,
    setdiff(names(summary_table), ["sample_id \\ subject tax name", "aownership", "geo_loc_name", "date.end"])...]]
DataFrames.rename!(
    summary_table,
    ["aownership" => "participant", "geo_loc_name" => "location", "date.end" => "collection date"]
)

In [None]:
unanimous_blast_host_hits

In [None]:
# ProgressMeter.@showprogress for sample_path in sample_paths[1:end]





# qualimap_coverage_table = parse_qualimap_contig_coverage(joinpath(SRR_path, "megahit", "qualimap", "genome_results.txt"))
# mmseqs_lca_files = filter(x -> occursin("_lca.tsv", x) && occursin("final.contigs.fastg.gfa.fna.mmseqs_easy_taxonomy", x), readdir(joinpath(SRR_path, "mmseqs_easy_taxonomy"), join=true))

# # mmseqs_lca_file = first(mmseqs_lca_files)
# for mmseqs_lca_file in mmseqs_lca_files

#     parse_mmseqs_easy_taxonomy_lca_tsv(mmseqs_lca_file)
#     lca_table = parse_mmseqs_easy_taxonomy_lca_tsv(mmseqs_lca_file)


# sample = basename(sample_path)
# blastn_directory = mkpath(joinpath(sample_path, "blastn"))
# assembled_fasta = joinpath(sample_path, "megahit", "final.contigs.fastg.gfa.fna")
# blast_file = joinpath(blastn_directory, basename(assembled_fasta) * ".blastn.$(db).$(blast_task).txt")
# this_blast_table = Mycelia.parse_blast_report(blast_file)

In [None]:
# results_dir = joinpath(data_dir, "results")
# # readdir(results_dir)

# uCSV.write(joinpath(results_dir, "blast_hits_summary_table.csv"), summary_table)

In [None]:
# results_dir

In [None]:
# m = "text/plain"
# m = "text/html"
# m =  "text/latex"
# m = "text/csv"
# m = "text/tab-separated-values"

# show(stdout, MIME(m), summary_table)

In [None]:
# # show(stdout, MIME("text/html"), )
# # PrettyTables.pretty_table(summary_table, backend = Val(:markdown))
# # PrettyTables.pretty_table(summary_table, backend = Val(:latex))
# # PrettyTables.pretty_table(summary_table, backend = Val(:html))
# PrettyTables.pretty_table(summary_table, backend = Val(:text))

In [None]:
# show(stdout, "text/plain", matrix)