In [None]:
using DCAUtils
ENV["PYTHON"] = "" 
using Pkg
Pkg.build("PyCall")  # Rebuild PyCall to use the internal Python
using Genie

using DCAUtils
ENV["PYTHON"] = "" 
using Pkg
Pkg.build("PyCall")  # Rebuild PyCall to use the internal Python
using Revise,Genie, Statistics, PyPlot



In [None]:
nat_msa = read_fasta_alignment("../synthetic/DBD_alignment.uniref90.cov80.a2m", 0.9);

L,M = size(nat_msa)

wts = [2748, 13202];
depths = [1,2,3,4,6,12];

d_wts   = Array{Vector{Float64}}(undef, length(wts), length(depths))
d_pairs = Array{Vector{Float64}}(undef, length(wts), length(depths))
mutation_rates   = Array{Vector{Float64}}(undef, length(wts), length(depths))

for i in 1:length(wts)
    for j in 1:length(depths)
        
        wt = wts[i]
        depth = depths[j]
        
        folder = "../synthetic/"
        files = filter(f -> occursin("wt$(wt)", basename(f)) && occursin("depth$(depth).", basename(f)),
                       readdir(folder; join=true))


        msas = [read_fasta_alignment(x, 0.9) for x in files];
        @time d_pair = [pairwise_ham_dist(x, n_seq = 4096) for x in msas];
        @time d_wt = [mean(ham_dist(nat_msa[:,wt], x)) for x in msas];

        mus = [parse(Float64, match(r"mu([0-9.]+)", f).captures[1]) for f in files];

        using PyPlot

        # Sort results by mu for clean plotting
        order = sortperm(mus)
        mus_sorted   = mus[order]
        d_wt_sorted  = d_wt[order]
        d_pair_sorted = d_pair[order]
        

        mutation_rates[i,j] = mus_sorted
        d_wts[i,j] = d_wt_sorted
        d_pairs[i,j] = d_pair_sorted

        #==
        # --- Plot 1: mu vs Hamming distance from WT ---
        figure()
        plot(mus_sorted, d_wt_sorted ./L, "o-", linewidth=2)
        xlabel("μ")
        ylabel("Hamming distance from WT")
        title("WT = $(wt), depth = $(depth) : Hamming distance from WT")
        xscale("log")  # log scale makes sense for mu
        tight_layout()
        savefig("../mu_vs_wt_hamdist_wt$(wt)_depth$(depth).png")

        # --- Plot 2: mu vs Pairwise Hamming distance ---
        figure()
        plot(mus_sorted, d_pair_sorted ./L, "o-", linewidth=2)
        xlabel("μ")
        ylabel("Pairwise Hamming distance")
        title("WT = $(wt), depth = $(depth) : Pairwise Hamming distance")
        xscale("log")
        tight_layout()
        savefig("../mu_vs_pairwise_hamdist_wt$(wt)_depth$(depth).png")
        
        close("all")  # clean up figures ==#
    end
end

using PyPlot

for i in 1:length(wts)
    wt = wts[i]

    figure()

    for j in 1:length(depths)
        depth = depths[j]

        mus   = mutation_rates[i,j]   # vector of μ values
        dpair = d_pairs[i,j]          # vector of pairwise Hamming distances

        plot(mus, dpair ./ L, "o-", linewidth=2, label="depth = $(depth)")
    end

    xlabel("μ")
    ylabel("Pairwise Hamming distance / L")
    title("WT = $(wt) : Pairwise Hamming distance across depths")
    xscale("log")
    legend()
    tight_layout()

    savefig("../pairwise_hamdist_wt$(wt)_alldepths.png")
    close("all")
end


using PyPlot

for i in 1:length(wts)
    wt = wts[i]

    figure()

    for j in 1:length(depths)
        depth = depths[j]

        mus  = mutation_rates[i,j]   # vector of μ values
        dwt  = d_wts[i,j]            # vector of Hamming distances from WT

        plot(mus, dwt ./ L, "o-", linewidth=2, label="depth = $(depth)")
    end

    xlabel("μ")
    ylabel("Hamming distance from WT / L")
    title("WT = $(wt) : Distance from WT across depths")
    xscale("log")
    legend()
    tight_layout()

    savefig("../wt_hamdist_wt$(wt)_alldepths.png")
    close("all")
end


In [None]:
ENV["PYTHON"] = "" 
using Pkg
Pkg.build("PyCall") 

using Revise, PhyloTools, DelimitedFiles, DCAUtils, JLD2, PyPlot, Statistics, LinearAlgebra
import KitMSA: fasta2matrix, matrix2fasta
using TreeTools # to handle p

depths = [1,2,3,4,6,12];
tree_files = ["../data_ASR_synthetic_trees/synthetic_tree_pipeline/artificial_tree_4096leaves_depth$(i).nwk" for i in depths];


trees = [read_tree(tree_file, node_data_type = Seq)  for tree_file in tree_files]; 
d_pair_trees = [];

for tree in trees
# Collect leaf labels
    leaf_labels = map(label, leaves(tree))
    # Collect all pairwise distances
    dists = Float64[]
    for i in 1:length(leaf_labels)-1
        for j in i+1:length(leaf_labels)
            push!(dists, distance(tree, leaf_labels[i], leaf_labels[j]))
        end
    end
    push!(d_pair_trees, dists)
    println(tree)
end

using PyPlot

# Define colors (optional, to distinguish depths)
colors = ["C0", "C1", "C2", "C3", "C4", "C5"]

figure()

for (k, dists) in enumerate(d_pair_trees)
    depth = depths[k]

    hist(dists;
         bins=30,
         histtype="step",    # step-type histogram
         linewidth=3,
         color=colors[k],
         label="depth = $(depth)")
end

xlabel("Pairwise distance")
ylabel("Count (log scale)")
yscale("log")
legend()
tight_layout()
savefig("../trees_pairwise_divergence_length.png",dpi = 300)