In [None]:
# if hit plotting library issues, try resetting LD path for julia
# can set in ~/.local/share/jupyter/kernels/
haskey(ENV, "LD_LIBRARY_PATH") && @assert ENV["LD_LIBRARY_PATH"] == ""
import Pkg
pkgs = [
    "Revise",
    "FASTX",
    "BioSequences",
    "Kmers",
    "Graphs",
    "MetaGraphs",
    "SparseArrays",
    "ProgressMeter",
    "Distributions",
    "HiddenMarkovModels",
    "BioAlignments",
    "StatsBase",
    "Random",
    "StatsPlots",
    "Statistics",
    # "GraphMakie",
    "IterTools",
    "Primes",
    "OnlineStats",
    "IteratorSampling",
    "HypothesisTests",
    "Clustering",
    "Distances"
]
# Pkg.add(pkgs)
for pkg in pkgs
    eval(Meta.parse("import $pkg"))
end
# Pkg.develop(path="/global/cfs/projectdirs/m4269/cjprybol/Mycelia")
# Pkg.develop(path="../../..")
import Mycelia

In [None]:
PROJECT_BASEDIR = dirname(pwd())
data_dir = joinpath(PROJECT_BASEDIR, "data")
genome_dir = mkpath(joinpath(data_dir, "genomes"))

In [None]:
working_dir = joinpath(data_dir, "test")
mkpath(working_dir)

In [None]:
# short_read_sets = unique(map(x -> match(r"^(.+\.\d+x)\.", x).captures[1], filter(x -> occursin(r"\.fna\.art", x) && occursin(r"\.fq\.gz", x) && !occursin("trimming_report", x) && !occursin("_val_", x), sort(readdir(genome_dir, join=true), by=x->filesize(x)))))
# # forward = short_read_set * ".1_val_1.fq.gz"
# # reverse = short_read_set * ".2_val_2.fq.gz"

In [None]:
long_read_fastqs = sort(filter(x -> occursin(r"\.filtlong\.fq\.gz$", x), readdir(genome_dir, join=true)), by=x->filesize(x))
fastq = long_read_fastqs[1]

In [None]:
reference_fasta = replace(fastq, r"\.badread.*" => "")

In [None]:
k = Mycelia.assess_dnamer_saturation([fastq])

In [None]:
kmer_type = Kmers.DNAKmer{k, 1}

In [None]:
reference_kmer_counts = Mycelia.fasta_to_reference_kmer_counts(kmer_type=kmer_type, fasta=reference_fasta)
records = collect(Mycelia.open_fastx(fastq))

In [None]:
fit_mean = OnlineStats.fit!(OnlineStats.Mean(), IterTools.chain(FASTX.quality_scores(record) for record in records))

In [None]:
fit_extrema = OnlineStats.fit!(OnlineStats.Extrema(), IterTools.chain(FASTX.quality_scores(record) for record in records))

In [None]:
fit_variance = OnlineStats.fit!(OnlineStats.Variance(), IterTools.chain(FASTX.quality_scores(record) for record in records))
standard_deviation = sqrt(OnlineStats.value(fit_variance))

In [None]:
read_quality_scores = [collect(FASTX.quality_scores(record)) for record in records]

# make a dictionary associating all kmers with their quality scores
all_kmer_quality_support = Dict{kmer_type, Vector{Float64}}()
for record in records
    record_quality_scores = collect(FASTX.quality_scores(record))
    record_quality_score_slices = [record_quality_scores[i:i+k-1] for i in 1:length(record_quality_scores)-k+1]
    sequence = BioSequences.LongDNA{2}(FASTX.sequence(record))
    for ((i, kmer), kmer_base_qualities) in zip(Kmers.EveryKmer{kmer_type}(sequence), record_quality_score_slices)
        if haskey(all_kmer_quality_support, kmer)
            all_kmer_quality_support[kmer] = all_kmer_quality_support[kmer] .+ kmer_base_qualities
        else
            all_kmer_quality_support[kmer] = kmer_base_qualities
        end
    end
end

kmer_counts = Mycelia.count_kmers(kmer_type, fastq)
kmer_indices = Dict(kmer => i for (i, kmer) in enumerate(keys(kmer_counts)))
canonical_kmer_counts = Mycelia.count_canonical_kmers(kmer_type, fastq)
canonical_kmer_indices = Dict(kmer => i for (i, kmer) in enumerate(keys(canonical_kmer_counts)))
reference_kmers = sort(collect(keys(reference_kmer_counts)))

strand_normalized_quality_support = Dict{kmer_type, Vector{Float64}}()
for (kmer, support) in all_kmer_quality_support
    strand_normalized_quality_support[kmer] = support
    if haskey(all_kmer_quality_support, BioSequences.reverse_complement(kmer))
        strand_normalized_quality_support[kmer] .+= all_kmer_quality_support[BioSequences.reverse_complement(kmer)]
    end
end

kmer_total_quality = Dict(kmer => sum(quality_values) for (kmer, quality_values) in strand_normalized_quality_support)
# state_likelihoods = Dict(kmer => kmer_count / total_kmers for (kmer, kmer_count) in kmer_counts)
state_likelihoods = Dict(kmer => total_quality / sum(values(kmer_total_quality)) for (kmer, total_quality) in kmer_total_quality)

total_states = length(state_likelihoods)

transition_likelihoods = SparseArrays.spzeros(total_states, total_states)
for record in records
    sequence = BioSequences.LongDNA{4}(FASTX.sequence(record))
    sources = Kmers.EveryKmer{kmer_type}(sequence[1:end-1])
    destinations = Kmers.EveryKmer{kmer_type}(sequence[2:end])
    for ((source_i, source), (destination_i, destination)) in zip(sources, destinations)
        source_index = kmer_indices[source]
        destination_index = kmer_indices[destination]
        transition_likelihoods[source_index, destination_index] += 1
    end
end
for source in 1:total_states
    # @show source
    outgoing_transition_counts = transition_likelihoods[source, :]
    if sum(outgoing_transition_counts) > 0
        transition_likelihoods[source, :] .= transition_likelihoods[source, :] ./ sum(transition_likelihoods[source, :]) 
    end
end
transition_likelihoods

g = Graphs.SimpleDiGraph(total_states)
row_indices, column_indices, cell_values = SparseArrays.findnz(transition_likelihoods)
for (row, col) in zip(row_indices, column_indices)
    Graphs.add_edge!(g, row, col)
end
g

unbranching_nodes = Set(Int[])
for node in Graphs.vertices(g)
    if (Graphs.indegree(g, node) <= 1) && (Graphs.outdegree(g, node) <= 1)
        push!(unbranching_nodes, node)
    end
end
branching_nodes = setdiff(Graphs.vertices(g), unbranching_nodes)
branching_nodes_set = Set(branching_nodes)

In [None]:
ordered_kmers = collect(keys(kmer_counts))

In [None]:
total_strand_normalized_quality_support = sum.(collect(values(strand_normalized_quality_support)))
# minimum_average = min(Statistics.mean(total_strand_normalized_quality_support), Statistics.median(total_strand_normalized_quality_support))
mean_total_support = Statistics.mean(total_strand_normalized_quality_support)
Statistics.std(total_strand_normalized_quality_support)
test_is_single_distribution = HypothesisTests.ExactOneSampleKSTest(total_strand_normalized_quality_support, Distributions.Normal())
if HypothesisTests.pvalue(test_is_single_distribution) < 1e-3
    @show "p = $(HypothesisTests.pvalue(test_is_single_distribution)) rejecting error-free hypothesis & entering error correction"
else
    @show "single distribution detected, this data may be error-free"
end

In [None]:
sorted_kmer_total_quality = sort(kmer_total_quality)

sorted_kmer_total_quality_values = collect(values(sorted_kmer_total_quality))

# StatsPlots.scatter(collect(values(kmer_total_quality)))

# BEING FLAGGED DOESN'T AUTOMATICALLY MEAN THAT WE WILL DROP IT, IT JUST MEANS THAT WE WILL ATTEMPT TO RESAMPLE IT

k = 2
results = Clustering.kmeans(permutedims(sorted_kmer_total_quality_values), k)

assignments = Clustering.assignments(results)
centroids = results.centers

# println("Cluster assignments: ", assignments)
println("Cluster centroids: ", centroids)
min_cluster_result = findmin(centroids)
smaller_cluster = last(last(min_cluster_result).I)

ys = [Float64[] for i in 1:k]
xs = [Int[] for i in 1:k]
for (i, (assignment, value)) in enumerate(zip(assignments, sorted_kmer_total_quality_values))
    # if assignment == 1
    push!(ys[assignment], value)
    push!(xs[assignment], i)
end
# group_values
label = smaller_cluster == 1 ? ["likely sequencing artifacts" "likely valid kmers"] : ["likely valid kmers" "likely sequencing artifacts"]

StatsPlots.scatter(
    xs,
    ys,
    title = "kmeans error separation",
    ylabel = "canonical kmer cumulative QUAL value",
    label = label,
    legend = :outertopright,
    size = (900, 500),
    margins=10StatsPlots.Plots.PlotMeasures.mm,
    xticks = false
)

In [None]:
likely_sequencing_artifact_indices = xs[smaller_cluster]
likely_sequencing_artifact_kmers = Set(ordered_kmers[likely_sequencing_artifact_indices])
likely_valid_kmer_indices = xs[first(setdiff([1, 2], smaller_cluster))]
likely_valid_kmers = Set(ordered_kmers[likely_valid_kmer_indices])
kmer_to_index_map = Dict(kmer => i for (i, kmer) in enumerate(ordered_kmers))

In [None]:
record = first(records)
record_sequence = BioSequences.LongDNA{4}(FASTX.sequence(record))
record_kmers = last.(collect(Kmers.EveryKmer{kmer_type}(record_sequence)))
record_kmer_solidity = map(kmer -> kmer in likely_valid_kmers, record_kmers)
record_branching_kmers = [kmer_to_index_map[kmer] in branching_nodes_set for kmer in record_kmers]
record_solid_branching_kmers = record_kmer_solidity .& record_branching_kmers

# initial_solid_kmer = findfirst(record_solid_branching_kmers)
initial_solid_kmer = findfirst(record_kmer_solidity)
# trim beginning
if initial_solid_kmer > 1
    record_kmers = record_kmers[initial_solid_kmer:end]
    record_kmer_solidity = map(kmer -> kmer in likely_valid_kmers, record_kmers)
    record_branching_kmers = [kmer_to_index_map[kmer] in branching_nodes_set for kmer in record_kmers]
    record_solid_branching_kmers = record_kmer_solidity .& record_branching_kmers
end
initial_solid_kmer = 1
# record_kmer_solidity

# last_solid_kmer = findlast(record_solid_branching_kmers)
last_solid_kmer = findlast(record_kmer_solidity)
# trim end
if last_solid_kmer != length(record_kmer_solidity)
    record_kmers = record_kmers[1:last_solid_kmer]
    record_kmer_solidity = map(kmer -> kmer in likely_valid_kmers, record_kmers)
    record_branching_kmers = [kmer_to_index_map[kmer] in branching_nodes_set for kmer in record_kmers]
    record_solid_branching_kmers = record_kmer_solidity .& record_branching_kmers
end
record_kmer_solidity

In [None]:
function find_false_ranges(vec::Vector{Bool})
    indices = findall(.!vec)  # Find the indices of false values
    if isempty(indices)
        return []
    end
    
    diffs = diff(indices)  # Calculate the differences between consecutive indices
    range_starts = [indices[1]]  # Start with the first false index
    range_ends = Int[]
    
    for (i, d) in enumerate(diffs)
        if d > 1
            push!(range_ends, indices[i])
            push!(range_starts, indices[i+1])
        end
    end
    
    push!(range_ends, indices[end])  # Add the last false index as a range end
    
    return [(start, stop) for (start, stop) in zip(range_starts, range_ends)]
end

low_quality_runs = find_false_ranges(record_kmer_solidity)
solid_branching_kmer_indices = findall(record_solid_branching_kmers)
YEN_K = 3
yen_k_shortest_paths_and_weights = Dict{Pair{Int, Int}, Vector{Pair{Vector{Int}, Vector{Float64}}}}()
for low_quality_run in low_quality_runs
    unders = filter(solid_branching_kmer -> solid_branching_kmer < first(low_quality_run), solid_branching_kmer_indices)
    overs = filter(solid_branching_kmer -> solid_branching_kmer > last(low_quality_run), solid_branching_kmer_indices)
    nearest_under = maximum(unders)
    nearest_over = minimum(overs)
    
    nearest_under_kmer = record_kmers[nearest_under]
    nearest_over_kmer = record_kmers[nearest_over]
    current_path = record_kmers[nearest_under:nearest_over]
    u = kmer_to_index_map[nearest_under_kmer]
    v = kmer_to_index_map[nearest_over_kmer]
    if !haskey(yen_k_shortest_paths_and_weights, u => v)
        yen_k_result = Graphs.yen_k_shortest_paths(g, u, v, Graphs.weights(g), YEN_K)
        yen_k_shortest_paths_and_weights[u => v] = Vector{Pair{Vector{Int}, Float64}}()
        for path in yen_k_result.paths
            push!(yen_k_shortest_paths_and_weights[u => v], path => [kmer_total_quality[ordered_kmers[node]] for node in path])
        end
    end
    yen_k_path_weights = yen_k_shortest_paths_and_weights[u => v]      
    if length(yen_k_path_weights) > 1
        current_distance = nearest_over - nearest_under
        initial_weights = Statistics.mean.(last.(yen_k_path_weights))
        path_lengths = length.(first.(yen_k_path_weights))
        deltas = map(l -> abs(l-current_distance), path_lengths)
        adjusted_weights = initial_weights .* map(d -> exp(-d * log(2)), deltas)
        selected_path_index = StatsBase.sample(StatsBase.weights(adjusted_weights))
        selected_path, selected_path_weights = yen_k_path_weights[selected_path_index]
        selected_path_kmers = [ordered_kmers[kmer_index] for kmer_index in selected_path]
        selected_path_kmer_counts = [canonical_kmer_counts[BioSequences.canonical(kmer)] for kmer in selected_path_kmers]
        @show selected_path_average_quality_scores = min.([weight./count for (weight, count) in zip(selected_path_weights, selected_path_kmer_counts)] ./ k, 60.0)
    end
end

In [None]:
# rather than pre-computing all shortest paths between all pairs of nodes, which may require us to solve paths unecessarily
# cache the results as we compute them so that we can avoid recomputing if we hit the same query multiple times
# adding this as a hunch, may be worth bencharking to see if it's unecessary



        # else
        #     @info "no alternate paths found"
        end
        # 
        # if we only have our one path, keep it
        # else, compare the relative weights of the paths and choose at random
        # slice out the kmers we want to get rid of and put these in
    end
end

In [None]:
# yen_state.dists

In [None]:
# yen_state.paths