In [None]:
# if hit plotting library issues, try resetting LD path for julia
# can set in ~/.local/share/jupyter/kernels/
haskey(ENV, "LD_LIBRARY_PATH") && @assert ENV["LD_LIBRARY_PATH"] == ""
import Pkg
pkgs = [
    "Revise",
    "FASTX",
    "BioSequences",
    "Kmers",
    "Graphs",
    "MetaGraphs",
    "SparseArrays",
    "ProgressMeter",
    "Distributions",
    "HiddenMarkovModels",
    "BioAlignments",
    "StatsBase",
    "Random",
    "StatsPlots",
    "Statistics",
    # "GraphMakie",
    "IterTools",
    "Primes",
    "OnlineStats",
    "IteratorSampling",
    "HypothesisTests",
    "Clustering",
    "Distances",
    "BioAlignments",
    "Statistics"
]
# Pkg.add(pkgs)
for pkg in pkgs
    eval(Meta.parse("import $pkg"))
end
# Pkg.develop(path="/global/cfs/projectdirs/m4269/cjprybol/Mycelia")
# Pkg.develop(path="../../..")
import Mycelia

In [None]:
PROJECT_BASEDIR = dirname(pwd())
data_dir = joinpath(PROJECT_BASEDIR, "data")
genome_dir = mkpath(joinpath(data_dir, "genomes"))

In [None]:
working_dir = joinpath(data_dir, "test")
mkpath(working_dir)

In [None]:
# short_read_sets = unique(map(x -> match(r"^(.+\.\d+x)\.", x).captures[1], filter(x -> occursin(r"\.fna\.art", x) && occursin(r"\.fq\.gz", x) && !occursin("trimming_report", x) && !occursin("_val_", x), sort(readdir(genome_dir, join=true), by=x->filesize(x)))))
# # forward = short_read_set * ".1_val_1.fq.gz"
# # reverse = short_read_set * ".2_val_2.fq.gz"

In [None]:
long_read_fastqs = sort(filter(x -> occursin(r"\.filtlong\.fq\.gz$", x), readdir(genome_dir, join=true)), by=x->filesize(x))
fastq = long_read_fastqs[1]

In [None]:
reference_fasta = replace(fastq, r"\.badread.*" => "")

In [None]:
assembly_k = Mycelia.assess_dnamer_saturation([fastq])

In [None]:
kmer_type = Kmers.DNAKmer{assembly_k, 1}

In [None]:
reference_kmer_counts = Mycelia.fasta_to_reference_kmer_counts(kmer_type=kmer_type, fasta=reference_fasta)
records = collect(Mycelia.open_fastx(fastq))

In [None]:
fit_mean = OnlineStats.fit!(OnlineStats.Mean(), IterTools.chain(FASTX.quality_scores(record) for record in records))

In [None]:
fit_extrema = OnlineStats.fit!(OnlineStats.Extrema(), IterTools.chain(FASTX.quality_scores(record) for record in records))

In [None]:
fit_variance = OnlineStats.fit!(OnlineStats.Variance(), IterTools.chain(FASTX.quality_scores(record) for record in records))
standard_deviation = sqrt(OnlineStats.value(fit_variance))

In [None]:
using FASTX
using Base.Threads
using CSV
using DataFrames

# need to run samtools fqidx first!!!!
println("Reading read FASTQ Index for preallocating dictionary size.")
fastq_index = string(fastq, ".fai")
total_records = CSV.read(fastq_index, DataFrame;delim='\t', header=false, types=[String, Int64, Int64, Int64, Int64, Int64])
total_records = length(total_records.Column1)

# read_quality_scores = [collect(FASTX.quality_scores(record)) for record in records]

# make a dictionary associating all kmers with their quality scores
all_kmer_quality_support = Dict{kmer_type, Vector{Float64}}()

reader = open(FASTQ.Reader, fastq)
record = FASTQ.Record()
counter = Atomic{Int}(0)
reader_lock = ReentrantLock()
record_quality_scores_lock = ReentrantLock()
record_quality_scores_slices_lock = ReentrantLock()
sequence_lock = ReentrantLock()
all_kmer_quality_support_lock = ReentrantLock()

println("Reading read FASTQ records for adding quality values and k-mers.")
Threads.@threads for i in 1:total_records
    local_record = FASTQ.Record()
    lock(reader_lock)
    read!(reader, local_record)
    unlock(reader_lock)
    lock(record_quality_scores_lock)
    record_quality_scores = collect(FASTX.quality_scores(local_record))
    unlock(record_quality_scores_lock)
    lock(record_quality_scores_slices_lock)
    record_quality_score_slices = [record_quality_scores[i:i+assembly_k-1] for i in 1:length(record_quality_scores)-assembly_k+1]
    unlock(record_quality_scores_slices_lock)
    lock(sequence_lock)
    sequence = BioSequences.LongDNA{2}(FASTX.sequence(local_record))
    unlock(sequence_lock)
    for ((i, kmer), kmer_base_qualities) in zip(Kmers.EveryKmer{kmer_type}(sequence), record_quality_score_slices)
       if haskey(all_kmer_quality_support, kmer)
           lock(all_kmer_quality_support_lock)
           all_kmer_quality_support[kmer] = all_kmer_quality_support[kmer] .+ kmer_base_qualities
           unlock(all_kmer_quality_support_lock)
           else
           lock(all_kmer_quality_support_lock)
           all_kmer_quality_support[kmer] = kmer_base_qualities
           unlock(all_kmer_quality_support_lock)
       end
    end
    atomic_add!(counter, 1)
    print("\rRead $(counter[]) FASTQ records.")
    flush(stdout)
end
close(reader)


kmer_counts = Mycelia.count_kmers(kmer_type, fastq)
kmer_indices = Dict(kmer => i for (i, kmer) in enumerate(keys(kmer_counts)))
canonical_kmer_counts = Mycelia.count_canonical_kmers(kmer_type, fastq)
canonical_kmer_indices = Dict(kmer => i for (i, kmer) in enumerate(keys(canonical_kmer_counts)))
# reference_kmers = sort(collect(keys(reference_kmer_counts)))

strand_normalized_quality_support = Dict{kmer_type, Vector{Float64}}()
for (kmer, support) in all_kmer_quality_support
    strand_normalized_quality_support[kmer] = support
    if haskey(all_kmer_quality_support, BioSequences.reverse_complement(kmer))
        strand_normalized_quality_support[kmer] .+= all_kmer_quality_support[BioSequences.reverse_complement(kmer)]
    end
end
strand_normalized_quality_support

In [None]:
ordered_kmers = collect(keys(kmer_counts))

In [None]:
kmer_mean_quality = Dict(kmer => strand_normalized_quality_support[kmer] ./ canonical_kmer_counts[BioSequences.canonical(kmer)] for kmer in ordered_kmers)

In [None]:
# strand_normalized_quality_support 

In [None]:
kmer_total_quality = Dict(kmer => sum(quality_values) for (kmer, quality_values) in strand_normalized_quality_support)
# state_likelihoods = Dict(kmer => kmer_count / total_kmers for (kmer, kmer_count) in kmer_counts)


function calculate_state_likelihoods(kmer_total_quality)
    state_likelihoods = Dict{eltype(keys(kmer_total_quality)), Float64}()
    total_sum = sum(values(kmer_total_quality))
    counter = Atomic{Int}(0)
    lock1 = ReentrantLock()
    
    kmer_total_quality_array = collect(kmer_total_quality)
    
    Threads.@threads for (kmer, total_quality) in kmer_total_quality_array
        likelihood = total_quality / total_sum
        lock(lock1)
        state_likelihoods[kmer] = likelihood
        unlock(lock1)
        atomic_add!(counter, 1)
        print("\rRead $(counter[]) records out of $(length(kmer_total_quality)).")
        flush(stdout)
    end
    
    return state_likelihoods
end

state_likelihoods = calculate_state_likelihoods(kmer_total_quality)


# state_likelihoods = Dict(kmer => total_quality / sum(values(kmer_total_quality)) for (kmer, total_quality) in kmer_total_quality)

total_states = length(state_likelihoods)

transition_likelihoods = SparseArrays.spzeros(total_states, total_states)


reader = open(FASTQ.Reader, fastq)
record = FASTQ.Record()
counter = Atomic{Int}(0)
reader_lock = ReentrantLock()
sequence_lock = ReentrantLock()
sources_lock = ReentrantLock()
destinations_lock = ReentrantLock()
source_index_lock = ReentrantLock()
destination_index_lock = ReentrantLock()
transition_likelihoods_lock = ReentrantLock()

Threads.@threads for i in 1:total_records
    local_record = FASTQ.Record()
    lock(reader_lock)
    read!(reader, local_record)
    unlock(reader_lock)
    lock(sequence_lock)
    sequence = BioSequences.LongDNA{4}(FASTX.sequence(local_record))
    unlock(sequence_lock)
    lock(sources_lock)
    sources = Kmers.EveryKmer{kmer_type}(sequence[1:end-1])
    unlock(sources_lock)
    lock(destinations_lock)
    destinations = Kmers.EveryKmer{kmer_type}(sequence[2:end])
    unlock(destinations_lock)
    for ((source_i, source), (destination_i, destination)) in zip(sources, destinations)
        lock(source_index_lock)
        source_index = kmer_indices[source]
        unlock(source_index_lock)
        lock(destination_index_lock)
        destination_index = kmer_indices[destination]
        unlock(destination_index_lock)
        lock(transition_likelihoods_lock)
        transition_likelihoods[source_index, destination_index] += 1
        unlock(transition_likelihoods_lock)
    end
    atomic_add!(counter, 1)
    print("\rRead $(counter[]) FASTQ records.")
    flush(stdout)
end
close(reader)


outgoing_transition_counts_lock = ReentrantLock()
transition_likelihoods_lock2 = ReentrantLock()
counter = Atomic{Int}(0)

Threads.@threads for source in 1:total_states
    lock(outgoing_transition_counts_lock)
    outgoing_transition_counts = transition_likelihoods[source, :]
    unlock(outgoing_transition_counts_lock)
    if sum(outgoing_transition_counts) > 0
        lock(transition_likelihoods_lock2)
        transition_likelihoods[source, :] .= transition_likelihoods[source, :] ./ sum(transition_likelihoods[source, :]) 
        unlock(transition_likelihoods_lock2)
    end
    atomic_add!(counter, 1)
    print("\rRead $(counter[]) states of $total_states.")
    flush(stdout)
end

g = Graphs.SimpleDiGraph(total_states)
row_indices, column_indices, cell_values = SparseArrays.findnz(transition_likelihoods)
for (row, col) in zip(row_indices, column_indices)
    Graphs.add_edge!(g, row, col)
end
g

unbranching_nodes = Set(Int[])
for node in Graphs.vertices(g)
    if (Graphs.indegree(g, node) <= 1) && (Graphs.outdegree(g, node) <= 1)
        push!(unbranching_nodes, node)
    end
end
branching_nodes = setdiff(Graphs.vertices(g), unbranching_nodes)
branching_nodes_set = Set(branching_nodes)

In [None]:
total_strand_normalized_quality_support = sum.(collect(values(strand_normalized_quality_support)))
# minimum_average = min(Statistics.mean(total_strand_normalized_quality_support), Statistics.median(total_strand_normalized_quality_support))
mean_total_support = Statistics.mean(total_strand_normalized_quality_support)
Statistics.std(total_strand_normalized_quality_support)
test_is_single_distribution = HypothesisTests.ExactOneSampleKSTest(total_strand_normalized_quality_support, Distributions.Normal())
if HypothesisTests.pvalue(test_is_single_distribution) < 1e-3
    @show "p = $(HypothesisTests.pvalue(test_is_single_distribution)) rejecting error-free hypothesis & entering error correction"
else
    @show "single distribution detected, this data may be error-free"
end

In [None]:
sorted_kmer_total_quality = sort(kmer_total_quality)
sorted_kmer_total_quality_values = collect(values(sorted_kmer_total_quality))

In [None]:
# BEING FLAGGED DOESN'T AUTOMATICALLY MEAN THAT WE WILL DROP IT, IT JUST MEANS THAT WE WILL ATTEMPT TO RESAMPLE IT
clustering_k = 2
clustering_result = Clustering.kmeans(permutedims(sorted_kmer_total_quality_values), clustering_k)

assignments = Clustering.assignments(clustering_result)
centroids = clustering_result.centers

# println("Cluster assignments: ", assignments)
# println("Cluster centroids: ", centroids)
min_cluster_result = findmin(centroids)
smaller_cluster = last(last(min_cluster_result).I)

ys = [Float64[] for i in 1:clustering_k]
xs = [Int[] for i in 1:clustering_k]
for (i, (assignment, value)) in enumerate(zip(assignments, sorted_kmer_total_quality_values))
    # if assignment == 1
    push!(ys[assignment], value)
    push!(xs[assignment], i)
end
# group_values
label = smaller_cluster == 1 ? ["likely sequencing artifacts" "likely valid kmers"] : ["likely valid kmers" "likely sequencing artifacts"]

StatsPlots.scatter(
    xs,
    ys,
    title = "kmeans error separation",
    ylabel = "canonical kmer cumulative QUAL value",
    label = label,
    legend = :outertopright,
    size = (900, 500),
    margins=10StatsPlots.Plots.PlotMeasures.mm,
    xticks = false
)

In [None]:
likely_sequencing_artifact_indices = xs[smaller_cluster]
likely_sequencing_artifact_kmers = Set(ordered_kmers[likely_sequencing_artifact_indices])
likely_valid_kmer_indices = xs[first(setdiff([1, 2], smaller_cluster))]
likely_valid_kmers = Set(ordered_kmers[likely_valid_kmer_indices])
kmer_to_index_map = Dict(kmer => i for (i, kmer) in enumerate(ordered_kmers))

In [None]:
function find_resampling_stretches(;record_kmer_solidity, solid_branching_kmer_indices)
    indices = findall(.!record_kmer_solidity)  # Find the indices of false values
    if isempty(indices)
        return []
    end
    
    diffs = diff(indices)  # Calculate the differences between consecutive indices
    range_starts = [indices[1]]  # Start with the first false index
    range_ends = Int[]
    
    for (i, d) in enumerate(diffs)
        if d > 1
            push!(range_ends, indices[i])
            push!(range_starts, indices[i+1])
        end
    end
    
    push!(range_ends, indices[end])  # Add the last false index as a range end
    
    low_quality_runs = [(start, stop) for (start, stop) in zip(range_starts, range_ends)]
    
    resampling_stretches = UnitRange{Int64}[]
    
    for low_quality_run in low_quality_runs
        nearest_under = maximum(filter(solid_branching_kmer -> solid_branching_kmer < first(low_quality_run), solid_branching_kmer_indices))
        nearest_over = minimum(filter(solid_branching_kmer -> solid_branching_kmer > last(low_quality_run), solid_branching_kmer_indices))
        push!(resampling_stretches, nearest_under:nearest_over)
    end
    if !allunique(resampling_stretches)
        resampling_stretches = unique!(resampling_stretches)
    end
    return resampling_stretches
end

In [None]:
function process_fastq_record(;record, likely_valid_kmers, kmer_to_index_map, branching_nodes_set, assembly_k, transition_likelihoods, kmer_mean_quality, yen_k_shortest_paths_and_weights, yen_k=7)
    new_record_identifier = FASTX.identifier(record) * ".k$(assembly_k)"
    record_sequence = BioSequences.LongDNA{4}(FASTX.sequence(record))
    record_kmers = last.(collect(Kmers.EveryKmer{kmer_type}(record_sequence)))
    record_quality_scores = collect(FASTX.quality_scores(record))
    record_kmer_quality_scores = [record_quality_scores[i:i+assembly_k-1] for i in 1:length(record_quality_scores)-assembly_k+1]
    
    record_kmer_solidity = map(kmer -> kmer in likely_valid_kmers, record_kmers)
    record_branching_kmers = [kmer_to_index_map[kmer] in branching_nodes_set for kmer in record_kmers]
    record_solid_branching_kmers = record_kmer_solidity .& record_branching_kmers
    
    # trim beginning of fastq
    initial_solid_kmer = findfirst(record_kmer_solidity)
    if initial_solid_kmer > 1
        record_kmers = record_kmers[initial_solid_kmer:end]
        record_kmer_quality_scores = record_kmer_quality_scores[initial_solid_kmer:end]
        record_kmer_solidity = map(kmer -> kmer in likely_valid_kmers, record_kmers)
        record_branching_kmers = [kmer_to_index_map[kmer] in branching_nodes_set for kmer in record_kmers]
        record_solid_branching_kmers = record_kmer_solidity .& record_branching_kmers
    end
    initial_solid_kmer = 1
    
    # trim end of fastq
    last_solid_kmer = findlast(record_kmer_solidity)
    if last_solid_kmer != length(record_kmer_solidity)
        record_kmers = record_kmers[1:last_solid_kmer]
        record_kmer_quality_scores = record_kmer_quality_scores[1:last_solid_kmer]
        record_kmer_solidity = map(kmer -> kmer in likely_valid_kmers, record_kmers)
        record_branching_kmers = [kmer_to_index_map[kmer] in branching_nodes_set for kmer in record_kmers]
        record_solid_branching_kmers = record_kmer_solidity .& record_branching_kmers
    end
    
    # identify low quality runs and the solid branchpoints we will use for resampling
    # low_quality_runs = find_false_ranges(record_kmer_solidity)
    solid_branching_kmer_indices = findall(record_solid_branching_kmers)
    resampling_stretches = find_resampling_stretches(;record_kmer_solidity, solid_branching_kmer_indices)

    trusted_range = 1:max(first(first(resampling_stretches))-1, 1)
    
    new_record_kmers = record_kmers[trusted_range]
    new_record_kmer_qualities = record_kmer_quality_scores[trusted_range]
    
    
    for (i, resampling_stretch) in enumerate(resampling_stretches)
        starting_solid_kmer = record_kmers[first(resampling_stretch)]
        ending_solid_kmer = record_kmers[last(resampling_stretch)]
        
        current_quality_scores = record_quality_scores[resampling_stretch]
        u = kmer_to_index_map[starting_solid_kmer]
        v = kmer_to_index_map[ending_solid_kmer]
        if !haskey(yen_k_shortest_paths_and_weights, u => v)
            yen_k_result = Graphs.yen_k_shortest_paths(g, u, v, Graphs.weights(g), yen_k)
            yen_k_shortest_paths_and_weights[u => v] = Vector{Pair{Vector{Int}, Float64}}()
            for path in yen_k_result.paths
                path_weight = Statistics.mean([kmer_total_quality[ordered_kmers[node]] for node in path])
                path_transition_likelihoods = 1.0
                for (a, b) in zip(path[1:end-1], path[2:end])
                    path_transition_likelihoods *= transition_likelihoods[a, b]
                end
                joint_weight = path_weight * path_transition_likelihoods
                push!(yen_k_shortest_paths_and_weights[u => v], path => joint_weight)
            end
        end
        yen_k_path_weights = yen_k_shortest_paths_and_weights[u => v]      
        if length(yen_k_path_weights) > 1
            current_distance = length(resampling_stretch)
            initial_weights = last.(yen_k_path_weights)
            path_lengths = length.(first.(yen_k_path_weights))
            deltas = map(l -> abs(l-current_distance), path_lengths)
            adjusted_weights = initial_weights .* map(d -> exp(-d * log(2)), deltas)
            
            # and a bonus for usually being correct
            
            selected_path_index = StatsBase.sample(StatsBase.weights(adjusted_weights))
            selected_path, selected_path_weights = yen_k_path_weights[selected_path_index]
            selected_path_kmers = [ordered_kmers[kmer_index] for kmer_index in selected_path]
            
            if last(new_record_kmers) == first(selected_path_kmers)
                selected_path_kmers = selected_path_kmers[2:end]
            end
            append!(new_record_kmers, selected_path_kmers)
            selected_kmer_qualities = [Int8.(floor.(kmer_mean_quality[kmer])) for kmer in selected_path_kmers]
            append!(new_record_kmer_qualities, selected_kmer_qualities)
        else
            selected_path_kmers = record_kmers[resampling_stretch]
            if last(new_record_kmers) == first(selected_path_kmers)
                selected_path_kmers = selected_path_kmers[2:end]
            end
            append!(new_record_kmers, selected_path_kmers)
            selected_kmer_qualities = [Int8.(floor.(kmer_mean_quality[kmer])) for kmer in selected_path_kmers]
            append!(new_record_kmer_qualities, selected_kmer_qualities)
        end
        if i < length(resampling_stretches) # append high quality gap
            next_solid_start = last(resampling_stretch)+1
            next_resampling_stretch = resampling_stretches[i+1]
            next_solid_stop = first(next_resampling_stretch)-1
            if !isempty(next_solid_start:next_solid_stop)
                selected_path_kmers = record_kmers[next_solid_start:next_solid_stop]
                append!(new_record_kmers, selected_path_kmers)
                selected_kmer_qualities = record_kmer_quality_scores[next_solid_start:next_solid_stop]
                append!(new_record_kmer_qualities, selected_kmer_qualities)
            end
        else # append remainder of sequence
            @assert i == length(resampling_stretches)
            next_solid_start = last(resampling_stretch)+1
            if next_solid_start < length(record_kmers)
                selected_path_kmers = record_kmers[next_solid_start:end]
                append!(new_record_kmers, selected_path_kmers)
                selected_kmer_qualities = record_kmer_quality_scores[next_solid_start:end]
                append!(new_record_kmer_qualities, selected_kmer_qualities)
            end
        end
    end
    
    for (a, b) in zip(new_record_kmers[1:end-1], new_record_kmers[2:end])
        @assert a != b
    end
    new_record_sequence = Mycelia.kmer_path_to_sequence(new_record_kmers)
    new_record_quality_scores = new_record_kmer_qualities[1]
    for new_record_kmer_quality in new_record_kmer_qualities[2:end]
        push!(new_record_quality_scores, last(new_record_kmer_quality))
    end
    # Fastx wont parse anything higher than 93
    new_record_quality_scores = min.(new_record_quality_scores, 93)
    # @show length(new_record_sequence) length(new_record_quality_scores)
    new_record_string = join(["@" * new_record_identifier, new_record_sequence, "+", join([Char(x+33) for x in new_record_quality_scores])], "\n")
    # @show new_record_string
    new_record = FASTX.parse(FASTX.FASTQRecord, new_record_string)
    @assert FASTX.sequence(new_record) == string(new_record_sequence)
    @assert collect(FASTX.quality_scores(new_record)) == new_record_quality_scores
    # BioAlignments.pairalign(BioAlignments.LevenshteinDistance(), FASTX.sequence(record), FASTX.sequence(new_record))
    return new_record
end

# 2:24 renewing yenk weights each time
# 2:22 sharing yenk weights
ProgressMeter.@showprogress for record in records
    yen_k_shortest_paths_and_weights = Dict{Pair{Int, Int}, Vector{Pair{Vector{Int}, Float64}}}()
    revised_record = process_fastq_record(;record, likely_valid_kmers, kmer_to_index_map, branching_nodes_set, assembly_k, kmer_mean_quality, transition_likelihoods, yen_k_shortest_paths_and_weights)
end

In [None]:
# finish updating all reads for all k rounds

In [None]:
# write out final assembly

In [None]:
# call variants

In [None]:
# assess accuracy