In [None]:
# if hit plotting library issues, try resetting LD path for julia
# can set in ~/.local/share/jupyter/kernels/
# @assert ENV["LD_LIBRARY_PATH"] == ""
import Pkg
pkgs = [
    "Revise",
    "FASTX",
    "BioSequences",
    "Kmers",
    "Graphs",
    "MetaGraphs",
    "SparseArrays",
    "ProgressMeter",
    "Distributions",
    "HiddenMarkovModels",
    "BioAlignments",
    "StatsBase",
    "Random",
    "StatsPlots",
    "Statistics",
    # "GraphMakie",
    "IterTools",
    "Primes"
]
# Pkg.add(pkgs)
for pkg in pkgs
    eval(Meta.parse("import $pkg"))
end
# Pkg.develop(path="/global/cfs/projectdirs/m4269/cjprybol/Mycelia")
# Pkg.develop(path="../../..")
import Mycelia

In [None]:
PROJECT_BASEDIR = dirname(pwd())
data_dir = joinpath(PROJECT_BASEDIR, "data")
genome_dir = mkpath(joinpath(data_dir, "genomes"))

In [None]:
working_dir = joinpath(data_dir, "test")
mkpath(working_dir)

In [None]:
# function find_junctions_to_resample(x)
#     false_indices = findall(!, x)
#     ranges = []
#     start = false_indices[1]
#     for i in 2:length(false_indices)
#         if false_indices[i] - false_indices[i-1] > 1
#           push!(ranges, start:false_indices[i-1])
#           start = false_indices[i]
#         end
#     end
#     push!(ranges, start:false_indices[end])
#     return ranges
# end

In [None]:
function find_weak_runs(bool_list)
    ranges = UnitRange{Int64}[]
    grouped_runs = IterTools.groupby(identity, bool_list)
    i = 0
    for grouped_run in grouped_runs
        if all(.!grouped_run)
            i += 1
            range_start = i
            for x in grouped_run[2:end]
                i += 1
            end
            range_stop = i
            push!(ranges, range_start:range_stop)
        else
            for x in grouped_run
                i += 1
            end
        end
    end
    return ranges
end

In [None]:
function polish_fastx(fastx; k=Mycelia.assess_dnamer_saturation([fastx], plot=false))
    kmer_type = Kmers.DNAKmer{k}
    canonical_kmer_counts = Mycelia.count_canonical_kmers(kmer_type, fastx)
    
    
    solid_threshold = floor(Statistics.mean(values(canonical_kmer_counts)))
    # solid_threshold
    @show solid_threshold
    
    stranded_kmer_counts = copy(canonical_kmer_counts)
    for (canonical_kmer, count) in canonical_kmer_counts
        stranded_kmer_counts[BioSequences.reverse_complement(canonical_kmer)] = count
    end
    sort!(stranded_kmer_counts)

    stranded_kmer_graph = MetaGraphs.MetaDiGraph(length(stranded_kmer_counts))
    MetaGraphs.set_prop!(stranded_kmer_graph, :k, k)
    for (i, (stranded_kmer, count)) in enumerate(stranded_kmer_counts)
        MetaGraphs.set_prop!(stranded_kmer_graph, i, :kmer, stranded_kmer)
        MetaGraphs.set_prop!(stranded_kmer_graph, i, :count, count)
    end
    MetaGraphs.set_indexing_prop!(stranded_kmer_graph, :kmer)
    # stranded_kmer_graph
    records = collect(Mycelia.open_fastx(fastx))
    for record in records
        sequence = BioSequences.LongDNA{4}(FASTX.sequence(record))
        sources = Kmers.EveryKmer{kmer_type}(sequence[1:end-1])
        destinations = Kmers.EveryKmer{kmer_type}(sequence[2:end])
        for (i, ((source_i, source), (destination_i, destination))) in enumerate(zip(sources, destinations))
            source_vertex = stranded_kmer_graph[source, :kmer]
            destination_vertex = stranded_kmer_graph[destination, :kmer]
            edge = Graphs.Edge(source_vertex, destination_vertex)
            observation = (record_identifier = FASTX.description(record), index = i, orientation = true)
            if !Graphs.has_edge(stranded_kmer_graph, edge)
                Graphs.add_edge!(stranded_kmer_graph, edge)
                MetaGraphs.set_prop!(stranded_kmer_graph, edge, :observations, Set([observation]))
            else
                observations = push!(MetaGraphs.get_prop(stranded_kmer_graph, edge, :observations), observation)
                MetaGraphs.set_prop!(stranded_kmer_graph, edge, :observations, observations)
            end

            # reverse_complement!
            source_vertex = stranded_kmer_graph[BioSequences.reverse_complement(destination), :kmer]
            destination_vertex = stranded_kmer_graph[BioSequences.reverse_complement(source), :kmer]
            edge = Graphs.Edge(source_vertex, destination_vertex)
            observation = (record_identifier = FASTX.description(record), index = i, orientation = false)
            if !Graphs.has_edge(stranded_kmer_graph, edge)
                Graphs.add_edge!(stranded_kmer_graph, edge)
                MetaGraphs.set_prop!(stranded_kmer_graph, edge, :observations, Set([observation]))
            else
                observations = push!(MetaGraphs.get_prop(stranded_kmer_graph, edge, :observations), observation)
                MetaGraphs.set_prop!(stranded_kmer_graph, edge, :observations, observations)
            end
        end
    end
    stranded_kmer_graph

    unbranching_nodes = Int[]
    for node in Graphs.vertices(stranded_kmer_graph)
        if (Graphs.indegree(stranded_kmer_graph, node) == 1) && (Graphs.outdegree(stranded_kmer_graph, node) == 1)
            push!(unbranching_nodes, node)
        end
    end
    unvisited_unbranching_nodes = Set(unbranching_nodes)

    unbranching_paths = []
    while !isempty(unvisited_unbranching_nodes)
        current_path = [rand(unvisited_unbranching_nodes)]
        delete!(unvisited_unbranching_nodes, first(current_path))

        outneighbors = Graphs.outneighbors(stranded_kmer_graph, last(current_path))

        while length(outneighbors) == 1
            outneighbor = first(outneighbors)
            outneighbors_inneighbors = Graphs.inneighbors(stranded_kmer_graph, outneighbor)
            if outneighbors_inneighbors == [last(current_path)]
                push!(current_path, outneighbor)
                delete!(unvisited_unbranching_nodes, outneighbor)
                outneighbors = Graphs.outneighbors(stranded_kmer_graph, outneighbor)
            else
                @assert length(outneighbors_inneighbors) > 1
                push!(current_path, outneighbor)
                delete!(unvisited_unbranching_nodes, outneighbor)
                break
            end
        end
        # @show Graphs.outneighbors(stranded_kmer_graph, current_node)
        inneighbors = Graphs.inneighbors(stranded_kmer_graph, first(current_path))
        while length(inneighbors) == 1
            inneighbor = first(inneighbors)
            inneighbors_outneighbors = Graphs.outneighbors(stranded_kmer_graph, inneighbor)
            if inneighbors_outneighbors == [first(current_path)]
                pushfirst!(current_path, inneighbor)
                delete!(unvisited_unbranching_nodes, inneighbor)
                inneighbors = Graphs.inneighbors(stranded_kmer_graph, inneighbor)
            else
                @assert length(inneighbors_outneighbors) > 1
                pushfirst!(current_path, inneighbor)
                delete!(unvisited_unbranching_nodes, inneighbor)
                break
            end
        end
        push!(unbranching_paths, current_path)
    end
    unbranching_paths

    unbranching_path_state_weights = [map(v -> MetaGraphs.get_prop(stranded_kmer_graph, v, :count), path) for path in unbranching_paths]

    solid_unbranching_paths = unbranching_paths[minimum.(unbranching_path_state_weights) .>= solid_threshold]

    # beginnings = filter(path -> Graphs.indegree(stranded_kmer_graph, first(path)) == 0, solid_unbranching_paths)
    # ends = filter(path -> Graphs.outdegree(stranded_kmer_graph, last(path)) == 0, solid_unbranching_paths)
    # mids = setdiff(setdiff(solid_unbranching_paths, beginnings), ends)
    # branch_points = filter(v -> Graphs.indegree(stranded_kmer_graph, v) > 1 || Graphs.outdegree(stranded_kmer_graph, v) > 1, Graphs.vertices(stranded_kmer_graph))
    # @assert all(x -> x in branch_points, first.(mids))
    # @assert all(x -> x in branch_points, last.(mids))
    # @assert all(x -> x in branch_points, first.(ends))
    # @assert all(x -> x in branch_points, last.(beginnings))

    solid_kmers = Set(collect(keys(filter(x -> x[2] >= solid_threshold, stranded_kmer_counts))))

    solid_vertices = filter(v -> MetaGraphs.get_prop(stranded_kmer_graph, v, :count) >= solid_threshold, Graphs.vertices(stranded_kmer_graph))

    solid_subgraph, vertex_map = Graphs.induced_subgraph(stranded_kmer_graph, solid_vertices)
    MetaGraphs.set_indexing_prop!(solid_subgraph, :kmer)
    distance_matrix = SparseArrays.spzeros(Graphs.nv(solid_subgraph), Graphs.nv(solid_subgraph))
    for edge in Graphs.edges(solid_subgraph)
        # @show edge
        observations = MetaGraphs.get_prop(solid_subgraph, edge, :observations)
        # @show length(observations)
        distance_matrix[edge.src, edge.dst] = distance_matrix[edge.dst, edge.src] = 1/length(observations)
    end
    distance_matrix

    updated_records = FASTX.FASTA.Record[]
    for record in records
        record_kmers = last.(collect(Kmers.EveryKmer{kmer_type}(BioSequences.LongDNA{4}(FASTX.sequence(record)))))
        kmer_is_solid = map(kmer -> kmer in solid_kmers, record_kmers)
        if all(kmer_is_solid)
            push!(updated_records, record)
        else
            junctions_to_resample = find_weak_runs(kmer_is_solid)
            while !isempty(junctions_to_resample)
                junction_to_resample = first(junctions_to_resample)
                start_index = first(junction_to_resample)
                if start_index > 1
                    start_index -= 1
                end
                stop_index = last(junction_to_resample)
                if stop_index < length(record_kmers)
                    stop_index += 1
                end
                # @show start_index
                # @show stop_index

                if first(junction_to_resample) == 1
                    record_kmers = [record_kmers[stop_index:end]...]
                elseif last(junction_to_resample) == length(record_kmers)
                    record_kmers = [record_kmers[1:start_index]...]
                else
                    # @assert haskey(solid_kmers, start_index)
                    # @assert haskey(solid_kmers, stop_index)
                    start_vertex = solid_subgraph[record_kmers[start_index], :kmer]
                    stop_vertex = solid_subgraph[record_kmers[stop_index], :kmer]
                    replacement_path = Graphs.a_star(solid_subgraph, start_vertex, stop_vertex, distance_matrix)
                    record_kmers = [record_kmers[1:start_index]..., [solid_subgraph[edge.dst, :kmer] for edge in replacement_path[1:end-1]]..., record_kmers[stop_index:end]...]
                end
                kmer_is_solid = map(kmer -> kmer in solid_kmers, record_kmers)
                junctions_to_resample = find_weak_runs(kmer_is_solid)
            end
            new_sequence = BioSequences.LongDNA{4}(first(record_kmers))
            for kmer in record_kmers[2:end]
                push!(new_sequence, last(kmer))
            end
            push!(updated_records, FASTX.FASTA.Record(FASTX.description(record), new_sequence))
        end
    end
    return (;updated_records, k)
end

In [None]:
kmer_type = Kmers.DNAKmer{k}
kmer_counts = Mycelia.count_kmers(kmer_type, fasta)

In [None]:

kmer_indices = Dict(kmer => i for (i, kmer) in enumerate(keys(kmer_counts)))

In [None]:
total_kmers = sum(values(kmer_counts))

In [None]:
state_likelihoods = Dict(kmer => kmer_count / total_kmers for (kmer, kmer_count) in kmer_counts)

In [None]:
total_states = length(state_likelihoods)

In [None]:
records = collect(Mycelia.open_fastx(fasta))

In [None]:
transition_likelihoods = SparseArrays.spzeros(total_states, total_states)
for record in records
    sequence = BioSequences.LongDNA{4}(FASTX.sequence(record))
    sources = Kmers.EveryKmer{kmer_type}(sequence[1:end-1])
    destinations = Kmers.EveryKmer{kmer_type}(sequence[2:end])
    for ((source_i, source), (destination_i, destination)) in zip(sources, destinations)
        source_index = kmer_indices[source]
        destination_index = kmer_indices[destination]
        transition_likelihoods[source_index, destination_index] += 1
    end
end
sum(transition_likelihoods)

In [None]:
for source in 1:total_states
    # @show source
    outgoing_transition_counts = transition_likelihoods[source, :]
    if sum(outgoing_transition_counts) > 0
        transition_likelihoods[source, :] .= transition_likelihoods[source, :] ./ sum(transition_likelihoods[source, :]) 
    end
end
sum(transition_likelihoods)

In [None]:


    unbranching_nodes = Int[]
    for node in Graphs.vertices(stranded_kmer_graph)
        if (Graphs.indegree(stranded_kmer_graph, node) == 1) && (Graphs.outdegree(stranded_kmer_graph, node) == 1)
            push!(unbranching_nodes, node)
        end
    end
    unvisited_unbranching_nodes = Set(unbranching_nodes)

    unbranching_paths = []
    while !isempty(unvisited_unbranching_nodes)
        current_path = [rand(unvisited_unbranching_nodes)]
        delete!(unvisited_unbranching_nodes, first(current_path))

        outneighbors = Graphs.outneighbors(stranded_kmer_graph, last(current_path))

        while length(outneighbors) == 1
            outneighbor = first(outneighbors)
            outneighbors_inneighbors = Graphs.inneighbors(stranded_kmer_graph, outneighbor)
            if outneighbors_inneighbors == [last(current_path)]
                push!(current_path, outneighbor)
                delete!(unvisited_unbranching_nodes, outneighbor)
                outneighbors = Graphs.outneighbors(stranded_kmer_graph, outneighbor)
            else
                @assert length(outneighbors_inneighbors) > 1
                push!(current_path, outneighbor)
                delete!(unvisited_unbranching_nodes, outneighbor)
                break
            end
        end
        # @show Graphs.outneighbors(stranded_kmer_graph, current_node)
        inneighbors = Graphs.inneighbors(stranded_kmer_graph, first(current_path))
        while length(inneighbors) == 1
            inneighbor = first(inneighbors)
            inneighbors_outneighbors = Graphs.outneighbors(stranded_kmer_graph, inneighbor)
            if inneighbors_outneighbors == [first(current_path)]
                pushfirst!(current_path, inneighbor)
                delete!(unvisited_unbranching_nodes, inneighbor)
                inneighbors = Graphs.inneighbors(stranded_kmer_graph, inneighbor)
            else
                @assert length(inneighbors_outneighbors) > 1
                pushfirst!(current_path, inneighbor)
                delete!(unvisited_unbranching_nodes, inneighbor)
                break
            end
        end
        push!(unbranching_paths, current_path)
    end
    unbranching_paths

    unbranching_path_state_weights = [map(v -> MetaGraphs.get_prop(stranded_kmer_graph, v, :count), path) for path in unbranching_paths]

    solid_unbranching_paths = unbranching_paths[minimum.(unbranching_path_state_weights) .>= solid_threshold]

    # beginnings = filter(path -> Graphs.indegree(stranded_kmer_graph, first(path)) == 0, solid_unbranching_paths)
    # ends = filter(path -> Graphs.outdegree(stranded_kmer_graph, last(path)) == 0, solid_unbranching_paths)
    # mids = setdiff(setdiff(solid_unbranching_paths, beginnings), ends)
    # branch_points = filter(v -> Graphs.indegree(stranded_kmer_graph, v) > 1 || Graphs.outdegree(stranded_kmer_graph, v) > 1, Graphs.vertices(stranded_kmer_graph))
    # @assert all(x -> x in branch_points, first.(mids))
    # @assert all(x -> x in branch_points, last.(mids))
    # @assert all(x -> x in branch_points, first.(ends))
    # @assert all(x -> x in branch_points, last.(beginnings))

    solid_kmers = Set(collect(keys(filter(x -> x[2] >= solid_threshold, stranded_kmer_counts))))

    solid_vertices = filter(v -> MetaGraphs.get_prop(stranded_kmer_graph, v, :count) >= solid_threshold, Graphs.vertices(stranded_kmer_graph))

    solid_subgraph, vertex_map = Graphs.induced_subgraph(stranded_kmer_graph, solid_vertices)
    MetaGraphs.set_indexing_prop!(solid_subgraph, :kmer)
    distance_matrix = SparseArrays.spzeros(Graphs.nv(solid_subgraph), Graphs.nv(solid_subgraph))
    for edge in Graphs.edges(solid_subgraph)
        # @show edge
        observations = MetaGraphs.get_prop(solid_subgraph, edge, :observations)
        # @show length(observations)
        distance_matrix[edge.src, edge.dst] = distance_matrix[edge.dst, edge.src] = 1/length(observations)
    end
    distance_matrix

    updated_records = FASTX.FASTA.Record[]
    for record in records
        record_kmers = last.(collect(Kmers.EveryKmer{kmer_type}(BioSequences.LongDNA{4}(FASTX.sequence(record)))))
        kmer_is_solid = map(kmer -> kmer in solid_kmers, record_kmers)
        if all(kmer_is_solid)
            push!(updated_records, record)
        else
            junctions_to_resample = find_weak_runs(kmer_is_solid)
            while !isempty(junctions_to_resample)
                junction_to_resample = first(junctions_to_resample)
                start_index = first(junction_to_resample)
                if start_index > 1
                    start_index -= 1
                end
                stop_index = last(junction_to_resample)
                if stop_index < length(record_kmers)
                    stop_index += 1
                end
                # @show start_index
                # @show stop_index

                if first(junction_to_resample) == 1
                    record_kmers = [record_kmers[stop_index:end]...]
                elseif last(junction_to_resample) == length(record_kmers)
                    record_kmers = [record_kmers[1:start_index]...]
                else
                    # @assert haskey(solid_kmers, start_index)
                    # @assert haskey(solid_kmers, stop_index)
                    start_vertex = solid_subgraph[record_kmers[start_index], :kmer]
                    stop_vertex = solid_subgraph[record_kmers[stop_index], :kmer]
                    replacement_path = Graphs.a_star(solid_subgraph, start_vertex, stop_vertex, distance_matrix)
                    record_kmers = [record_kmers[1:start_index]..., [solid_subgraph[edge.dst, :kmer] for edge in replacement_path[1:end-1]]..., record_kmers[stop_index:end]...]
                end
                kmer_is_solid = map(kmer -> kmer in solid_kmers, record_kmers)
                junctions_to_resample = find_weak_runs(kmer_is_solid)
            end
            new_sequence = BioSequences.LongDNA{4}(first(record_kmers))
            for kmer in record_kmers[2:end]
                push!(new_sequence, last(kmer))
            end
            push!(updated_records, FASTX.FASTA.Record(FASTX.description(record), new_sequence))
        end
    end
    return (;updated_records, k)

In [None]:
Random.seed!(20240209)
genome = BioSequences.randdnaseq(100)
initial_records = [FASTX.FASTA.Record(Random.randstring(), Mycelia.observe(genome, error_rate=0.01)) for i in 1:100]
prior_records = initial_records
fasta = Mycelia.write_fasta(records = prior_records, outfile = joinpath(working_dir, Random.randstring() * ".fna"))
updated_records, k = polish_fastx(fasta)
prior_records == updated_records

In [None]:
prior_records = updated_records
temp_fasta = Mycelia.write_fasta(records = prior_records)
updated_records, k = polish_fastx(temp_fasta, k = Primes.nextprime(k+1))
prior_records == updated_records

In [None]:
prior_records = updated_records
temp_fasta = Mycelia.write_fasta(records = prior_records)
updated_records, k = polish_fastx(temp_fasta, k = Primes.nextprime(k+1))
prior_records == updated_records

In [None]:
prior_records = updated_records
temp_fasta = Mycelia.write_fasta(records = prior_records)
updated_records, k = polish_fastx(temp_fasta, k = Primes.nextprime(k+1))
prior_records == updated_records

In [None]:
fasta

In [None]:
updated_fasta = Mycelia.write_fasta(records = updated_records, outfile = fasta * ".updated.fna")

In [None]:
#NEED TO FIX CANONICAL KMER GRAPHS AND GFA OUTPUT - THE REST SEEMS CORRECT

In [None]:
k = Primes.nextprime(k+1)
# make my own kmer graph
assembly_graph = Mycelia.fastx_to_kmer_graph(Kmers.DNAKmer{k}, updated_fasta)
gfa_file = updated_fasta * ".gfa"
Mycelia.graph_to_gfa(graph=assembly_graph, outfile=gfa_file)
image = gfa_file * ".mycelia.gfa.jpg"
# run(`chmod +x $(homedir())/software/bin/Bandage`)
run(`$(homedir())/software/bin/Bandage image $(gfa_file) $(image)`)