In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
from pathlib import Path
from typing import Literal, cast

from shapely.geometry import Point
from scipy.signal import find_peaks
import numpy as np

from s2shores.bathy_debug.spatial_correlation_bathy_estimator_debug import SpatialCorrelationBathyEstimatorDebug
from s2shores.local_bathymetry.spatial_correlation_bathy_estimation import SpatialCorrelationBathyEstimation
from s2shores.waves_exceptions import WavesEstimationError, NotExploitableSinogram
from s2shores.image_processing.waves_radon import WavesRadon
from s2shores.generic_utils.image_utils import normalized_cross_correlation
from s2shores.generic_utils.signal_utils import find_period_from_zeros
from s2shores.bathy_physics import celerity_offshore, period_offshore, wavelength_offshore

from utils import initialize_sequential_run, build_dataset

In [5]:
base_path = Path("../TestsS2Shores").resolve()
test_case: Literal["7_4", "8_2"] = "8_2"
method: Literal["spatial_corr", "spatial_dft", "temporal_corr"] = "spatial_corr"

product_path: Path = base_path / "products" / f"SWASH_{test_case}/testcase_{test_case}.tif"
config_path: Path = base_path / f"reference_results/debug_pointswash_{method}/wave_bathy_inversion_config.yaml"
debug_file: Path = base_path / f"debug_points/debug_points_SWASH_{test_case}.yaml"

estimation_point = Point(451.0, 499.0)

In [17]:
bathy_estimator, ortho_bathy_estimator, ortho_sequence, config = initialize_sequential_run(
    product_path=product_path,
    config_path=config_path,
    point=estimation_point,
)

Perhaps you already have a cluster running?
Hosting the HTTP server on port 33621 instead


If you want to change any parameter of the configuration, modify the values of the object `config` as you would with a dict.  

Example:
```python
config["parameter"] = "new_value"
```

In [18]:
self = SpatialCorrelationBathyEstimatorDebug(
    estimation_point,
    ortho_sequence,
    bathy_estimator,
)

if not self.can_estimate_bathy():
    raise WavesEstimationError("Cannot estimate bathy.")

## Preprocess images

Modified attributes:
- local_estimator.ortho_sequence.\<elements\>

In [None]:
from copy import deepcopy, copy

def debug(
        obj,
        func,
        *args,
        display_obj=lambda obj, res: display(res),
        **kwargs,
):
    copied_obj = deepcopy(obj)
    result = func(copied_obj, *args, **kwargs)
    display(display_obj(copied_obj, result))

    user_response = input("Do you want to rollback those modifications? [y]/n\n")

    match user_response.lower():
        case "" | "y" | "yes":
            return obj
        case "n" | "no":
            return func(obj, *args, **kwargs)
        case _:
            print("Input not understood. Rolling back...")


In [37]:
self.ortho_sequence[0].pixels

array([[0.2707773 , 0.39541137, 0.52350542, ..., 1.23557564, 1.29505486,
        1.31401977],
       [0.27077882, 0.39541111, 0.52350444, ..., 1.23556214, 1.29504876,
        1.31402177],
       [0.27077367, 0.39540125, 0.52348993, ..., 1.23552036, 1.2950183 ,
        1.31400467],
       ...,
       [0.27065968, 0.39527219, 0.52331731, ..., 1.2352353 , 1.29468456,
        1.31365961],
       [0.27069273, 0.39531591, 0.52336067, ..., 1.23525244, 1.29470838,
        1.31369012],
       [0.27072097, 0.39535941, 0.52340656, ..., 1.23525693, 1.29471312,
        1.31369999]])

In [36]:
debug(
    self,
    SpatialCorrelationBathyEstimatorDebug.preprocess_images,
    display_obj=lambda obj, res: obj.ortho_sequence[0].pixels,
)


() {}
<s2shores.bathy_debug.spatial_correlation_bathy_estimator_debug.SpatialCorrelationBathyEstimatorDebug object at 0x7f4840168bf0>


array([[0.2707773 , 0.39541137, 0.52350542, ..., 1.23557564, 1.29505486,
        1.31401977],
       [0.27077882, 0.39541111, 0.52350444, ..., 1.23556214, 1.29504876,
        1.31402177],
       [0.27077367, 0.39540125, 0.52348993, ..., 1.23552036, 1.2950183 ,
        1.31400467],
       ...,
       [0.27065968, 0.39527219, 0.52331731, ..., 1.2352353 , 1.29468456,
        1.31365961],
       [0.27069273, 0.39531591, 0.52336067, ..., 1.23525244, 1.29470838,
        1.31369012],
       [0.27072097, 0.39535941, 0.52340656, ..., 1.23525693, 1.29471312,
        1.31369999]])

<s2shores.bathy_debug.spatial_correlation_bathy_estimator_debug.SpatialCorrelationBathyEstimatorDebug at 0x7f48401b39e0>

In [13]:
self.ortho_sequence[0].pixels

array([[0.2707773 , 0.39541137, 0.52350542, ..., 1.23557564, 1.29505486,
        1.31401977],
       [0.27077882, 0.39541111, 0.52350444, ..., 1.23556214, 1.29504876,
        1.31402177],
       [0.27077367, 0.39540125, 0.52348993, ..., 1.23552036, 1.2950183 ,
        1.31400467],
       ...,
       [0.27065968, 0.39527219, 0.52331731, ..., 1.2352353 , 1.29468456,
        1.31365961],
       [0.27069273, 0.39531591, 0.52336067, ..., 1.23525244, 1.29470838,
        1.31369012],
       [0.27072097, 0.39535941, 0.52340656, ..., 1.23525693, 1.29471312,
        1.31369999]])

## Find direction

Modified attributes:
- None

In [None]:
if True:
    estimated_direction = self.find_direction()
else:
    tmp_wavesradon = WavesRadon(
        self.ortho_sequence[0],
        self.selected_directions,
    )
    estimated_direction, _ = tmp_wavesradon.get_direction_maximum_variance()

## Compute radon transforms

New elements:
- local_estimator.randon_transforms

In [None]:
if True:
    self.compute_radon_transforms(estimated_direction)
else:
    for image in self.ortho_sequence:
        radon_transform = WavesRadon(image, np.array([estimated_direction]))
        radon_transform_augmented = radon_transform.radon_augmentation(
            self.radon_augmentation_factor)
        self.radon_transforms.append(radon_transform_augmented)

## Compute spatial correlation

New elements:
- local_estimator.sinograms

In [None]:
if True:
    correlation_signal = self.compute_spatial_correlation(estimated_direction)
else:
    for radon_transform in self.radon_transforms:
        tmp_wavessinogram = radon_transform[estimated_direction]
        tmp_wavessinogram.values *= tmp_wavessinogram.variance
        self.sinograms.append(tmp_wavessinogram)
    sinogram_1 = self.sinograms[0].values
    # TODO: should be independent from 0/1 (for multiple pairs of frames)
    sinogram_2 = self.sinograms[1].values
    correl_mode = self.local_estimator_params['CORRELATION_MODE']
    corr_init = normalized_cross_correlation(sinogram_1, sinogram_2, correl_mode)
    corr_init_ac = normalized_cross_correlation(corr_init, corr_init, correl_mode)
    corr_1 = normalized_cross_correlation(corr_init_ac, sinogram_1, correl_mode)
    corr_2 = normalized_cross_correlation(corr_init_ac, sinogram_2, correl_mode)
    correlation_signal = normalized_cross_correlation(corr_1, corr_2, correl_mode)

## Compute wavelength

Modified attributes:
- None

In [None]:
if True:
    wavelength = self.compute_wavelength(correlation_signal)
else:
    min_wavelength = wavelength_offshore(
        self.global_estimator.waves_period_min,
        self.gravity,
    )
    min_period_unitless = int(min_wavelength / self.augmented_resolution)
    try:
        period, _ = find_period_from_zeros(correlation_signal, min_period_unitless)
        wavelength = period * self.augmented_resolution
    except ValueError as excp:
        raise NotExploitableSinogram('Wave length can not be computed from sinogram') from excp

## Compute delta position

Modified attributes:
- None

In [None]:
if True:
    delta_position = self.compute_delta_position(correlation_signal, wavelength)
else:
    peaks_pos, _ = find_peaks(correlation_signal)
    if peaks_pos.size == 0:
        raise WavesEstimationError('Unable to find any directional peak')
    argmax_ac = len(correlation_signal) // 2
    relative_distance = (peaks_pos - argmax_ac) * self.augmented_resolution

    celerity_offshore_max = celerity_offshore(
        self.global_estimator.waves_period_max,
        self.gravity,
    )
    spatial_shift_offshore_max = celerity_offshore_max * self.propagation_duration
    spatial_shift_min = min(-spatial_shift_offshore_max, spatial_shift_offshore_max)
    spatial_shift_max = -spatial_shift_min

    stroboscopic_factor_offshore = self.propagation_duration / period_offshore(
        1 / wavelength, self.gravity)
    
    if abs(stroboscopic_factor_offshore) >= 1:
        # unused for s2
        print('test stroboscopie vrai')
        spatial_shift_offshore_max = (
            self.local_estimator_params['PEAK_POSITION_MAX_FACTOR']
            * stroboscopic_factor_offshore
            * wavelength
        )

    pt_in_range = peaks_pos[np.where(
        (relative_distance >= spatial_shift_min)
        & (relative_distance < spatial_shift_max)
    )]
    if pt_in_range.size == 0:
        raise WavesEstimationError('Unable to find any directional peak')
    argmax = pt_in_range[correlation_signal[pt_in_range].argmax()]
    delta_position = (argmax - argmax_ac) * self.augmented_resolution


## Save wave field estimation

New elements:
- local_estimator.bathymetry_estimations

In [None]:
if True:
    self.save_wave_field_estimation(estimated_direction, wavelength, delta_position)
else:
    bathymetry_estimation = cast(
        SpatialCorrelationBathyEstimation,
        self.create_bathymetry_estimation(estimated_direction, wavelength),
    )
    bathymetry_estimation.delta_position = delta_position
    self.bathymetry_estimations.append(bathymetry_estimation)


In [None]:
self.bathymetry_estimations[0].is_physical()

In [None]:
dataset = build_dataset(
    bathy_estimator,
    ortho_bathy_estimator,
    self,
)
dataset

## TODO

- Créer une fonction pour les premières étapes ✅
- Documenter les modifications des attributs d'instance de chaque méthode du run ✅
- Eclater la cellule de run ✅
- rentrer dans chaque étape ✅
- flag _DEFAULT pour chaque étape ou custom function ✅
- faire la même chose pour les 3 méthodes



Option: Configuration pydantic ✅