In [None]:
import sys
import os

notebook_cwd = os.getcwd()
src_path = os.path.join(notebook_cwd, '..')
if src_path not in sys.path:
    sys.path.append(src_path)


# Imports

In [None]:
from dataclasses import dataclass
from typing import Optional

import numpy as np

from src.common_utils.interferogram import Interferogram
from src.common_utils.light_wave import Spectrum
from src.common_utils.utils import calculate_rmse
from src.interface.configuration import load_config
from src.direct_model.interferometer import FabryPerotInterferometer
from src.common_utils.utils import oversample
from dataclasses import replace
from src.inverse_model.analytical_inverter import HaarInverter
from src.common_utils.custom_vars import InversionProtocolType


# Experiment Class

In [None]:
@dataclass(frozen=True)
class ExperimentMetadata:
    experiment_name: str
    device_label: str
    protocol_label: list[str]
    protocol_color: list[str]


@dataclass(frozen=True)
class Experiment:
    spectrum: Spectrum
    device: FabryPerotInterferometer
    noise: float
    protocols_id_list: list[int]
    
    def interferogram(self) -> Interferogram:
        interferogram = self.device.acquire_interferogram(spectrum=self.spectrum)
        return interferogram
    
    def reconstruct(self) -> list[Spectrum]:
        interferogram = self.interferogram()
        if self.noise is not None:
            np.random.seed(0)
            interferogram = interferogram.add_noise(snr_db=self.noise)
        
        transfer_matrix = self.device.transmittance_response(wavenumbers=self.spectrum.wavenumbers)
        transfer_matrix = transfer_matrix.rescale(new_max=1., axis=None)
        db = load_config().database()
        spectrum_rec_list = []
        argmin_rmse_list = []
        for protocol_id in self.protocols_id_list:
            
            if db.inversion_protocols[protocol_id].type == InversionProtocolType.HAAR:
                transmittance = self.device.transmittance(wavenumbers=self.spectrum.wavenumbers).mean()
                reflectivity = self.device.reflectance(wavenumbers=self.spectrum.wavenumbers).mean()
                haar = HaarInverter(
                    transmittance=transmittance,
                    wavenumbers=self.spectrum.wavenumbers,
                    reflectance=reflectivity,
                    order=20,
                    is_mean_center=True,
                )
                spectrum_rec = haar.reconstruct_spectrum(interferogram=interferogram)
                
            else:
                lambdaas = db.inversion_protocol_lambdaas(inv_protocol_id=protocol_id)
                spectrum_rec_lambdaas = np.zeros(shape=(lambdaas.size, *self.spectrum.data.shape))
                for i_lmd, lambdaa in enumerate(lambdaas):
                    inverter = db.inversion_protocol(inv_protocol_id=protocol_id, lambdaa=lambdaa)
                    spectrum_rec_lambdaa = inverter.reconstruct_spectrum(
                        interferogram=interferogram, transmittance_response=transfer_matrix
                    )
                    spectrum_rec_lambdaas[i_lmd] = spectrum_rec_lambdaa.data
                rmse_lambdaas = calculate_rmse(
                    array=spectrum_rec_lambdaas,
                    reference=self.spectrum.data,
                    is_match_axis=-2,
                    is_match_stats=True,
                    is_rescale_reference=True,
                )
                argmin_rmse = np.argmin(rmse_lambdaas)
                spectrum_rec = replace(self.spectrum, data=spectrum_rec_lambdaas[argmin_rmse])
                
            spectrum_rec_list.append(spectrum_rec)
            argmin_rmse_list.append(argmin_rmse)
        
        return spectrum_rec_list, argmin_rmse_list
            

# Experiment Options Schemas

In [None]:
@dataclass(frozen=True)
class ProtocolOptionsSchema:
    id_database: int
    label: str
    color: str


@dataclass(frozen=True)
class InterferometerOptionsSchema:
    label: str
    transmittance: np.ndarray
    reflectivity: np.ndarray


@dataclass(frozen=True)
class ExperimentOptionsSchema:
    experiment_name: str
    dataset_label: str
    opds_sampling: str
    interferometer_options: InterferometerOptionsSchema
    noise: Optional[float]
    protocols_options: list[ProtocolOptionsSchema]
    
    def dataset_id(self) -> int:
        if self.dataset_label == "solar":
            dataset_id = 0
        elif self.dataset_label == "specim":
            dataset_id = 1
        else:
            raise ValueError(f"Dataset label \"{self.dataset_label}\" is not supported.")
        return dataset_id
    
    def spectrum(self) -> Spectrum:
        spectrum = load_config().database().dataset(dataset_id=self.dataset_id())
        return spectrum
    
    def opds(self) -> np.array:
        if self.opds_sampling == "regular":
            step = 0.175
            num = 319
            opds = np.arange(0, step * num, step)
        elif self.opds_sampling == "irregular":
            opds = load_config().database().characterization(characterization_id=0).opds
            opds = np.sort(opds)
            opd_mean_step = np.mean(np.diff(opds))
            lowest_missing_opds = np.arange(start=0., stop=opds.min(), step=opd_mean_step)
            opds = np.concatenate((lowest_missing_opds, opds))
        else:
            raise ValueError("OPDs sampling option must either be \"regular\" or \"irregular\".")
        return opds
    
    def device(self) -> FabryPerotInterferometer:
        device = FabryPerotInterferometer(
            transmittance_coefficients=self.interferometer_options.transmittance,
            opds=self.opds(),
            phase_shift=np.array([0]),
            reflectance_coefficients=self.interferometer_options.reflectivity,
            order=0,
        )
        return device
    
    def spectrum_continuous(self) -> Spectrum:
        spectrum = self.spectrum()
        wn_step = np.mean(np.diff(spectrum.wavenumbers))
        wn_step_nyquist = 1 / (2 * self.opds().max())
        wn_step_target = wn_step_nyquist / self.device().harmonic_order()
        wn_factor_new = int(np.ceil(wn_step / wn_step_target))
        wavenumbers_new = oversample(array=spectrum.wavenumbers, factor=wn_factor_new)
        spectrum_continuous = spectrum.interpolate(wavenumbers=wavenumbers_new, kind="slinear")
        return spectrum_continuous
    
    def create_experiment(self) -> Experiment:
        return Experiment(
            spectrum=self.spectrum_continuous(),
            device=self.device(),
            noise=self.noise,
            protocols_id_list=[protocol.id_database for protocol in self.protocols_options],
        )
    
    def metadata(self) -> ExperimentMetadata:
        return ExperimentMetadata(
            experiment_name=self.experiment_name,
            device_label=self.interferometer_options.label,
            protocol_label=[protocol.label for protocol in self.protocols_options],
            protocol_color=[protocol.color for protocol in self.protocols_options],
        )
        

# Experiment Options

In [None]:
options = ExperimentOptionsSchema(
    experiment_name="simulated/reflectivity_levels",
    interferometer_options=InterferometerOptionsSchema(
        label="fp_0_var_r",
        transmittance=np.array([[6.07531869, -15.40598431, 16.45325227, -8.68203863, 2.25753936, -0.22934861]]),
        reflectivity=np.array([[-5.07531869, 15.40598431, -16.45325227, 8.68203863, -2.25753936, 0.22934861]]),
    ),
    noise=None,
    opds_sampling="regular",
    dataset_label="solar",
    protocols_options=[
        ProtocolOptionsSchema(id_database=0, label="IDCT", color="green"),
        ProtocolOptionsSchema(id_database=19, label="HAAR", color="red"),
        ProtocolOptionsSchema(id_database=1, label="PINV", color="black"),
    ],
)


In [None]:
experiment = options.create_experiment()

spectrum_rec_list = experiment.reconstruct()