In [None]:
ENV["PYTHON"] = "" 
using Pkg
Pkg.build("PyCall")  # Rebuild PyCall to use the internal Python
ENV["PYTHON"] = "" 
using Pkg
Pkg.build("PyCall")  # Rebuild PyCall to use the internal Python
using Revise, Genie, DelimitedFiles, DCAUtils, JLD2, PyPlot

@load "../data_Genie/pars_dbd.jld2"


In [None]:
using PyPlot, Statistics, LinearAlgebra

# ---------- PARAMETERS ----------
models = [(8,4), (8,8), (16,4)]  # extend as needed
N_comp_values = [2,4,8,16]

base_dir = "../data_msa_autoencoder/debug/"
ste_dir  = "../data_msa_autoencoder/debug_STE/"

# pre-trained fields/couplings
# (assume h_dbd, J_dbd are already loaded in your session)

# ---------- HELPER ----------
function compute_energy_from_fasta(path::String, h, J)
    aln = read_fasta_alignment(path, 0.9)
    return energy(aln, h, J)
end

function pearson(x::Vector{Float64}, y::Vector{Float64})
    xm, ym = mean(x), mean(y)
    num = sum((x .- xm) .* (y .- ym))
    den = sqrt(sum((x .- xm).^2) * sum((y .- ym).^2))
    return num/den
end

# ---------- ANALYSIS ----------
for (d,q) in models
    figure(figsize=(12,6))
    
    for (i,N) in enumerate(N_comp_values)
        # file paths
        base_path = "$(base_dir)DBD_shuffle_d$(d)_qq$(q)_gmm$(N).00.fasta"
        ste_path  = "$(ste_dir)DBD_shuffle_d$(d)_qq$(q)_gmm$(N).00.fasta"
        
        # compute energies
        E_base = compute_energy_from_fasta(base_path, h_dbd, J_dbd)
        E_ste  = compute_energy_from_fasta(ste_path,  h_dbd, J_dbd)

        # ---- Histograms ----
        subplot(2, length(N_comp_values), i)
        hist(E_base, bins=30, alpha=0.5, label="BASE")
        hist(E_ste, bins=30, alpha=0.5, label="STE")
        title("d=$d, q=$q, N=$N")
        legend()

        # ---- Scatter with Pearson ----
        subplot(2, length(N_comp_values), length(N_comp_values)+i)
        scatter(E_base, E_ste, alpha=0.5, s=10)
        r = pearson(E_base, E_ste)
        plot([minimum(E_base), maximum(E_base)], [minimum(E_base), maximum(E_base)],
             "k--", lw=1)
        xlabel("BASE energy")
        ylabel("STE energy")
        title("œÅ = $(round(r, digits=3))")
    end
    
    tight_layout()
    savefig("energy_comparison_d$(d)_q$(q).png", dpi=200)
end
