In [None]:
TODAY="2021-11-13"
TASK = "phylogenetic-determination"
DIR = "$(homedir())/$(TODAY)-$(TASK)"
if !isdir(DIR)
    mkdir(DIR)
end
cd(DIR)

In [None]:
# alternative workflow would be to utilize ezAAI http://leb.snu.ac.kr/ezaai
# or recreate ezAAI with diamond

# use the genus assignment of the phage, and then assume the family level assignment from the genus classification
# require that scale dependent, correlation, and probability based genus is consisent, otherwise flag

In [None]:
import Pkg

pkgs = [
"JSON",
"HTTP",
"Dates",
"uCSV",
"DelimitedFiles",
"DataFrames",
"ProgressMeter",
"BioSequences",
"FASTX",
"Distances",
"Plots",
"StatsPlots",
"StatsBase",
"Statistics",
"Mmap",
"MultivariateStats",
"PyCall",
"Random",
"Primes",
"Revise",
"SparseArrays",
"SHA",
"Mycelia",
"GenomicAnnotations",
"BioFetch",
"Combinatorics",
"StaticArrays",
"BioSymbols",
"RollingFunctions",
"OrderedCollections"
]

for pkg in pkgs
    try
        eval(Meta.parse("import $pkg"))
    catch
        Pkg.add(pkg)
        eval(Meta.parse("import $pkg"))
    end
end

In [None]:
function generate_all_possible_kmers(k, alphabet)
    kmer_iterator = Iterators.product([alphabet for i in 1:k]...)
    kmer_vectors = collect.(vec(collect(kmer_iterator)))
    if eltype(alphabet) == BioSymbols.AminoAcid
        kmers = BioSequences.LongAminoAcidSeq.(kmer_vectors)
    elseif eltype(alphabet) == BioSymbols.DNA
        kmers = BioSequences.LongDNASeq.(kmer_vectors)
    else
        error()
    end
    return sort!(kmers)
end

function generate_all_possible_canonical_kmers(k, alphabet)
    kmers = generate_all_possible_kmers(k, alphabet)
    if eltype(alphabet) == BioSymbols.AminoAcid
        return kmers
    elseif eltype(alphabet) == BioSymbols.DNA
        return BioSequences.DNAMer.(unique!(BioSequences.canonical.(kmers)))
    else
        error()
    end
end

In [None]:
function count_canonical_aamers(k, fasta_proteins)
    aamer_counts = OrderedCollections.OrderedDict{BioSequences.LongAminoAcidSeq, Int64}()
    for protein in fasta_proteins
        s = FASTX.sequence(protein)
        these_counts = sort(StatsBase.countmap([s[i:i+k-1] for i in 1:length(s)-k-1]))
        merge!(+, aamer_counts, these_counts)
    end
    return sort(aamer_counts)
end

In [None]:
function update_counts_matrix!(matrix, sample_index, countmap, sorted_kmers)
    for (i, kmer) in enumerate(sorted_kmers)
        matrix[i, sample_index] = get(countmap, kmer, 0)
    end
    return matrix
end

In [None]:
function accession_list_to_aamer_counts_table(accession_list, k, AA_ALPHABET; outfile="")
    if isempty(outfile)
        outfile = joinpath(pwd(), "$(hash(accession_list)).AA.k$(k).bin")
    end
    @show outfile
    canonical_aamers = generate_all_possible_canonical_kmers(k, AA_ALPHABET)
    if !isfile(outfile)
        aamer_counts_matrix = Mmap.mmap(outfile, Array{Int, 2}, (length(canonical_aamers), length(accession_list)))
        aamer_counts_matrix .= 0
        ProgressMeter.@showprogress for (entity_index, accession) in enumerate(accession_list)
            fna_file = "$(accession).fna"
            if !isfile(fna_file)
                open(FASTX.FASTA.Writer, fna_file) do fastx_io
                    for record in Mycelia.get_sequence(db="nuccore", accession = accession)
                        write(fastx_io, record)
                    end
                end
            end
            faa_file = "$(accession).fna.faa"
            if !isfile(faa_file)
                run(pipeline(`prodigal -i $(fna_file) -o $(fna_file).genes -a $(faa_file) -p meta`, stderr="$(fna_file).prodigal.stderr"))
            end
            entity_aamer_counts = count_canonical_aamers(aa_k, collect(FASTX.FASTA.Reader(open(faa_file))))
            update_counts_matrix!(aamer_counts_matrix, entity_index, entity_aamer_counts, canonical_aamers)
        end
    else
        aamer_counts_matrix = Mmap.mmap(outfile, Array{Int, 2}, (length(canonical_aamers), length(accession_list)))
    end
    return aamer_counts_matrix
end

In [None]:
function accession_list_to_dnamer_counts_table(accession_list, k)
    canonical_dnamers = generate_all_possible_canonical_kmers(k, DNA_ALPHABET)
    dnamer_counts_matrix = zeros(length(canonical_dnamers), length(accession_list))

    ProgressMeter.@showprogress for (entity_index, accession) in enumerate(accession_list)
        fasta_dna_sequences = collect(Mycelia.get_sequence(db="nuccore", accession = accession))
        entity_dnamer_counts = Mycelia.count_canonical_kmers(BioSequences.DNAMer{dna_k}, fasta_dna_sequences)
        update_counts_matrix!(dnamer_counts_matrix, entity_index, entity_dnamer_counts, canonical_dnamers)
    end
    return dnamer_counts_matrix    
end

In [None]:
function normalize_distance_matrix(distance_matrix)
    max_non_nan_value = maximum(filter(x -> !isnan(x) && !isnothing(x) && !ismissing(x), vec(distance_matrix)))
    return distance_matrix ./ max_non_nan_value
end

In [None]:
function count_matrix_to_probability_matrix(counts_matrix)
    probability_matrix = copy(counts_matrix)
    for (i, col) in enumerate(eachcol(probability_matrix))
        probability_matrix[:, i] .= col ./ sum(col)
    end
    return probability_matrix
end

In [None]:
# MYCELIA_METADATA = joinpath(Pkg.dir("Mycelia"), "metadata")
MYCELIA_METADATA = joinpath(dirname(dirname(pathof(Mycelia))),  "metadata")

In [None]:
# AA_ALPHABET = collect(filter(x -> x != BioSequences.AA_Term, Mycelia.AA_ALPHABET))
AA_ALPHABET = Mycelia.AA_ALPHABET
DNA_ALPHABET = Mycelia.DNA_ALPHABET

In [None]:
# https://www.ncbi.nlm.nih.gov/labs/virus/vssi/#/virus?SeqType_s=Nucleotide&VirusLineage_ss=Bacteriophage,%20all%20taxids&Completeness_s=complete
entity_metadata = DataFrames.DataFrame(uCSV.read("$(MYCELIA_METADATA)/2021-11-13-ncbi-complete-bacteriophage.csv", header=1, quotes='"')...)
sort!(entity_metadata, "Accession")
entity_metadata = entity_metadata[entity_metadata[!, "Nuc_Completeness"] .== "complete", :]
# filter down to only include labelled genera
entity_metadata = entity_metadata[.!isempty.(entity_metadata[!, "Genus"]), :]
# filter down to only include genera that are present at least once
genera_counts = sort(collect(StatsBase.countmap(entity_metadata[!, "Genus"])), by=x->x[2], rev=true)
min_threshold = 2
repeat_genera = Set(first.(filter(x -> x[2] >= min_threshold, genera_counts)))
entity_metadata = entity_metadata[map(genus -> genus in repeat_genera, entity_metadata[!, "Genus"]), :]

In [None]:
accession_list = entity_metadata[!, "Accession"]

In [None]:
# these are too small, all of the within vs between have some disagreement
# dna_k = 5
# aa_k = 2
dna_k = 7
aa_k = 3

In [None]:
# run(`sudo conda install -c bioconda prodigal`)

In [None]:
aamer_counts_matrix = accession_list_to_aamer_counts_table(accession_list, aa_k, AA_ALPHABET)

In [None]:
dnamer_counts_matrix = accession_list_to_dnamer_counts_table(accession_list, dna_k)

In [None]:
unique_genera = filter(!isempty, sort(unique(entity_metadata[!, "Genus"])))
unique_families = filter(!isempty, sort(unique(entity_metadata[!, "Family"])))

In [None]:
aamer_probability_matrix = count_matrix_to_probability_matrix(aamer_counts_matrix)
dnamer_probility_matrix = count_matrix_to_probability_matrix(dnamer_counts_matrix)

In [None]:
# Euclidean
# Correlation
# Total Variation
# assign at the genus level according to best average match

In [None]:
matrix_metric_grammar_groups = [
        (normalize_distance_matrix(Distances.pairwise(Distances.euclidean, aamer_counts_matrix, dims=2)), "euclidean", "AA"), # good
        (normalize_distance_matrix(Distances.pairwise(Distances.euclidean, dnamer_counts_matrix, dims=2)), "euclidean", "DNA"), # good
        (normalize_distance_matrix(Distances.pairwise(Distances.cityblock, aamer_counts_matrix, dims=2)), "cityblock", "AA"), # redundant with above
        (normalize_distance_matrix(Distances.pairwise(Distances.cityblock, dnamer_counts_matrix, dims=2)), "cityblock", "DNA"), # redundant with above
        
        (normalize_distance_matrix(Distances.pairwise(Distances.corr_dist, aamer_counts_matrix, dims=2)), "corr_dist", "AA"), # meh
        (normalize_distance_matrix(Distances.pairwise(Distances.corr_dist, dnamer_counts_matrix, dims=2)), "corr_dist", "DNA"), # meh
        (normalize_distance_matrix(Distances.pairwise(Distances.cosine_dist, aamer_counts_matrix, dims=2)), "cosine_dist", "AA"), # meh
        (normalize_distance_matrix(Distances.pairwise(Distances.cosine_dist, dnamer_counts_matrix, dims=2)), "cosine_dist", "DNA"), # very bad
        
        (normalize_distance_matrix(Distances.pairwise(Distances.totalvariation, aamer_probability_matrix, dims=2)), "totalvariation", "AA"), # good
        (normalize_distance_matrix(Distances.pairwise(Distances.totalvariation, dnamer_probility_matrix, dims=2)), "totalvariation", "DNA"), # bad
        (normalize_distance_matrix(Distances.pairwise(Distances.js_divergence, aamer_probability_matrix, dims=2)), "js_divergence", "AA"), # good
        (normalize_distance_matrix(Distances.pairwise(Distances.js_divergence, dnamer_probility_matrix, dims=2)), "js_divergence", "DNA"), # bad
        (normalize_distance_matrix(Distances.pairwise(Distances.bhattacharyya, aamer_probability_matrix, dims=2)), "bhattacharyya", "AA"), # good
        (normalize_distance_matrix(Distances.pairwise(Distances.bhattacharyya, dnamer_probility_matrix, dims=2)), "bhattacharyya", "DNA"), # bad
        (normalize_distance_matrix(Distances.pairwise(Distances.hellinger, aamer_probability_matrix, dims=2)), "hellinger", "AA"), # good
        (normalize_distance_matrix(Distances.pairwise(Distances.hellinger, dnamer_probility_matrix, dims=2)), "hellinger", "DNA"), # bad
    ]

In [None]:
# How accurate is the best hit?
for (distance_matrix, distance_metric, grammar) in matrix_metric_grammar_groups

    correct_species_hits = 0
    correct_genus_hits = 0
    correct_family_hits = 0

    samples = 1:size(entity_metadata, 1)

    ProgressMeter.@showprogress for i in samples
        other_indices = setdiff(1:size(entity_metadata, 1), i)
        value, index = findmin(distance_matrix[i, other_indices])
        index = other_indices[index]
        if entity_metadata[i, "Species"] == entity_metadata[index, "Species"]
            correct_species_hits += 1
        end
        if entity_metadata[i, "Genus"] == entity_metadata[index, "Genus"]
            correct_genus_hits += 1
        end
        if entity_metadata[i, "Family"] == entity_metadata[index, "Family"]
            correct_family_hits += 1
        end
    end
    @show distance_metric, grammar
    @show correct_species_hits / length(samples)
    @show correct_genus_hits / length(samples)
    @show correct_family_hits / length(samples)
end

In [None]:
10	AA	measure	family	value
10	AA	measure	genus	value
10	AA	measure	species	value
10	DNA	measure	family	value
10	DNA	measure	genus	value
10	DNA	measure	species	value

In [None]:
# average = Statistics.median
# average = Statistics.mean
# for (taxon_level, unique_taxa) in ("Family" => unique_families, "Genus" => unique_genera, "Species" => unique_species)
for (taxon_level, unique_taxa) in ["Genus" => unique_genera]
    for (distance_matrix, distance_metric, grammar) in matrix_metric_grammar_groups
#         for average in [Statistics.median, Statistics.mean]
        for average in [Statistics.median]
            within_vs_between_distances = []
            ProgressMeter.@showprogress  for taxon in unique_taxa
                taxa_indices = findall(entity_metadata[!, taxon_level] .== taxon)
                other_taxa_indices = setdiff(1:1:size(distance_matrix, 1), taxa_indices)
                if isempty(other_taxa_indices)
                    continue
                end
                for index in taxa_indices
                    other_indices = filter(i -> i != index, taxa_indices)
                    if isempty(other_indices)
                        continue
                    end
                    avg_within_taxa_distance = average(vec(distance_matrix[index, other_indices]))
                    avg_between_taxa_distance = average(vec(distance_matrix[index, other_taxa_indices]))
                    push!(within_vs_between_distances, avg_within_taxa_distance => avg_between_taxa_distance)
#                     if avg_within_taxa_distance > avg_between_taxa_distance
#                         @show index
#                     end
                end
            end
#             @show within_vs_between_distances
            within_vs_between_distances = filter(d -> !any(map(x1 -> isnan(x1), collect(d))), within_vs_between_distances)

            ys = collect.(within_vs_between_distances)
        
#             assess correlation
#         assess average slope
#             assess % wrong slope
            confusion = round(count(x -> x[1] >= x[2], within_vs_between_distances) / length(within_vs_between_distances), digits=3)

            xs = [[1, 2] for x in ys]

            if grammar == "AA"
                k = aa_k
            elseif grammar == "DNA"
                k = dna_k
            end

            n = size(distance_matrix, 1)

            p = StatsPlots.plot(
                xs,
                ys,
                xticks = ([1, 2], ["within $(taxon_level)", "between $(taxon_level)"]),
                legend = false,
                xlims = (0.75, 2.25),
                alpha = 0.1,
                title = "$(distance_metric) @ $grammar k=$k $(average)\n(n=$(n)) (misclassification rate = $(confusion))",
                ylabel = "normalized distance",
                marker = :circle
            )
            StatsPlots.savefig(p, "$DIR/$(distance_metric)-distance-$(taxon_level)-k$k-$(grammar)-n$n-avg-$(average).png")
            StatsPlots.savefig(p, "$DIR/$(distance_metric)-distance-$(taxon_level)-k$k-$(grammar)-n$n-avg-$(average).svg")
            display(p)
        end
    end
end

In [None]:
# Do I median all of the genera, and then compare the distance to that median?
# Or do I compare distances to all of the phage, and then find the genus with the lowest median value?

In [None]:
output table format
identifier/accession
euclidean aa
euclidean dna
corr_dist aa
corr_dist dna
totalvariation dna
totalvariation dna

In [None]:
# average = Statistics.median
# for (distance_matrix, distance_metric, grammar) in matrix_metric_grammar_groups

#     correct_species_hits = 0 
#     correct_genus_hits = 0
#     correct_family_hits = 0

#     samples = 1:size(entity_metadata, 1)

#     ProgressMeter.@showprogress for i in samples
#         other_indices = setdiff(1:size(entity_metadata, 1), i)
#         value, index = findmin(distance_matrix[i, other_indices])
#         if entity_metadata[i, "Species"] == entity_metadata[index, "Species"]
#             correct_species_hits += 1
#         end
#         if entity_metadata[i, "Genus"] == entity_metadata[index, "Genus"]
#             correct_genus_hits += 1
#         end
#         if entity_metadata[i, "Family"] == entity_metadata[index, "Family"]
#             correct_family_hits += 1
#         end
#     end
#     @show distance_metric, grammar
#     @show correct_species_hits / length(samples)
#     @show correct_genus_hits / length(samples)
#     @show correct_family_hits / length(samples)
# end

In [None]:
entity_metadata[genera_indices[15], "Genus"]

In [None]:
genera_indices = [findall(entity_metadata[!, "Genus"] .== genus) for genus in unique_genera]

In [None]:
correct = 0
# indices = 1:10
indices = 1:size(entity_metadata, 1)
ProgressMeter.@showprogress for (i, row) in enumerate(eachrow(entity_metadata[indices, :]))
    if isempty(entity_metadata[i, "Genus"])
        continue
    end
    votes = []
    for (distance_matrix, distance_metric, grammar) in matrix_metric_grammar_groups
        genus_medians = zeros(length(unique_genera))
        for (i2, genus_indices) in enumerate(genera_indices)
            if length(genus_indices) > 1
                genus_indices = setdiff(genus_indices, i)
            end
            genus_medians[i2] = Statistics.median(distance_matrix[i, genus_indices])
        end
        min_value, min_value_index = findmin(genus_medians)
        push!(votes, unique_genera[min_value_index])
    end
#     if first(first(sort(collect(StatsBase.countmap(votes)), by=x->x[2], rev=true))) == entity_metadata[i, "Genus"]
#     @show votes
#     @show entity_metadata[i, "Genus"]
#     @show intersect(votes, entity_metadata[i, "Genus"])
    if entity_metadata[i, "Genus"] in votes
#     if !isempty(intersect(votes, entity_metadata[i, "Genus"]))
        correct += 1
    end
end
@show correct/length(indices)

In [None]:
# 0.8067410811993871 across all types
# 0.807835412562924
# 0.8108995403808273
# 0.7695338148391333
# 0.7743488728386956
# 0.7680017509301816
# 0.7861676515648939

# 2, 4, 6
# 0.7929525060188225

# 0.8330050339242723 for any vote being correct

In [None]:
# add ANI + AAI

In [None]:
# ./minimap2 -a test/MT-human.fa test/MT-orang.fa > test.sam

In [None]:
# just do whole genome alignments???

In [None]:
# Try looking only at other phage that have the same host to see if that improves the calls

In [None]:
# sudo conda install -c bioconda comparem

In [None]:
#     Common workflows:
#      aai_wf      -> Calculate AAI between all pairs of genomes
#                     (runs call_genes => similarity => aai)
#      classify_wf -> Identify similar genomes based on AAI values
#                     (runs call_genes => similarity => classify)
                     
#     Gene homology and genome similarity:
#      similarity -> Perform reciprocal sequence similarity search between proteins
#      aai        -> Calculate AAI between all pairs of genomes
#      classify   -> Identify similar genomes based on AAI value

In [None]:
# run(`comparem classify -h`)

# usage: comparem classify_wf [-h] [-k NUM_TOP_TARGETS] [-t TAXONOMY_FILE]
#                             [-e EVALUE] [-p PER_IDENTITY] [-a PER_ALN_LEN]
#                             [-x FILE_EXT] [--proteins]
#                             [--force_table FORCE_TABLE] [--blastp]
#                             [--sensitive] [--keep_headers] [--keep_rbhs]
#                             [--tmp_dir TMP_DIR] [-c CPUS] [--silent]
#                             query_files target_files output_dir

#   -k, --num_top_targets NUM_TOP_TARGETS
#                         number of top scoring target genomes to report per
#                         query genome (default: 1)

#   -p, --per_identity PER_IDENTITY
#                         percent identity for defining homology (default: 30.0)
#   -a, --per_aln_len PER_ALN_LEN
#                         percent alignment length of query sequence for
#                         defining homology (default: 70.0)

#   -c, --cpus CPUS       number of CPUs to use (default: 1)

In [None]:
# https://manual.microbial-genomes.org/part5/workflow