In [None]:
# run(`conda create --channel conda-forge --channel bioconda --channel defaults --strict-channel-priority --name blast blast`)
# run(`conda create --channel conda-forge --channel bioconda --channel defaults --strict-channel-priority --name taxonkit taxonkit`)
# run(`conda create --channel conda-forge --channel bioconda --channel defaults --strict-channel-priority --name minimap2 minimap2`)

In [None]:
# don't try and install plotting libraries without this
# can set in ~/.local/share/jupyter/kernels/
@assert ENV["LD_LIBRARY_PATH"] == ""
import Pkg
pkgs = [
    "Revise",
    "DataFrames",
    "uCSV",
    "StatsPlots",
    "StatsBase",
    "FreqTables",
    "Conda",
    "ProgressMeter",
    "PrettyTables",
    "Distances",
    "Statistics",
    "Kmers",
    "Colors",
    "FASTX"
]
# Pkg.add(pkgs)
for pkg in pkgs
    eval(Meta.parse("import $pkg"))
end
import Mycelia

In [None]:
base_dir = dirname(pwd())
data_dir = joinpath(base_dir, "data")
results_dir = joinpath(data_dir, "results")

# load in metadata
metadata_dir = joinpath(dirname(pwd()), "metadata")

exposome_environmental_data = DataFrames.DataFrame(uCSV.read(
    joinpath(metadata_dir, "metadata_exposome.rds.tsv"),
    delim='\t',
    header=1,
    typedetectrows=300
))

joint_sample_metadata = DataFrames.DataFrame(uCSV.read(
    joinpath(metadata_dir, "exposome/joint_sample_metadata.tsv"),
    delim='\t',
    header=1,
    typedetectrows=300
))

@assert joint_sample_metadata[!, "Library Name"] == joint_sample_metadata[!, "LibraryName"]

joint_metadata = DataFrames.innerjoin(
    joint_sample_metadata,
    exposome_environmental_data,
    on="Library Name" => "samplenames")

sample_directories = joinpath.(data_dir, "SRA", joint_metadata[!, "Run"])

In [None]:
viral_contigs_by_tool = Dict{String, Dict{String, Set{String}}}()

In [None]:
joint_genomad_results = DataFrames.DataFrame()
# sample_directory = first(sample_directories)
ProgressMeter.@showprogress for sample_directory in sample_directories
    genomad_virus_summary = joinpath(sample_directory, "genomad", "final.contigs.fastg.gfa_summary", "final.contigs.fastg.gfa_virus_summary.tsv")
    genomad_results = DataFrames.DataFrame(uCSV.read(genomad_virus_summary, delim='\t', header=1, typedetectrows=100)...)
    genomad_results[!, "sample_id"] .= basename(sample_directory)
    append!(joint_genomad_results, genomad_results, promote=true)
end
joint_genomad_results[!, "seq_name"] = string.(joint_genomad_results[!, "seq_name"])
joint_genomad_results

In [None]:
viral_contigs_by_tool["genomad"] = Dict{String, Set{String}}()
for gdf in DataFrames.groupby(joint_genomad_results, "sample_id")
    sample_id = gdf[1, "sample_id"]
    viral_contigs_by_tool["genomad"][sample_id] = Set()
    for row in DataFrames.eachrow(gdf)
        push!(viral_contigs_by_tool["genomad"][sample_id], row["seq_name"])
    end
end
viral_contigs_by_tool["genomad"]

In [None]:
blast_task = "megablast"

In [None]:
db = "nt_viruses"

In [None]:
# load in metadata
metadata_dir = joinpath(dirname(pwd()), "metadata")

exposome_environmental_data = DataFrames.DataFrame(uCSV.read(
    joinpath(metadata_dir, "metadata_exposome.rds.tsv"),
    delim='\t',
    header=1,
    typedetectrows=300
))

joint_sample_metadata = DataFrames.DataFrame(uCSV.read(
    joinpath(metadata_dir, "exposome/joint_sample_metadata.tsv"),
    delim='\t',
    header=1,
    typedetectrows=300
))

@assert joint_sample_metadata[!, "Library Name"] == joint_sample_metadata[!, "LibraryName"]

joint_metadata = DataFrames.innerjoin(
    joint_sample_metadata,
    exposome_environmental_data,
    on="Library Name" => "samplenames")

run_ids = sort(joint_metadata[!, "Run"])

sample_paths = joinpath.(data_dir, "SRA", run_ids)

In [None]:
# NCBI host metadata
ncbi_metadata_file = joinpath(dirname(pwd()), "metadata", "NCBI-virus-refseq.transformed.tsv")
ncbi_host_metadata = DataFrames.DataFrame(uCSV.read(ncbi_metadata_file, header=1, delim='\t', encodings=Dict("false" => false, "true" => true)))

# ICTV host metadata
ictv_metadata_file = joinpath(dirname(pwd()), "metadata", "VMR_MSL38_v1 - VMR MSL38 v1.transformed.tsv")
ictv_host_metadata = DataFrames.DataFrame(uCSV.read(ictv_metadata_file, header=1, delim='\t', typedetectrows=100))
ictv_host_metadata = ictv_host_metadata[.!isempty.(ictv_host_metadata[!, "taxid"]), :]
ictv_host_metadata[!, "taxid"] = parse.(Int, ictv_host_metadata[!, "taxid"])

viral_tax_ids = Set(Mycelia.list_subtaxa(10239))

In [None]:
joint_top_hits = DataFrames.DataFrame()
ProgressMeter.@showprogress for sample_path in sample_paths[1:end]
    sample = basename(sample_path)
    blastn_directory = mkpath(joinpath(sample_path, "blastn"))
    assembled_fasta = joinpath(sample_path, "megahit", "final.contigs.fastg.gfa.fna")
    blast_file = joinpath(blastn_directory, basename(assembled_fasta) * ".blastn.$(db).$(blast_task).txt")
    this_blast_table = Mycelia.parse_blast_report(blast_file)
    if isempty(this_blast_table)
        continue
    else
        this_blast_table[!, "sample_id"] .= sample
        # bonferonni correction on raw tests
        this_blast_table[!, "evalue"] = this_blast_table[!, "evalue"] .* DataFrames.nrow(this_blast_table)
        
        # filter to top hits to avoid ballooning memory just to throw it away later
        this_top_hits = DataFrames.DataFrame()
        for gdf in DataFrames.groupby(this_blast_table, "query id")
            push!(this_top_hits, first(sort(gdf, "bit score", rev=true)))
        end
        append!(joint_top_hits, this_top_hits)
    end
end

In [None]:
taxids = unique(joint_top_hits[!, "subject tax id"])
taxid2name_map = Dict(row["taxid"] => row["tax_name"] for row in DataFrames.eachrow(Mycelia.taxids2lineage_name_and_rank(taxids)))
joint_top_hits[!, "subject tax name"] = map(taxid -> taxid2name_map[taxid], joint_top_hits[!, "subject tax id"])

# filter to good hits even after bonferroni correction
joint_top_hits = joint_top_hits[joint_top_hits[!, "evalue"] .<= 0.001, :]

# filter to viral only
viral_hits_df = joint_top_hits[map(x -> x in viral_tax_ids, joint_top_hits[!, "subject tax id"]), :]

viral_contigs_by_tool["blast"] = Dict{String, Set{String}}()
for gdf in DataFrames.groupby(viral_hits_df, "sample_id")
    sample_id = gdf[1, "sample_id"]
    viral_contigs_by_tool["blast"][sample_id] = Set()
    for row in DataFrames.eachrow(gdf)
        push!(viral_contigs_by_tool["blast"][sample_id], row["query id"])
    end
end
viral_contigs_by_tool["blast"]

In [None]:
db = "UniRef50"
uniref50_df = DataFrames.DataFrame()
ProgressMeter.@showprogress for sample_path in sample_paths
    sample_id = basename(sample_path)
    mmseqs_lca_file = joinpath(sample_path, "mmseqs_easy_taxonomy", "final.contigs.fastg.gfa.fna.mmseqs_easy_taxonomy.$(db)_lca.tsv")
    mmseqs_lca_table = Mycelia.parse_mmseqs_easy_taxonomy_lca_tsv(mmseqs_lca_file)
    mmseqs_lca_table[!, "sample_id"] .= sample_id
    append!(uniref50_df, mmseqs_lca_table)
end
uniref50_df
uniref50_viral_df = uniref50_df[map(x -> x in viral_tax_ids, uniref50_df[!, "taxon_id"]), :]

viral_contigs_by_tool[db] = Dict{String, Set{String}}()
for gdf in DataFrames.groupby(uniref50_viral_df, "sample_id")
    sample_id = gdf[1, "sample_id"]
    viral_contigs_by_tool[db][sample_id] = Set()
    for row in DataFrames.eachrow(gdf)
        push!(viral_contigs_by_tool[db][sample_id], string(row["contig_id"]))
    end
end
viral_contigs_by_tool[db]

In [None]:
db = "UniRef90"
uniref90_df = DataFrames.DataFrame()

ProgressMeter.@showprogress for sample_path in sample_paths
    sample_id = basename(sample_path)
    mmseqs_lca_file = joinpath(sample_path, "mmseqs_easy_taxonomy", "final.contigs.fastg.gfa.fna.mmseqs_easy_taxonomy.$(db)_lca.tsv")
    mmseqs_lca_table = Mycelia.parse_mmseqs_easy_taxonomy_lca_tsv(mmseqs_lca_file)
    mmseqs_lca_table[!, "sample_id"] .= sample_id
    append!(uniref90_df, mmseqs_lca_table)
end
uniref90_df
uniref90_viral_df = uniref90_df[map(x -> x in viral_tax_ids, uniref90_df[!, "taxon_id"]), :]

viral_contigs_by_tool[db] = Dict{String, Set{String}}()
for gdf in DataFrames.groupby(uniref90_viral_df, "sample_id")
    sample_id = gdf[1, "sample_id"]
    viral_contigs_by_tool[db][sample_id] = Set()
    for row in DataFrames.eachrow(gdf)
        push!(viral_contigs_by_tool[db][sample_id], string(row["contig_id"]))
    end
end
viral_contigs_by_tool[db]

In [None]:
db = "UniRef100"
uniref100_df = DataFrames.DataFrame()

ProgressMeter.@showprogress for sample_path in sample_paths
    sample_id = basename(sample_path)
    mmseqs_lca_file = joinpath(sample_path, "mmseqs_easy_taxonomy", "final.contigs.fastg.gfa.fna.mmseqs_easy_taxonomy.$(db)_lca.tsv")
    mmseqs_lca_table = Mycelia.parse_mmseqs_easy_taxonomy_lca_tsv(mmseqs_lca_file)
    mmseqs_lca_table[!, "sample_id"] .= sample_id
    append!(uniref100_df, mmseqs_lca_table)
end
uniref100_df
uniref100_viral_df = uniref100_df[map(x -> x in viral_tax_ids, uniref100_df[!, "taxon_id"]), :]

viral_contigs_by_tool[db] = Dict{String, Set{String}}()
for gdf in DataFrames.groupby(uniref100_viral_df, "sample_id")
    sample_id = gdf[1, "sample_id"]
    viral_contigs_by_tool[db][sample_id] = Set()
    for row in DataFrames.eachrow(gdf)
        push!(viral_contigs_by_tool[db][sample_id], string(row["contig_id"]))
    end
end
viral_contigs_by_tool[db]

In [None]:
viral_contigs_by_tool

In [None]:
ordered_tools = 
["genomad",
"blast",
"UniRef50",
"UniRef90",
"UniRef100"]

ordered_samples = basename.(sample_paths)

In [None]:
unanimous_hits = Dict{String, Set{String}}()
for sample in ordered_samples
    unanimous_hits[sample] = Set(viral_contigs_by_tool["genomad"][sample])
    for other_tool in setdiff(ordered_tools, "genomad")
        # @show other_tool
        unanimous_hits[sample] = intersect(unanimous_hits[sample], viral_contigs_by_tool["genomad"][sample])
    end
end
unanimous_hits

In [None]:
# extract contigs into a single fasta file
fasta_records = []
ProgressMeter.@showprogress for (sample_id, contig_ids) in unanimous_hits
    sample_path = joinpath(data_dir, "SRA", sample_id)
    fasta_path = joinpath(sample_path, "megahit", "final.contigs.fastg.gfa.fna")
    open(fasta_path) do io
        fastx_io = FASTX.FASTA.Reader(io)
        for record in fastx_io
            if FASTX.identifier(record) in contig_ids
                push!(fasta_records, FASTX.FASTA.Record(sample_id * "__" * FASTX.identifier(record), FASTX.sequence(record)))
            end
        end
        close(fastx_io)
    end
end
fasta_records

In [None]:
fasta_records = sort(fasta_records, by=x->length(FASTX.sequence(x)), rev=true)

In [None]:
p = StatsPlots.histogram(
    length.(FASTX.sequence.(fasta_records)),
    bins=100,
    label=missing,
    xlabel = "contig length",
    ylabel = "# of contigs",
    title = "high confidence viral contigs"
)

In [None]:
# StatsBase.describe(length.(FASTX.sequence.(fasta_records)))
# # long_records = filter(x -> length(FASTX.sequence(x)) >= 10000, fasta_records)

# p = StatsPlots.histogram(
#     length.(FASTX.sequence.(long_records)),
#     bins=100,
#     label=missing,
#     xlabel = "contig length",
#     ylabel = "# of contigs",
#     title = "high confidence viral contigs"
# )

In [None]:
# StatsPlots.savefig(p, joinpath(results_dir, "high-confidence-viral-contigs.fna.2k-filtered.png"))

In [None]:
# filtered_fasta_records = filter(x -> length(FASTX.sequence(x)) >= 2000, fasta_records)
filtered_fasta_records = filter(x -> length(FASTX.sequence(x)) >= 5000, fasta_records)
# filtered_fasta_records = filter(x -> length(FASTX.sequence(x)) >= 10000, fasta_records)
high_confidence_viral_fasta = joinpath(results_dir, "high-confidence-viral-contigs.fna")
open(high_confidence_viral_fasta, "w") do io
    fastx_io = FASTX.FASTA.Writer(io)
    for record in filtered_fasta_records
        write(fastx_io, record)
    end
    close(fastx_io)
end

bgzipped_high_confidence_viral_fasta = high_confidence_viral_fasta * ".gz"
if isfile(bgzipped_high_confidence_viral_fasta)
    rm(bgzipped_high_confidence_viral_fasta)
end
run(`conda run --live-stream -n samtools bgzip $(high_confidence_viral_fasta)`)
run(`conda run --live-stream -n samtools samtools faidx $(bgzipped_high_confidence_viral_fasta)`)
# haplotypes = 2
haplotypes = Int(ceil(sqrt(length(filtered_fasta_records))))
# run(`conda run --live-stream -n pggb pggb -i $(bgzipped_caudovirales_fasta) -o $(pggb_outdir) -t 4 -n $(haplotypes) -p 70 -s 100 -l 300`)
pggb_outdir = joinpath(results_dir, "pggb_high_confidence_viral")
run(`conda run --live-stream -n pggb pggb -i $(bgzipped_high_confidence_viral_fasta) -o $(pggb_outdir) -t 4 -n $(haplotypes)`)

In [None]:
# pggb_sensitive_outdir = joinpath(results_dir, "pggb_high_confidence_viral_sensitive")
# # 1/10 the default segment length
# # 20% less stringent than default 90% identity requirement
# run(`conda run --live-stream -n pggb pggb -i $(bgzipped_high_confidence_viral_fasta) -o $(pggb_sensitive_outdir) -t 4 -n $(haplotypes) -p 70 -s 500`)