Utilities

In [None]:
# %pip install xarray[complete] netcdf4 h5netcdf
# %pip install matplotlib
# %pip install numpy
# %pip install pandas
# %pip install scipy
# %pip install dask
# %pip install tensorflow --user
# %pip install scikit-learn
# %pip install pyyaml h5py

In [None]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras

from tensorflow.keras import layers
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, Add, Conv2DTranspose, Concatenate, concatenate, AveragePooling2D, UpSampling2D
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

Dataset load and train/test splitting

In [None]:
# LOAD DATASETS

# dataset = np.load('dati/mist/datasets/dataset.npy')
# date = np.load('dati/mist/datasets/date.npy')

dataset_d = np.load('dati/mist/datasets/dataset_d.npy')
dataset_n = np.load('dati/mist/datasets/dataset_n.npy')
date_d = np.load('dati/mist/datasets/date_d.npy')
date_n = np.load('dati/mist/datasets/date_n.npy')

# baseline = np.load('dati/mist/datasets/baseline.npy')
baseline_d = np.load('dati/mist/datasets/baseline_d.npy')
baseline_n = np.load('dati/mist/datasets/baseline_n.npy')

italy_mask = np.load('dati/mist/datasets/italy_mask.npy')
data_min = np.load('dati/mist/data_min.npy')
data_max = np.load('dati/mist/data_max.npy')

In [None]:
# Split the day and the night dataset into training and testing sets

train_indices_d, temp_indices_d = train_test_split(np.arange(dataset_d.shape[0]), test_size=0.25, random_state=42)
val_indices_d, test_indices_d = train_test_split(temp_indices_d, test_size=0.4, random_state=42)

x_train_d = dataset_d[train_indices_d]
x_val_d = dataset_d[val_indices_d]
x_test_d = dataset_d[test_indices_d]

dates_train_d = date_d[train_indices_d]
dates_val_d = date_d[val_indices_d]
dates_test_d = date_d[test_indices_d]


train_indices_n, temp_indices_n = train_test_split(np.arange(dataset_n.shape[0]), test_size=0.25, random_state=42)
val_indices_n, test_indices_n = train_test_split(temp_indices_n, test_size=0.4, random_state=42)

x_train_n = dataset_n[train_indices_n]
x_val_n = dataset_n[val_indices_n]
x_test_n = dataset_n[test_indices_n]

dates_train_n = date_n[train_indices_n]
dates_val_n = date_n[val_indices_n]
dates_test_n = date_n[test_indices_n]

Metrics, Loss and Hyperparameters

In [None]:
def customLoss(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)

    clear_mask = y_true[:,:,:,1:2]      # 0 for land/clouds, 1 for clear sea
    y_true = y_true[:,:,:,0:1]          # The true SST values. Obfuscated areas are already converted to 0
    
    # Calculate the squared error only over clear sea
    squared_error = tf.square(y_true - y_pred)
    clear_masked_error = squared_error * clear_mask

    # Calculate the mean of the masked errors
    clear_loss = tf.reduce_mean(clear_masked_error)     # The final loss

    return clear_loss

In [None]:
def ClearMetric(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)

    clear_mask = y_true[:,:,:,1:2]  # 0 for land/clouds, 1 for clear sea
    y_true = y_true[:,:,:,0:1]  # The true SST values. Obfuscated areas are already converted to 0

    # Calculate the squared error only over clear sea
    squared_error = tf.square(y_true - y_pred)
    clear_masked_error = squared_error * clear_mask
    # Calculate the mean of the masked errors
    clr_metric = tf.reduce_sum(clear_masked_error) / tf.reduce_sum(clear_mask)

    #loss = clear_loss
    return clr_metric

In [None]:
def ArtificialMetric(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)    # Was getting an error because of the different types: y_true in the metrics is float64 instead of the normal float32

    artificial_mask = y_true[:,:,:,2:3]  # 1 for artificial clouds, 0 for the rest
    y_true = y_true[:,:,:,0:1]  # The true SST values. Obfuscated areas are already converted to 0

    # Calculate the squared error only over artificially clouded areas
    squared_error = tf.square(y_true - y_pred)
    artificial_masked_error = squared_error * artificial_mask
    # Calculate the mean of the masked errors
    art_metric = tf.reduce_sum(artificial_masked_error) / tf.reduce_sum(artificial_mask)

    return art_metric

In [None]:
# Hyperparameters

epochs=100
batch_size=32

lr = 1e-4

loss = customLoss
metrics = [ClearMetric, ArtificialMetric]
early_stop = EarlyStopping(monitor='val_loss', patience=10, verbose=1)

steps_per_epoch = min(100, len(x_train_d) // batch_size)
validation_steps = 20
testing_steps = 20

input_shape = (256, 256, 4)

Generator

In [None]:
# Baseline generator function

def baseline_generator(batch_size, data_d, data_n, dates_d, dates_n, dayChance=0.5):
    while True:
        batch_x = np.zeros((batch_size, 256, 256, 4))
        batch_y = np.zeros((batch_size, 256, 256, 3))

        #Randomly choose between day and night dataset
        (dataset, date, baseline) = (data_d, dates_d, baseline_d) if np.random.rand() < dayChance else (data_n, dates_n, baseline_n)

        for b in range(batch_size):
            # Choose a random index as the current day, and 3 random indices
            i, r1, r2, r3= np.random.randint(0, dataset.shape[0], 4)

            # Extract the image and mask from the current day, and the masks from the other days
            image_current = np.nan_to_num(dataset[i], nan=0)
            mask_current = np.isnan(dataset[i])
            mask_r1 = np.isnan(dataset[r1])
            mask_r2 = np.isnan(dataset[r2])
            mask_r3 = np.isnan(dataset[r3])

            # Perform OR operation between masks
            mask_or_r1 = np.logical_or(mask_current, mask_r1)
            mask_or_r2 = np.logical_or(mask_current, mask_r2)
            mask_or_r3 = np.logical_or(mask_current, mask_r3)
            #choose the middle mask
            masks = [mask_or_r1, mask_or_r2, mask_or_r3]
            masks.sort(key=np.sum)
            artificial_mask = masks[1] # The mask with the medium amount of coverage

            # Apply the amplified mask to the current day's image
            image_masked = np.where(artificial_mask, 0, image_current)
            
            # Convert the current date to a datetime object using pandas
            date_series = pd.to_datetime(date[i], unit='D', origin='unix')
            day_of_year = date_series.dayofyear

            # Fix masks before they are used in the loss and metric functions
            artificial_mask = np.logical_xor(artificial_mask, mask_current)  # 1 for artificially obfuscated, 0 for the rest
            mask_current = np.logical_not(mask_current) # 1 for clear sea, 0 for land/clouds
            
            # Create batch_x and batch_y
            batch_x[b, ..., 0] = image_masked               #artificially cloudy image
            batch_x[b, ..., 1] = mask_current               #real mask
            batch_x[b, ..., 2] = italy_mask                 #land-sea mask
            batch_x[b, ..., 3] = baseline[day_of_year - 1]  #baseline values for the current day (day_of_year starts from 1)

            batch_y[b, ..., 0] = image_current              #real image
            batch_y[b, ..., 1] = mask_current               #real mask
            batch_y[b, ..., 2] = artificial_mask            #artificial mask used for the input
        
        yield batch_x, batch_y

In [None]:
#Generator that returns dates
def gen_qual_dates(batch_size, data_d, data_n, qual_d, qual_n, dates_d, dates_n, dayChance=0.5):
    while True:
        batch_x = np.zeros((batch_size, 256, 256, 4))
        batch_y = np.zeros((batch_size, 256, 256, 4))
        batch_dates = []
        batch_day_night = []

        #Randomly choose between day and night dataset
        is_day = np.random.rand() < dayChance
        (dataset, qdataset, date, baseline) = (data_d, qual_d, dates_d, baseline_d) if is_day else (data_n, qual_n, dates_n, baseline_n)

        for b in range(batch_size):
            # Choose a random index as the current day, and 3 random indices
            i, r1, r2, r3= np.random.randint(0, dataset.shape[0], 4)

            # Extract the image  and mask from the current day, and the masks from the other days
            image_current = np.nan_to_num(dataset[i], nan=0)
            mask_current = np.isnan(dataset[i])
            mask_r1 = np.isnan(dataset[r1])
            mask_r2 = np.isnan(dataset[r2])
            mask_r3 = np.isnan(dataset[r3])

            # Perform OR operation between masks
            mask_or_r1 = np.logical_or(mask_current, mask_r1)
            mask_or_r2 = np.logical_or(mask_current, mask_r2)
            mask_or_r3 = np.logical_or(mask_current, mask_r3)
            #choose the middle mask
            masks = [mask_or_r1, mask_or_r2, mask_or_r3]
            masks.sort(key=np.sum)
            artificial_mask = masks[1] # The mask with the medium amount of coverage
            # Apply the amplified mask to the current day's image
            image_masked = np.where(artificial_mask, 0, image_current)

            # Convert the current date to a datetime object using pandas
            date_series = pd.to_datetime(date[i], unit='D', origin='unix')
            day_of_year = date_series.dayofyear

            # Fix masks before they are used in the loss and metric functions
            artificial_mask = np.logical_xor(artificial_mask, mask_current)  # 1 for artificially obfuscated, 0 for the rest
            mask_current = np.logical_not(mask_current) # 1 for clear sea, 0 for land/clouds

            #Prepare quality measurements
            q_image = np.nan_to_num(qdataset[i], nan=-1)    # -1 for missing quality measurements
            
            # Create batch_x and batch_y
            batch_x[b, ..., 0] = image_masked               #artificially cloudy image
            batch_x[b, ..., 1] = mask_current               #real mask
            batch_x[b, ..., 2] = italy_mask                 #land-sea mask
            batch_x[b, ..., 3] = baseline[day_of_year - 1]  #baseline values for the current day (day_of_year starts from 1)

            batch_y[b, ..., 0] = image_current              #real image
            batch_y[b, ..., 1] = mask_current               #real mask
            batch_y[b, ..., 2] = artificial_mask            #artificial mask used for the input
            batch_y[b, ..., 3] = q_image                    #quality measurement

            batch_dates.append(date_series.date())
            batch_day_night.append('diurnal' if is_day else 'nocturnal')
        
        yield batch_x, batch_y, batch_dates, batch_day_night

In [None]:
# Create the generators

train_gen = baseline_generator(batch_size, x_train_d, x_train_n, dates_train_d, dates_train_n)
val_gen = baseline_generator(batch_size, x_val_d, x_val_n, dates_val_d, dates_val_n)
test_gen = baseline_generator(batch_size, x_test_d, x_test_n, dates_test_d, dates_test_n)

# Test generator that returns dates
#test_gen_dates = gen_qual_dates(batch_size, x_test_d, x_test_n, q_test_d, q_test_n, dates_test_d, dates_test_n)

In [None]:
# Test the generator

x,y = next(train_gen)
r = np.random.randint(0, batch_size)    # Choose a random image from the batch

# Plot the image
plt.figure(figsize=(8, 8))

# Plot the x data
plt.subplot(2, 2, 1)
plt.imshow(x[r, :, :, 0], cmap='jet')
plt.title("x_0 (model input)")
plt.colorbar()

# Plot the y data
plt.subplot(2, 2, 2)
plt.imshow(y[r, :, :, 0], cmap='jet')
plt.title("y_0 (ground truth)")
plt.colorbar()

plt.show()

# Information about the data
#print(np.isnan(x).any())
print("x.shape:", x.shape)
print("y.shape:", y.shape)
print("min of all x:", np.min(x[..., 0]))
print("max of all x:", np.max(x[..., 0]))
print("min of this x:", np.min(x[r, :, :, 0]))
print("max of this x:", np.max(x[r, :, :, 0]))

Model and Training

In [None]:
# U-Net model with residual blocks

def ResidualBlock(depth):
    def apply(x):
        input_depth = x.shape[3]    # Get the number of channels from the channels dimension
        if input_depth == depth:    # It's already the desired channel number
            residual = x
        else:                       # Adjust the number of channels with a 1x1 convolution
            residual = Conv2D(depth, kernel_size=1)(x)

        x = BatchNormalization(center=False, scale=False)(x)    
        x = Conv2D(depth, kernel_size=3, padding="same", activation='swish')(x) 
        x = Conv2D(depth, kernel_size=3, padding="same")(x)
        x = Add()([x, residual])
        return x
    
    return apply


def DownBlock(depth, block_depth):
    def apply(x):
        x, skips = x
        for _ in range(block_depth):
            x = ResidualBlock(depth)(x)
            skips.append(x)
        x = AveragePooling2D(pool_size=2)(x)    #downsampling
        return x

    return apply


def UpBlock(depth, block_depth):
    def apply(x):
        x, skips = x
        x = UpSampling2D(size=2, interpolation="bilinear")(x)   #upsampling
        for _ in range(block_depth):
            x = Concatenate()([x, skips.pop()])
            x = ResidualBlock(depth)(x)
        return x

    return apply


def get_Unet(image_size, depths, block_depth):
    input_images = Input(shape=image_size)  #input layer
    
    x = Conv2D(depths[0], kernel_size=1)(input_images)  #reduce the number of channels

    skips = []  #store the skip connections
    
    for depth in depths[:-1]:   #downsampling layers
        x = DownBlock(depth, block_depth)([x, skips])

    for _ in range(block_depth):    #middle layer
        x = ResidualBlock(depths[-1])(x)

    for depth in reversed(depths[:-1]):   #upsampling layers
        x = UpBlock(depth, block_depth)([x, skips])

    x = Conv2D(1, kernel_size=1, kernel_initializer="zeros", name = "output_noise")(x)  #output layer
    
    return Model(input_images, outputs=x, name="UNetInpainter")

In [None]:
# Define the model, create it and print the summary
depths = [32, 64, 128]
block_depth = 2

model = get_Unet(input_shape, depths, block_depth)
# #model.summary()

In [None]:
# Compile model with custom loss function
opt = Adam(learning_rate=lr)
model.compile(optimizer=opt, loss=loss, metrics=metrics)

In [None]:
# LOAD WEIGHTS
# model.load_weights('weights/baseline.weights.h5')   # does not work on local machine
model.load_weights('weights/baseline.h5')

# SAVE WEIGHTS
#model.save_weights('weights/baseline.h5')  # execute remotely after training

In [None]:
# Train model
#history = model.fit(train_gen, epochs=epochs, steps_per_epoch=steps_per_epoch, validation_data=val_gen, validation_steps=validation_steps, verbose=1, callbacks=[early_stop])

Experiment on Errors

In [None]:
#Loop error calculation over tot batches

# Initialize lists to store the average errors, maximum errors, and variances
avg_errors_list = []
avg_max_errors_list = []
var_max_errors_list = []

# Generate and evaluate tot batches
tot = 100
for _ in range(tot):
    # Generate a batch
    x_true, y_true = next(test_gen)
    predictions = model.predict(x_true)

    # Denormalize
    predictions_denorm = ((predictions[..., 0] + 1) / 2) * (data_max - data_min) + data_min
    true_values_denorm = ((y_true[..., 0] + 1) / 2) * (data_max - data_min) + data_min

    # Calculate the errors
    clearMask = y_true[..., 1]
    errors = np.where(clearMask, np.abs(predictions_denorm - true_values_denorm), np.nan)

    # Calculate the average and maximum error for each image in the batch
    avg_errors = np.nanmean(errors, axis=(1, 2))
    max_errors = np.nanmax(errors, axis=(1, 2))

    # Add the average error, average maximum error, and variance of maximum errors to the lists
    avg_errors_list.append(np.mean(avg_errors))
    avg_max_errors_list.append(np.mean(max_errors))
    var_max_errors_list.append(np.var(max_errors))

# Print the average, average maximum, and variance of maximums calculated over tot batches
print(f"Average error over {tot} batches:", np.mean(avg_errors_list))
print(f"Average maximum error over {tot} batches:", np.mean(avg_max_errors_list))
print(f"Variance of maximum errors over {tot} batches:", np.mean(var_max_errors_list))

In [None]:
# Calculate errors on degrees

#Prediction
x_true, y_true, t_dates, t_dn = next(test_gen_dates)
predictions = model.predict(x_true)
# Denormalization
predictions_denorm = ((predictions[..., 0] + 1) / 2) * (data_max - data_min) + data_min
true_values_denorm = ((y_true[..., 0] + 1) / 2) * (data_max - data_min) + data_min
input_images_denorm = ((x_true[..., 0] + 1) / 2) * (data_max - data_min) + data_min
 
clearMask = y_true[..., 1]  # 1 for all clear sea, 0 for land/clouds
artificialMask = y_true[..., 2] # 1 for artificially obfuscated sea parts, 0 for the rest
NonArtificialMask = np.logical_xor(clearMask, artificialMask)   # 1 for unobfuscated clear sea, 0 for the rest

# Calculate the errors (batch_size, 256, 256)
errors = np.where(clearMask, np.abs(predictions_denorm - true_values_denorm), np.nan)
# Calculate the average and maximum error for each image in the batch
avg_errors = np.nanmean(errors, axis=(1, 2))
max_errors = np.nanmax(errors, axis=(1, 2))

avg_error = np.mean(avg_errors)
print("Average error:", avg_error)
# Calculate the average and variance of maximum errors
avg_max_error = np.mean(max_errors)
print("Average maximum error:", avg_max_error)
var_max_error = np.var(max_errors)
print("Variance of maximum errors:", var_max_error)

print("===================================================")

def printValues(i, y, x):
    print("Sample n°", i+1, ", date:", t_dates[i], ",", t_dn[i], "; y and x:", y, x)
    print("Error in specific point:", errors[i, y, x])
    print("Predicted value in specific point:", predictions_denorm[i, y, x])
    print("Real value in specific point:", true_values_denorm[i, y, x])

    # Define the offsets for the 8 surrounding points
    offsets = [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 0), (0, 1), (1, -1), (1, 0), (1, 1)]

    # Initialize the 3x3 matrices
    real_values_matrix = np.full((3, 3), np.nan)
    predicted_values_matrix = np.full((3, 3), np.nan)
    quality_assertion_matrix = np.full((3, 3), np.nan)

    # For each offset, set the value in the matrix if the point is not masked
    for dy, dx in offsets:
        ny, nx = y + dy, x + dx
        if 0 <= ny < 256 and 0 <= nx < 256 and clearMask[i, ny, nx]:
            real_values_matrix[dy+1, dx+1] = true_values_denorm[i, ny, nx]
            predicted_values_matrix[dy+1, dx+1] = predictions_denorm[i, ny, nx]
            quality_assertion_matrix[dy+1, dx+1] = y_true[i, ny, nx, 3]

    print("3x3 matrix of real values:")
    print(real_values_matrix)
    print("3x3 matrix of predicted values:")
    print(predicted_values_matrix)
    print("3x3 matrix of quality assertion values:")
    print(quality_assertion_matrix)


max_error_indices = np.nanargmax(errors.reshape(errors.shape[0], -1), axis=1)  # Find the index of the maximum error in each flattened image
max_error_positions = np.unravel_index(max_error_indices, errors.shape[1:]) # Convert the index back to 2D coordinates

# Plot each image with the position of the maximum error
for i, (y, x) in enumerate(zip(*max_error_positions)):
    # Check if the maximum error for this sample is equal to or above 5
    if max_errors[i] >= 5:
        #in that point, for all samples, print max_error, its predicted value and its ground_truth value
        printValues(i,y,x)

        plt.figure(figsize=(12, 12))

        plt.subplot(2, 2, 1)
        plt.imshow(errors[i], cmap='jet')
        plt.title(f"Error - Batch {i+1}")

        plt.subplot(2, 2, 2)
        masked_prediction = np.where(italy_mask, predictions_denorm[i], np.nan)
        plt.imshow(masked_prediction, cmap='viridis')
        plt.scatter(x, y, color='red')
        plt.title(f"Prediction (masked) with max error marked - Batch {i+1}")

        plt.subplot(2, 2, 3)
        mask_overlay = np.where(clearMask[i], true_values_denorm[i], np.nan)
        plt.imshow(mask_overlay, cmap='viridis')
        plt.title(f"Real image - Batch {i+1}")

        plt.subplot(2, 2, 4)
        mask_overlay = np.where(NonArtificialMask[i], input_images_denorm[i], np.nan)
        plt.imshow(mask_overlay, cmap='viridis')
        plt.title(f"Input image - Batch {i+1}")

        plt.show()

In [None]:
# Check for problematic spots in real images

# Generate a batch
x_true, y_true, t_dates, t_dn = next(test_gen_dates)
predictions = model.predict(x_true)

# Denormalize the true values
predictions_denorm = ((predictions[..., 0] + 1) / 2) * (data_max - data_min) + data_min
true_values_denorm = ((y_true[..., 0] + 1) / 2) * (data_max - data_min) + data_min

# Get the clear mask
clearMask = y_true[..., 1]  # 1 for all clear sea, 0 for land/clouds

# Get the quality assessment values
quality_values = np.where(clearMask, y_true[..., 3], np.nan)

# Initialize an array to store the differences
errors = np.where(clearMask, np.abs(predictions_denorm - true_values_denorm), np.nan)

# For each image in the batch
for i in range(batch_size):
    problematic_spots = []
    problematic_spot_found = False

    # For each pixel in the image
    for y in range(true_values_denorm.shape[0]):
        for x in range(true_values_denorm.shape[1]):
            if clearMask[i, y, x] != 0: # Only consider the pixel if it's not masked
                pixel = true_values_denorm[i, y, x]

                # Get the values of the neighbors
                neighbors_mask = clearMask[i, max(0, y-1):min(y+2, true_values_denorm.shape[0]), max(0, x-1):min(x+2, true_values_denorm.shape[1])]
                neighbors_values = true_values_denorm[i, max(0, y-1):min(y+2, true_values_denorm.shape[0]), max(0, x-1):min(x+2, true_values_denorm.shape[1])]
                neighbors_values[neighbors_mask == 0] = np.nan  # Replace masked values with np.nan

                # Calculate the max difference
                max_diff = np.nanmax(np.abs(neighbors_values - pixel)) if np.count_nonzero(~np.isnan(neighbors_values)) > 0 else 0

                # If the max difference is greater than a tot amount, it's a problematic spot
                if max_diff > 5:
                    print("Sample n°", i+1, ", date:", t_dates[i], ",", t_dn[i])
                    print(f"Problematic spot found in position ({y}, {x}), max difference: {max_diff}")
                    print("Adjacent values:")
                    print(neighbors_values)

                    # Print the quality assertion values for the pixel and its neighbors
                    quality_assertion_values = quality_values[i, max(0, y-1):min(y+2, true_values_denorm.shape[0]), max(0, x-1):min(x+2, true_values_denorm.shape[1])]
                    print("Quality assertion values:")
                    print(quality_assertion_values)

                    # Set the flag to True and store the coordinates of the problematic spot
                    problematic_spot_found = True
                    problematic_spots.append((x, y))

    # If a problematic spot was found, plot the error image and the error image with the problematic spots marked
    if problematic_spot_found:
        fig, axs = plt.subplots(1, 2, figsize=(24, 12))
        # Plot the error image
        axs[0].imshow(errors[i], cmap='jet')
        axs[0].set_title(f"Error map - Batch {i+1}")
        # Plot the error image with the problematic spots marked
        axs[1].imshow(errors[i], cmap='jet')
        axs[1].scatter(*zip(*problematic_spots), color='magenta')  # Scatter plot of the problematic spots
        axs[1].set_title(f"Error map - Batch {i+1} (Problematic spots marked)")

        plt.show()

Model Evaluation

In [None]:
# Utility to show loss and metrics

# Evaluate the model on the test data
x_true, y_true = next(test_gen)
results = model.evaluate(x_true, y_true)

In [None]:
# Evaluate the model

# Generate predictions. This generates a batch of data.
x_true, y_true = next(test_gen)
print("x_true shape:", x_true.shape)
print("y_true shape:", y_true.shape)
print("is there a Nan in x?", np.isnan(x_true).any())
print("is there a Nan in y?", np.isnan(y_true).any())

predictions = model.predict(x_true)

print("--------------------")

print("y's min:", np.min(y_true[:, :, :, 0]))
print("y's max:", np.max(y_true[:, :, :, 0]))
print("x's min:", np.min(predictions[:, :, :, 0]))
print("x's max:", np.max(predictions[:, :, :, 0]))

print("--------------------")

evalx = model.evaluate(x_true, y_true)
print("evalx: ", evalx)
xloss = customLoss(y_true, predictions)
print("xloss: ", xloss)

print("--------------------")

#get the coordinates of min and max values in a single prediction. This is to check if the model is predicting the same values as the true ones
coordxmin = np.argmin(predictions[0, :, :, 0])
coordxmax = np.argmax(predictions[0, :, :, 0])
print("first x's min:", coordxmin%256, coordxmin//256, np.nanmin(predictions[0, :, :, 0]))
print("first x's max:", coordxmax%256, coordxmax//256, np.nanmax(predictions[0, :, :, 0]))
print("predictions in coordxmin:", predictions[0, coordxmin//256, coordxmin%256, 0])
print("predictions in coordxmax:", predictions[0, coordxmax//256, coordxmax%256, 0])

print("--------------------")

# Plot the predictions and true values

for i in range(10):
    plt.figure(figsize=(20, 8))

    # Plot the true value
    plt.subplot(1, 3, 1)
    mask_overlay = np.where(y_true[i, :, :, 1], y_true[i, :, :, 0], np.nan)
    plt.imshow(mask_overlay, cmap='jet', vmin=-1, vmax=1)
    plt.title("y_0 (Ground Truth)")
    plt.colorbar()

    # Plot the prediction with the land mask
    plt.subplot(1, 3, 2)
    masked_prediction = np.where(italy_mask, predictions[i, :, :, 0], np.nan)
    plt.imshow(masked_prediction, cmap='jet', vmin=-1, vmax=1)
    plt.title("Prediction with Land Mask")
    plt.colorbar()

    # # Plot the predicted 'pure' value 
    # plt.subplot(1, 3, 3)
    # plt.imshow(predictions[i], cmap='jet', vmin=-1, vmax=1)
    # plt.title("Unmasked prediction (DEBUG)")
    # plt.colorbar()

    plt.show()

Baseline comparison

In [None]:
# Calculate the MSE for the predictions and the baseline

batch_x, batch_y = next(test_gen)
predictions = model.predict(batch_x)

filter_mask = batch_y[..., 1].astype(bool)    # We filter out the land and cloud data


# Calculate the MSE for the predictions and the baseline, and the average MSEs in that batch. Only consider the ocean data.
mse_predictions = [mean_squared_error(batch_y[i, :, :, 0][filter_mask[i]], predictions[i, :, :, 0][filter_mask[i]]) for i in range(batch_size)]
mse_baseline = [mean_squared_error(batch_y[i, :, :, 0][filter_mask[i]], batch_x[i, :, :, 3][filter_mask[i]]) for i in range(batch_size)]
print('MSE for predictions:', mse_predictions)
print('MSE for baseline:', mse_baseline)

avg_mse_predictions = np.mean(mse_predictions)
avg_mse_baseline = np.mean(mse_baseline)
print('Average MSE for predictions:', avg_mse_predictions)
print('Average MSE for baseline:', avg_mse_baseline)

# Plot the MSEs
indices = np.arange(batch_size + 1)   # Add 1 to the batch size to include the average MSEs

fig, ax = plt.subplots()
ax.bar(indices[:-1] - 0.2, mse_predictions, width=0.4, label='Predictions')
ax.bar(indices[:-1] + 0.2, mse_baseline, width=0.4, label='Baseline')
ax.bar(batch_size - 0.2, avg_mse_predictions, width=0.4, color='blue', label='Avg of Predictions')
ax.bar(batch_size + 0.2, avg_mse_baseline, width=0.4, color='red', label='Avg of Baseline')

ax.set_xlabel('Batch Index')
ax.set_ylabel('MSE')
ax.set_title('MSE for Predictions and Baseline')
ax.legend()

plt.show()
#fig.savefig("mseGraph.png")


# Check if any prediction has an MSE greater than the baseline
for i in range(batch_size):
    if mse_predictions[i] > mse_baseline[i]:
        # Plot the prediction and ground truth
        plt.figure(figsize=(10, 5))

        plt.subplot(1, 2, 1)
        mask_overlay = np.where(batch_y[i, :, :, 1], batch_y[i, :, :, 0], np.nan)
        plt.imshow(mask_overlay, cmap='jet')
        plt.title(f"Ground Truth {i}")

        plt.subplot(1, 2, 2)
        italy_overlay = np.where(italy_mask, predictions[i, :, :, 0], np.nan)
        plt.imshow(italy_overlay, cmap='jet')
        plt.title(f"Prediction {i}")

        plt.show()