In [None]:
# if hit plotting library issues, try resetting LD path for julia
# can set in ~/.local/share/jupyter/kernels/
@assert ENV["LD_LIBRARY_PATH"] == ""
import Pkg
pkgs = [
    "Revise",
    "FASTX",
    "BioSequences",
    "Kmers",
    "Graphs",
    "MetaGraphs",
    "SparseArrays",
    "ProgressMeter",
    "Distributions",
    "HiddenMarkovModels",
    "BioAlignments",
    "StatsBase",
    "Random",
    "StatsPlots",
    "Statistics",
    # "GraphMakie",
    "IterTools",
    "Primes",
    "OnlineStats",
    "IteratorSampling"
]
# Pkg.add(pkgs)
for pkg in pkgs
    eval(Meta.parse("import $pkg"))
end
# Pkg.develop(path="/global/cfs/projectdirs/m4269/cjprybol/Mycelia")
# Pkg.develop(path="../../..")
import Mycelia

In [None]:
PROJECT_BASEDIR = dirname(pwd())
data_dir = joinpath(PROJECT_BASEDIR, "data")
genome_dir = mkpath(joinpath(data_dir, "genomes"))

In [None]:
working_dir = joinpath(data_dir, "test")
mkpath(working_dir)

In [None]:
# short_read_sets = unique(map(x -> match(r"^(.+\.\d+x)\.", x).captures[1], filter(x -> occursin(r"\.fna\.art", x) && occursin(r"\.fq\.gz", x) && !occursin("trimming_report", x) && !occursin("_val_", x), sort(readdir(genome_dir, join=true), by=x->filesize(x)))))
# # forward = short_read_set * ".1_val_1.fq.gz"
# # reverse = short_read_set * ".2_val_2.fq.gz"

In [None]:
long_read_fastqs = sort(filter(x -> occursin(r"\.filtlong\.fq\.gz$", x), readdir(genome_dir, join=true)), by=x->filesize(x))
fastq = long_read_fastqs[1]

In [None]:
reference_fasta = replace(fastq, r"\.badread.*" => "")

In [None]:
k = Mycelia.assess_dnamer_saturation([fastq])

In [None]:
kmer_type = Kmers.DNAKmer{k, 1}

In [None]:
reference_kmer_counts = Mycelia.fasta_to_reference_kmer_counts(kmer_type=kmer_type, fasta=reference_fasta)
records = collect(Mycelia.open_fastx(fastq))

In [None]:
fit_mean = OnlineStats.fit!(OnlineStats.Mean(), IterTools.chain(FASTX.quality_scores(record) for record in records))

In [None]:
fit_extrema = OnlineStats.fit!(OnlineStats.Extrema(), IterTools.chain(FASTX.quality_scores(record) for record in records))

In [None]:
fit_variance = OnlineStats.fit!(OnlineStats.Variance(), IterTools.chain(FASTX.quality_scores(record) for record in records))
standard_deviation = sqrt(OnlineStats.value(fit_variance))

In [None]:
read_quality_scores = [collect(FASTX.quality_scores(record)) for record in records]

p = StatsPlots.scatter(
    IteratorSampling.itsample(IterTools.chain(read_quality_scores...), 10^4),
    title = "base quality scores",
    xlabel = "read index",
    ylabel = "quality score (PHRED)",
    # color = :black,
    alpha = 0.25,
    label = nothing)

StatsPlots.hline!(
    p,
    [OnlineStats.value(fit_mean)],
    labels = "mean = $(round(OnlineStats.value(fit_mean), digits=3))",
    linestyle = :dash
)

StatsPlots.hline!(p, [floor(OnlineStats.value(fit_mean) - standard_deviation)], label = "(mean - 1σ)")

In [None]:
# make a dictionary associating all kmers with their quality scores
all_kmer_quality_support = Dict{kmer_type, Vector{Float64}}()
for record in records
    record_quality_scores = collect(FASTX.quality_scores(record))
    record_quality_score_slices = [record_quality_scores[i:i+k-1] for i in 1:length(record_quality_scores)-k+1]
    sequence = BioSequences.LongDNA{2}(FASTX.sequence(record))
    for ((i, kmer), kmer_base_qualities) in zip(Kmers.EveryKmer{kmer_type}(sequence), record_quality_score_slices)
        if haskey(all_kmer_quality_support, kmer)
            all_kmer_quality_support[kmer] = all_kmer_quality_support[kmer] .+ kmer_base_qualities
        else
            all_kmer_quality_support[kmer] = kmer_base_qualities
        end
    end
end

kmer_counts = Mycelia.count_kmers(kmer_type, fastq)
kmer_indices = Dict(kmer => i for (i, kmer) in enumerate(keys(kmer_counts)))
canonical_kmer_counts = Mycelia.count_canonical_kmers(kmer_type, fastq)
canonical_kmer_indices = Dict(kmer => i for (i, kmer) in enumerate(keys(canonical_kmer_counts)))

valid_kmer_counts = [count for (kmer, count) in canonical_kmer_counts if !(kmer in keys(reference_kmer_counts))]
invalid_kmer_counts = [count for (kmer, count) in canonical_kmer_counts if (kmer in keys(reference_kmer_counts))]

reference_kmers = sort(collect(keys(reference_kmer_counts)))

strand_normalized_quality_support = Dict{kmer_type, Vector{Float64}}()
for (kmer, support) in all_kmer_quality_support
    strand_normalized_quality_support[kmer] = support
    if haskey(all_kmer_quality_support, BioSequences.reverse_complement(kmer))
        strand_normalized_quality_support[kmer] .+= all_kmer_quality_support[BioSequences.reverse_complement(kmer)]
    end
end

In [None]:
valid_total_qualities = Float64[]
invalid_total_qualities = Float64[]
for (kmer, quality_values) in strand_normalized_quality_support
    if kmer in reference_kmers
        append!(valid_total_qualities, sum(quality_values))
    else
        append!(invalid_total_qualities, sum(quality_values))
    end
end

p = StatsPlots.scatter(
    [Mycelia.jitter(2, length(invalid_total_qualities)), Mycelia.jitter(1, length(valid_total_qualities))],
    [invalid_total_qualities, valid_total_qualities],
    alpha=0.2,
    title = "Total adjusted joint-Q value for each Kmer",
    xticks = ((1, 2), ("valid kmers", "sequencing artifacts")),
    labels = nothing
)
StatsPlots.plot!(p, 
    [0.75, 1.25],
    [Statistics.mean(valid_total_qualities), Statistics.mean(valid_total_qualities)],
    linewidth=4,
    color=:orange,
    label = "mean = $(round(Statistics.mean(valid_total_qualities), digits=3))")
StatsPlots.plot!(p,
    [1.75, 2.25],
    [Statistics.mean(invalid_total_qualities), Statistics.mean(invalid_total_qualities)],
    linewidth=4,
    color=:blue,
    label="mean = $(round(Statistics.mean(invalid_total_qualities), digits=3))")

In [None]:
kmer_total_quality = Dict(kmer => sum(quality_values) for (kmer, quality_values) in strand_normalized_quality_support)
# state_likelihoods = Dict(kmer => kmer_count / total_kmers for (kmer, kmer_count) in kmer_counts)
state_likelihoods = Dict(kmer => total_quality / sum(values(kmer_total_quality)) for (kmer, total_quality) in kmer_total_quality)

In [None]:
total_states = length(state_likelihoods)

In [None]:
transition_likelihoods = SparseArrays.spzeros(total_states, total_states)
for record in records
    sequence = BioSequences.LongDNA{4}(FASTX.sequence(record))
    sources = Kmers.EveryKmer{kmer_type}(sequence[1:end-1])
    destinations = Kmers.EveryKmer{kmer_type}(sequence[2:end])
    for ((source_i, source), (destination_i, destination)) in zip(sources, destinations)
        source_index = kmer_indices[source]
        destination_index = kmer_indices[destination]
        transition_likelihoods[source_index, destination_index] += 1
    end
end
for source in 1:total_states
    # @show source
    outgoing_transition_counts = transition_likelihoods[source, :]
    if sum(outgoing_transition_counts) > 0
        transition_likelihoods[source, :] .= transition_likelihoods[source, :] ./ sum(transition_likelihoods[source, :]) 
    end
end
transition_likelihoods

In [None]:
g = Graphs.SimpleDiGraph(total_states)
row_indices, column_indices, cell_values = SparseArrays.findnz(transition_likelihoods)
for (row, col) in zip(row_indices, column_indices)
    Graphs.add_edge!(g, row, col)
end
g

In [None]:
unbranching_nodes = Set(Int[])
for node in Graphs.vertices(g)
    if (Graphs.indegree(g, node) <= 1) && (Graphs.outdegree(g, node) <= 1)
        push!(unbranching_nodes, node)
    end
end
unvisited_unbranching_nodes = Set(unbranching_nodes)

branching_nodes = setdiff(Graphs.vertices(g), unbranching_nodes)
for branching_node in branching_nodes
    @assert Graphs.degree(g, branching_node) >= 2
end
inbranching_nodes = filter(node -> Graphs.indegree(g, node) > 1, branching_nodes)
outbranching_nodes = filter(node -> Graphs.outdegree(g, node) > 1, branching_nodes)

unbranching_paths = []
while !isempty(unvisited_unbranching_nodes)
    current_path = [rand(unvisited_unbranching_nodes)]
    delete!(unvisited_unbranching_nodes, first(current_path))

    outneighbors = Graphs.outneighbors(g, last(current_path))

    while length(outneighbors) == 1
        outneighbor = first(outneighbors)
        outneighbors_inneighbors = Graphs.inneighbors(g, outneighbor)
        if outneighbors_inneighbors == [last(current_path)]
            push!(current_path, outneighbor)
            delete!(unvisited_unbranching_nodes, outneighbor)
            outneighbors = Graphs.outneighbors(g, outneighbor)
        else
            @assert length(outneighbors_inneighbors) > 1
            push!(current_path, outneighbor)
            delete!(unvisited_unbranching_nodes, outneighbor)
            break
        end
    end
    inneighbors = Graphs.inneighbors(g, first(current_path))
    while length(inneighbors) == 1
        inneighbor = first(inneighbors)
        inneighbors_outneighbors = Graphs.outneighbors(g, inneighbor)
        if inneighbors_outneighbors == [first(current_path)]
            pushfirst!(current_path, inneighbor)
            delete!(unvisited_unbranching_nodes, inneighbor)
            inneighbors = Graphs.inneighbors(g, inneighbor)
        else
            @assert length(inneighbors_outneighbors) > 1
            pushfirst!(current_path, inneighbor)
            delete!(unvisited_unbranching_nodes, inneighbor)
            break
        end
    end
    push!(unbranching_paths, current_path)
end
unbranching_paths

In [None]:
unbranching_path_scores = Float64[]
state_scores = collect(values(sort(kmer_total_quality)))
for unbranching_path in unbranching_paths
    push!(unbranching_path_scores, Statistics.mean(state_scores[state] for state in unbranching_path))
end
StatsPlots.histogram(unbranching_path_scores)

In [None]:
ordered_kmers = collect(keys(kmer_counts))
solid_states = findall(x -> x in reference_kmers, ordered_kmers)

In [None]:
unbranching_path_solidity = [count(s -> s in solid_states, unbranching_path)/length(unbranching_path) for unbranching_path in unbranching_paths]

In [None]:
solid_unbranching_paths = findall(unbranching_path_solidity .== 1.0)

In [None]:
valid_unbranching_path_scores = Float64[]
invalid_unbranching_path_scores = Float64[]
for (i, s) in enumerate(unbranching_path_scores)
    if i in solid_unbranching_paths
        push!(valid_unbranching_path_scores, s)
    else
        push!(invalid_unbranching_path_scores, s)
    end
end

In [None]:
p = StatsPlots.scatter(
    [Mycelia.jitter(2, length(invalid_unbranching_path_scores)), Mycelia.jitter(1, length(valid_unbranching_path_scores))],
    [invalid_unbranching_path_scores, valid_unbranching_path_scores],
    alpha=0.2,
    title = "mean total adjusted joint-Q value for untigs",
    xticks = ((1, 2), ("valid paths", "sequencing artifacts")),
    labels = nothing
)


StatsPlots.plot!(p, 
    [0.75, 1.25],
    [Statistics.mean(valid_unbranching_path_scores), Statistics.mean(valid_unbranching_path_scores)],
    linewidth=4,
    color=:orange,
    label = "mean = $(round(Statistics.mean(valid_unbranching_path_scores), digits=3))")
StatsPlots.plot!(p,
    [1.75, 2.25],
    [Statistics.mean(invalid_unbranching_path_scores), Statistics.mean(invalid_unbranching_path_scores)],
    linewidth=4,
    color=:blue,
    label="mean = $(round(Statistics.mean(invalid_unbranching_path_scores), digits=3))")

In [None]:
path_support = Vector{Float64}[]
for unbranching_path in unbranching_paths
    push!(path_support, [state_likelihoods[collect(keys(kmer_counts))[state]] for state in unbranching_path])
    # push!(path_support, [kmer_quality_support[collect(keys(kmer_counts))[state]] for state in unbranching_path])
end
path_support

p = StatsPlots.scatter(
    # sort(Statistics.mean.(path_support)),
    sort(Statistics.median.(path_support)),
    title = "median kmer likelihood of unbranching paths",
    legend=false,
    ylabel = ""
)

# low_support_path_unbranching_path_indices = findall(Statistics.mean.(path_support) .< Statistics.median(Statistics.mean.(path_support)))
# low_support_path_unbranching_path_indices = findall(Statistics.mean.(path_support) .< Statistics.median(Statistics.mean.(path_support)))

StatsPlots.hline!(p, [Statistics.mean(Statistics.median.(path_support))], label = "mean of the medians")

In [None]:
# now use the shortest path algorithm

In [None]:
Use yen k shortest paths to find the best replacement route for flagged routes that we are going to drop
    
distances = 1 / median

In [None]:
distance matrix is 

In [None]:
# let's read in the original genome in the forward and backward orientations to identify which kmers are actually true

In [None]:
# which nodes are not in the unbranching paths - should only be the hub nodes that would be the basis of the simplified graph?


In [None]:
Graphs.inneighbors(g, unbranching_paths[1])

In [None]:
unbranching_paths

In [None]:


# low_support_path_unbranching_path_indices = findall(Statistics.mean.(path_support) .< Statistics.median(Statistics.mean.(path_support)))
# low_support_path_unbranching_path_indices = findall(Statistics.mean.(path_support) .< Statistics.median(Statistics.mean.(path_support)))

# StatsPlots.hline!(p, [Statistics.mean(Statistics.median.(path_support))], label = "mean of the medians")
# StatsPlots.hline!(p, [Statistics.median(Statistics.mean.(path_support))], label = "median of the means")

In [None]:
path_support = Vector{Float64}[]
for unbranching_path in unbranching_paths
    # push!(path_support, [state_likelihoods[collect(keys(kmer_counts))[state]] for state in unbranching_path])
    push!(path_support, [kmer_quality_support[collect(keys(kmer_counts))[state]] for state in unbranching_path])
end
path_support

p = StatsPlots.scatter(
    sort(Statistics.mean.(path_support)),
    title = "average kmer quality score of unbranching paths",
    labels="mean unbranching path likelihood",
    ylabel = ""
)

low_support_path_unbranching_path_indices = findall(Statistics.mean.(path_support) .< Statistics.median(Statistics.mean.(path_support)))

# StatsPlots.hline!(p, [Statistics.mean(Statistics.median.(path_support))], label = "mean of the medians")
StatsPlots.hline!(p, [Statistics.median(Statistics.mean.(path_support))], label = "median of the means")

In [None]:
quality_threshold = min(Statistics.median(quality_scores), Statistics.mean(quality_scores))

In [None]:
state_likelihood_threshold = min(Statistics.mean(collect(values(state_likelihoods))), Statistics.median(collect(values(state_likelihoods))))

In [None]:
nonzero_transition_likelihoods = filter(x -> x > 0, vec(transition_likelihoods))

In [None]:
transition_likelihood_treshold = min(Statistics.mean(nonzero_transition_likelihoods), Statistics.median(nonzero_transition_likelihoods))

In [None]:
kmer_count_threshold = Statistics.mean(collect(values(kmer_counts)))

In [None]:
record = first(records)

sequence = BioSequences.LongDNA{4}(FASTX.sequence(record))

record_kmers = Kmers.EveryKmer{kmer_type}(sequence)

record_quality_scores = collect(FASTX.quality_scores(record))
record_quality_score_slices = [record_quality_scores[i:i+k-1] for i in 1:length(record_quality_scores)-k+1]

kmer_count_values = [kmer_counts[kmer] for kmer in last.(Kmers.EveryKmer{kmer_type}(sequence))]
p = StatsPlots.plot(
    kmer_count_values,
    title = "state frequency across fastq record",
    xlabel = "read index",
    ylabel = "total equivalent observations",
    labels = "# of equivalent observations"
)
StatsPlots.hline!(p, [kmer_count_threshold], label = "threshold")

In [None]:
p = StatsPlots.plot(
    [
        Statistics.mean.(record_quality_score_slices),
    ],
    ylabel = "quality score",
    xlabel = "read index",
    title = "quality scores across fastq record",
    labels = "mean"
)
StatsPlots.hline!(p, [quality_threshold], labels = "quality threshold")

In [None]:
sources = Kmers.EveryKmer{kmer_type}(sequence[1:end-1])
destinations = Kmers.EveryKmer{kmer_type}(sequence[2:end])
transition_likelihood_values = [transition_likelihoods[kmer_indices[source], kmer_indices[destination]] for ((source_i, source), (destination_i, destination)) in zip(sources, destinations)]

p = StatsPlots.plot(
    transition_likelihood_values,
    title = "transition likelihoods across fastq record",
    label = false,
    ylabel = "relative likelihood",
    xlabel = "read index"
)
StatsPlots.hline!(p, [transition_likelihood_treshold], label = "threshold")

In [None]:
kmer_count_flags = findall(kmer_count_values .< kmer_count_threshold)

In [None]:
quality_score_flags = findall(Statistics.mean.(record_quality_score_slices) .< quality_threshold)

In [None]:


transition_likelihood_flags = findall(transition_likelihood_values .< transition_likelihood_treshold)
hit_results = StatsBase.countmap(vcat(kmer_count_flags, quality_score_flags, transition_likelihood_flags))
StatsPlots.scatter(
    first.(sort(collect(hit_results))),
    last.(sort(collect(hit_results))),
    legend = false,
    title = "# of flags thrown at record index",
    ylabel = "# of flags",
    xlabel = "record index"
)

In [None]:
universally_flagged_indices = sort(collect(keys(filter(hit_result -> hit_result[2] == maximum(values(hit_results)), hit_results))))
universally_flagged_kmers = last.(collect(record_kmers))[universally_flagged_indices]
universally_flagged_kmer_indices = sort([kmer_indices[universally_flagged_kmer] for universally_flagged_kmer in universally_flagged_kmers])

intersect(universally_flagged_kmer_indices, reduce(vcat, unbranching_paths[low_support_path_unbranching_path_indices]))

In [None]:
universally_flagged_kmer_indices

In [None]:
reduce(vcat, unbranching_paths[low_support_path_unbranching_path_indices])

In [None]:
# low_support_unbranching_paths = Dict(x => iunbranching_paths[low_support_paths]

In [None]:
# for low_support_unbranching_paths

In [None]:
function polish_fastx(fastx; k=Mycelia.assess_dnamer_saturation([fastx], plot=false))
    kmer_type = Kmers.DNAKmer{k}
    canonical_kmer_counts = Mycelia.count_canonical_kmers(kmer_type, fastx)
    
    
    solid_threshold = floor(Statistics.mean(values(canonical_kmer_counts)))
    # solid_threshold
    @show solid_threshold
    
    stranded_kmer_counts = copy(canonical_kmer_counts)
    for (canonical_kmer, count) in canonical_kmer_counts
        stranded_kmer_counts[BioSequences.reverse_complement(canonical_kmer)] = count
    end
    sort!(stranded_kmer_counts)

    stranded_kmer_graph = MetaGraphs.MetaDiGraph(length(stranded_kmer_counts))
    MetaGraphs.set_prop!(stranded_kmer_graph, :k, k)
    for (i, (stranded_kmer, count)) in enumerate(stranded_kmer_counts)
        MetaGraphs.set_prop!(stranded_kmer_graph, i, :kmer, stranded_kmer)
        MetaGraphs.set_prop!(stranded_kmer_graph, i, :count, count)
    end
    MetaGraphs.set_indexing_prop!(stranded_kmer_graph, :kmer)
    # stranded_kmer_graph
    records = collect(Mycelia.open_fastx(fastx))
    for record in records
        sequence = BioSequences.LongDNA{4}(FASTX.sequence(record))
        sources = Kmers.EveryKmer{kmer_type}(sequence[1:end-1])
        destinations = Kmers.EveryKmer{kmer_type}(sequence[2:end])
        for (i, ((source_i, source), (destination_i, destination))) in enumerate(zip(sources, destinations))
            source_vertex = stranded_kmer_graph[source, :kmer]
            destination_vertex = stranded_kmer_graph[destination, :kmer]
            edge = Graphs.Edge(source_vertex, destination_vertex)
            observation = (record_identifier = FASTX.description(record), index = i, orientation = true)
            if !Graphs.has_edge(stranded_kmer_graph, edge)
                Graphs.add_edge!(stranded_kmer_graph, edge)
                MetaGraphs.set_prop!(stranded_kmer_graph, edge, :observations, Set([observation]))
            else
                observations = push!(MetaGraphs.get_prop(stranded_kmer_graph, edge, :observations), observation)
                MetaGraphs.set_prop!(stranded_kmer_graph, edge, :observations, observations)
            end

            # reverse_complement!
            source_vertex = stranded_kmer_graph[BioSequences.reverse_complement(destination), :kmer]
            destination_vertex = stranded_kmer_graph[BioSequences.reverse_complement(source), :kmer]
            edge = Graphs.Edge(source_vertex, destination_vertex)
            observation = (record_identifier = FASTX.description(record), index = i, orientation = false)
            if !Graphs.has_edge(stranded_kmer_graph, edge)
                Graphs.add_edge!(stranded_kmer_graph, edge)
                MetaGraphs.set_prop!(stranded_kmer_graph, edge, :observations, Set([observation]))
            else
                observations = push!(MetaGraphs.get_prop(stranded_kmer_graph, edge, :observations), observation)
                MetaGraphs.set_prop!(stranded_kmer_graph, edge, :observations, observations)
            end
        end
    end
    stranded_kmer_graph

    unbranching_nodes = Int[]
    for node in Graphs.vertices(stranded_kmer_graph)
        if (Graphs.indegree(stranded_kmer_graph, node) == 1) && (Graphs.outdegree(stranded_kmer_graph, node) == 1)
            push!(unbranching_nodes, node)
        end
    end
    unvisited_unbranching_nodes = Set(unbranching_nodes)

    unbranching_paths = []
    while !isempty(unvisited_unbranching_nodes)
        current_path = [rand(unvisited_unbranching_nodes)]
        delete!(unvisited_unbranching_nodes, first(current_path))

        outneighbors = Graphs.outneighbors(stranded_kmer_graph, last(current_path))

        while length(outneighbors) == 1
            outneighbor = first(outneighbors)
            outneighbors_inneighbors = Graphs.inneighbors(stranded_kmer_graph, outneighbor)
            if outneighbors_inneighbors == [last(current_path)]
                push!(current_path, outneighbor)
                delete!(unvisited_unbranching_nodes, outneighbor)
                outneighbors = Graphs.outneighbors(stranded_kmer_graph, outneighbor)
            else
                @assert length(outneighbors_inneighbors) > 1
                push!(current_path, outneighbor)
                delete!(unvisited_unbranching_nodes, outneighbor)
                break
            end
        end
        # @show Graphs.outneighbors(stranded_kmer_graph, current_node)
        inneighbors = Graphs.inneighbors(stranded_kmer_graph, first(current_path))
        while length(inneighbors) == 1
            inneighbor = first(inneighbors)
            inneighbors_outneighbors = Graphs.outneighbors(stranded_kmer_graph, inneighbor)
            if inneighbors_outneighbors == [first(current_path)]
                pushfirst!(current_path, inneighbor)
                delete!(unvisited_unbranching_nodes, inneighbor)
                inneighbors = Graphs.inneighbors(stranded_kmer_graph, inneighbor)
            else
                @assert length(inneighbors_outneighbors) > 1
                pushfirst!(current_path, inneighbor)
                delete!(unvisited_unbranching_nodes, inneighbor)
                break
            end
        end
        push!(unbranching_paths, current_path)
    end
    unbranching_paths

    unbranching_path_state_weights = [map(v -> MetaGraphs.get_prop(stranded_kmer_graph, v, :count), path) for path in unbranching_paths]

    solid_unbranching_paths = unbranching_paths[minimum.(unbranching_path_state_weights) .>= solid_threshold]

    # beginnings = filter(path -> Graphs.indegree(stranded_kmer_graph, first(path)) == 0, solid_unbranching_paths)
    # ends = filter(path -> Graphs.outdegree(stranded_kmer_graph, last(path)) == 0, solid_unbranching_paths)
    # mids = setdiff(setdiff(solid_unbranching_paths, beginnings), ends)
    # branch_points = filter(v -> Graphs.indegree(stranded_kmer_graph, v) > 1 || Graphs.outdegree(stranded_kmer_graph, v) > 1, Graphs.vertices(stranded_kmer_graph))
    # @assert all(x -> x in branch_points, first.(mids))
    # @assert all(x -> x in branch_points, last.(mids))
    # @assert all(x -> x in branch_points, first.(ends))
    # @assert all(x -> x in branch_points, last.(beginnings))

    solid_kmers = Set(collect(keys(filter(x -> x[2] >= solid_threshold, stranded_kmer_counts))))

    solid_vertices = filter(v -> MetaGraphs.get_prop(stranded_kmer_graph, v, :count) >= solid_threshold, Graphs.vertices(stranded_kmer_graph))

    solid_subgraph, vertex_map = Graphs.induced_subgraph(stranded_kmer_graph, solid_vertices)
    MetaGraphs.set_indexing_prop!(solid_subgraph, :kmer)
    distance_matrix = SparseArrays.spzeros(Graphs.nv(solid_subgraph), Graphs.nv(solid_subgraph))
    for edge in Graphs.edges(solid_subgraph)
        # @show edge
        observations = MetaGraphs.get_prop(solid_subgraph, edge, :observations)
        # @show length(observations)
        distance_matrix[edge.src, edge.dst] = distance_matrix[edge.dst, edge.src] = 1/length(observations)
    end
    distance_matrix

    updated_records = FASTX.FASTA.Record[]
    for record in records
        record_kmers = last.(collect(Kmers.EveryKmer{kmer_type}(BioSequences.LongDNA{4}(FASTX.sequence(record)))))
        kmer_is_solid = map(kmer -> kmer in solid_kmers, record_kmers)
        if all(kmer_is_solid)
            push!(updated_records, record)
        else
            junctions_to_resample = find_weak_runs(kmer_is_solid)
            while !isempty(junctions_to_resample)
                junction_to_resample = first(junctions_to_resample)
                start_index = first(junction_to_resample)
                if start_index > 1
                    start_index -= 1
                end
                stop_index = last(junction_to_resample)
                if stop_index < length(record_kmers)
                    stop_index += 1
                end
                # @show start_index
                # @show stop_index

                if first(junction_to_resample) == 1
                    record_kmers = [record_kmers[stop_index:end]...]
                elseif last(junction_to_resample) == length(record_kmers)
                    record_kmers = [record_kmers[1:start_index]...]
                else
                    # @assert haskey(solid_kmers, start_index)
                    # @assert haskey(solid_kmers, stop_index)
                    start_vertex = solid_subgraph[record_kmers[start_index], :kmer]
                    stop_vertex = solid_subgraph[record_kmers[stop_index], :kmer]
                    replacement_path = Graphs.a_star(solid_subgraph, start_vertex, stop_vertex, distance_matrix)
                    record_kmers = [record_kmers[1:start_index]..., [solid_subgraph[edge.dst, :kmer] for edge in replacement_path[1:end-1]]..., record_kmers[stop_index:end]...]
                end
                kmer_is_solid = map(kmer -> kmer in solid_kmers, record_kmers)
                junctions_to_resample = find_weak_runs(kmer_is_solid)
            end
            new_sequence = BioSequences.LongDNA{4}(first(record_kmers))
            for kmer in record_kmers[2:end]
                push!(new_sequence, last(kmer))
            end
            push!(updated_records, FASTX.FASTA.Record(FASTX.description(record), new_sequence))
        end
    end
    return (;updated_records, k)
end

In [None]:
kmer_type = Kmers.DNAKmer{k}
kmer_counts = Mycelia.count_kmers(kmer_type, fasta)

In [None]:


    unbranching_nodes = Int[]
    for node in Graphs.vertices(stranded_kmer_graph)
        if (Graphs.indegree(stranded_kmer_graph, node) == 1) && (Graphs.outdegree(stranded_kmer_graph, node) == 1)
            push!(unbranching_nodes, node)
        end
    end
    unvisited_unbranching_nodes = Set(unbranching_nodes)

    unbranching_paths = []
    while !isempty(unvisited_unbranching_nodes)
        current_path = [rand(unvisited_unbranching_nodes)]
        delete!(unvisited_unbranching_nodes, first(current_path))

        outneighbors = Graphs.outneighbors(stranded_kmer_graph, last(current_path))

        while length(outneighbors) == 1
            outneighbor = first(outneighbors)
            outneighbors_inneighbors = Graphs.inneighbors(stranded_kmer_graph, outneighbor)
            if outneighbors_inneighbors == [last(current_path)]
                push!(current_path, outneighbor)
                delete!(unvisited_unbranching_nodes, outneighbor)
                outneighbors = Graphs.outneighbors(stranded_kmer_graph, outneighbor)
            else
                @assert length(outneighbors_inneighbors) > 1
                push!(current_path, outneighbor)
                delete!(unvisited_unbranching_nodes, outneighbor)
                break
            end
        end
        # @show Graphs.outneighbors(stranded_kmer_graph, current_node)
        inneighbors = Graphs.inneighbors(stranded_kmer_graph, first(current_path))
        while length(inneighbors) == 1
            inneighbor = first(inneighbors)
            inneighbors_outneighbors = Graphs.outneighbors(stranded_kmer_graph, inneighbor)
            if inneighbors_outneighbors == [first(current_path)]
                pushfirst!(current_path, inneighbor)
                delete!(unvisited_unbranching_nodes, inneighbor)
                inneighbors = Graphs.inneighbors(stranded_kmer_graph, inneighbor)
            else
                @assert length(inneighbors_outneighbors) > 1
                pushfirst!(current_path, inneighbor)
                delete!(unvisited_unbranching_nodes, inneighbor)
                break
            end
        end
        push!(unbranching_paths, current_path)
    end
    unbranching_paths

    unbranching_path_state_weights = [map(v -> MetaGraphs.get_prop(stranded_kmer_graph, v, :count), path) for path in unbranching_paths]

    solid_unbranching_paths = unbranching_paths[minimum.(unbranching_path_state_weights) .>= solid_threshold]

    # beginnings = filter(path -> Graphs.indegree(stranded_kmer_graph, first(path)) == 0, solid_unbranching_paths)
    # ends = filter(path -> Graphs.outdegree(stranded_kmer_graph, last(path)) == 0, solid_unbranching_paths)
    # mids = setdiff(setdiff(solid_unbranching_paths, beginnings), ends)
    # branch_points = filter(v -> Graphs.indegree(stranded_kmer_graph, v) > 1 || Graphs.outdegree(stranded_kmer_graph, v) > 1, Graphs.vertices(stranded_kmer_graph))
    # @assert all(x -> x in branch_points, first.(mids))
    # @assert all(x -> x in branch_points, last.(mids))
    # @assert all(x -> x in branch_points, first.(ends))
    # @assert all(x -> x in branch_points, last.(beginnings))

    solid_kmers = Set(collect(keys(filter(x -> x[2] >= solid_threshold, stranded_kmer_counts))))

    solid_vertices = filter(v -> MetaGraphs.get_prop(stranded_kmer_graph, v, :count) >= solid_threshold, Graphs.vertices(stranded_kmer_graph))

    solid_subgraph, vertex_map = Graphs.induced_subgraph(stranded_kmer_graph, solid_vertices)
    MetaGraphs.set_indexing_prop!(solid_subgraph, :kmer)
    distance_matrix = SparseArrays.spzeros(Graphs.nv(solid_subgraph), Graphs.nv(solid_subgraph))
    for edge in Graphs.edges(solid_subgraph)
        # @show edge
        observations = MetaGraphs.get_prop(solid_subgraph, edge, :observations)
        # @show length(observations)
        distance_matrix[edge.src, edge.dst] = distance_matrix[edge.dst, edge.src] = 1/length(observations)
    end
    distance_matrix

    updated_records = FASTX.FASTA.Record[]
    for record in records
        record_kmers = last.(collect(Kmers.EveryKmer{kmer_type}(BioSequences.LongDNA{4}(FASTX.sequence(record)))))
        kmer_is_solid = map(kmer -> kmer in solid_kmers, record_kmers)
        if all(kmer_is_solid)
            push!(updated_records, record)
        else
            junctions_to_resample = find_weak_runs(kmer_is_solid)
            while !isempty(junctions_to_resample)
                junction_to_resample = first(junctions_to_resample)
                start_index = first(junction_to_resample)
                if start_index > 1
                    start_index -= 1
                end
                stop_index = last(junction_to_resample)
                if stop_index < length(record_kmers)
                    stop_index += 1
                end
                # @show start_index
                # @show stop_index

                if first(junction_to_resample) == 1
                    record_kmers = [record_kmers[stop_index:end]...]
                elseif last(junction_to_resample) == length(record_kmers)
                    record_kmers = [record_kmers[1:start_index]...]
                else
                    # @assert haskey(solid_kmers, start_index)
                    # @assert haskey(solid_kmers, stop_index)
                    start_vertex = solid_subgraph[record_kmers[start_index], :kmer]
                    stop_vertex = solid_subgraph[record_kmers[stop_index], :kmer]
                    replacement_path = Graphs.a_star(solid_subgraph, start_vertex, stop_vertex, distance_matrix)
                    record_kmers = [record_kmers[1:start_index]..., [solid_subgraph[edge.dst, :kmer] for edge in replacement_path[1:end-1]]..., record_kmers[stop_index:end]...]
                end
                kmer_is_solid = map(kmer -> kmer in solid_kmers, record_kmers)
                junctions_to_resample = find_weak_runs(kmer_is_solid)
            end
            new_sequence = BioSequences.LongDNA{4}(first(record_kmers))
            for kmer in record_kmers[2:end]
                push!(new_sequence, last(kmer))
            end
            push!(updated_records, FASTX.FASTA.Record(FASTX.description(record), new_sequence))
        end
    end
    return (;updated_records, k)

In [None]:
Random.seed!(20240209)
genome = BioSequences.randdnaseq(100)
initial_records = [FASTX.FASTA.Record(Random.randstring(), Mycelia.observe(genome, error_rate=0.01)) for i in 1:100]
prior_records = initial_records
fasta = Mycelia.write_fasta(records = prior_records, outfile = joinpath(working_dir, Random.randstring() * ".fna"))
updated_records, k = polish_fastx(fasta)
prior_records == updated_records

In [None]:
prior_records = updated_records
temp_fasta = Mycelia.write_fasta(records = prior_records)
updated_records, k = polish_fastx(temp_fasta, k = Primes.nextprime(k+1))
prior_records == updated_records

In [None]:
prior_records = updated_records
temp_fasta = Mycelia.write_fasta(records = prior_records)
updated_records, k = polish_fastx(temp_fasta, k = Primes.nextprime(k+1))
prior_records == updated_records

In [None]:
prior_records = updated_records
temp_fasta = Mycelia.write_fasta(records = prior_records)
updated_records, k = polish_fastx(temp_fasta, k = Primes.nextprime(k+1))
prior_records == updated_records

In [None]:
fasta

In [None]:
updated_fasta = Mycelia.write_fasta(records = updated_records, outfile = fasta * ".updated.fna")

In [None]:
#NEED TO FIX CANONICAL KMER GRAPHS AND GFA OUTPUT - THE REST SEEMS CORRECT

In [None]:
k = Primes.nextprime(k+1)
# make my own kmer graph
assembly_graph = Mycelia.fastx_to_kmer_graph(Kmers.DNAKmer{k}, updated_fasta)
gfa_file = updated_fasta * ".gfa"
Mycelia.graph_to_gfa(graph=assembly_graph, outfile=gfa_file)
image = gfa_file * ".mycelia.gfa.jpg"
# run(`chmod +x $(homedir())/software/bin/Bandage`)
run(`$(homedir())/software/bin/Bandage image $(gfa_file) $(image)`)