In [23]:
from specutils import Spectrum
import astropy.units as u
from specutils.analysis import template_correlate
import pandas as pd
import xarray as xr
from specutils.fitting import fit_generic_continuum
from specutils.manipulation import FluxConservingResampler
import numpy as np

In [None]:
def trim_wavelength(ds, wmin, wmax):
    return ds.sel(wavelength=(ds.wavelength >= wmin) & (ds.wavelength <= wmax))

# Extract observed spectrum at time index i
def xarray_to_spectrum(xds, i):
    wave = xds.wavelength.values * u.AA
    flux = xds.flux.isel(time=i).values * u.Unit("erg cm-2 s-1 AA-1")  # Adjust unit as needed
    return Spectrum(spectral_axis=wave, flux=flux)

def cross_correlate_spectrum(data_spec, model_spec):
    result = template_correlate(data_spec, model_spec)
    return result
def normalize_with_continuum(spectrum):
    cont_model = fit_generic_continuum(spectrum)
    continuum = cont_model(spectrum.spectral_axis)
    normalized_flux = spectrum.flux / continuum
    return Spectrum(spectral_axis=spectrum.spectral_axis, flux=normalized_flux)

def resample_model(model_spec, target_spec):
    resampler = FluxConservingResampler()
    return resampler(model_spec, target_spec.spectral_axis)

def trim_model_to_spectrum(model_cube, spectrum, margin=20):
    wmin = spectrum.spectral_axis.min().to_value('angstrom') - margin
    wmax = spectrum.spectral_axis.max().to_value('angstrom') + margin
    return model_cube.sel(wavelength=(model_cube.wavelength >= wmin) & (model_cube.wavelength <= wmax))

def correlate_model_grid(model_cube, model_name, observed_spectrum):
    results = []
    trimmed_model = model_cube
    trimmed_model = trim_model_to_spectrum(model_cube, observed_spectrum, margin=20)
    observed_spectrum = normalize_with_continuum(observed_spectrum)

    for teff in trimmed_model.temperature.values:
        for logg in trimmed_model.gravity.values:
            model_flux = trimmed_model["flux"].sel(temperature=teff, gravity=logg)
            model_spec = Spectrum(
                spectral_axis=trimmed_model.wavelength.values * observed_spectrum.spectral_axis.unit,
                flux=model_flux.values * observed_spectrum.flux.unit
            )
            #model_spec = normalize_with_continuum(model_spec)
            model_spec = resample_model(model_spec, observed_spectrum)

            lags, corr = template_correlate(observed_spectrum, model_spec)
            corr /= np.max(np.abs(corr))

            max_corr = np.max(corr)
            best_lag = lags[np.argmax(corr)]
            lambda0 = observed_spectrum.spectral_axis.mean().to_value('angstrom')
            velocity_shift = (best_lag / lambda0) * 3e5  # km/s

            results.append({
                "model": model_name,
                "teff": teff,
                "logg": logg,
                "correlation": max_corr,
                "velocity_shift_kms": velocity_shift
            })
    return results

In [35]:
# Load the actual datasets from file
red_spectra = xr.open_dataset("red_spectra_again!.h5")
blue_spectra = xr.open_dataset("blue_spectra_fixed.h5")
model_A = xr.open_dataset("Flux 3D Models/flux_model_3D_A.h5")
model_B = xr.open_dataset("Flux 3D Models/flux_model_3D_B.h5")

red_spectra = trim_wavelength(red_spectra, 6000, 9000)
blue_spectra = trim_wavelength(blue_spectra, 3500, 5500)

In [36]:
all_fits = []

for i in range(1, 2):
    obs_spec_non = xarray_to_spectrum(blue_spectra, i)
    obs_spec = normalize_with_continuum(obs_spec_non)

    results_A = correlate_model_grid(model_A, "A", obs_spec)
    results_B = correlate_model_grid(model_B, "B", obs_spec)
    
    combined = pd.DataFrame(results_A + results_B)
    best = combined.loc[combined['correlation'].idxmax()].copy()
    best['time_index'] = i

    
    all_fits.append(best)

df_all_fits = pd.DataFrame(all_fits)

best_fit = df_all_fits.loc[df_all_fits['correlation'].idxmax()]

print(f"Best fit (time {best['time_index']}):")
print(f"Model: {best['model']}, Teff: {best['teff']}, logg: {best['logg'] / 100:.2f}")
print(f"Correlation: {best['correlation']:.2f}, Velocity Shift: {best['velocity_shift_kms']:.2f} km/s")



Best fit (time 1):
Model: A, Teff: 5000.0, logg: 5.00
Correlation: 1.00, Velocity Shift: 0.00 km/s
