In [None]:
using DataFrames
using CSV
using ScatteredInterpolation
using CairoMakie

struct IsochroneInterpolator{T<:Real}
    data::DataFrame
    coords::Dict{Symbol,Vector{T}}
end

function find_missing_columns(df::DataFrame)
    [col for col in names(df) if any(ismissing.(df[!, col]))]
end

function remove_missing_columns!(df::DataFrame)
    missing_cols = find_missing_columns(df)
    select!(df, Not(missing_cols))
    return df
end

function load_isochrone_data(filepath::String)
    df = DataFrame(CSV.File(filepath, delim=' ', ignorerepeated=true, comment='#'))
    remove_missing_columns!(df)
    return df
end

function add_evolution_phase!(df::DataFrame)
    df.evol = zeros(nrow(df))
    for label in unique(df.label)
        idx = df.label .== label
        n_points = sum(idx)
        df[idx, :evol] = label .+ range(0, 1 - 1/n_points, length=n_points)
    end
    return df
end

function find_nearest_isochrones(interp::IsochroneInterpolator{T}, log_age::T, metal::T) where T<:Real
    age_nodes = _bracket(interp.coords[:logAge], log_age)
    metal_nodes = _bracket(interp.coords[:MH], metal)
    interp.data[(interp.data.logAge .∈ age_nodes) .& (interp.data.MH .∈ metal_nodes), :]
end

function _bracket(sorted_seq::Vector{T}, x::T) where T<:Real
    idx = searchsortedfirst(sorted_seq, x)
    idx == 1 ? [sorted_seq[1], sorted_seq[1]] :
    idx > length(sorted_seq) ? [sorted_seq[end], sorted_seq[end]] :
    [sorted_seq[idx-1], sorted_seq[idx]]
end

function interpolate_isochrone(family::Symbol, filter::String, age::T, metal::T) where T<:Real
    @assert family == :parsec "Only Parsec isochrones accepted"
    log_age = log10(age)
    @assert 5 ≤ log_age ≤ 10.3 "Age range: 5 ≤ log10(age) ≤ 10.3"
    @assert -2.2 < metal ≤ 0.5 "Metallicity range: -2.2 < [M/H] ≤ 0.5"

    if filter == "hsc"
        file_path = "artifacts/isochrones/parsec/$filter/family_MH_-2.2_0.5_logAge_9.2_10.3.dat"
        df = load_isochrone_data(file_path)
        interp = IsochroneInterpolator(df, Dict(
            :logAge => unique(df.logAge),
            :MH => unique(df.MH))
        ))
        return interpolate_isochrone(interp, age, metal)
    end
    error("Filter $filter not implemented")
end

function interpolate_isochrone(interp::IsochroneInterpolator{T}, age::T, metal::T) where T<:Real
    log_age = log10(age)
    nearest = find_nearest_isochrones(interp, log_age, metal)
    add_evolution_phase!(nearest)

    points = Matrix(nearest[:, [:logAge, :MH, :evol]])
    mag_cols = setdiff(names(nearest), [:logAge, :MH, :label, :evol])

    evol_points = range(0, maximum(nearest.label) + 0.99, length=10_000)
    targets = hcat(
        fill(log_age, length(evol_points)),
        fill(metal, length(evol_points)),
        collect(evol_points)
    )

    result = DataFrame(
        evol = evol_points,
        logAge = fill(log_age, length(evol_points)),
        MH = fill(metal, length(evol_points)),
        label = floor.(Int, evol_points)
    )

    for col in mag_cols
        itp = interpolate(Linear(), points', nearest[:, col])
        result[!, col] = evaluate(itp, targets')
    end

    return dropmissing(result)
end

function plot_isochrone_cmd(df::DataFrame; filter=:gaia, filename=nothing)
    x_col = filter == :gaia ? :BP_RP : :g_r
    y_col = filter == :gaia ? :Gmag : :gmag

    fig = Figure()
    ax = Axis(fig[1, 1], xlabel=string(x_col), ylabel=string(y_col), yreversed=true)

    colors = [:blue, :cyan, :red, :magenta, :orange]
    for (label, color) in enumerate(colors)
        idx = df.label .== (label-1)
        lines!(ax, df[idx, x_col], df[idx, y_col], color=color, label="Phase $(label-1)")
    end

    axislegend(ax)
    isnothing(filename) ? fig : save(filename, fig)
    return fig
end

# Example usage
function main()
    # Interpolate isochrone
    iso = interpolate_isochrone(:parsec, "hsc", 1e9, -0.5)

    # Plot CMD
    fig = plot_isochrone_cmd(iso, filter=:gaia)
    save("isochrone_plot.png", fig)
    return iso, fig
end