# Core Proteome

In [1]:
DATE_TASK = "2022-03-06-ecoli-phapecoctavirus-core-genome"
DIR = mkpath("$(homedir())/workspace/$DATE_TASK")
cd(DIR)
DATE, TASK = match(r"^(\d{4}-\d{2}-\d{2})-(.*)$", DATE_TASK).captures

2-element Vector{Union{Nothing, SubString{String}}}:
 "2022-03-06"
 "ecoli-phapecoctavirus-core-genome"

In [7]:
import Pkg
Pkg.update()
pkgs = [
"JSON",
"HTTP",
"Dates",
"uCSV",
"DelimitedFiles",
"DataFrames",
"ProgressMeter",
"BioSequences",
"FASTX",
"Distances",
"StatsPlots",
"StatsBase",
"Statistics",
"MultivariateStats",
"Random",
"Primes",
"SparseArrays",
"SHA",
"GenomicAnnotations",
"Combinatorics",
"OrderedCollections",
"Downloads",
"Clustering",
"Revise",
"Mmap",
"LsqFit",
"BioSymbols"
]

for pkg in pkgs
    try
        eval(Meta.parse("import $pkg"))
    catch
        Pkg.add(pkg)
        eval(Meta.parse("import $pkg"))
    end
end

# works but can't update locally, need to push and restart kernel to activate changes
# "https://github.com/cjprybol/Mycelia.git#master",
# didn't work
# "$(homedir())/git/Mycelia#master",
pkg_path = "$(homedir())/git/Mycelia"
try
    eval(Meta.parse("import $(basename(pkg_path))"))
catch
    # Pkg.add(url=pkg)
    Pkg.develop(path=pkg_path)
    # pkg = replace(basename(pkg), ".git#master" => "")
    # pkg = replace(basename(pkg), "#master" => "")
    eval(Meta.parse("import $(basename(pkg_path))"))
end

[32m[1m    Updating[22m[39m registry at `~/.julia/registries/General`
[32m[1m    Updating[22m[39m git-repo `https://github.com/cjprybol/Mycelia.git#master`
[32m[1m  No Changes[22m[39m to `~/git/Mycelia/docs/Project.toml`
[32m[1m  No Changes[22m[39m to `~/git/Mycelia/docs/Manifest.toml`
[32m[1mPrecompiling[22m[39m project...
[32m  ✓ [39m[90mCodecBzip2[39m
[32m  ✓ [39m[90mStringEncodings[39m
[32m  ✓ [39m[90mMKL_jll[39m
[32m  ✓ [39m[90mGumbo_jll[39m
[32m  ✓ [39m[90mLERC_jll[39m
[32m  ✓ [39m[90mEzXML[39m
[32m  ✓ [39m[90mGumbo[39m
[32m  ✓ [39m[90mXMLDict[39m
[32m  ✓ [39m[90mYAML[39m
[32m  ✓ [39m[90mLibtiff_jll[39m
[32m  ✓ [39m[90mXLSX[39m
[32m  ✓ [39m[90mBioServices[39m
[32m  ✓ [39m[90mGR_jll[39m
[32m  ✓ [39mWeave
[33m  ✓ [39m[90mGR[39m
[33m  ✓ [39m[90mFFTW[39m
[32m  ✓ [39mBioFetch
[32m  ✓ [39m[90mStatsModels[39m
[33m  ✓ [39m[90mKernelDensity[39m
[32m  ✓ [39m[90mGLM[39m
[33m  ✓ [39mPlots
[33m 

In [8]:
function wcss(clustering_result)
    n_clusters = length(clustering_result.counts)
    total_squared_cost = 0.0
    for cluster_id in 1:n_clusters
        cluster_indices = clustering_result.assignments .== cluster_id
        total_squared_cost += sum(clustering_result.costs[cluster_indices] .^ 2)
    end
    return total_squared_cost
end

wcss (generic function with 1 method)

In [9]:
function generate_all_possible_kmers(k, alphabet)
    kmer_iterator = Iterators.product([alphabet for i in 1:k]...)
    kmer_vectors = collect.(vec(collect(kmer_iterator)))
    if eltype(alphabet) == BioSymbols.AminoAcid
        kmers = BioSequences.LongAminoAcidSeq.(kmer_vectors)
        if k > 1
            filter!(kmer -> kmer[1] != BioSequences.AA_Term, kmers)
        end
    elseif eltype(alphabet) == BioSymbols.DNA
        kmers = BioSequences.LongDNASeq.(kmer_vectors)
    else
        error()
    end
    return sort!(kmers)
end

generate_all_possible_kmers (generic function with 1 method)

In [10]:
function fit_optimal_number_of_clusters(distance_matrix)
    ks_to_try = vcat([2^i for i in 0:Int(floor(log2(size(distance_matrix, 1))))], size(distance_matrix, 1))
    @show ks_to_try
    
    # can calculate this for k >= 1
    # within_cluster_sum_of_squares = Union{Float64, Missing}[]
    within_cluster_sum_of_squares = Float64[]
    # these are only valid for k >= 2 so set initial value to missing
    # between_cluster_sum_of_squares = [missing, zeros(length(ks_to_try)-1)...]
    # silhouette_scores = Union{Float64, Missing}[]
    silhouette_scores = Float64[]
        
    current_k_index = 1
    @info "assessing k = $(ks_to_try[current_k_index])"
    this_clustering = Clustering.kmeans(distance_matrix, ks_to_try[current_k_index])
    push!(within_cluster_sum_of_squares, wcss(this_clustering))
    push!(silhouette_scores, 0)

    if length(ks_to_try) == 1
        optimal_number_of_clusters = ks_to_try[current_k_index]
    else
        current_k_index += 1
        @info "assessing k = $(ks_to_try[current_k_index])"
        this_clustering = Clustering.kmeans(distance_matrix, ks_to_try[current_k_index])
        push!(within_cluster_sum_of_squares, wcss(this_clustering))
        push!(silhouette_scores, Statistics.mean(Clustering.silhouettes(this_clustering, distance_matrix)))
        
        if (within_cluster_sum_of_squares[2] >= within_cluster_sum_of_squares[1])
            optimal_number_of_clusters = ks_to_try[1]
        else
            optimal_number_of_clusters = ks_to_try[2]
            if length(ks_to_try) > 2
                current_k_index += 1
                @info "assessing k = $(ks_to_try[current_k_index])"
                this_clustering = Clustering.kmeans(distance_matrix, ks_to_try[current_k_index])
                push!(within_cluster_sum_of_squares, wcss(this_clustering))
                push!(silhouette_scores, Statistics.mean(Clustering.silhouettes(this_clustering, distance_matrix)))
                
                while (silhouette_scores[end] > silhouette_scores[end-1]) &&
                        (within_cluster_sum_of_squares[end] < within_cluster_sum_of_squares[end-1]) &&
                        (current_k_index < length(ks_to_try))
                    current_k_index += 1
                    @info "assessing k = $(ks_to_try[current_k_index])"
                    this_clustering = Clustering.kmeans(distance_matrix, ks_to_try[current_k_index])
                    push!(within_cluster_sum_of_squares, wcss(this_clustering))
                    push!(silhouette_scores, Statistics.mean(Clustering.silhouettes(this_clustering, distance_matrix)))
                end
                # here is where we should start grid searching within the best range
                optimal_silhouette, optimal_index = findmax(silhouette_scores)
                optimal_number_of_clusters = ks_to_try[optimal_index]
                @info "refining..."
                @info "current optimal number of clusters = $(ks_to_try[optimal_index])"
                @info "current best silhouette score = $(optimal_silhouette)"
                                
                if optimal_index != length(ks_to_try)
                    window_of_focus = ks_to_try[optimal_index-1:optimal_index+1]
                    
                    k_to_try = Int(round(Statistics.mean(window_of_focus[1:2])))
                    insertion_index = first(searchsorted(ks_to_try, k_to_try))
                    if ks_to_try[insertion_index] != k_to_try
                        insert!(ks_to_try, insertion_index, k_to_try)
                        @info "assessing k = $(k_to_try)"
                        this_clustering = Clustering.kmeans(distance_matrix, k_to_try)
                        insert!(within_cluster_sum_of_squares, insertion_index, wcss(this_clustering))
                        insert!(silhouette_scores, insertion_index, Statistics.mean(Clustering.silhouettes(this_clustering, distance_matrix)))
                    end

                    k_to_try = Int(round(Statistics.mean(window_of_focus[2:3])))
                    insertion_index = first(searchsorted(ks_to_try, k_to_try))
                    if ks_to_try[insertion_index] != k_to_try
                        @info "assessing k = $(k_to_try)"
                        this_clustering = Clustering.kmeans(distance_matrix, k_to_try)
                        insert!(ks_to_try, insertion_index, k_to_try)
                        insert!(within_cluster_sum_of_squares, insertion_index, wcss(this_clustering))
                        insert!(silhouette_scores, insertion_index, Statistics.mean(Clustering.silhouettes(this_clustering, distance_matrix)))
                    end
                    
                    new_optimal_silhouette, new_optimal_index = findmax(silhouette_scores)
                    new_optimal_number_of_clusters = ks_to_try[new_optimal_index]
                    
                    while (new_optimal_number_of_clusters != optimal_number_of_clusters) && (new_optimal_index != length(ks_to_try))
                        optimal_number_of_clusters = new_optimal_number_of_clusters
                        optimal_index = new_optimal_index
                        optimal_silhouette = new_optimal_silhouette
                        @info "current optimal number of clusters = $(ks_to_try[optimal_index])"
                        @info "current best silhouette score = $(optimal_silhouette)"
                        
                        window_of_focus = ks_to_try[optimal_index-1:optimal_index+1]

                        k_to_try = Int(round(Statistics.mean(window_of_focus[1:2])))
                        insertion_index = first(searchsorted(ks_to_try, k_to_try))
                        if ks_to_try[insertion_index] != k_to_try
                            @info "assessing k = $(k_to_try)"
                            this_clustering = Clustering.kmeans(distance_matrix, k_to_try)
                            insert!(ks_to_try, insertion_index, k_to_try)
                            insert!(within_cluster_sum_of_squares, insertion_index, wcss(this_clustering))
                            insert!(silhouette_scores, insertion_index, Statistics.mean(Clustering.silhouettes(this_clustering, distance_matrix)))
                        end

                        k_to_try = Int(round(Statistics.mean(window_of_focus[2:3])))
                        insertion_index = first(searchsorted(ks_to_try, k_to_try))
                        if ks_to_try[insertion_index] != k_to_try
                            @info "assessing k = $(k_to_try)"
                            this_clustering = Clustering.kmeans(distance_matrix, k_to_try)
                            insert!(ks_to_try, insertion_index, k_to_try)
                            insert!(within_cluster_sum_of_squares, insertion_index, wcss(this_clustering))
                            insert!(silhouette_scores, insertion_index, Statistics.mean(Clustering.silhouettes(this_clustering, distance_matrix)))
                        end

                        new_optimal_silhouette, new_optimal_index = findmax(silhouette_scores)
                        new_optimal_number_of_clusters = ks_to_try[new_optimal_index]
                    end
                end
            end
        end
    end
    return optimal_number_of_clusters, ks_to_try, within_cluster_sum_of_squares, silhouette_scores
end

fit_optimal_number_of_clusters (generic function with 1 method)

In [11]:
function assess_aamer_saturation(fasta_records::AbstractVector{FASTX.FASTA.Record}, k; kmers_to_assess=Inf, power=10)
    kmers = Set{BioSequences.LongAminoAcidSeq}()
    
    max_possible_kmers = length(generate_all_possible_kmers(k, Mycelia.AA_ALPHABET))
    
    if kmers_to_assess == Inf
        kmers_to_assess = max_possible_kmers
    end
    
    sampling_points = Int[0]
    i = 0
    while power^i <= kmers_to_assess
        push!(sampling_points, power^i)
        i += 1
    end
    
    unique_kmer_counts = zeros(Int, length(sampling_points))
    
    if length(sampling_points) < 3
        @info "increase the # of reads analyzed or decrease the power to acquire more data points"
        return (;sampling_points, unique_kmer_counts)
    end
    
    p = ProgressMeter.Progress(kmers_to_assess, 1)
    
    kmers_assessed = 0
    for record in fasta_records
        # for kmer in BioSequences.each(kmer_type, FASTX.sequence(record))
        for i in 1:length(FASTX.sequence(record))-k+1
            kmer = FASTX.sequence(record)[i:i+k-1]
            push!(kmers, kmer)
            kmers_assessed += 1
            if (length(kmers) == max_possible_kmers)                 
                sampling_points = vcat(filter(s -> s < kmers_assessed, sampling_points), [kmers_assessed])
                unique_kmer_counts = vcat(unique_kmer_counts[1:length(sampling_points)-1], length(kmers))
                return (;sampling_points, unique_kmer_counts, eof = false)
            elseif kmers_assessed in sampling_points
                i = findfirst(sampling_points .== kmers_assessed)
                unique_kmer_counts[i] = length(kmers)
                if i == length(sampling_points)
                    return (sampling_points = sampling_points, unique_kmer_counts = unique_kmer_counts, eof = false)
                end
            end
            ProgressMeter.next!(p)
        end
    end
    sampling_points = vcat(filter(s -> s < kmers_assessed, sampling_points), [kmers_assessed])
    unique_kmer_counts = vcat(unique_kmer_counts[1:length(sampling_points)-1], [length(kmers)])    
    return (sampling_points = sampling_points, unique_kmer_counts = unique_kmer_counts, eof = true)
end


function assess_aamer_saturation(fastxs::AbstractVector{String}, k; kmers_to_assess=Inf, power=10)
    kmers = Set{BioSequences.LongAminoAcidSeq}()
    
    max_possible_kmers = length(generate_all_possible_kmers(k, Mycelia.AA_ALPHABET))
    
    if kmers_to_assess == Inf
        kmers_to_assess = max_possible_kmers
    end
    
    sampling_points = Int[0]
    i = 0
    while power^i <= kmers_to_assess
        push!(sampling_points, power^i)
        i += 1
    end
    
    unique_kmer_counts = zeros(Int, length(sampling_points))
    
    if length(sampling_points) < 3
        @info "increase the # of reads analyzed or decrease the power to acquire more data points"
        return (;sampling_points, unique_kmer_counts)
    end
    
    p = ProgressMeter.Progress(kmers_to_assess, 1)
    
    kmers_assessed = 0
    for fastx in fastxs
        for record in Mycelia.open_fastx(fastx)
            # for kmer in BioSequences.each(kmer_type, FASTX.sequence(record))
            for i in 1:length(FASTX.sequence(record))-k+1
                kmer = FASTX.sequence(record)[i:i+k-1]
                push!(kmers, kmer)
                kmers_assessed += 1
                if (length(kmers) == max_possible_kmers)                 
                    sampling_points = vcat(filter(s -> s < kmers_assessed, sampling_points), [kmers_assessed])
                    unique_kmer_counts = vcat(unique_kmer_counts[1:length(sampling_points)-1], length(kmers))
                    return (;sampling_points, unique_kmer_counts, eof = false)
                elseif kmers_assessed in sampling_points
                    i = findfirst(sampling_points .== kmers_assessed)
                    unique_kmer_counts[i] = length(kmers)
                    if i == length(sampling_points)
                        return (sampling_points = sampling_points, unique_kmer_counts = unique_kmer_counts, eof = false)
                    end
                end
                ProgressMeter.next!(p)
            end
        end
    end
    sampling_points = vcat(filter(s -> s < kmers_assessed, sampling_points), [kmers_assessed])
    unique_kmer_counts = vcat(unique_kmer_counts[1:length(sampling_points)-1], [length(kmers)])    
    return (sampling_points = sampling_points, unique_kmer_counts = unique_kmer_counts, eof = true)
end

function assess_aamer_saturation(fastxs; outdir="", min_k=1, max_k=15, threshold=0.1)
    
    if isempty(outdir)
        outdir = joinpath(pwd(), "aamer-saturation")
    end
    mkpath(outdir)
    
    ks = Primes.primes(min_k, max_k)
    ks = min_k:max_k
    minimum_saturation = Inf
    midpoint = Inf
    
    
    
    for k in ks
        kmers_to_assess = 10_000_000
        sampling_points, kmer_counts, hit_eof = assess_aamer_saturation(fastxs, k, kmers_to_assess=kmers_to_assess)
        @show sampling_points, kmer_counts, hit_eof
        observed_midpoint_index = findfirst(i -> kmer_counts[i] > last(kmer_counts)/2, 1:length(sampling_points))
        observed_midpoint = sampling_points[observed_midpoint_index]
        initial_parameters = Float64[maximum(kmer_counts), observed_midpoint]
        @time fit = LsqFit.curve_fit(Mycelia.calculate_v, sampling_points, kmer_counts, initial_parameters)
        if hit_eof
            inferred_maximum = last(kmer_counts)
        else
            inferred_maximum = max(Int(ceil(fit.param[1])), last(kmer_counts))
        end

        max_possible_kmers = length(generate_all_possible_kmers(k, Mycelia.AA_ALPHABET))
        
        inferred_midpoint = Int(ceil(fit.param[2]))
        predicted_saturation = inferred_maximum / max_possible_kmers
        @show k, predicted_saturation

        p = StatsPlots.scatter(
            sampling_points,
            kmer_counts,
            label="observed kmer counts",
            ylabel="# unique kmers",
            xlabel="# kmers assessed",
            title = "sequencing saturation @ k = $k",
            legend=:outertopright,
            size=(800, 400),
            margins=3StatsPlots.PlotMeasures.mm
            )
        StatsPlots.hline!(p, [max_possible_kmers], label="absolute maximum")
        StatsPlots.hline!(p, [inferred_maximum], label="inferred maximum")
        StatsPlots.vline!(p, [inferred_midpoint], label="inferred midpoint")
        # xs = vcat(sampling_points, [last(sampling_points) * 2^i for i in 1:2])
        xs = sort([sampling_points..., inferred_midpoint])
        ys = Mycelia.calculate_v(xs, fit.param)
        StatsPlots.plot!(
            p,
            xs,
            ys,
            label="fit trendline")
        display(p)
        StatsPlots.savefig(p, joinpath(outdir, "$k.png"))
        StatsPlots.savefig(p, joinpath(outdir, "$k.svg"))

        if predicted_saturation < minimum_saturation
            minimum_saturation = predicted_saturation
            min_k = k
            midpoint = inferred_midpoint 
        end
        if predicted_saturation < threshold
            chosen_k_file = joinpath(outdir, "chosen_k.txt")
            println("chosen k = $k")
            open(chosen_k_file, "w") do io
                println(io, k)
            end
            return k
        end
    end
end

assess_aamer_saturation (generic function with 3 methods)

In [12]:
# https://www.ncbi.nlm.nih.gov/Taxonomy/Browser/wwwtax.cgi?&id=$(tax_id)
# https://www.ncbi.nlm.nih.gov/Taxonomy/Browser/wwwtax.cgi?lvl=0&amp;id=2733124
root_tax_id = 2733124

2733124

In [None]:
# child_tax_ids = vcat(Mycelia.taxonomic_id_to_children(root_tax_id), root_tax_id)
# # child_tax_ids = vcat(child_tax_ids, root_tax_id)

In [None]:
# TODO
# here is where we should apply a filter where host == Escherichia
# need to load host information into neo4j taxonomy

In [None]:
# # refseq_metadata = Mycelia.load_refseq_metadata()
# ncbi_metadata = Mycelia.load_genbank_metadata()

In [None]:
# show(ncbi_metadata[1:1, :], allcols=true)

In [None]:
# tax_id_filter = map(taxid -> taxid in child_tax_ids, ncbi_metadata[!, "taxid"])
# is_right_host = map(x -> occursin(r"Escherichia"i, x), ncbi_metadata[!, "organism_name"])
# not_excluded = ncbi_metadata[!, "excluded_from_refseq"] .== ""
# is_full = ncbi_metadata[!, "genome_rep"] .== "Full"
# # assembly_levels = ["Complete Genome"]
# assembly_levels = ["Complete Genome", "Chromosome"]
# # assembly_levels = ["Complete Genome", "Chromosome", "Scaffold"]
# # assembly_levels = ["Complete Genome", "Chromosome", "Scaffold", "Contig"]
# assembly_level_filter = map(x -> x in assembly_levels, ncbi_metadata[!, "assembly_level"])
# full_filter = is_full .& not_excluded .& assembly_level_filter .& tax_id_filter .& is_right_host
# count(full_filter)

In [None]:
# TODO
# here is another place we could enforce host == escherichia
# we'll use a manual filter as a temporary solution

In [None]:
# ncbi_metadata_of_interest = ncbi_metadata[full_filter, :]

In [None]:
# https://www.ncbi.nlm.nih.gov/sviewer/viewer.cgi?db=nuccore&report=genbank&id=GCA_021354775

In [None]:
# for col in names(ncbi_metadata_of_interest)
#     @show col, ncbi_metadata_of_interest[1, col]
# end

In [None]:
# GCA_002956955.1

In [None]:
# # can I also get genbank record?????
# # for extension in ["genomic.fna.gz", "protein.faa.gz"]
# for extension in ["genomic.fna.gz", "protein.faa.gz", "genomic.gbff.gz"]
#     outdir = mkpath(joinpath(DIR, extension))
#     ProgressMeter.@showprogress for row in DataFrames.eachrow(ncbi_metadata_of_interest)
#         url = Mycelia.ncbi_ftp_path_to_url(row["ftp_path"], extension)
#         outfile = joinpath(outdir, basename(url))
#         if !isfile(outfile)
#             try
#                 Downloads.download(url, outfile)
#             catch e
#                 # @show e
#                 showerror(stdout, e)
#                 # @assert extension == "protein.faa.gz"
#                 # here is where we should call prodigal to fill in protein annotations if we don't otherwise see them
#             end
#         end
#     end
# end

In [None]:
extension = "protein.faa.gz"
outdir = mkpath(joinpath(DIR, extension))

In [None]:
fastx_files = filter(x -> !occursin(".ipynb_checkpoints", x), readdir(outdir, join=true))

### This section generates a distance matrix for the fasta files i.e. the protein profile of the entire genome

In [None]:
# # these are too small, all of the within vs between have some disagreement
# # dna_k = 5
# aa_k = 2
# # should use these?
# # dna_k = 7
# # aa_k = 3

In [None]:
# counts_table, outfile = Mycelia.fasta_list_to_counts_table(fasta_list=fastx_files, k=aa_k, alphabet=:AA, outfile="$(outdir).$(aa_k).counts.bin")

In [None]:
# distance_matrix = Mycelia.counts_matrix_to_distance_matrix(counts_table)

### This section generates a distance matrix for the individual proteins, so we can find clusters

In [None]:
# record_table = DataFrames.DataFrame(
#     fastx_file = String[],
#     record_identifier = String[],
#     record_description = String[]
# )
# ProgressMeter.@showprogress for fastx_file in fastx_files
#     for record in Mycelia.open_fastx(fastx_file)
#         row = (
#             fastx_file = fastx_file,
#             record_identifier = FASTX.identifier(record),
#             record_description = FASTX.description(record)
#         )
#         push!(record_table, row)
#     end
# end
# record_table

In [None]:
# alphabet = :AA
# k = aa_k
# fasta_list = fastx_files

In [None]:
# if alphabet == :AA
#     canonical_mers = Mycelia.generate_all_possible_canonical_kmers(k, Mycelia.AA_ALPHABET)
# elseif alphabet == :DNA
#     canonical_mers = Mycelia.generate_all_possible_canonical_kmers(k, Mycelia.DNA_ALPHABET)
# else
#     error("invalid alphabet")
# end

In [None]:
# # if isempty(outfile)
# outfile = joinpath(pwd(), "$(hash(fasta_list)).$(alphabet).k$(k).by-record.bin")
# # end

In [None]:
# #     # if isfile(outfile)
# # load into memory
# # mer_counts_matrix = Mmap.mmap(open(outfile), Array{Int, 2}, (length(canonical_mers), DataFrames.nrow(record_table)))
# # else
# # start from scratch
# mer_counts_matrix = Mmap.mmap(open(outfile, "w+"), Array{Int, 2}, (length(canonical_mers), DataFrames.nrow(record_table)))

In [None]:
# function count_aamers(k, fasta_protein::FASTX.FASTA.Record)
#     s = FASTX.sequence(fasta_protein)
#     these_counts = sort(StatsBase.countmap([s[i:i+k-1] for i in 1:length(s)-k-1]))
#     return these_counts    
# end

In [None]:
# i = 0
# p = ProgressMeter.Progress(DataFrames.nrow(record_table))
# # ProgressMeter.@showprogress for fastx_file in fastx_files
# for fastx_file in fastx_files
#     for record in Mycelia.open_fastx(fastx_file)
#         i += 1
#         @assert fastx_file == record_table[i, "fastx_file"]
#         @assert FASTX.identifier(record) == record_table[i, "record_identifier"]
#         @assert FASTX.description(record) == record_table[i, "record_description"]
#         ProgressMeter.next!(p)
#         # entity_mer_counts = Mycelia.count_aamers(k, record)
#         entity_mer_counts = count_aamers(k, record)
#         Mycelia.update_counts_matrix!(mer_counts_matrix, i, entity_mer_counts, canonical_mers)
#     end
# end
# mer_counts_matrix

In [None]:
# now, distance matrix
# distance_matrix = Mycelia.counts_matrix_to_distance_matrix(mer_counts_matrix)

In [None]:
# conda install -c bioconda diamond

In [None]:
# run(`diamond help`)

In [None]:
joint_fasta_outfile = outdir * ".joint.faa.gz"
if !isfile(joint_fasta_outfile)
    open(joint_fasta_outfile, "w") do io
        for fastx_file in fastx_files
            write(io, read(fastx_file))
        end
    end
end

In [None]:
# @time run(`diamond makedb --in $(joint_fasta_outfile) -d $(joint_fasta_outfile)`)

# N_RECORDS = DataFrames.nrow(record_table)
# # qseqid sseqid pident length mismatch gapopen qlen qstart qend slen sstart send evalue bitscore

# blastp_header = [
#     "qseqid",
#     "sseqid",
#     "pident",
#     "length",
#     "mismatch",
#     "gapopen",
#     "qlen",
#     "qstart",
#     "qend",
#     "slen",
#     "sstart",
#     "send",
#     "evalue",
#     "bitscore"
# ]

# # --fast                   enable fast mode
# # --mid-sensitive          enable mid-sensitive mode
# # --sensitive              enable sensitive mode)
# # --more-sensitive         enable more sensitive mode
# # --very-sensitive         enable very sensitive mode
# # --ultra-sensitive        enable ultra sensitive mode
# # --iterate                iterated search with increasing sensitivity

# # TODO: pairwise output is all of the alignments, super helpful!

# # running a search in blastp mode
# # ./diamond blastp -d reference -q queries.fasta -o matches.tsv
# # @time run(`diamond blastp --outfmt 0 -d $(joint_fasta_outfile).dmnd -q $(joint_fasta_outfile) -o $(joint_fasta_outfile).diamond.tsv`)
# # @time run(`diamond blastp --sensitive -d $(joint_fasta_outfile).dmnd -q $(joint_fasta_outfile) -o $(joint_fasta_outfile).diamond.tsv`)
# # @time run(`diamond blastp --iterate --id 0 --min-score 0 --max-target-seqs $(N_RECORDS) --unal 1 --outfmt 6 qseqid sseqid pident length mismatch gapopen qlen qstart qend slen sstart send evalue bitscore -d $(joint_fasta_outfile).dmnd -q $(joint_fasta_outfile) -o $(joint_fasta_outfile).diamond.tsv`)
# @time run(`diamond blastp --ultra-sensitive --id 0 --min-score 0 --max-target-seqs $(N_RECORDS) --unal 1 --outfmt 6 qseqid sseqid pident length mismatch gapopen qlen qstart qend slen sstart send evalue bitscore -d $(joint_fasta_outfile).dmnd -q $(joint_fasta_outfile) -o $(joint_fasta_outfile).diamond.tsv`)
# @time run(`diamond blastp --ultra-sensitive --id 0 --min-score 0 --max-target-seqs $(N_RECORDS) --unal 1 --outfmt 0  -d $(joint_fasta_outfile).dmnd -q $(joint_fasta_outfile) -o $(joint_fasta_outfile).diamond.pairwise.txt`)

# # iterate
# # Total time = 1.16s
# # Reported 46718 pairwise alignments, 46718 HSPs.
# # sensitive
# # Total time = 5.673s
# # Reported 49976 pairwise alignments, 49976 HSPs.
# # ultra sensitive
# # Total time = 14.939s
# # Reported 52446 pairwise alignments, 52446 HSPs.

# blastp_results = DataFrames.DataFrame(uCSV.read("$(joint_fasta_outfile).diamond.tsv", header=0, delim='\t', typedetectrows=100)[1], blastp_header)

# uCSV.write("$(joint_fasta_outfile).diamond.with_header.tsv", blastp_results, delim='\t')

In [None]:
# id_to_index_map = Dict(identifier => i for (i, identifier) in enumerate(record_table[!, "record_identifier"]))

In [None]:
# show(blastp_results, allcols=true)

In [None]:
# distance_matrix = ones(N_RECORDS, N_RECORDS)

In [None]:
# for row in DataFrames.eachrow(blastp_results)
#     row_idx = id_to_index_map[row["qseqid"]]
#     col_idx = id_to_index_map[row["sseqid"]]
#     # distance = 1 - (row["pident"] / 100)
#     sequence_identity = row["pident"] / 100
#     size_identity = row["length"] / max(row["qlen"], row["slen"])
#     overall_identity = sequence_identity * size_identity
#     distance = 1 - (overall_identity)
#     distance_matrix[row_idx, col_idx] = distance
# end
# distance_matrix

In [None]:
# just percent identity
# Summary Stats:
# Length:         13942756
# Missing Count:  0
# Mean:           0.996645
# Minimum:        0.000000
# 1st Quartile:   1.000000
# Median:         1.000000
# 3rd Quartile:   1.000000
# Maximum:        1.000000
# Type:           Float64

# percent size and percent identity
# Summary Stats:
# Length:         13942756
# Missing Count:  0
# Mean:           0.996742
# Minimum:        0.000000
# 1st Quartile:   1.000000
# Median:         1.000000
# 3rd Quartile:   1.000000
# Maximum:        1.000000
# Type:           Float64

# StatsBase.describe(vec(distance_matrix))

In [None]:
# optimal_number_of_clusters, ks_assessed, within_cluster_sum_of_squares, silhouette_scores = fit_optimal_number_of_clusters(distance_matrix)

In [None]:
# p1 = StatsPlots.plot(
#     ks_assessed[1:length(within_cluster_sum_of_squares)],
#     within_cluster_sum_of_squares,
#     ylabel = "within cluster sum of squares\n(lower is better)",
#     xlabel = "n clusters",
#     title = "Optimal n clusters = $(optimal_number_of_clusters)",
#     legend=false
# )
# StatsPlots.vline!(p1, [optimal_number_of_clusters])
# p2 = StatsPlots.plot(
#     ks_assessed[1:length(silhouette_scores)],
#     silhouette_scores,
#     ylabel = "silhouette scores\n(higher is better)",
#     xlabel = "n clusters",
#     title = "Optimal n clusters = $(optimal_number_of_clusters)",
#     legend=false
# )
# StatsPlots.vline!(p2, [optimal_number_of_clusters])
# display(p1)
# display(p2)

In [None]:
# optimal_clustering_result = Clustering.kmeans(distance_matrix, optimal_number_of_clusters)
# record_table[!, "cluster_assignments"] = optimal_clustering_result.assignments
# show(record_table, allcols=true)

In [None]:
# sorted_clusters = sort(collect(StatsBase.countmap(record_table[!, "cluster_assignments"])), by=x->x[2], rev=true)

# cluster_descriptions = DataFrames.DataFrame(
#     cluster_id = Int[],
#     cluster_count = Int[],
#     cluster_description = String[]
# )
# for cluster in first.(sorted_clusters)
#     word_cloud = Dict{String, Int}()
#     cluster_indices = findall(record_table[!, "cluster_assignments"] .== cluster)
#     for row in DataFrames.eachrow(record_table[cluster_indices, DataFrames.Not("fastx_file")])
#         # @show row["record_identifier"]
#         # @show row["record_description"]
#         filtered_description = replace(row["record_description"], r"\[.*?\]$" => "")
#         # @show filtered_description
#         merge!(+, word_cloud, StatsBase.countmap(split(lowercase(filtered_description))))
#     end
#     word_cloud

#     word_cloud = sort(collect(word_cloud), by=x->x[2], rev=true)

#     if length(word_cloud) > 1
#         word_cloud = filter(x -> x[2] > 1, word_cloud)
#     end
#     uninformative_words = [
#         "hypothetical",
#         "putative",
#         "protein",
#         "of"
#     ]
#     word_cloud = filter(x -> !(x[1] in uninformative_words), word_cloud)
#     # filter out any words that are substrings of other words (e.g. sir2 is a substring of sir2-like)
#     word_cloud = filter(x -> !any(y -> x[1] != y[1] && occursin(x[1], y[1]), word_cloud), word_cloud)

#     joint_descriptor = join(first.(word_cloud), " ")
#     if isempty(joint_descriptor)
#         joint_descriptor = "hypothetical protein of uknown function"
#     end
#     row = (
#         cluster_id = cluster,
#         cluster_count = length(cluster_indices),
#         cluster_description = joint_descriptor
#     )
#     push!(cluster_descriptions, row)
# end

# show(cluster_descriptions[cluster_descriptions[!, "cluster_description"] .!= "hypothetical protein of uknown function", :], allrows=true, allcols=true)

In [None]:
function term_frequency(documents)
    
end

In [None]:
function document_frequency(documents)
    
end

In [None]:
function tf_idf(document_groups)
    
end

In [None]:
# heatmap of clusters against genomes

In [None]:
# record_table[!, ["fastx_file", "cluster_assignments"]]

In [None]:
# n_fastas = length(fastx_files)
# n_clusters = optimal_number_of_clusters
# fasta_cluster_containment_matrix = falses(n_fastas, n_clusters)

# for (i, fastx_file_group) in enumerate(DataFrames.groupby(record_table, "fastx_file"))
#     clusters_contained = unique(fastx_file_group[!, "cluster_assignments"])
#     for cluster in clusters_contained
#         fasta_cluster_containment_matrix[i, cluster] = true
#     end
# end

# clusters_ordered_by_coreness = sortperm(map(col -> sum(col), eachcol(fasta_cluster_containment_matrix)), rev=true)
# StatsPlots.heatmap(
#     fasta_cluster_containment_matrix[:, clusters_ordered_by_coreness],
#     # legend = false,
#     title = "Core and accessory protein clusters",
#     ylabel = "genome index",
#     xlabel = "ordered protein clusters",
#     yticks = false,
#     xticks = false,
#     margins = 1StatsPlots.cm
# )

In [None]:
# cluster_descriptions

In [None]:
# names(record_table)

In [None]:
# record_table[!, "cluster_assignments"]

In [None]:
# joint_table = DataFrames.innerjoin(record_table, cluster_descriptions, on="cluster_assignments" => "cluster_id")

In [None]:
# joint_table[!, "cluster_count"]

In [None]:
# joint_table[!, "cluster_frequency"] = joint_table[!, "cluster_count"] ./ n_fastas

In [None]:
# uCSV.write("$(joint_fasta_outfile).protein_clusters.tsv", joint_table, delim='\t')
joint_table = DataFrames.DataFrame(uCSV.read("$(joint_fasta_outfile).protein_clusters.tsv", delim='\t', header=1)...)

In [None]:
# CREATE CONSUSES PROTEIN FOR EACH PROTEIN CLUSTER

In [None]:
# # for cluster in DataFrames.groupby(join_table, "cluster_assignments")
# cluster = first(DataFrames.groupby(joint_table, "cluster_assignments"))
# show(cluster, allcols=true)

In [None]:
clusters = sort(unique(joint_table[!, "cluster_assignments"]))
cluster_fasta_files = [replace(joint_fasta_outfile, ".faa.gz" => "") .* ".cluster_$(cluster).faa" for cluster in clusters]

In [None]:
# for each cluster, write out cluster to a specific fasta file

In [None]:
cluster_fasta_ios = [FASTX.FASTA.Writer(open(f, "w")) for f in cluster_fasta_files]
for record in Mycelia.open_fastx(joint_fasta_outfile)
    [FASTX.identifier(record)]
    record_index = findfirst(joint_table[!, "record_identifier"] .== FASTX.identifier(record))
    cluster_assignment = joint_table[record_index, "cluster_assignments"]
    cluster_io = cluster_fasta_ios[cluster_assignment]
    write(cluster_io, record)
end
for io_stream in cluster_fasta_ios
    close(io_stream)
end

In [None]:
chosen_k = assess_aamer_saturation([joint_fasta_outfile], threshold=0.1)

In [None]:
# write out clustalw alignment for each fasta

In [None]:
ProgressMeter.@showprogress for cluster_fasta_file in cluster_fasta_files
    # for outfmt in ["fasta", "clustal", "msf", "phylip", "selex", "stockholm", "vienna"]
    for outfmt in ["clustal"]
        outfile = "$(cluster_fasta_file).clustal_omega.$(outfmt)"
        if !isfile(outfile)
            try
                run(`clustalo -i $(cluster_fasta_file) --outfmt $(outfmt) -o $(outfile)`)
            catch e
                # FATAL: File '...' contains 1 sequence, nothing to align
                continue
            end
        end
    end
end

In [None]:
fastx_to_aamer_graph

In [None]:
read in the list of fastas

In [None]:
count aamers

In [None]:
initialize graph with aamer nodes and counts

In [None]:
add edges

In [None]:
edges have weights too?