In [None]:
ENV["PYTHON"] = "" 
using Pkg
Pkg.build("PyCall")  # Rebuild PyCall to use the internal Python
using Revise, Genie, DelimitedFiles, DCAUtils, JLD2, PyPlot, Statistics, LinearAlgebra
import KitMSA: fasta2matrix, matrix2fasta


h_tmp, J_tmp = read_par_BM_lettersave("../model_natural_PSE.dat");
h = set_max_field_to_0save(h_tmp);
J_tmp2 = symmetrize_Jsave(J_tmp); 
J = permutedims(J_tmp2, [1,3,2,4]);


nat_msa = Int8.(fasta2matrix("../Gen.jl/data/alignments/natural/PF13354_noinsert_max19gaps_nodupl_noclose.faa")')[3:201,:]

L, M = size(nat_msa)

cdes = [mean(cont_dep_entr(nat_msa[:,i], h, J)) for i in 1:M];

cc = ["orange", "red", "maroon", "yellow", "limegreen", "darkgreen", "cyan", "dodgerblue", "darkblue"];

function read_fasta_headers(filepath::String)
           headers = String[]
           open(filepath, "r") do io
               for line in eachline(io)
                   if startswith(line, '>')
                       push!(headers, strip(line[5:end]))  # remove '>' and strip whitespace
                   end
               end
           end
           return headers
       end

function sample_low_mid_high_indices(values::Vector{<:Real}, n_each::Int=1)
                  # Sort values and keep original indices
                  sorted_indices = sortperm(values)
                  N = length(values)
                  
                  # Determine third sizes
                  unit = div(N, 10)
                  rem = N % 10

                  # Assign indices to categories
                  lows   = sorted_indices[1:unit]
                  mids   = sorted_indices[(4*unit+1):(5*unit)]
                  highs  = sorted_indices[(9*unit+1):end]

                  # Handle uneven splits by distributing remainders
                  if rem > 0
                      mids = vcat(mids, sorted_indices[end-rem+1:end])
                  end

                  # Sample indices
                  sampled_low  = rand(lows, min(n_each, length(lows)))
                  sampled_mid  = rand(mids, min(n_each, length(mids)))
                  sampled_high = rand(highs, min(n_each, length(highs)))

                  return sampled_low, sampled_mid, sampled_high
              end



all_headers = read_fasta_headers("../Gen.jl/data/alignments/natural/PF13354_noinsert_max19gaps_nodupl_noclose.faa");

ecolx_headers = filter(h -> occursin(r"_ECOLX", h), all_headers);

fragments = [
    "H6V563/57-263",
    "I3VNV9/57-263",
    "J7KCB4/57-263",
    "Q8KSA6/60-263",
    "Q58G80/60-263",
    "Q9EXV5/57-263",
    "E0XN37/54-269",
    "A0A2U9GMV9/56-262",
    "A0A0H3YEU2/56-262",
    "D9IQH0/60-263",
    "O69395/57-261"
];

low_idx, mid_idx, high_idx = sample_low_mid_high_indices(cdes, 3);
#imp_idxs = vcat(low_idx, mid_idx,high_idx);

imp_idxs = [203,240,264,357,449,617,3051,3182,6826,8616,9740];
imp_seqs = nat_msa[:, imp_idxs]; N_start_seq_imp = size(imp_seqs,2);

cdes_imp = [mean(cont_dep_entr(imp_seqs[:,i], h, J)) for i in 1:N_start_seq_imp];
names = read_fasta_headers("../Gen.jl/data/alignments/natural/PF13354_noinsert_max19gaps_nodupl_noclose.faa")[imp_idxs];

close("all"); plt.hist(cdes, histtype = "step", color = "grey", linewidth = 3., label = "nat", density = true)
for i in 1:length(cdes_imp)
    plt.plot([cdes_imp[i], cdes_imp[i]], [0.,  1.])#color = cc[i]);     
end
plt.legend()
plt.xlabel("CDE")
savefig("../CDEs_beta.png")


writedlm("../beta_headers.txt", 
    read_fasta_headers("../Gen.jl/data/alignments/natural/PF13354_noinsert_max19gaps_nodupl_noclose.faa"))


N_steps = 10^4; N_chains = 500; NN_points = 30; steps = unique([
        trunc(Int,10^y) for y in range(log10(1), log10(N_steps), 
                length=NN_points)]); sweeps = steps ./ L; N_points = length(steps);



res_all_imp = []; hams_all_imp = zeros(N_start_seq_imp, N_points, N_chains
    ); start_seq_imp = [imp_seqs[:,i] for i in 1:N_start_seq_imp]; 

@time for n in 1:N_start_seq_imp
    res = run_evolution(Int8.(hcat([start_seq_imp[n] for i in 1:N_chains]...)), 
        h, 
        J, 
        p = 0.5, 
        temp = 1.0, 
        N_points = NN_points, 
        N_steps = N_steps);
        println("Initial seq $(n)")
    push!(res_all_imp, res) 
    hams_all_imp[n,:,:] .= ham_dist(res.step_msa)
end

hams_single_imp = zeros(N_start_seq_imp, N_points); chi_dyn_single_imp = zeros(N_start_seq_imp, N_points);
for i in 1:N_start_seq_imp
    for n in 1:N_points
        hams_single_imp[i,n] = mean(ham_dist(res_all_imp[i].step_msa[1], res_all_imp[i].step_msa[n]))
        chi_dyn_single_imp[i,n] = var(ham_dist(res_all_imp[i].step_msa[1], res_all_imp[i].step_msa[n]))
    end
end

hams_single_all_imp = mean(hams_all_imp,dims = 3)[:,:,1];



close("all")

for i in 1:N_start_seq_imp
    plt.plot(steps, chi_dyn_single_imp[i,:] ./ (L^2), linewidth = 4.0)# color = cc[i])
end


#plt.legend()
plt.xlabel("MCMC steps")
plt.xscale("log")
plt.ylabel("chi_dyn_A")
plt.savefig("../beta_single_wt_chi_dyn_imp.png")



close("all")
plt.plot()

for i in 1:N_start_seq_imp
    plt.plot(steps, hams_single_all_imp[i,:] ./ L , linewidth = 4.0)# color = cc[i])
end



#plt.legend()
plt.xlabel("MCMC steps")
plt.xscale("log")
plt.ylabel("[H_A]")
plt.savefig("../beta_all_single_wt_mean_ham_dist_imp.png")

matrix2fasta("../betalact_seqs_diff_CDE.fa", Int8.(imp_seqs'))
writedlm("../betalact_names_seqs_diff_cde.txt", names)



