In [None]:
# if hit plotting library issues, try resetting LD path for julia
# can set in ~/.local/share/jupyter/kernels/
haskey(ENV, "LD_LIBRARY_PATH") && @assert ENV["LD_LIBRARY_PATH"] == ""
import Pkg
pkgs = [
    "Revise",
    "FASTX",
    "ProgressMeter",
    "DataFrames",
    "StatsBase",
    "StatsPlots",
    "OrderedCollections",
    "ColorSchemes",
    "uCSV"
]
Pkg.add(pkgs)
for pkg in pkgs
    eval(Meta.parse("import $pkg"))
end
# Pkg.develop(path="/global/cfs/projectdirs/m4269/cjprybol/Mycelia")
# Pkg.develop(path="../../..")
import Mycelia

In [None]:
PROJECT_BASEDIR = dirname(pwd())
data_dir = joinpath(PROJECT_BASEDIR, "data")
genome_dir = mkpath(joinpath(data_dir, "genomes"))
results_dir = joinpath(PROJECT_BASEDIR, "results")

In [None]:
original_fastas = sort(filter(x -> occursin(r"\.fna$", x) && !occursin("normalized", x), readdir(genome_dir, join=true)), by=x->filesize(x))

reference_variant_fastas = original_fastas .* ".normalized.vcf.fna"

@assert all(isfile.(reference_variant_fastas))

fastqs = filter(x -> occursin(r"\.fq\.gz$", x), readdir(genome_dir, join=true))

forward_fastqs = filter(x -> occursin("1_val_1.fq.gz", x), fastqs)

reverse_fastqs = filter(x -> occursin("2_val_2.fq.gz", x), fastqs)

long_read_fastqs = filter(x -> occursin("filtlong.fq.gz", x), fastqs)

assembly_table = DataFrames.DataFrame(
    original_fasta = String[],
    reference_variant_fasta = String[],
    coverage = String[],
    fastqs = Vector{String}[],
    assembler = String[],
    assembled_variant_fasta = String[]
)

assemblers = [
    "megahit",
    "spades_isolate",
    "flye",
    "raven",
    "hifiasm",
    "mycelia-lr",
    "mycelia-sr"
]
coverages = ["10x", "100x", "1000x"]

for (original_fasta, reference_variant_fasta) in zip(original_fastas, reference_variant_fastas)
    reference_variant_fasta_matches = filter(x -> occursin(reference_variant_fasta, x), readdir(genome_dir, join=true))
    for coverage in coverages
        coverage_matches = filter(x -> occursin(coverage, x), reference_variant_fasta_matches)
        for assembler in assemblers
            # @show coverage_matches
            if assembler == "mycelia-lr"
                assembly_matches = filter(x -> occursin(r"\.fna$", x) && occursin("badread", x) && !occursin("joint", x), coverage_matches)
                # @show assembly_matches
                @assert 0 <= length(assembly_matches) <= 1
                if isempty(assembly_matches)
                    continue
                else
                    assembled_variant_fasta = first(assembly_matches)
                    matching_fastqs = filter(x -> occursin(coverage, x) && occursin(reference_variant_fasta, x), long_read_fastqs)
                    # @show assembled_variant_fasta
                    # @show matching_fastqs
                end
            elseif assembler == "mycelia-sr"
                assembly_matches = filter(x -> occursin(r"\.fna$", x) && occursin("art", x) && !occursin(r"\.joint\.fna$", x), coverage_matches)
                if isempty(assembly_matches)
                    continue
                else
                    assembled_variant_fasta = first(assembly_matches)
                    matching_fastqs = filter(x -> occursin(coverage, x) && occursin(reference_variant_fasta, x), vcat(forward_fastqs, reverse_fastqs))
                    # @show assembled_variant_fasta
                    # @show matching_fastqs
                end
            else
                directories = filter(x -> isdir(x) && occursin(assembler, x), coverage_matches)
                if !isempty(directories)
                    # directory = missing
                    @assert length(directories) == 1 "$(reference_variant_fasta) $(assembler) $(coverage)"
                    directory = first(directories)
                else
                    continue
                end
                if assembler == "megahit"
                    assembled_variant_fasta = joinpath(directory, "final.contigs.fa")
                    matching_fastqs = filter(x -> occursin(coverage, x) && occursin(reference_variant_fasta, x), vcat(forward_fastqs, reverse_fastqs))
                elseif assembler == "spades_isolate"
                    assembled_variant_fasta = joinpath(directory, "scaffolds.fasta")
                    matching_fastqs = filter(x -> occursin(coverage, x) && occursin(reference_variant_fasta, x), vcat(forward_fastqs, reverse_fastqs))
                elseif assembler == "flye"
                    assembled_variant_fasta = joinpath(directory, "assembly.fasta")
                    matching_fastqs = filter(x -> occursin(coverage, x) && occursin(reference_variant_fasta, x), long_read_fastqs)
                elseif assembler == "raven"
                    assembled_variant_fasta = joinpath(directory, "assembly.fasta")
                    matching_fastqs = filter(x -> occursin(coverage, x) && occursin(reference_variant_fasta, x), long_read_fastqs)
                elseif assembler == "hifiasm"
                    directory_fastas = filter(x -> occursin(r"\.hifiasm\.p_ctg\.gfa\.fna$", x), readdir(directory))
                    # @show directory_fastas
                    @assert length(directory_fastas) <= 1
                    if length(directory_fastas) == 0
                        continue
                    else
                        assembled_variant_fasta = joinpath(directory, "assembly.fasta")
                        if !isfile(assembled_variant_fasta)
                            cp(joinpath(directory, first(directory_fastas)), assembled_variant_fasta)
                        end
                        matching_fastqs = filter(x -> occursin(coverage, x) && occursin(original_fasta, x), long_read_fastqs)
                    end
                end
            end
            if !isnothing(assembled_variant_fasta)
                push!(assembly_table, (original_fasta=original_fasta, reference_variant_fasta=reference_variant_fasta, assembler=assembler, coverage=coverage, fastqs=matching_fastqs, assembled_variant_fasta=assembled_variant_fasta))
            end
        end
    end
end
assembly_table

assembly_table = assembly_table[map(x -> isfile(x) && filesize(x) > 0, assembly_table[!, "assembled_variant_fasta"]), :]

# assembly_table_to_visualize[!, "reference_variant_calls"] = assembly_table_to_visualize[!, "reference_variant_calls"] .* ".gz"
assembly_table[!, "reference_variant_calls"] .= replace.(assembly_table[!, "reference_variant_fasta"], r"\.fna$" => ".gz")

mycelia_subset = assembly_table[map(x -> occursin("mycelia", x), assembly_table[!, "assembler"]), ["coverage", "fastqs", "original_fasta"]]

assembly_table = DataFrames.innerjoin(assembly_table, mycelia_subset, on=["coverage", "fastqs", "original_fasta"])

# assembly_table[!, "reference_assembly"] = replace.(assembly_table[!, "reference_assembly"], "$(genome_dir)/" => "")
# assembly_table[!, "fasta"] = replace.(assembly_table[!, "fasta"], "$(genome_dir)/" => "")
genome_coverages_to_visualize = [x[1] for x in filter(x -> x[2] > 1, StatsBase.countmap(collect(eachrow(Matrix(unique(assembly_table[!, ["original_fasta", "coverage", "fastqs"]])[!, ["original_fasta", "coverage"]])))))]
assembly_table_to_visualize = DataFrames.innerjoin(
    assembly_table,
    DataFrames.DataFrame(original_fasta = [x[1] for x in genome_coverages_to_visualize], coverage = [x[2] for x in genome_coverages_to_visualize]),
    on=["original_fasta", "coverage"]
)

assembly_qv_results = DataFrames.DataFrame()
for row in DataFrames.eachrow(assembly_table_to_visualize)
    assembly_result_table = Mycelia.assess_assembly_quality(assembly=row["assembled_variant_fasta"], observations=row["fastqs"])
    assembly_result_table[!, "assembled_variant_fasta"] .= row["assembled_variant_fasta"]
    append!(assembly_qv_results, assembly_result_table)
end
assembly_qv_results_to_visualize = DataFrames.innerjoin(assembly_table_to_visualize, assembly_qv_results, on="assembled_variant_fasta")
# max_qv_value = (ceil(maximum(filter(x -> x != Inf, assembly_qv_results_to_visualize[!, "qv"])) / 10) + 2) * 10
max_qv_value = 80
assembly_qv_results_to_visualize[!, "qv"] = map(x -> x == Inf ? max_qv_value : x, assembly_qv_results_to_visualize[!, "qv"])
assembly_qv_results_to_visualize

ks = sort(unique(assembly_qv_results_to_visualize[!, "k"]))
assemblers = sort(append!(unique(assembly_qv_results_to_visualize[!, "assembler"]), ["hifiasm", "raven"]))
reference_variant_fasta_coverage_groups = DataFrames.groupby(assembly_qv_results_to_visualize, ["reference_variant_fasta", "coverage"])

In [None]:
reference_variant_fasta_coverage_group = reference_variant_fasta_coverage_groups[1]
this_fasta = basename(reference_variant_fasta_coverage_group[1, "original_fasta"])
this_coverage = reference_variant_fasta_coverage_group[1, "coverage"]
# group_assembly_qv_scores = OrderedCollections.OrderedDict(assembler => Vector{Union{Float64, Missing}}(missing, length(ks)) for assembler in assemblers)
group_assembly_qv_scores = OrderedCollections.OrderedDict(assembler => zeros(length(ks)) for assembler in assemblers)
for assembler_group in DataFrames.groupby(reference_variant_fasta_coverage_group, "assembler")
    assembler = assembler_group[1, "assembler"]
    sag = sort(assembler_group)
    scores = Float64[]
    for (i, (row, k)) in enumerate(zip(DataFrames.eachrow(sag), ks))
        @assert row["k"] == k
        push!(scores, row["qv"])
    end
    group_assembly_qv_scores[assembler] = scores
end
p = StatsPlots.heatmap(
    reduce(hcat, [group_assembly_qv_scores[assembler] for assembler in assemblers]),
    xlabel = "Assembler",
    ylabel = "k-length",
    xticks = (1:length(assemblers), assemblers),
    yticks = (1:length(ks), ks),
    title = "Merqury QV values\n$(this_fasta) @ $(this_coverage)",
    colorbar_title = "Quality Score",
    size = (600, 400),
    clims=(0,80),
    topmargin=5StatsPlots.Plots.PlotMeasures.mm
)
display(p)
StatsPlots.savefig(p, joinpath(results_dir, "merqury-qv.$(this_fasta).$(this_coverage).svg"))

In [None]:
reference_variant_fasta_coverage_group = reference_variant_fasta_coverage_groups[2]
this_fasta = basename(reference_variant_fasta_coverage_group[1, "original_fasta"])
this_coverage = reference_variant_fasta_coverage_group[1, "coverage"]
# group_assembly_qv_scores = OrderedCollections.OrderedDict(assembler => Vector{Union{Float64, Missing}}(missing, length(ks)) for assembler in assemblers)
group_assembly_qv_scores = OrderedCollections.OrderedDict(assembler => zeros(length(ks)) for assembler in assemblers)
for assembler_group in DataFrames.groupby(reference_variant_fasta_coverage_group, "assembler")
    assembler = assembler_group[1, "assembler"]
    sag = sort(assembler_group)
    scores = Float64[]
    for (i, (row, k)) in enumerate(zip(DataFrames.eachrow(sag), ks))
        @assert row["k"] == k
        push!(scores, row["qv"])
    end
    group_assembly_qv_scores[assembler] = scores
end
p = StatsPlots.heatmap(
    reduce(hcat, [group_assembly_qv_scores[assembler] for assembler in assemblers]),
    xlabel = "Assembler",
    ylabel = "k-length",
    xticks = (1:length(assemblers), assemblers),
    yticks = (1:length(ks), ks),
    title = "Merqury QV values\n$(this_fasta) @ $(this_coverage)",
    colorbar_title = "Quality Score",
    size = (600, 400),
    clims=(0,80),
    topmargin=5StatsPlots.Plots.PlotMeasures.mm
)
display(p)
StatsPlots.savefig(p, joinpath(results_dir, "merqury-qv.$(this_fasta).$(this_coverage).svg"))

In [None]:
reference_variant_fasta_coverage_group = reference_variant_fasta_coverage_groups[4]
this_fasta = basename(reference_variant_fasta_coverage_group[1, "original_fasta"])
this_coverage = reference_variant_fasta_coverage_group[1, "coverage"]
# group_assembly_qv_scores = OrderedCollections.OrderedDict(assembler => Vector{Union{Float64, Missing}}(missing, length(ks)) for assembler in assemblers)
group_assembly_qv_scores = OrderedCollections.OrderedDict(assembler => zeros(length(ks)) for assembler in assemblers)
for assembler_group in DataFrames.groupby(reference_variant_fasta_coverage_group, "assembler")
    assembler = assembler_group[1, "assembler"]
    sag = sort(assembler_group)
    scores = Float64[]
    for (i, (row, k)) in enumerate(zip(DataFrames.eachrow(sag), ks))
        @assert row["k"] == k
        push!(scores, row["qv"])
    end
    group_assembly_qv_scores[assembler] = scores
end
p = StatsPlots.heatmap(
    reduce(hcat, [group_assembly_qv_scores[assembler] for assembler in assemblers]),
    xlabel = "Assembler",
    ylabel = "k-length",
    xticks = (1:length(assemblers), assemblers),
    yticks = (1:length(ks), ks),
    title = "Merqury QV values\n$(this_fasta) @ $(this_coverage)",
    colorbar_title = "Quality Score",
    size = (600, 400),
    clims=(0,80),
    topmargin=5StatsPlots.Plots.PlotMeasures.mm
)
display(p)
StatsPlots.savefig(p, joinpath(results_dir, "merqury-qv.$(this_fasta).$(this_coverage).svg"))

In [None]:
pggb_vcfs = String[]
for row in DataFrames.eachrow(assembly_table_to_visualize)
    genomes = [row["original_fasta"], row["assembled_variant_fasta"]]
    joint_fasta = row["assembled_variant_fasta"] * "." * "joint.fna"
    @show joint_fasta
    outdir = joint_fasta * "__PGGB"
    gfa_file = first(filter(x -> occursin(r"\.gfa", x), readdir(outdir, join=true)))
    final_vcf = gfa_file * ".vcf"
    push!(pggb_vcfs, final_vcf)
end
pggb_vcfs

cactus_vcfs = String[]
ProgressMeter.@showprogress for row in DataFrames.eachrow(assembly_table_to_visualize)
    outdir = joinpath(genome_dir, row["assembled_variant_fasta"] * "-cactus")
    out = replace(outdir, "$(genome_dir)/" => "")
    vcf = joinpath(outdir, out * ".vcf.gz")
    push!(cactus_vcfs, vcf)
end
cactus_vcfs

vcf_table = DataFrames.DataFrame(
    vcf = vcat(pggb_vcfs, cactus_vcfs),
    caller = vcat(fill("pggb", length(pggb_vcfs)), fill("cactus", length(cactus_vcfs))))
vcf_table[!, "original_fasta"] = [first(filter(x -> occursin(x, row["vcf"]), original_fastas)) for row in DataFrames.eachrow(vcf_table)]
vcf_table = vcf_table[isfile.(vcf_table[!, "vcf"]), :]
normalized_vcfs = String[]
for row in DataFrames.eachrow(vcf_table)
    push!(normalized_vcfs, Mycelia.normalize_vcf(reference_fasta=row["original_fasta"], vcf_file=row["vcf"]))
end
vcf_table[!, "normalized_vcf"] = normalized_vcfs
vcf_table = DataFrames.innerjoin(vcf_table, unique(assembly_table_to_visualize[!, ["original_fasta", "reference_variant_calls"]]), on="original_fasta")

pangenomic_variant_call_accuracy_table = DataFrames.DataFrame()
for row in DataFrames.eachrow(vcf_table)
    normalized_vcf = row["normalized_vcf"]
    reference_vcf = row["reference_variant_calls"]
    reference_fasta = row["original_fasta"]
    outdir = normalized_vcf * "_RTG_eval"
    if isdir(outdir) && ("done" in readdir(outdir))
        row = (
            weighted_roc=joinpath(outdir, "weighted_roc.tsv.gz"),
            snp_roc = joinpath(outdir, "snp_roc.tsv.gz"),
            non_snp_roc = joinpath(outdir, "non_snp_roc.tsv.gz"),
            normalized_vcf=normalized_vcf,
            outdir = outdir
        )
        push!(pangenomic_variant_call_accuracy_table, row)
    end
end
pangenomic_variant_call_accuracy_table
pangenomic_variant_call_accuracy_table = DataFrames.innerjoin(vcf_table, pangenomic_variant_call_accuracy_table, on="normalized_vcf")

# need to get assembly out of this - caller is already broken out! pangenomic_variant_call_accuracy_table

assembled_fastas_to_visualize = assembly_table_to_visualize[!, "assembled_variant_fasta"]

pangenomic_variant_call_accuracy_table[!, "assembled_variant_fasta"] .= ""
for (i, vcf) in enumerate(pangenomic_variant_call_accuracy_table[!, "vcf"])
    matching_fasta_assemblies = filter(x -> occursin(x, vcf), assembled_fastas_to_visualize)
    # display(matching_fasta_assemblies)
    @assert length(matching_fasta_assemblies) == 1
    pangenomic_variant_call_accuracy_table[i, "assembled_variant_fasta"] = first(matching_fasta_assemblies)
end
pangenomic_variant_call_accuracy_table[!, "assembled_variant_fasta"]

pangenomic_variant_call_accuracy_table = DataFrames.innerjoin(pangenomic_variant_call_accuracy_table, assembly_table_to_visualize, on=["assembled_variant_fasta", "original_fasta", "reference_variant_calls"], makeunique=true)

weighted_roc_table = DataFrames.DataFrame()
for row in DataFrames.eachrow(pangenomic_variant_call_accuracy_table)
    result_table = Mycelia.parse_rtg_eval_output(row["weighted_roc"])
    if !isempty(result_table)
        result_table[!, "weighted_roc"] .= row["weighted_roc"]
        append!(weighted_roc_table, result_table)
        # row = (weighted_roc = row["weighted_roc"], )
    end
end
weighted_roc_table = DataFrames.innerjoin(pangenomic_variant_call_accuracy_table, weighted_roc_table, on="weighted_roc")
weighted_roc_table[!, "assembler-caller"] .= weighted_roc_table[!, "assembler"] .* " - " .* weighted_roc_table[!, "caller"]
variant_call_accuracy_visualization_groups = DataFrames.groupby(weighted_roc_table, ["original_fasta", "coverage"])

In [None]:
variant_call_accuracy_visualization_group = variant_call_accuracy_visualization_groups[2]
sort!(variant_call_accuracy_visualization_group, "assembler-caller")
this_fasta = basename(first(variant_call_accuracy_visualization_group[!, "original_fasta"]))
this_coverage = first(variant_call_accuracy_visualization_group[!, "coverage"])
title = "$(this_fasta) @ $(this_coverage)"
unique_assembler_callers = variant_call_accuracy_visualization_group[!, "assembler-caller"]
# Precision plot
p = StatsPlots.scatter(
    collect(1:length(unique_assembler_callers)) .- 0.1,
    variant_call_accuracy_visualization_group.precision,
    xticks = (1:length(unique_assembler_callers), unique_assembler_callers),
    xlabel = "Assembler-Caller",
    title = title,
    label = "Precision",
    ylims = (-0.1, 1.1),
    margins = 10StatsPlots.Plots.PlotMeasures.mm,
    xrotation=45,
    legend=:bottomleft
)

# Sensitivity plot
StatsPlots.scatter!(p,
    collect(1:length(unique_assembler_callers)),
    variant_call_accuracy_visualization_group.sensitivity,
    label = "Sensitivity"
)

# F1 score plot
StatsPlots.scatter!(p,
    collect(1:length(unique_assembler_callers)) .+ 0.1,
    variant_call_accuracy_visualization_group.f_measure,
    label = "F1 Score"
)

display(p)
StatsPlots.savefig(p, joinpath(results_dir, "rtg-eval.$(this_fasta).$(this_coverage).svg"))