In [1]:
import sys
import os

notebook_cwd = os.getcwd()
src_path = os.path.join(notebook_cwd, '..')
if src_path not in sys.path:
    sys.path.append(src_path)


# Imports

In [2]:
from enum import Enum
from dataclasses import dataclass
from typing import Optional

import numpy as np

from src.common_utils.interferogram import Interferogram
from src.common_utils.light_wave import Spectrum
from src.common_utils.utils import calculate_rmse
from src.interface.configuration import load_config
from src.direct_model.interferometer import FabryPerotInterferometer
from src.common_utils.utils import oversample
from dataclasses import replace
from src.inverse_model.analytical_inverter import HaarInverter
from src.common_utils.custom_vars import InversionProtocolType
from src.common_utils.custom_vars import DatasetTitle


# Experiment Class

In [3]:
@dataclass(frozen=True)
class ExperimentMetadata:
    experiment_name: str
    device_label: str
    protocol_label: list[str]
    protocol_color: list[str]


@dataclass(frozen=True)
class Experiment:
    spectrum: Spectrum
    device: FabryPerotInterferometer
    noise: float
    protocols_id_list: list[int]
    
    def interferogram(self) -> Interferogram:
        interferogram = self.device.acquire_interferogram(spectrum=self.spectrum)
        return interferogram
    
    def reconstruct(self) -> list[Spectrum]:
        interferogram = self.interferogram()
        if self.noise is not None:
            np.random.seed(0)
            interferogram = interferogram.add_noise(snr_db=self.noise)
        
        spectrum_rec_list = []
        argmin_rmse_list = []
        for protocol_id in self.protocols_id_list:
            
            if load_config().database().inversion_protocols[protocol_id].type == InversionProtocolType.HAAR:
                spectrum_rec = self.invert_haar(
                    interferogram=interferogram,
                    device=self.device,
                    wavenumbers=self.spectrum.wavenumbers,
                )
                
            else:
                spectrum_rec, argmin_rmse = self.invert_protocol(
                    interferogram=interferogram,
                    device=self.device,
                    wavenumbers=self.spectrum.wavenumbers,
                    protocol_id=protocol_id,
                    spectrum_ref=self.spectrum,
                )
                
            spectrum_rec_list.append(spectrum_rec)
            argmin_rmse_list.append(argmin_rmse)
        
        return spectrum_rec_list, argmin_rmse_list
    
    @staticmethod
    def invert_haar(
        interferogram: Interferogram,
        device: FabryPerotInterferometer,
        wavenumbers: np.ndarray,
    ) -> Spectrum:
        transmittance = device.transmittance(wavenumbers=wavenumbers).mean()
        reflectivity = device.reflectance(wavenumbers=wavenumbers).mean()
        haar = HaarInverter(
            transmittance=transmittance,
            wavenumbers=wavenumbers,
            reflectance=reflectivity,
            order=20,
            is_mean_center=True,
        )
        spectrum_rec = haar.reconstruct_spectrum(interferogram=interferogram)
        return spectrum_rec
    
    @staticmethod
    def invert_protocol(
        interferogram: Interferogram,
        device: FabryPerotInterferometer,
        wavenumbers: np.ndarray,
        protocol_id: int,
        spectrum_ref: Spectrum,
    ) -> Spectrum:
        transfer_matrix = device.transmittance_response(wavenumbers=wavenumbers)
        transfer_matrix = transfer_matrix.rescale(new_max=1., axis=None)
        db = load_config().database()
        lambdaas = db.inversion_protocol_lambdaas(inv_protocol_id=protocol_id)
        spectrum_rec_lambdaas = np.zeros(shape=(lambdaas.size, wavenumbers.size, interferogram.data.shape[1]))
        for i_lmd, lambdaa in enumerate(lambdaas):
            inverter = db.inversion_protocol(inv_protocol_id=protocol_id, lambdaa=lambdaa)
            spectrum_rec_lambdaa = inverter.reconstruct_spectrum(
                interferogram=interferogram, transmittance_response=transfer_matrix
            )
            spectrum_rec_lambdaas[i_lmd] = spectrum_rec_lambdaa.data
        rmse_lambdaas = calculate_rmse(
            array=spectrum_rec_lambdaas,
            reference=spectrum_ref.data,
            is_match_axis=-2,
            is_match_stats=True,
            is_rescale_reference=True,
        )
        argmin_rmse = np.argmin(rmse_lambdaas)
        spectrum_rec = replace(spectrum_ref, data=spectrum_rec_lambdaas[argmin_rmse])
        return spectrum_rec, argmin_rmse
    

# Experiment Options Schemas

In [4]:
@dataclass(frozen=True)
class ProtocolOptionsSchema:
    id_database: int
    label: str
    color: str


@dataclass(frozen=True)
class InterferometerOptionsSchema:
    label: str
    transmittance: np.ndarray
    reflectivity: np.ndarray


@dataclass(frozen=True)
class ExperimentOptionsSchema:
    experiment_name: str
    dataset_label: DatasetTitle
    opds_sampling: str
    interferometer_options: InterferometerOptionsSchema
    noise: Optional[float]
    protocols_options: list[ProtocolOptionsSchema]
    
    def dataset_id(self) -> int:
        if self.dataset_label == DatasetTitle.SOLAR:
            dataset_id = 0
        elif self.dataset_label == DatasetTitle.SPECIM:
            dataset_id = 1
        else:
            raise ValueError(f"Dataset label \"{self.dataset_label}\" is not supported.")
        return dataset_id
    
    def spectrum(self) -> Spectrum:
        spectrum = load_config().database().dataset(dataset_id=self.dataset_id())
        return spectrum
    
    def opds(self) -> np.array:
        if self.opds_sampling == "regular":
            step = 0.175
            num = 319
            opds = np.arange(0, step * num, step)
        elif self.opds_sampling == "irregular":
            opds = load_config().database().characterization(characterization_id=0).opds
            opds = np.sort(opds)
            opd_mean_step = np.mean(np.diff(opds))
            lowest_missing_opds = np.arange(start=0., stop=opds.min(), step=opd_mean_step)
            opds = np.concatenate((lowest_missing_opds, opds))
        else:
            raise ValueError("OPDs sampling option must either be \"regular\" or \"irregular\".")
        return opds
    
    def device(self) -> FabryPerotInterferometer:
        device = FabryPerotInterferometer(
            transmittance_coefficients=self.interferometer_options.transmittance,
            opds=self.opds(),
            phase_shift=np.array([0]),
            reflectance_coefficients=self.interferometer_options.reflectivity,
            order=0,
        )
        return device
    
    def spectrum_continuous(self) -> Spectrum:
        spectrum = self.spectrum()
        wn_step = np.mean(np.diff(spectrum.wavenumbers))
        wn_step_nyquist = 1 / (2 * self.opds().max())
        wn_step_target = wn_step_nyquist / self.device().harmonic_order()
        wn_factor_new = int(np.ceil(wn_step / wn_step_target))
        wavenumbers_new = oversample(array=spectrum.wavenumbers, factor=wn_factor_new)
        spectrum_continuous = spectrum.interpolate(wavenumbers=wavenumbers_new, kind="slinear")
        return spectrum_continuous
    
    def create_experiment(self) -> Experiment:
        return Experiment(
            spectrum=self.spectrum_continuous(),
            device=self.device(),
            noise=self.noise,
            protocols_id_list=[protocol.id_database for protocol in self.protocols_options],
        )
    
    def metadata(self) -> ExperimentMetadata:
        return ExperimentMetadata(
            experiment_name=self.experiment_name,
            device_label=self.interferometer_options.label,
            protocol_label=[protocol.label for protocol in self.protocols_options],
            protocol_color=[protocol.color for protocol in self.protocols_options],
        )
        

## Simple Test Example (setting options manually)

In [5]:
options = ExperimentOptionsSchema(
    experiment_name="simulated/reflectivity_levels",
    interferometer_options=InterferometerOptionsSchema(
        label="fp_0_var_r",
        transmittance=np.array([[6.07531869, -15.40598431, 16.45325227, -8.68203863, 2.25753936, -0.22934861]]),
        reflectivity=np.array([[-5.07531869, 15.40598431, -16.45325227, 8.68203863, -2.25753936, 0.22934861]]),
    ),
    noise=None,
    opds_sampling="regular",
    dataset_label="solar",
    protocols_options=[
        ProtocolOptionsSchema(id_database=0, label="IDCT", color="green"),
        ProtocolOptionsSchema(id_database=19, label="HAAR", color="red"),
        ProtocolOptionsSchema(id_database=1, label="PINV", color="black"),
    ],
)

experiment = options.create_experiment()

spectrum_rec_list = experiment.reconstruct()


# Paper Experiment Options

In [10]:
class SimulatedTestType(str, Enum):
    REFLECTIVITY = "reflectivity"
    IRREGULAR_SAMPLING = "irregular_sampling"
    NOISE_LEVELS = "noise_levels"

def load_paper_options(test_type: SimulatedTestType):
    if test_type == SimulatedTestType.REFLECTIVITY:
        print("Loading Reflectivity options.")
        options = {
            "experiment_name": "simulated/reflectivity_levels",
            "interferometer_options_list": [
                InterferometerOptionsSchema(
                    label="fp_0_low_r",
                    transmittance=np.array([[1]]),
                    reflectivity=np.array([[0.2]]),
                ),
                InterferometerOptionsSchema(
                    label="fp_0_medium_r",
                    transmittance=np.array([[1]]),
                    reflectivity=np.array([[0.4]]),
                ),
                InterferometerOptionsSchema(
                    label="fp_0_high_r",
                    transmittance=np.array([[1]]),
                    reflectivity=np.array([[0.7]]),
                ),
                InterferometerOptionsSchema(
                    label="fp_0_variable_r",
                    transmittance=np.array([[6.07531869, -15.40598431, 16.45325227, -8.68203863, 2.25753936, -0.22934861]]),
                    reflectivity=np.array([[-5.07531869, 15.40598431, -16.45325227, 8.68203863, -2.25753936, 0.22934861]]),
                ),
            ],
            "noise_list": [None],
            "opds_sampling_list": ["regular"],
            "dataset_label_list": [DatasetTitle.SOLAR, DatasetTitle.SPECIM],
            "protocols_options": [
                ProtocolOptionsSchema(id_database=0, label="IDCT", color="green"),
                ProtocolOptionsSchema(id_database=19, label="HAAR", color="red"),
                ProtocolOptionsSchema(id_database=1, label="PINV", color="black"),
            ],
        }
        
    elif test_type == SimulatedTestType.IRREGULAR_SAMPLING:
        print("Loading Irregular Sampling options.")
        options = {
            "experiment_name": "simulated/reflectivity_levels",
            "interferometer_options_list": [
                InterferometerOptionsSchema(
                    label="fp_0_low_r",
                    transmittance=np.array([[1]]),
                    reflectivity=np.array([[0.2]]),
                ),
            ],
            "noise_list": [None],
            "opds_sampling_list": ["regular", "irregular"],
            "dataset_label_list": [DatasetTitle.SOLAR, DatasetTitle.SPECIM],
            "protocols_options": [
                ProtocolOptionsSchema(id_database=0, label="IDCT", color="green"),
                ProtocolOptionsSchema(id_database=19, label="HAAR", color="red"),
                ProtocolOptionsSchema(id_database=1, label="PINV", color="black"),
            ],
        }
        
    elif test_type == SimulatedTestType.NOISE_LEVELS:
        print("Loading Noise Levels options.")
        options = {
            "experiment_name": "simulated/reflectivity_levels",
            "interferometer_options_list": [
                InterferometerOptionsSchema(
                    label="fp_0_low_r",
                    transmittance=np.array([[1]]),
                    reflectivity=np.array([[0.2]]),
                ),
            ],
            "noise_list": [20., 15.],
            "opds_sampling_list": ["regular"],
            "dataset_label_list": [DatasetTitle.SOLAR, DatasetTitle.SPECIM],
            "protocols_options": [
                ProtocolOptionsSchema(id_database=0, label="IDCT", color="green"),
                ProtocolOptionsSchema(id_database=19, label="HAAR", color="red"),
                ProtocolOptionsSchema(id_database=2, label="TSVD", color="purple"),
                ProtocolOptionsSchema(id_database=3, label="RR", color="orange"),
                ProtocolOptionsSchema(id_database=4, label="LV-L1", color="black"),
            ],
        }
    
    else:
        raise ValueError(f"Simulated test type option \"{test_type}\" is not supported.")

    return options


In [11]:
test_type = SimulatedTestType.NOISE_LEVELS
test_options = load_paper_options(test_type=test_type)
for dataset_label in test_options["dataset_label_list"]:
    for opds_sampling in test_options["opds_sampling_list"]:
        for interferometer_options in test_options["interferometer_options_list"]:
            for noise in test_options["noise_list"]:
                options = ExperimentOptionsSchema(
                    experiment_name=test_options["experiment_name"],
                    interferometer_options=interferometer_options,
                    noise=noise,
                    opds_sampling=opds_sampling,
                    dataset_label=dataset_label,
                    protocols_options=test_options["protocols_options"],
                )
                experiment = options.create_experiment()
                spectrum_rec_list = experiment.reconstruct()


Loading Noise Levels options.


KeyboardInterrupt: 