# Implement and test iterative extinction correction for Gaia DR3

In [None]:

def compute_km(A0, X, coeffs):
    """
    Compute the extinction coefficient k_m for a given band using the formula:
    k_m = a1 + a2*X + a3*X^2 + a4*X^3 + a5*A0 + a6*A0^2 + a7*A0^3 
          + a8*A0*X + a9*A0*X^2 + a10*X*A0^2

    Parameters:
        A0 (float): Extinction at 550 nm.
        X (float): Intrinsic color or effective temperature variable.
        coeffs (dict): Coefficients for the specific band and X.

    Returns:
        float: Computed extinction coefficient k_m.
    """
    return (
        coeffs["Intercept"]
        + coeffs["X"] * X
        + coeffs["X2"] * X**2
        + coeffs["X3"] * X**3
        + coeffs["A"] * A0
        + coeffs["A2"] * A0**2
        + coeffs["A3"] * A0**3
        + coeffs["XA"] * A0 * X
        + coeffs["AX2"] * A0 * X**2
        + coeffs["XA2"] * X * A0**2
    )

def iterative_extinction_correction(X, ebv, band_coeffs, max_iter=10, tol=1e-6):
    """
    Iteratively compute the extinction correction for a given band.

    Parameters:
        X (float): Intrinsic color or effective temperature variable.
        ebv (float): Reddening value E(B-V).
        band_coeffs (dict): Coefficients for the specific band and X.
        max_iter (int): Maximum number of iterations for convergence.
        tol (float): Convergence tolerance for A0.

    Returns:
        float: Final extinction A_m for the band.
    """
    # Initial guess for A0
    A0 = 3.1 * ebv

    for _ in range(max_iter):
        km = compute_km(A0, X, band_coeffs)
        new_A0 = km * A0

        if abs(new_A0 - A0) < tol:
            break
        A0 = new_A0

    return A0

def extinction_correction_iterative(catalog, coeffs, max_iter=10, tol=1e-6):
    """
    Compute the extinction correction in the Gaia passbands for the stars in the
    catalog using an iterative method to apply the transformation to Gaia G, BP, and RP bands.

    Parameters:
        catalog (pd.DataFrame): Gaia catalog with columns 'ra', 'dec', 'phot_g_mean_mag', 
                                'phot_bp_mean_mag', and 'phot_rp_mean_mag'.
        coeffs (dict): Dictionary of extinction coefficients for each band and color/temperature.
        max_iter (int): Maximum number of iterations for the extinction computation.
        tol (float): Convergence tolerance for iterative extinction computation.

    Returns:
        pd.DataFrame: Input catalog with added columns for extinction corrections.
    """
    # Define SkyCoord object
    coords = SkyCoord(ra=catalog['ra'], dec=catalog['dec'], unit=(u.degree, u.degree), frame='icrs')

    # Retrieve extinction values E(B-V) from SFD
    ebv = sfd(coords)

    # Compute the initial extinction in the V band
    A_0 = 3.1 * ebv

    # Initialize lists to store the computed extinction values
    A_G_values, A_bp_values, A_rp_values = [], [], []

    # Process each star in the catalog
    for index, row in catalog.iterrows():
        # Extract intrinsic color X (e.g., (G_BP - G_RP)_0)
        X = row['phot_bp_mean_mag'] - row['phot_rp_mean_mag']

        # Compute extinction for each band using the iterative method
        A_G = iterative_extinction_correction(X, ebv[index], coeffs["kG"]["BPRP"], max_iter, tol)
        A_bp = iterative_extinction_correction(X, ebv[index], coeffs["kBP"]["BPRP"], max_iter, tol)
        A_rp = iterative_extinction_correction(X, ebv[index], coeffs["kRP"]["BPRP"], max_iter, tol)

        # Store the results
        A_G_values.append(A_G)
        A_bp_values.append(A_bp)
        A_rp_values.append(A_rp)

    # Add corrected magnitudes to the catalog
    catalog['A_G'] = A_G_values
    catalog['A_bp'] = A_bp_values
    catalog['A_rp'] = A_rp_values
    catalog['G_corr'] = catalog['phot_g_mean_mag'] - catalog['A_G']
    catalog['phot_bp_mean_mag_corr'] = catalog['phot_bp_mean_mag'] - catalog['A_bp']
    catalog['phot_rp_mean_mag_corr'] = catalog['phot_rp_mean_mag'] - catalog['A_rp']
    catalog['bp_rp_corr'] = catalog['phot_bp_mean_mag_corr'] - catalog['phot_rp_mean_mag_corr']

    return catalog

# Load the coefficients from the CSV file
def load_coefficients(filepath):
    """
    Load extinction coefficients from a CSV file and structure them into a dictionary.

    Parameters:
        filepath (str): Path to the CSV file containing extinction coefficients.

    Returns:
        dict: Structured coefficients for extinction correction.
    """
    data = pd.read_csv(filepath)
    coefficients = {}
    for _, row in data.iterrows():
        band = row["Kname"]
        xname = row["Xname"]
        if band not in coefficients:
            coefficients[band] = {}
        coefficients[band][xname] = {
            "Intercept": row["Intercept"],
            "X": row["X"],
            "X2": row["X2"],
            "X3": row["X3"],
            "A": row["A"],
            "A2": row["A2"],
            "A3": row["A3"],
            "XA": row["XA"],
            "AX2": row["AX2"],
            "XA2": row["XA2"]
        }
    return coefficients