In [None]:
from PS_S2_fusion import (
    rasterio,
    NDVIDataLoader,
    NDVIProcessor,
    random,
    Dict,
    NDVIImage,
    np,
    LinearRegression,
    RandomForestRegressor,
    NDVIVisualizer,
)
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

from pathlib import Path

In [None]:
# Global flag: Set to True to include before/after Sentinel-2 images as features, False to exclude
INCLUDE_BEFORE_AFTER = False

# global input and output dir settings
input_dir = Path("D:/KiKompAg/Workshop/Data/clip_for_workshop/13")
output_dir = Path("D:/KiKompAg/Workshop/Data/output")

# input_dir = Path("./input/13")
# output_dir = Path("./output")

In [None]:
def save_tif(data, output_file, crs, transform):
    """Saves raster data to a GeoTIFF file.

    Args:
        data (array-like): The 2D or 3D raster data to save.
        output_file (str): The path and filename of the output GeoTIFF file.
        crs (str or rasterio.crs.CRS): The Coordinate Reference System.
              Can be an EPSG code string (e.g., 'EPSG:4326')
              or a rasterio.crs.CRS object.
        transform (affine.Affine): The affine transformation for the raster.

    Returns:
        None
    """

    # Determine number of bands
    if data.ndim == 2:
        count = 1
    elif data.ndim == 3:
        count = data.shape[0]  # Assume bands are in the first dimension
    else:
        raise ValueError("Data must be 2D or 3D.")

    profile = {
        "driver": "GTiff",
        "height": data.shape[1] if data.ndim == 3 else data.shape[0],
        "width": data.shape[2] if data.ndim == 3 else data.shape[1],
        "count": count,
        "dtype": data.dtype.name,  # Use data type of input data
        "crs": crs,
        "transform": transform,
    }

    with rasterio.open(output_file, "w", **profile) as dst:
        if count == 1:
            dst.write(data, 1)
        else:
            for band_index in range(1, count + 1):
                dst.write(data[band_index - 1], band_index)

In [None]:
# 1. Load NDVI data
data_loader = NDVIDataLoader()
processor = NDVIProcessor()

shapefile_2021 = input_dir / "2021/shp/2021_fid13_maize_subset.shp"
shapefile_2022 = input_dir / "2022/shp/2022_fid13_maize_subset.shp"

sentinel_2_images_2021 = data_loader.load_ndvi_images(
    input_dir / "2021/ndvi/S2/", "S2*.tif"
)
planet_scope_images_2021 = data_loader.load_ndvi_images(
    input_dir / "2021/ndvi/PS/", "PS*.tif"
)

sentinel_2_images_2022 = data_loader.load_ndvi_images(
    input_dir / "2022/ndvi/S2/", "S2*.tif"
)
planet_scope_images_2022 = data_loader.load_ndvi_images(
    input_dir / "2022/ndvi/PS/", "PS*.tif"
)

In [None]:
def find_matching_image_pairs(sentinel_images, planet_images, year):
    """
    Finds matching image pairs from Sentinel-2 and PlanetScope dictionaries
    based on DOY with a descriptive output.

    Args:
        sentinel_images: Dictionary with DOY as key and Sentinel-2 NDVI data as value.
        planet_images: Dictionary with DOY as key and PlanetScope NDVI data as value.

    Returns:
        A list of dictionaries, where each dictionary represents a match:
            {'planet_doy': ...,
             'sentinel_matching_doy': ...,
             'sentinel_before_doy': ...,
             'sentinel_after_doy': ...}
    """
    matching_pairs = []

    for planet_doy in planet_images:
        for doy_offset in range(-1, 2):  # Check DOY ± 3 days
            matching_sentinel_doy = planet_doy + doy_offset
            if matching_sentinel_doy in sentinel_images:
                # Find nearest before/after Sentinel images
                sentinel_before_doy = max(
                    [doy for doy in sentinel_images if doy < matching_sentinel_doy]
                )
                sentinel_after_doy = min(
                    [doy for doy in sentinel_images if doy > matching_sentinel_doy]
                )
                matching_pairs.append(
                    {
                        "year": year,
                        "planet_doy": planet_doy,
                        "sentinel_matching_doy": matching_sentinel_doy,
                        "sentinel_before_doy": sentinel_before_doy,
                        "sentinel_after_doy": sentinel_after_doy,
                    }
                )
                break  # Move on to the next PlanetScope image after a match
    return matching_pairs

In [None]:
# Example usage with your dictionaries:
matching_pairs_2022 = find_matching_image_pairs(
    sentinel_2_images_2022, planet_scope_images_2022, 2022
)

matching_pairs_2021 = find_matching_image_pairs(
    sentinel_2_images_2021, planet_scope_images_2021, 2021
)

In [None]:
matching_pairs = matching_pairs_2021 + matching_pairs_2022
validation_pairs = random.sample(matching_pairs, 5)

training_pairs = [
    element for element in matching_pairs if element not in validation_pairs
]

In [None]:
print(matching_pairs_2021)

In [None]:
def create_feature_target_pair(
    match,
    sentinel_images: Dict[int, NDVIImage],
    planet_images: Dict[int, NDVIImage],
    shapefile_path: Path,
):
    """
    Creates a single feature array (X) and target array (y) pair
    for one set of matching images.

    Args:
        match: A single dictionary from the matching_pairs list.
        sentinel_images: Dictionary with DOY as key and Sentinel-2 NDVIImage.
        planet_images: Dictionary with DOY as key and PlanetScope NDVIImage.

    Returns:
        A tuple: (features (X), targets (y))
            - features (X): 2D NumPy array of features, one row per valid pixel.
            - targets (y): 1D NumPy array of corresponding Sentinel-2 NDVI.
    """
    processor = NDVIProcessor()

    year = match["year"]
    planet_doy = match["planet_doy"]
    sentinel_matching_doy = match["sentinel_matching_doy"]
    sentinel_before_doy = match["sentinel_before_doy"]
    sentinel_after_doy = match["sentinel_after_doy"]

    # Resample PlanetScope
    planet_image = processor.resample_planet_to_sentinel(
        sentinel_images[sentinel_matching_doy], planet_images[planet_doy]
    )

    # Extract pixel data
    sentinel_matching_image = sentinel_images[sentinel_matching_doy]
    sentinel_before_image = sentinel_images[sentinel_before_doy]
    sentinel_after_image = sentinel_images[sentinel_after_doy]

    # exclude maize
    planet_image.ndvi = processor.mask_array_with_shapefile(
        planet_image.ndvi, planet_image.transform, shapefile_path
    )
    sentinel_matching_image.ndvi = processor.mask_array_with_shapefile(
        sentinel_matching_image.ndvi, sentinel_matching_image.transform, shapefile_path
    )
    sentinel_before_image.ndvi = processor.mask_array_with_shapefile(
        sentinel_before_image.ndvi, sentinel_before_image.transform, shapefile_path
    )
    sentinel_after_image.ndvi = processor.mask_array_with_shapefile(
        sentinel_after_image.ndvi, sentinel_after_image.transform, shapefile_path
    )

    # ... (Your processing with invalid_mask as before)
    invalid_mask_matching = processor.get_invalid_mask(
        sentinel_ndvi_data=sentinel_matching_image.ndvi,
        planet_ndvi_data=planet_image.ndvi,
    )

    invalid_mask_before_after = processor.get_invalid_mask(
        sentinel_before_image.ndvi, sentinel_after_image.ndvi
    )

    invalid_mask = np.logical_or(invalid_mask_matching, invalid_mask_before_after)
    invalid_mask = invalid_mask_matching  # for +/-1 day offset

    sentinel_matching_data = processor.get_preprocessed_ndiv_data(
        sentinel_matching_image, invalid_mask
    ).flatten()
    if INCLUDE_BEFORE_AFTER:
        sentinel_before_data = processor.get_preprocessed_ndiv_data(
            sentinel_before_image, invalid_mask
        ).flatten()
        sentinel_after_data = processor.get_preprocessed_ndiv_data(
            sentinel_after_image, invalid_mask
        ).flatten()
    planet_data = processor.get_preprocessed_ndiv_data(
        planet_image, invalid_mask
    ).flatten()

    num_valid_pixels = len(sentinel_matching_data)

    # Create feature matrix (one row per valid pixel)
    if INCLUDE_BEFORE_AFTER:
        features = np.column_stack(
            (
                planet_data,
                sentinel_before_data,
                sentinel_after_data,
                np.full(num_valid_pixels, planet_doy),
                np.full(num_valid_pixels, sentinel_before_doy),
                np.full(num_valid_pixels, sentinel_after_doy),
                np.full(num_valid_pixels, year),
            )
        )
    else:
        features = np.column_stack(
            (
                planet_data,
                np.full(num_valid_pixels, planet_doy),
                np.full(num_valid_pixels, year),
            )
        )

    targets = sentinel_matching_data

    return features, targets, invalid_mask

In [None]:
# Example usage: Process data for each match in a loop
all_features = []
all_targets = []
all_masks = []
for match in training_pairs:
    if match["year"] == 2022:
        X, y, invalid_mask = create_feature_target_pair(
            match,
            sentinel_2_images_2022,
            planet_scope_images_2022,
            shapefile_2022,
        )
    else:
        X, y, invalid_mask = create_feature_target_pair(
            match,
            sentinel_2_images_2021,
            planet_scope_images_2021,
            shapefile_2021,
        )

    all_features.append(X)
    all_targets.append(y)
    all_masks.append(invalid_mask)

# Concatenate the results if needed
X = np.concatenate(all_features, axis=0)
y = np.concatenate(all_targets)

print(X.shape, y.shape)

In [None]:
lr_model = LinearRegression()

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)
lr_model.fit(X_train, y_train)

predictions = lr_model.predict(X_test)
mse = mean_squared_error(y_test, predictions)

print(mse)

In [None]:
rf_model = RandomForestRegressor(
    n_estimators=40, max_depth=15, random_state=42, n_jobs=8
)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)
rf_model.fit(X_train, y_train)

predictions = rf_model.predict(X_test)
mse = mean_squared_error(y_test, predictions)

print(mse)

In [None]:
import copy

all_features = []
all_targets = []
all_masks = []
for match in validation_pairs:
    if match["year"] == 2021:
        X, y, invalid_mask = create_feature_target_pair(
            match,
            sentinel_2_images_2021,
            planet_scope_images_2021,
            shapefile_2021,
        )
    else:
        X, y, invalid_mask = create_feature_target_pair(
            match,
            sentinel_2_images_2022,
            planet_scope_images_2022,
            shapefile_2022,
        )
    all_features.append(X)
    all_targets.append(y)
    all_masks.append(invalid_mask)

# Evaluate and visualize within a single loop
mse_dict_synth = {}
mse_dict_resampled = {}
mean_ndiv_dict_synth = {}
mean_ndiv_dict_resampled = {}
mean_ndiv_dict_real = {}
visualizer = NDVIVisualizer()

for i, match in enumerate(validation_pairs):
    year = match["year"]
    if year == 2022:
        sentinel_2_images = sentinel_2_images_2022
        planet_scope_images = planet_scope_images_2022
    else:
        sentinel_2_images = sentinel_2_images_2021
        planet_scope_images = planet_scope_images_2021

    planet_doy = match["planet_doy"]
    sentinel_matching_doy = match["sentinel_matching_doy"]

    # Make predictions
    predictions = lr_model.predict(all_features[i])

    # Create synthetic image
    synth_image = copy.deepcopy(sentinel_2_images[sentinel_matching_doy])
    synth_image.ndvi[~all_masks[i]] = predictions

    #    save_tif(synth_image.ndvi, f"output/{year}-{planet_doy}-linear.tif", synth_image.crs, synth_image.transform)

    # Resampled PlanetScope for comparison
    resampled_image = processor.resample_planet_to_sentinel(
        sentinel_2_images[sentinel_matching_doy],
        planet_scope_images[planet_doy],
    )
    if year == 2022:
        resampled_image.ndvi = processor.mask_array_with_shapefile(
            resampled_image.ndvi, resampled_image.transform, shapefile_2022
        )
    else:
        resampled_image.ndvi = processor.mask_array_with_shapefile(
            resampled_image.ndvi, resampled_image.transform, shapefile_2021
        )

    # Visualize using your NDVIVisualizer
    visualizer.visualize_ndvi_images(
        sentinel_2_images[sentinel_matching_doy].ndvi,
        synth_image.ndvi,
        resampled_image.ndvi,
        planet_doy,
        year,
    )

    cropout_sentinel_2_ndvi = sentinel_2_images[sentinel_matching_doy].ndvi[
        #        600:800, 1000:1200
        300:500, 200:400
    ]
    cropout_synth_ndvi = synth_image.ndvi[300:500, 200:400]
    cropout_resampled_ndvi = resampled_image.ndvi[300:500, 200:400]

    visualizer.visualize_ndvi_images(
        cropout_sentinel_2_ndvi,
        cropout_synth_ndvi,
        cropout_resampled_ndvi,
        planet_doy,
        year,
    )

    # Calculate MSE
    mse_synth = mean_squared_error(all_targets[i], predictions)

    print(f"Mean Squared Error for Synthetic Planetscope DOY {planet_doy}: {mse_synth}")

    mse_resampled = mean_squared_error(
        all_targets[i],
        processor.get_preprocessed_ndiv_data(resampled_image, all_masks[i]),
    )

    print(
        f"Mean Squared Error for Resampled Planetscope DOY {planet_doy}: {mse_resampled}"
    )

    mse_dict_synth[planet_doy] = mse_synth
    mse_dict_resampled[planet_doy] = mse_resampled
    mean_ndiv_dict_synth[planet_doy] = np.mean(predictions)
    mean_ndiv_dict_resampled[planet_doy] = np.nanmean(resampled_image.ndvi)
    mean_ndiv_dict_real[planet_doy] = np.nanmean(
        sentinel_2_images[sentinel_matching_doy].ndvi
    )


visualizer.plot_time_series(
    mse_dict_synth,
    mse_dict_resampled,
    mean_ndiv_dict_synth,
    mean_ndiv_dict_resampled,
    mean_ndiv_dict_real,
)

In [None]:
print(f"Total Error Resampled: {sum(mse_dict_resampled.values())}")
print(f"Total Error Synthetic: {sum(mse_dict_synth.values())}")

In [None]:
import copy

all_features = []
all_targets = []
all_masks = []
for match in validation_pairs:
    if match["year"] == 2021:
        X, y, invalid_mask = create_feature_target_pair(
            match, sentinel_2_images_2021, planet_scope_images_2021, shapefile_2021
        )
    else:
        X, y, invalid_mask = create_feature_target_pair(
            match, sentinel_2_images_2022, planet_scope_images_2022, shapefile_2022
        )
    all_features.append(X)
    all_targets.append(y)
    all_masks.append(invalid_mask)

# Evaluate and visualize within a single loop
mse_dict_synth = {}
mse_dict_resampled = {}
mean_ndiv_dict_synth = {}
mean_ndiv_dict_resampled = {}
mean_ndiv_dict_real = {}
visualizer = NDVIVisualizer()

for i, match in enumerate(validation_pairs):
    year = match["year"]
    if year == 2022:
        sentinel_2_images = sentinel_2_images_2022
        planet_scope_images = planet_scope_images_2022
    else:
        sentinel_2_images = sentinel_2_images_2021
        planet_scope_images = planet_scope_images_2021

    planet_doy = match["planet_doy"]
    sentinel_matching_doy = match["sentinel_matching_doy"]

    # Make predictions
    predictions = rf_model.predict(all_features[i])

    # Create synthetic image
    synth_image = copy.deepcopy(sentinel_2_images[sentinel_matching_doy])
    synth_image.ndvi[~all_masks[i]] = predictions

    #    save_tif(synth_image.ndvi, f"output/{year}-{planet_doy}-random-forest.tif", synth_image.crs, synth_image.transform)

    # Resampled PlanetScope for comparison
    resampled_image = processor.resample_planet_to_sentinel(
        sentinel_2_images[sentinel_matching_doy],
        planet_scope_images[planet_doy],
    )
    if year == 2022:
        resampled_image.ndvi = processor.mask_array_with_shapefile(
            resampled_image.ndvi, resampled_image.transform, shapefile_2022
        )
    else:
        resampled_image.ndvi = processor.mask_array_with_shapefile(
            resampled_image.ndvi, resampled_image.transform, shapefile_2021
        )

    # Visualize using your NDVIVisualizer
    visualizer.visualize_ndvi_images(
        sentinel_2_images[sentinel_matching_doy].ndvi,
        synth_image.ndvi,
        resampled_image.ndvi,
        planet_doy,
        year,
    )

    cropout_sentinel_2_ndvi = sentinel_2_images[sentinel_matching_doy].ndvi[
        300:500, 200:400
    ]
    cropout_synth_ndvi = synth_image.ndvi[300:500, 200:400]
    cropout_resampled_ndvi = resampled_image.ndvi[300:500, 200:400]

    visualizer.visualize_ndvi_images(
        cropout_sentinel_2_ndvi,
        cropout_synth_ndvi,
        cropout_resampled_ndvi,
        planet_doy,
        year,
    )

    # Calculate MSE
    mse_synth = mean_squared_error(all_targets[i], predictions)

    print(f"Mean Squared Error for Synthetic Planetscope DOY {planet_doy}: {mse_synth}")

    mse_resampled = mean_squared_error(
        all_targets[i],
        processor.get_preprocessed_ndiv_data(resampled_image, all_masks[i]),
    )

    print(
        f"Mean Squared Error for Resampled Planetscope DOY {planet_doy}: {mse_resampled}"
    )

    mse_dict_synth[planet_doy] = mse_synth
    mse_dict_resampled[planet_doy] = mse_resampled
    mean_ndiv_dict_synth[planet_doy] = np.mean(predictions)
    mean_ndiv_dict_resampled[planet_doy] = np.nanmean(resampled_image.ndvi)
    mean_ndiv_dict_real[planet_doy] = np.nanmean(
        sentinel_2_images[sentinel_matching_doy].ndvi
    )


visualizer.plot_time_series(
    mse_dict_synth,
    mse_dict_resampled,
    mean_ndiv_dict_synth,
    mean_ndiv_dict_resampled,
    mean_ndiv_dict_real,
)

In [None]:
print(f"Total Error Resampled: {sum(mse_dict_resampled.values())}")
print(f"Total Error Synthetic: {sum(mse_dict_synth.values())}")

# From Here on run this after training the model to load and synthesize/merge all images


In [None]:
def find_closest_sentinel_images(sentinel_images, planet_images, year):
    """
    Finds matching image pairs from Sentinel-2 and PlanetScope dictionaries
    based on DOY with a descriptive output.

    Args:
        sentinel_images: Dictionary with DOY as key and Sentinel-2 NDVI data as value.
        planet_images: Dictionary with DOY as key and PlanetScope NDVI data as value.

    Returns:
        A list of dictionaries, where each dictionary represents a match:
            {'planet_doy': ...,
             'sentinel_before_doy': ...,
             'sentinel_after_doy': ...}
    """
    set_of_input_images = []

    for planet_doy in planet_images.keys():
        ## Find nearest before/after Sentinel images
        sentinel_before_doy = max([doy for doy in sentinel_images if doy < planet_doy])
        sentinel_after_doy = min([doy for doy in sentinel_images if doy > planet_doy])
        set_of_input_images.append(
            {
                "year": year,
                "planet_doy": planet_doy,
                "sentinel_before_doy": sentinel_before_doy,
                "sentinel_after_doy": sentinel_after_doy,
            }
        )

    return set_of_input_images


# Here we get all input images. The Planetscope image and both of the closes sentinel 2 images to predict the image after resampling
set_of_input_images_2021 = find_closest_sentinel_images(
    sentinel_2_images_2021, planet_scope_images_2021, 2021
)
set_of_input_images_2022 = find_closest_sentinel_images(
    sentinel_2_images_2022, planet_scope_images_2022, 2022
)


In [None]:
print(set_of_input_images_2021)

In [None]:
def create_input_data(
    match,
    sentinel_images: Dict[int, NDVIImage],
    planet_images: Dict[int, NDVIImage],
    shapefile_path: Path,
):
    """
    Creates a single feature array (X) and target array (y) pair
    for one set of matching images.

    Args:
        match: A single dictionary from the matching_pairs list.
        sentinel_images: Dictionary with DOY as key and Sentinel-2 NDVIImage.
        planet_images: Dictionary with DOY as key and PlanetScope NDVIImage.

    Returns:
        A tuple: (features (X), targets (y))
            - features (X): 2D NumPy array of features, one row per valid pixel.
            - targets (y): 1D NumPy array of corresponding Sentinel-2 NDVI.
    """
    processor = NDVIProcessor()

    year = match["year"]
    planet_doy = match["planet_doy"]
    sentinel_before_doy = match["sentinel_before_doy"]
    sentinel_after_doy = match["sentinel_after_doy"]

    # Resample PlanetScope
    planet_image = processor.resample_planet_to_sentinel(
        sentinel_images[sentinel_before_doy], planet_images[planet_doy]
    )

    # Extract pixel data
    sentinel_before_image = sentinel_images[sentinel_before_doy]
    sentinel_after_image = sentinel_images[sentinel_after_doy]

    # include maize
    planet_image.ndvi = processor.mask_array_with_shapefile(
        planet_image.ndvi, planet_image.transform, shapefile_path
    )

    sentinel_before_image.ndvi = processor.mask_array_with_shapefile(
        sentinel_before_image.ndvi, sentinel_before_image.transform, shapefile_path
    )
    sentinel_after_image.ndvi = processor.mask_array_with_shapefile(
        sentinel_after_image.ndvi, sentinel_after_image.transform, shapefile_path
    )

    invalid_mask_matching = processor.get_invalid_mask(
        sentinel_ndvi_data=sentinel_before_image.ndvi,
        planet_ndvi_data=planet_image.ndvi,
    )

    invalid_mask_before_after = processor.get_invalid_mask(
        sentinel_before_image.ndvi, sentinel_after_image.ndvi
    )

    invalid_mask = np.logical_or(invalid_mask_matching, invalid_mask_before_after)
    invalid_mask = invalid_mask_matching

    if INCLUDE_BEFORE_AFTER:
        sentinel_before_data = processor.get_preprocessed_ndiv_data(
            sentinel_before_image, invalid_mask
        ).flatten()
        sentinel_after_data = processor.get_preprocessed_ndiv_data(
            sentinel_after_image, invalid_mask
        ).flatten()
    planet_data = processor.get_preprocessed_ndiv_data(
        planet_image, invalid_mask
    ).flatten()

    num_valid_pixels = (
        len(sentinel_before_data) if INCLUDE_BEFORE_AFTER else len(planet_data)
    )

    # Create feature matrix (one row per valid pixel)
    if INCLUDE_BEFORE_AFTER:
        features = np.column_stack(
            (
                planet_data,
                sentinel_before_data,
                sentinel_after_data,
                np.full(num_valid_pixels, planet_doy),
                np.full(num_valid_pixels, sentinel_before_doy),
                np.full(num_valid_pixels, sentinel_after_doy),
                np.full(num_valid_pixels, year),
            )
        )
    else:
        features = np.column_stack(
            (
                planet_data,
                np.full(num_valid_pixels, planet_doy),
                np.full(num_valid_pixels, year),
            )
        )
    return features, invalid_mask

In [None]:
input_images = set_of_input_images_2021 + set_of_input_images_2022

# Example usage: Process data for each match in a loop
all_inputs = []
all_masks = []
for match in input_images:
    if match["year"] == 2022:
        X, invalid_mask = create_input_data(
            match,
            sentinel_2_images_2022,
            planet_scope_images_2022,
            shapefile_2022,
        )
    else:
        X, invalid_mask = create_input_data(
            match,
            sentinel_2_images_2021,
            planet_scope_images_2021,
            shapefile_2021,
        )

    all_inputs.append(X)
    all_masks.append(invalid_mask)

In [None]:
import copy

for i, match in enumerate(input_images):
    year = match["year"]
    if year == 2022:
        sentinel_2_images = sentinel_2_images_2022
        planet_scope_images = planet_scope_images_2022
    else:
        sentinel_2_images = sentinel_2_images_2021
        planet_scope_images = planet_scope_images_2021

    planet_doy = match["planet_doy"]
    sentinel_before_doy = match["sentinel_before_doy"]

    # Only synthesize if there is NOT a Sentinel-2 image for this DOY
    if planet_doy not in sentinel_2_images:
        # Make predictions
        predictions = rf_model.predict(all_inputs[i])

        # Use the correct reference image for output
        if INCLUDE_BEFORE_AFTER:
            synth_image = copy.deepcopy(sentinel_2_images[sentinel_before_doy])
        else:
            # If not using before/after, still need a reference image; use before_doy as default
            synth_image = copy.deepcopy(sentinel_2_images[sentinel_before_doy])

        synth_image.ndvi[~all_masks[i]] = predictions

        save_tif(
            synth_image.ndvi,
            output_dir / f"Synth-rf_ndvi_doy_{year}_{planet_doy:03d}_.tif",
            synth_image.crs,
            synth_image.transform,
        )