In [None]:
from jax import config
#config.update("jax_platform_name", "cpu")
config.update("jax_enable_x64", True)

import jax 
from jax import jit,vmap
import jax.numpy as jnp
import glob 
import os 
import pandas as pd 
import matplotlib.pyplot as plt 
from typing import Callable
from jax import value_and_grad
from pathlib import Path

In [None]:
import sheap
from sheap.spectra_readers import parallel_reader 
from sheap.MainSheap import Sheapectral 
from sheap.RegionHandler.RegionBuilder import RegionBuilder
from sheap.RegionFitting.RegionFitting import RegionFitting
from sheap.Plotting.SheapPlot import SheapPlot
from sheap.Posterior.ParameterEstimation import ParameterEstimation
from sheap.Tools.paths_func import cross_pandas_spectra
from sheap.HostSubtraction import HostSubtraction
from sheap.FunctionsMinimize.utils import combine_auto

In [None]:
from errors import batch_error_covariance_in_chunks

In [None]:
%load_ext autoreload
%autoreload 2

### A

In [None]:
spectrum_dic = Path(sheap.__file__).resolve().parent / "SuportData" / "Spectrum"
files = glob.glob(f"{spectrum_dic}/*")


In [None]:
coords,spectra,_ = parallel_reader(files) 
z = [0.161769,0.184366]

In [None]:
sheapspectral = Sheapectral(spectra,z=z,coords=coords)

In [None]:
sheapspectral.build_region(6000,8000,n_broad=1)

In [None]:
sheapspectral.fit_region(num_steps_list=[1000,1000])

In [None]:
model = combine_auto(sheapspectral.profile_functions)

In [None]:
def residuals(func: Callable, params: jnp.ndarray, xs: jnp.ndarray, y: jnp.ndarray, y_uncertainties: jnp.ndarray) -> jnp.ndarray:
    predictions = func(xs, params)
    return jnp.abs(y - predictions) / y_uncertainties

In [None]:
from sheap.RegionFitting.uncertainty_functions import error_covariance_matrix_single

def error_for_loop(model,params,spectra):
    x,y,error = jnp.moveaxis(spectra,0,1)
    list = jnp.zeros_like(params)
    for n,(params_i,x_i,y_i,error_i) in enumerate(zip(params,x,y,error)):
        sigma = error_covariance_matrix_single(model,params_i,x_i,y_i,error_i,free_params=1)
        list = list.at[n].set(sigma)
        #list[n] = sigma
        #list.append(sigma)
    return list

In [None]:
errors_sigmas

In [None]:
errors_sigmas = error_for_loop(model,sheapspectral.params,sheapspectral.spectra)

In [None]:
errors_sigmas

In [None]:
yerr_i =sheapspectral.mask[0],1e31,yerr_i

In [None]:
params_i = sheapspectral.params[0]
xs_i,y_i,yerr_i = sheapspectral.spectra[0]

In [None]:
residual =  residuals(model, params_i, xs_i, y_i, yerr_i)
#jnp.where(sheapspectral.mask[0],0.0001,residuals(model, params_i, xs_i, y_i, yerr_i))

In [None]:
residual

In [None]:
jac_fn = lambda p: residuals(model, p, xs_i, y_i, yerr_i)

In [None]:
jacobian = jax.jacobian(jac_fn)(params_i)
JTJ = jacobian.T @ jacobian
dof = residual.shape[0] - 0
s_sq = jnp.sum(residual ** 2) / dof
cov = jnp.linalg.inv(JTJ + 1e-6 * jnp.eye(params_i.shape[0])) * s_sq
errors = jnp.sqrt(jnp.diag(cov))

In [None]:
errors

In [None]:
jacobian.shape

In [None]:



jac_fn = lambda p: residuals(func, p, xs_i, y_i, yerr_i)
jacobian = jax.jacobian(jac_fn)(params_i)
JTJ = jacobian.T @ jacobian
dof = residual.shape[0] - free_params
s_sq = jnp.sum(residual ** 2) / dof
cov = jnp.linalg.inv(JTJ + 1e-6 * jnp.eye(params_i.shape[0])) * s_sq

### B

In [None]:
sheapclass = Sheapectral.from_pickle("model_mode.pkl")

In [None]:
#sheapclass.complex_region_class

In [None]:
sheapclass.build_region(6000,8000,n_broad=2)

In [None]:
sheapclass.fit_region(num_steps_list=[100,100])

In [None]:
parameterestimation = ParameterEstimation(sheapclass)

In [None]:
dd=parameterestimation.compute_Luminosity_w()

In [None]:
parameterestimation.compute_params_wu()

In [None]:
parameterestimation.dict_params["broad"]

In [None]:
parameterestimation.compute_bolometric_luminosity()

In [None]:
l.value.shape

In [None]:
#grad_f = grad(lambda p: jnp.sum(profile_func(jnp.array([w]), p), axis=-1))(params)  # shape (225, n_params)

In [None]:
L_w

In [None]:
grad_f(uncertainty_params)

In [None]:
profile_func(jnp.array([w]),params).shape

In [None]:
profile_func(jnp.array([1350.]),params.T)

In [None]:
def scalar_flux(params, w):
    return jnp.sum(profile_func(jnp.array([w]), params), axis=-1)  # shape (N_spectra,)

# Gradient of scalar flux per spectrum
def grad_flux(p, w):
    return grad(lambda p_: jnp.sum(profile_func(jnp.array([w]), p_), axis=-1))(p)  # shape (N_spectra, n_params)

# Evaluate flux and gradients
w = 5100.  # example
flux = profile_func(jnp.array([w]), params)  # shape (225, 2)
#flux_scalar = jnp.sum(flux, axis=-1)         # shape (225,)

#grad_f = grad(lambda p: jnp.sum(profile_func(jnp.array([w]), p), axis=-1))(params)  # shape (225, n_params)
#sigma_f = jnp.sqrt(jnp.sum((grad_f * uncertainty_params)**2, axis=-1))             # shape (225,)


In [None]:
flux

In [None]:
profile_func(jnp.array([w]),params)

In [None]:
parameterestimation.dict_basic_params['broad']

In [None]:
parameterestimation.dict_flux

In [None]:
parameterestimation.complex_region_class.to_dict().keys()

In [None]:
for i in  set([i.kind for i in parameterestimation.complex_region]):
    print(i)


In [None]:
parameterestimation.compute_flux_wu()

In [None]:
parameterestimation.params_dict

In [None]:
amplitudes = [line.amplitude for line in parameterestimation.broad_map.lines]

In [None]:
from sheap.SuportFunctions.functions import mapping_params

In [None]:
type(parameterestimation.broad_map.params_names)

In [None]:
broad_map = parameterestimation.broad_map

In [None]:
idx_amplitude = mapping_params(broad_map.params_names,"amplitude")



In [None]:
idx_amplitude

In [None]:
parameterestimation.RegionMap.complex_region

In [None]:
parameterestimation.complex_region_class.to_dict()['complex_region']

In [None]:
parameterestimation.RegionMap.complex_region

In [None]:
from sheap.MainSheap import ComplexRegion

In [None]:
type(sheapclass.complex_region_class)

In [None]:
# complexregionclass= ComplexRegion(complex_region=sheapclass.complex_region,
#                                 params=sheapclass.params,
#                                 uncertainty_params=sheapclass.uncertainty_params,
#                                 profile_functions=sheapclass.profile_functions,
#                                 params_dict=sheapclass.params_dict,
#                                 profile_names=sheapclass.profile_names,
#                                 profile_params_index_list=sheapclass.profile_params_index_list)

In [None]:
from sheap.SuportFunctions.functions import LineMapper

In [None]:
maps=LineMapper(sheapclass.complex_region,sheapclass.profile_functions,sheapclass.params,sheapclass.uncertainty_params,
           sheapclass.profile_params_index_list,sheapclass.params_dict,sheapclass.profile_names)

In [None]:
sheapclass.profile_names

In [None]:
broads_lines=maps._get("kind","broad")

In [None]:
sheapclass.profile_functions

In [None]:
broads_lines.lines

In [None]:
parameterestimation.compute_flux_wu()

In [None]:
parameterestimation.compute_fwhm_wu()

In [None]:
np.array(parameterestimation.d)

In [None]:
from auto_uncertainties import Uncertainty
sigma=parameterestimation.sigma[0]
sigma_u=parameterestimation.sigma_u[0]

In [None]:
from auto_uncertainties import Uncertainty
import numpy as np
value = np.linspace(start=0, stop=10, num=5)
error = np.ones_like(value)*0.1
u = Uncertainty(value, error)

In [None]:
Uncertainty(np.array(sigma),np.array(sigma_u))

In [None]:
jnp.nan_to_num(parameterestimation.sigma_u,nan=1e8)[-1]

In [None]:
norm_amplitude = Uncertainty(self.norm_amplitude, self.norm_amplitude_u)
sigma = Uncertainty(self.sigma, self.sigma_u)
flux =  jnp.sqrt(2. * jnp.pi) * norm_amplitude * sigma 

In [None]:
sp = SheapPlot(sheapclass)

In [None]:
sp.plot(10)

In [None]:
sheap

In [None]:
model = combine_auto(sheapclass.profile_functions)

spectra = sheapclass.spectra.at[:,[1,2],:].multiply(10 ** (-1 * sheapclass.spectra_exp[:,jnp.newaxis,jnp.newaxis])) #apply escale to 0-20 max
#x,y,y_uncertainties = jnp.moveaxis(spectra,0,1)

In [None]:
y_uncertainties

In [None]:
errrs_array_fulljacobian = batch_error_covariance_in_chunks(model,sheapclass.params, x, y, y_uncertainties)

In [None]:
def residuals(func,params: jnp.ndarray, xs, y: jnp.ndarray, y_uncertainties: jnp.ndarray):
        predictions = func(xs, params)
            
        return jnp.abs(y - predictions) / y_uncertainties

def error_covariance_matrix_single(
    func: Callable,
    params_i: jnp.ndarray,
    xs_i: jnp.ndarray,
    y_i: jnp.ndarray,
    yerr_i: jnp.ndarray,
    free_params: int
) -> jnp.ndarray:
    residual = residuals(func, params_i, xs_i, y_i, yerr_i)
    
    # Jacobian w.r.t. params
    jac_fn = lambda p: residuals(func, p, xs_i, y_i, yerr_i)
    jacobian = jax.jacobian(jac_fn)(params_i)  # shape (n_data, n_params)
    
    JTJ = jacobian.T @ jacobian
    chi_square = jnp.sum(residual ** 2)
    dof = residual.shape[0] - free_params
    s_sq = chi_square / dof

    # Add small diagonal term to avoid singular matrix
    cov = jnp.linalg.inv(JTJ + 1e-6 * jnp.eye(params_i.shape[0])) * s_sq
    return jnp.sqrt(jnp.diag(cov))  # shape: (n_params,)

def batch_error_covariance_in_chunks(params, xs, y, yerr, batch_size=30):
    n = params.shape[0]
    results = []
    for i in range(0, n, batch_size):
        batch_fn = vmap(
            lambda p, x, y_, ye: error_covariance_matrix_single(model, p, x, y_, ye, 0),
            in_axes=(0, 0, 0, 0)
        )
        batch_res = batch_fn(
            params[i:i+batch_size],
            xs[i:i+batch_size],
            y[i:i+batch_size],
            yerr[i:i+batch_size]
        )
        results.append(batch_res)
    return jnp.concatenate(results, axis=0)

In [None]:
# import traceback
# import gc

# def try_batched_error_covariance(model, params, xs, y, yerr, free_params=0, max_batch=30):
#     n_total = params.shape[0]
#     batch_size = max_batch
#     results = []
#     i = 0
    
#     while i < n_total:
#         try:
#             current_batch_size = min(batch_size, n_total - i)
#             #print(current_batch_size)
#             print("from to:",i,i+current_batch_size)
#             batch_res = params[i:i+current_batch_size]
#             #i += current_batch_size
#             batch_fn = vmap(
#             lambda p, x, y_, ye: error_covariance_matrix_single(model, p, x, y_, ye, 0),
#             in_axes=(0, 0, 0, 0)
#         )
#             batch_res = batch_fn(
#                  params[i:i+current_batch_size],
#                  xs[i:i+current_batch_size],
#                  y[i:i+current_batch_size],
#                  yerr[i:i+current_batch_size]
#              )
#             results.append(batch_res)
#             i += current_batch_size
#             batch_size = min(current_batch_size + 5, max_batch)
#             del batch_res, batch_fn
#             jax.clear_backends()
#         except RuntimeError as e:
#             #del batch_res, batch_fn
#             gc.collect()
#             jax.clear_backends()
#             if "RESOURCE_EXHAUSTED" in str(e) or "out of memory" in str(e).lower():
#                 batch_size = max(1, batch_size // 2)
#                 print(f"DOOM at batch {i}-{i+current_batch_size}, reducing to batch_size={batch_size}")
#                 if batch_size == 1:
#                     print("Cannot reduce batch size further. Exiting.")
#                     raise e
#             else:
#                 print("nhandled error:")
#                 traceback.print_exc()
#                 raise e
#     return jnp.concatenate(results, axis=0)     

In [None]:
sheapclass.params[0]

In [None]:
errrs_array_fulljacobian[0]

In [None]:
def fisher_error_estimate(
    func: Callable,
    params_i: jnp.ndarray,
    xs_i: jnp.ndarray,
    y_i: jnp.ndarray,
    yerr_i: jnp.ndarray,
    free_params: int
) -> jnp.ndarray:
    def loss_fn(p):
        residual = residuals(func, p, xs_i, y_i, yerr_i)
        return jnp.sum(residual**2)

    loss, grad = value_and_grad(loss_fn)(params_i)  # grad: shape (n_params,)
    
    grad = jnp.nan_to_num(grad)
    JTJ = jnp.outer(grad, grad)  # Approximates Fisher matrix (1 sample)
    
    dof = y_i.shape[0] - free_params
    s_sq = loss / dof
    
    cov = jnp.linalg.inv(JTJ + 1e-6 * jnp.eye(grad.shape[0])) * s_sq
    return jnp.sqrt(jnp.diag(cov))

def batch_fisher_errors(func, params, xs, y, yerr, free_params=0):
    return vmap(
        lambda p, x, y_, ye: fisher_error_estimate(func, p, x, y_, ye, free_params),
        in_axes=(0, 0, 0, 0)
    )(params, xs, y, yerr)

In [None]:
resid_array = sheapclass.params #batch_error_covariance_in_chunks(params, xs, y, yerr, batch_size=30)

In [None]:
resid_array = batch_fisher_errors(model,sheapclass.params, x, y, y_uncertainties)


In [None]:
sheapclass.params_dict

In [None]:
errrs_array_fulljacobian[0]

In [None]:
import gc
gc.collect()

In [None]:
# import jax
# import jax.numpy as jnp
import traceback

def try_batched_error_covariance(model, params, xs, y, yerr, free_params=0, max_batch=64):
    n_total = params.shape[0]
    batch_size = max_batch
    results = []

    i = 0
    while i < n_total:
        current_batch_size = min(batch_size, n_total - i)
        try:
            batch_fn = jax.vmap(
                lambda p, x, y_, ye: error_covariance_matrix_single(
                    model, p, x, y_, ye, free_params
                ),
                in_axes=(0, 0, 0, 0)
            )
            # Try this batch
            batch_res = batch_fn(
                params[i:i+current_batch_size],
                xs[i:i+current_batch_size],
                y[i:i+current_batch_size],
                yerr[i:i+current_batch_size]
            )
            results.append(batch_res)
            i += current_batch_size
            # If it worked, try increasing batch size slightly next time
            batch_size = min(batch_size + 5, max_batch)

        except RuntimeError as e:
            if "RESOURCE_EXHAUSTED" in str(e) or "out of memory" in str(e).lower():
                batch_size = max(1, batch_size // 2)
                print(f"DOOM at batch {i}-{i+current_batch_size}, reducing to batch_size={batch_size}")
                if batch_size == 1:
                    print("Cannot reduce batch size further. Exiting.")
                    raise e
            else:
                print("nhandled error:")
                traceback.print_exc()
                raise e

    return jnp.concatenate(results, axis=0)



In [None]:
#7.minutes cpu 
#30, 1.46 #max ->
resid_array = try_batched_error_covariance(model,sheapclass.params, x, y, y_uncertainties)

In [None]:
batched_error_covariance = vmap(
    lambda params_i, xs_i, y_i, yerr_i: error_covariance_matrix_single(
        model, params_i, xs_i, y_i, yerr_i, 0
    ),
    in_axes=(0, 0, 0, 0)
)
resid_array = batched_error_covariance(sheapclass.params, x, y, y_uncertainties)

In [None]:
region_plot = SheapPlot(sheapclass)
region_plot.plot(3,add_name=False)#207

In [None]:
sheapclass.build_region(4400, 5600,fe_mode="sum",n_broad=2,add_outflow=True,by_region=True,force_linear=False,add_balmercontiniumm=True)

In [None]:
sheapclass.builded_region.complex_region

In [None]:
sheapclass.fit_region([100,100])

In [None]:
sheapclass.complex_region

In [None]:
sheapclass.complex_region

In [None]:
region_plot = SheapPlot(sheapclass)
region_plot.plot(20,add_name=False)#207

In [None]:
plt.plot(jnp.array(sheapclass.loss).T[3])

In [None]:

from jax import vmap 

In [None]:
# Must be first, before jax.numpy or anything else
from jax import config
config.update("jax_platform_name", "cpu")

# Now safe to import
import jax
import jax.numpy as jnp
from jax import jit, vmap




In [None]:
#from jax import config
#config.update("jax_platform_name", "cpu")

In [None]:
resid_array[1]

In [None]:
resid_array

In [None]:
sheapclass.params[0:2]

In [None]:
@jit
def error_covariance_matrix_method(
    optimized_params_flat: jnp.ndarray,
    xs: List[jnp.ndarray],
    y: jnp.ndarray,
    y_uncertainties: jnp.ndarray,
    free_params: int
) -> jnp.ndarray:
    """
    Calculate the error covariance matrix based on residuals for multiple input variables.

    Parameters:
    - optimized_params_flat: Optimized parameters as a flat array.
    - xs: List of input data arrays (e.g., [x1, x2, ...]).
    - y: Target data.
    - y_uncertainties: Uncertainties in target data.
    - free_params: Number of free parameters in the model.

    Returns:
    - Standard errors for each parameter.
    """
    residual = residuals(optimized_params_flat, xs, y, y_uncertainties)
    jacobian = jax.jacobian(residuals)(optimized_params_flat, xs, y, y_uncertainties)
    JTJ = jacobian.T @ jacobian
    chi_square = jnp.sum(residual ** 2)
    degrees_of_freedom = len(residual) - free_params
    s_sq = chi_square / degrees_of_freedom
    covariance_matrix = jnp.linalg.inv(JTJ + 1e-6 * jnp.eye(len(optimized_params_flat))) * s_sq
    return jnp.sqrt(jnp.diag(covariance_matrix))

In [None]:
from jax import vmap 
batched_residuals = vmap(
    lambda xs_i, params_i, y_i, yerr_i: residuals(model, params_i, xs_i, y_i, yerr_i),
    in_axes=(0, 0, 0, 0)
)

In [None]:
from jax import vmap

def residuals(func, params: jnp.ndarray, xs, y: jnp.ndarray, y_uncertainties: jnp.ndarray):
    predictions = func(xs, params)  # returns shape (4614,)
    return jnp.abs(y - predictions) / y_uncertainties  # shape (4614,)

# vmap over the 225 spectra
batched_residuals = vmap(
    lambda params_i, xs_i, y_i, yerr_i: residuals(model, params_i, xs_i, y_i, yerr_i),
    in_axes=(0, 0, 0, 0)
)

# usage:
resid_array = batched_residuals(sheapclass.params, x, y, y_uncertainties)


In [None]:
def error_covariance_matrix_single(
    func: Callable,
    params_i: jnp.ndarray,
    xs_i: jnp.ndarray,
    y_i: jnp.ndarray,
    yerr_i: jnp.ndarray,
    free_params: int
) -> jnp.ndarray:
    residual = residuals(func, params_i, xs_i, y_i, yerr_i)
    
    # Jacobian w.r.t. params
    jac_fn = lambda p: residuals(func, p, xs_i, y_i, yerr_i)
    jacobian = jax.jacobian(jac_fn)(params_i)  # shape (n_data, n_params)
    
    JTJ = jacobian.T @ jacobian
    chi_square = jnp.sum(residual ** 2)
    dof = residual.shape[0] - free_params
    s_sq = chi_square / dof

    # Add small diagonal term to avoid singular matrix
    cov = jnp.linalg.inv(JTJ + 1e-6 * jnp.eye(params_i.shape[0])) * s_sq
    return jnp.sqrt(jnp.diag(cov))  # shape: (n_params,)


In [None]:
jacobian = jax.vmap(jax.grad(lambda p, xi, yi: jnp.abs(yi - model(xi,p )) ))(sheapclass.params, x, y)