In [None]:
from pathlib import Path
from typing import Literal, cast

from shapely.geometry import Point
from scipy.signal import find_peaks
import numpy as np

from s2shores.bathy_debug.spatial_correlation_bathy_estimator_debug import SpatialCorrelationBathyEstimatorDebug
from s2shores.local_bathymetry.spatial_correlation_bathy_estimation import SpatialCorrelationBathyEstimation
from s2shores.waves_exceptions import WavesEstimationError, NotExploitableSinogram
from s2shores.image_processing.waves_radon import WavesRadon
from s2shores.generic_utils.image_utils import normalized_cross_correlation
from s2shores.generic_utils.signal_utils import find_period_from_zeros
from s2shores.bathy_physics import celerity_offshore, period_offshore, wavelength_offshore

from utils import initialize_sequential_run, build_dataset

In [None]:
base_path = Path("../TestsS2Shores").resolve()
test_case: Literal["7_4", "8_2"] = "8_2"
method: Literal["spatial_corr", "spatial_dft", "temporal_corr"] = "spatial_corr"

product_path: Path = base_path / "products" / f"SWASH_{test_case}/testcase_{test_case}.tif"
config_path: Path = base_path / f"reference_results/debug_pointswash_{method}/wave_bathy_inversion_config.yaml"
debug_file: Path = base_path / f"debug_points/debug_points_SWASH_{test_case}.yaml"

estimation_point = Point(451.0, 499.0)

In [None]:
bathy_estimator, ortho_bathy_estimator, ortho_sequence, config = initialize_sequential_run(
    product_path=product_path,
    config_path=config_path,
    point=estimation_point,
)

If you want to change any parameter of the configuration, modify the values of the object `config` as you would with a dict.  

Example:
```python
config["parameter"] = "new_value"
```

In [None]:
local_estimator = SpatialCorrelationBathyEstimatorDebug(
    estimation_point,
    ortho_sequence,
    bathy_estimator,
)

if not local_estimator.can_estimate_bathy():
    raise WavesEstimationError("Cannot estimate bathy.")

## Preprocess images

Modified attributes:
- local_estimator.ortho_sequence.\<elements\>

In [None]:
if True:
    for image in local_estimator.ortho_sequence:
        filtered_image = image.apply_filters(local_estimator.preprocessing_filters)
        image.pixels = filtered_image.pixels
else:
    # Write your custom code here.
    ...


## Find direction

Modified attributes:
- None

In [None]:
if True:
    tmp_wavesradon = WavesRadon(
        local_estimator.ortho_sequence[0],
        local_estimator.selected_directions,
    )
    estimated_direction, _ = tmp_wavesradon.get_direction_maximum_variance()
else:
    # Write your custom code here.
    ...


## Compute radon transforms

New elements:
- local_estimator.randon_transforms

In [None]:
if True:
    for image in local_estimator.ortho_sequence:
        radon_transform = WavesRadon(image, np.array([estimated_direction]))
        radon_transform_augmented = radon_transform.radon_augmentation(
            local_estimator.radon_augmentation_factor)
        local_estimator.radon_transforms.append(radon_transform_augmented)
else:
    # Write your custom code here.
    ...


## Compute spatial correlation

New elements:
- local_estimator.sinograms

In [None]:
if True:
    for radon_transform in local_estimator.radon_transforms:
        tmp_wavessinogram = radon_transform[estimated_direction]
        tmp_wavessinogram.values *= tmp_wavessinogram.variance
        local_estimator.sinograms.append(tmp_wavessinogram)
    sinogram_1 = local_estimator.sinograms[0].values
    # TODO: should be independent from 0/1 (for multiple pairs of frames)
    sinogram_2 = local_estimator.sinograms[1].values
    correl_mode = local_estimator.local_estimator_params['CORRELATION_MODE']
    corr_init = normalized_cross_correlation(sinogram_1, sinogram_2, correl_mode)
    corr_init_ac = normalized_cross_correlation(corr_init, corr_init, correl_mode)
    corr_1 = normalized_cross_correlation(corr_init_ac, sinogram_1, correl_mode)
    corr_2 = normalized_cross_correlation(corr_init_ac, sinogram_2, correl_mode)
    correlation_signal = normalized_cross_correlation(corr_1, corr_2, correl_mode)
else:
    # Write your custom code here.
    ...


## Compute wavelength

Modified attributes:
- None

In [None]:
if True:
    min_wavelength = wavelength_offshore(
        local_estimator.global_estimator.waves_period_min,
        local_estimator.gravity,
    )
    min_period_unitless = int(min_wavelength / local_estimator.augmented_resolution)
    try:
        period, _ = find_period_from_zeros(correlation_signal, min_period_unitless)
        wavelength = period * local_estimator.augmented_resolution
    except ValueError as excp:
        raise NotExploitableSinogram('Wave length can not be computed from sinogram') from excp
else:
    # Write your custom code here.
    ...


## Compute delta position

Modified attributes:
- None

In [None]:
if True:
    peaks_pos, _ = find_peaks(correlation_signal)
    if peaks_pos.size == 0:
        raise WavesEstimationError('Unable to find any directional peak')
    argmax_ac = len(correlation_signal) // 2
    relative_distance = (peaks_pos - argmax_ac) * local_estimator.augmented_resolution

    celerity_offshore_max = celerity_offshore(
        local_estimator.global_estimator.waves_period_max,
        local_estimator.gravity,
    )
    spatial_shift_offshore_max = celerity_offshore_max * local_estimator.propagation_duration
    spatial_shift_min = min(-spatial_shift_offshore_max, spatial_shift_offshore_max)
    spatial_shift_max = -spatial_shift_min

    stroboscopic_factor_offshore = local_estimator.propagation_duration / period_offshore(
        1 / wavelength, local_estimator.gravity)
    
    if abs(stroboscopic_factor_offshore) >= 1:
        # unused for s2
        print('test stroboscopie vrai')
        spatial_shift_offshore_max = (
            local_estimator.local_estimator_params['PEAK_POSITION_MAX_FACTOR']
            * stroboscopic_factor_offshore
            * wavelength
        )

    pt_in_range = peaks_pos[np.where(
        (relative_distance >= spatial_shift_min)
        & (relative_distance < spatial_shift_max)
    )]
    if pt_in_range.size == 0:
        raise WavesEstimationError('Unable to find any directional peak')
    argmax = pt_in_range[correlation_signal[pt_in_range].argmax()]
    delta_position = (argmax - argmax_ac) * local_estimator.augmented_resolution
else:
    # Write your custom code here.
    ...


## Save wave field estimation

New elements:
- local_estimator.bathymetry_estimations

In [None]:
if True:
    bathymetry_estimation = cast(
        SpatialCorrelationBathyEstimation,
        local_estimator.create_bathymetry_estimation(estimated_direction, wavelength),
    )
    bathymetry_estimation.delta_position = delta_position
    local_estimator.bathymetry_estimations.append(bathymetry_estimation)
else:
    # Write your custom code here.
    ...


In [None]:
local_estimator.bathymetry_estimations[0].is_physical()

In [None]:
dataset = build_dataset(
    bathy_estimator,
    ortho_bathy_estimator,
    local_estimator,
)
dataset

## TODO

- Créer une fonction pour les premières étapes ✅
- Documenter les modifications des attributs d'instance de chaque méthode du run
- Eclater la cellule de run ✅
- rentrer dans chaque étape ✅
- flag _DEFAULT pour chaque étape ou custom function ✅
- faire la même chose pour les 3 méthodes



Option: Configuration pydantic