# Abstract

This notebook was created for OHW23's Sea Surface Temprature (SST) Prediction project, as part of the suite of deep learning models used. We will be creating a simple 3D-CCN model on historic SST data and comparing model outputs to known SST measurements to evaluate the model's accuracy

# Setup virtual environment and import modules

IF you don't have admin access for your machine/environment (if you run into something like `ERROR: Could not install packages due to an OSError: [Errno 13] Permission denied: '/env/lib/python3.10/site-packages/bin'` when running `pip install` AND you don't have all of the modules contained two cells down (you run into `ImportError`), uncomment and run next cell once. Otherwise, you can just `pip install` missing modules. If you have everything, nothing needs to be done :)

Explanation:
- First line `%%sh` converts the cell into bash commands
- Second line creates virtual environment. Named `myenv` here
- Third line links existing modules to your newly created environment
- Fourth line activates virtual environment
- Fifth line runs `pip install` for any missing modules. Just `tensorflow` and `tensorboard` in my case
- Sixth line converts virtual environment into a Jupyter kernel

Once the next cell finishes running, restart your server - your newly created kernel wil be available to select in the top right corner

In [None]:
# %%sh
# python -m venv ~/venvs/myenv
# realpath /env/lib/python3.10/site-packages > ~/venvs/myenv/lib/python3.10/site-packages/base_venv.pth
# source ~/venvs/myenv/bin/activate
# pip install tensorflow tensorboard
# python -m ipykernel install --user --name=myenv --display-name "myenv"

In [1]:
import gc

import s3fs
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import dask.array as da
import tensorflow as tf
import tensorboard

from tensorflow.keras import Input, Model
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import BatchNormalization, Conv3D, Flatten, Dense, MaxPool3D, GlobalAveragePooling3D, Dropout, MaxPooling3D, Reshape
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.python.ops.variables import Variable

from dask.distributed import Client, LocalCluster
from dask.delayed import delayed
from sklearn.model_selection import train_test_split

2023-08-11 07:26:29.214907: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-08-11 07:26:29.252784: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


I tend to ignore warnings like a main character in a horror movie but they are sometimes helpful

# Initialise Dask LocalCluster

In [2]:
cluster = LocalCluster(n_workers=4)
client = Client(cluster)

In [3]:
# Hao's dashboard link if in CSIRO EASI environment:
# https://hub.csiro.easi-eo.solutions/user/csiro-csiro-aad_tan196@csiro.au/proxy/8787/status
# Otherwise use provided dashboard link

client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 4
Total threads: 32,Total memory: 124.00 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:40985,Workers: 4
Dashboard: http://127.0.0.1:8787/status,Total threads: 32
Started: Just now,Total memory: 124.00 GiB

0,1
Comm: tcp://127.0.0.1:33369,Total threads: 8
Dashboard: http://127.0.0.1:39233/status,Memory: 31.00 GiB
Nanny: tcp://127.0.0.1:46415,
Local directory: /tmp/dask-worker-space/worker-hqe69ez7,Local directory: /tmp/dask-worker-space/worker-hqe69ez7

0,1
Comm: tcp://127.0.0.1:46117,Total threads: 8
Dashboard: http://127.0.0.1:46233/status,Memory: 31.00 GiB
Nanny: tcp://127.0.0.1:39589,
Local directory: /tmp/dask-worker-space/worker-vus8f5n7,Local directory: /tmp/dask-worker-space/worker-vus8f5n7

0,1
Comm: tcp://127.0.0.1:36193,Total threads: 8
Dashboard: http://127.0.0.1:34429/status,Memory: 31.00 GiB
Nanny: tcp://127.0.0.1:35421,
Local directory: /tmp/dask-worker-space/worker-x1eejvkc,Local directory: /tmp/dask-worker-space/worker-x1eejvkc

0,1
Comm: tcp://127.0.0.1:40885,Total threads: 8
Dashboard: http://127.0.0.1:42559/status,Memory: 31.00 GiB
Nanny: tcp://127.0.0.1:45973,
Local directory: /tmp/dask-worker-space/worker-a5rjdl2o,Local directory: /tmp/dask-worker-space/worker-a5rjdl2o


# Load MUR Satellite Data

Read data straight from public S3 bucket: https://registry.opendata.aws/mur/

In [4]:
# Bypass AWS tokens, keys etc.
s3 = s3fs.S3FileSystem(anon=True)

# Verify that we're in the right place
sst_files = s3.ls("mur-sst/zarr-v1/")
sst_files

['mur-sst/zarr-v1/',
 'mur-sst/zarr-v1/.zattrs',
 'mur-sst/zarr-v1/.zgroup',
 'mur-sst/zarr-v1/.zmetadata',
 'mur-sst/zarr-v1/analysed_sst',
 'mur-sst/zarr-v1/analysis_error',
 'mur-sst/zarr-v1/lat',
 'mur-sst/zarr-v1/lon',
 'mur-sst/zarr-v1/mask',
 'mur-sst/zarr-v1/sea_ice_fraction',
 'mur-sst/zarr-v1/time']

In [16]:
%%time

# Load
ds = xr.open_zarr(
        store=s3fs.S3Map(
            root=f"s3://{sst_files[0]}", s3=s3, check=False
        )
     )

CPU times: user 1.82 s, sys: 0 ns, total: 1.82 s
Wall time: 23.6 s


In [17]:
ds

Unnamed: 0,Array,Chunk
Bytes,15.19 TiB,123.53 MiB
Shape,"(6443, 17999, 36000)","(5, 1799, 3600)"
Dask graph,141790 chunks in 2 graph layers,141790 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 15.19 TiB 123.53 MiB Shape (6443, 17999, 36000) (5, 1799, 3600) Dask graph 141790 chunks in 2 graph layers Data type float32 numpy.ndarray",36000  17999  6443,

Unnamed: 0,Array,Chunk
Bytes,15.19 TiB,123.53 MiB
Shape,"(6443, 17999, 36000)","(5, 1799, 3600)"
Dask graph,141790 chunks in 2 graph layers,141790 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,15.19 TiB,123.53 MiB
Shape,"(6443, 17999, 36000)","(5, 1799, 3600)"
Dask graph,141790 chunks in 2 graph layers,141790 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 15.19 TiB 123.53 MiB Shape (6443, 17999, 36000) (5, 1799, 3600) Dask graph 141790 chunks in 2 graph layers Data type float32 numpy.ndarray",36000  17999  6443,

Unnamed: 0,Array,Chunk
Bytes,15.19 TiB,123.53 MiB
Shape,"(6443, 17999, 36000)","(5, 1799, 3600)"
Dask graph,141790 chunks in 2 graph layers,141790 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,15.19 TiB,123.53 MiB
Shape,"(6443, 17999, 36000)","(5, 1799, 3600)"
Dask graph,141790 chunks in 2 graph layers,141790 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 15.19 TiB 123.53 MiB Shape (6443, 17999, 36000) (5, 1799, 3600) Dask graph 141790 chunks in 2 graph layers Data type float32 numpy.ndarray",36000  17999  6443,

Unnamed: 0,Array,Chunk
Bytes,15.19 TiB,123.53 MiB
Shape,"(6443, 17999, 36000)","(5, 1799, 3600)"
Dask graph,141790 chunks in 2 graph layers,141790 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,15.19 TiB,123.53 MiB
Shape,"(6443, 17999, 36000)","(5, 1799, 3600)"
Dask graph,141790 chunks in 2 graph layers,141790 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 15.19 TiB 123.53 MiB Shape (6443, 17999, 36000) (5, 1799, 3600) Dask graph 141790 chunks in 2 graph layers Data type float32 numpy.ndarray",36000  17999  6443,

Unnamed: 0,Array,Chunk
Bytes,15.19 TiB,123.53 MiB
Shape,"(6443, 17999, 36000)","(5, 1799, 3600)"
Dask graph,141790 chunks in 2 graph layers,141790 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


# Mask sea ice and land

We only want water to be part of our dataset

In [7]:
sst = ds['analysed_sst']
cond = (ds.mask==1) & ((ds.sea_ice_fraction<.15) | np.isnan(ds.sea_ice_fraction))

ds = ds['analysed_sst'].where(cond)
ds

Unnamed: 0,Array,Chunk
Bytes,15.19 TiB,123.53 MiB
Shape,"(6443, 17999, 36000)","(5, 1799, 3600)"
Dask graph,141790 chunks in 12 graph layers,141790 chunks in 12 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 15.19 TiB 123.53 MiB Shape (6443, 17999, 36000) (5, 1799, 3600) Dask graph 141790 chunks in 12 graph layers Data type float32 numpy.ndarray",36000  17999  6443,

Unnamed: 0,Array,Chunk
Bytes,15.19 TiB,123.53 MiB
Shape,"(6443, 17999, 36000)","(5, 1799, 3600)"
Dask graph,141790 chunks in 12 graph layers,141790 chunks in 12 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


# Subsetting dataset so JupyterHub can handle it

In [18]:
# Check that data is daily only and no missing days
assert pd.infer_freq(ds.time.values) == 'D'

Good job NASA :)

In [19]:
def dates_to_mur_indices(start, end):
    
    start_of_dataset = np.datetime64('2002-06-01')
    end_of_dataset = np.datetime64('2020-01-20')
    
    if (type(start) == str) and (type(end) == str):
        try:
            start = np.datetime64(start)
            end = np.datetime64(end)
        except ValueError as e:
            print(e)
            raise TypeError(f'Date(s) not in the format YYYY-MM-DD')
    else:
        raise TypeError(f'Please enter dates as np.datetime64 or strings in the format YYYY-MM-DD')
        
    assert (start >= start_of_dataset) and (start < end_of_dataset), f'{start} out of dataset range {start_of_dataset} - {end_of_dataset}'
    assert (end <= end_of_dataset) and (end > start_of_dataset), f'{end} out of dataset range {start_of_dataset} - {end_of_dataset}'
    assert start <= end, f'start date {start} after end date {end}'
    
    start_index = (start - start_of_dataset).astype(int)
    end_index = (end - start_of_dataset).astype(int) + 1
    
    return range(start_index, end_index)

In [20]:
# Taking 6 months' worth of data
ds_sample = ds.isel(time=dates_to_mur_indices('2009-01-01', '2019-07-01'))
# ds_sample = ds.isel(time=0)

In [21]:
ds_sample

Unnamed: 0,Array,Chunk
Bytes,9.04 TiB,123.53 MiB
Shape,"(3834, 17999, 36000)","(5, 1799, 3600)"
Dask graph,84370 chunks in 3 graph layers,84370 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 9.04 TiB 123.53 MiB Shape (3834, 17999, 36000) (5, 1799, 3600) Dask graph 84370 chunks in 3 graph layers Data type float32 numpy.ndarray",36000  17999  3834,

Unnamed: 0,Array,Chunk
Bytes,9.04 TiB,123.53 MiB
Shape,"(3834, 17999, 36000)","(5, 1799, 3600)"
Dask graph,84370 chunks in 3 graph layers,84370 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,9.04 TiB,123.53 MiB
Shape,"(3834, 17999, 36000)","(5, 1799, 3600)"
Dask graph,84370 chunks in 3 graph layers,84370 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 9.04 TiB 123.53 MiB Shape (3834, 17999, 36000) (5, 1799, 3600) Dask graph 84370 chunks in 3 graph layers Data type float32 numpy.ndarray",36000  17999  3834,

Unnamed: 0,Array,Chunk
Bytes,9.04 TiB,123.53 MiB
Shape,"(3834, 17999, 36000)","(5, 1799, 3600)"
Dask graph,84370 chunks in 3 graph layers,84370 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,9.04 TiB,123.53 MiB
Shape,"(3834, 17999, 36000)","(5, 1799, 3600)"
Dask graph,84370 chunks in 3 graph layers,84370 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 9.04 TiB 123.53 MiB Shape (3834, 17999, 36000) (5, 1799, 3600) Dask graph 84370 chunks in 3 graph layers Data type float32 numpy.ndarray",36000  17999  3834,

Unnamed: 0,Array,Chunk
Bytes,9.04 TiB,123.53 MiB
Shape,"(3834, 17999, 36000)","(5, 1799, 3600)"
Dask graph,84370 chunks in 3 graph layers,84370 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,9.04 TiB,123.53 MiB
Shape,"(3834, 17999, 36000)","(5, 1799, 3600)"
Dask graph,84370 chunks in 3 graph layers,84370 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 9.04 TiB 123.53 MiB Shape (3834, 17999, 36000) (5, 1799, 3600) Dask graph 84370 chunks in 3 graph layers Data type float32 numpy.ndarray",36000  17999  3834,

Unnamed: 0,Array,Chunk
Bytes,9.04 TiB,123.53 MiB
Shape,"(3834, 17999, 36000)","(5, 1799, 3600)"
Dask graph,84370 chunks in 3 graph layers,84370 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [22]:
# Focus on our area of interest
ds_sample = ds_sample.sel(lat=slice(-5, 35), lon=slice(45,90))

In [13]:
# Downsample by factor of 10 - if your environment can't handle volume of data
ds_sample = ds_sample.isel(lat=slice(0, None, 10), lon=slice(0, None, 10))

Here I was thinking that NaNs would interfere with our model outputs and not allow it to run efficiently or at all. Turns out later that this wasn't the case - can skip. Thanks Jiarui and team for code snippet

In [None]:
# # Get dates with NaN values
# all_nan_dates = np.isnan(ds_sample["analysed_sst"]).any(dim=["lon", "lat"]).compute()

# # Were there any?
# if not all_nan_dates.any():
#     print('No NaN values')
# else:
#     print('NaN values exist')

In [None]:
# # How many?
# np.isnan(ds_sample["analysed_sst"]).sum(dim=["lon", "lat"]).compute()

In [None]:
# # Fill dataset with -32768
# # May be better to remove
# ds_sample['analysed_sst'] = ds_sample['analysed_sst'].fillna(-32768)

In [14]:
# Rename SST variable so that I don't need to make changes to Jiarui and team's functions below
ds_sample = ds_sample.rename({'analysed_sst': 'sst'})

ValueError: cannot rename 'analysed_sst' because it is not a variable or dimension in this dataset

In [15]:
ds_sample

Unnamed: 0,Array,Chunk
Bytes,2.58 GiB,0.93 MiB
Shape,"(3834, 401, 451)","(5, 180, 271)"
Dask graph,4602 chunks in 15 graph layers,4602 chunks in 15 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 2.58 GiB 0.93 MiB Shape (3834, 401, 451) (5, 180, 271) Dask graph 4602 chunks in 15 graph layers Data type float32 numpy.ndarray",451  401  3834,

Unnamed: 0,Array,Chunk
Bytes,2.58 GiB,0.93 MiB
Shape,"(3834, 401, 451)","(5, 180, 271)"
Dask graph,4602 chunks in 15 graph layers,4602 chunks in 15 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


# Preprocess and split data (thanks for the functions Jiarui and team)

In [None]:
def preprocess_day_data(day_data):
    day_data = da.squeeze(day_data)
    mean_val = da.nanmean(day_data).compute()  # compute here to get scalar value
    return day_data - mean_val

#TODO: Utilise multiprocessing to parallelise if possible?
def preprocess_data(zarr_ds, chunk_size=200):
    total_len = zarr_ds['sst'].shape[0]
    chunk_shape = (chunk_size,) + zarr_ds['sst'].shape[1:]  # Adjusted chunking
    chunks = []

    for start_idx in range(0, total_len, chunk_size):
        end_idx = min(start_idx + chunk_size, total_len)
        
        # Directly slice the dask array without wrapping it with da.from_array again
        chunk = zarr_ds['sst'][start_idx:end_idx]
        
        processed_chunk = chunk.map_blocks(preprocess_day_data)
        
        # Use da.where to replace NaNs with 0.0
        processed_chunk = da.where(da.isnan(processed_chunk), 0.0, processed_chunk)
        
        chunks.append(processed_chunk)

    return da.concatenate(chunks, axis=0)


def prepare_data_from_processed(processed_data, window_size=5): 
    length = processed_data.shape[0]
    X, y = [], []

    for i in range(length - window_size):
        X.append(processed_data[i:i+window_size])
        y.append(processed_data[i+window_size])

    X, y = da.array(X), da.array(y)
    return X, y


def time_series_split(X, y, train_ratio=0.7, val_ratio=0.2):
    total_length = X.shape[0]
    
    # Compute end indices for each split
    train_end = int(total_length * train_ratio)
    val_end = int(total_length * (train_ratio + val_ratio))
    
    X_train = X[:train_end]
    y_train = y[:train_end]
    
    X_val = X[train_end:val_end]
    y_val = y[train_end:val_end]
    
    X_test = X[val_end:]
    y_test = y[val_end:]
    
    return X_train, y_train, X_val, y_val, X_test, y_test


In [None]:
%%time

processed_data = preprocess_data(ds_sample)
processed_data

In [None]:
# Split data into training and test
X, y = prepare_data_from_processed(processed_data)
X_train, y_train, X_val, y_val, X_test, y_test = time_series_split(X, y)

In [None]:
training_dims = np.shape(X_train)
training_dims

# Build model

In [None]:
def create_simple_model(input_shape, target_shape):
    model = Sequential()
    
    model.add(Conv3D(filters=8, kernel_size=(3, 3, 3), input_shape=input_shape, padding='same'))
    model.add(MaxPooling3D(pool_size=(2, 2, 2)))
    
    model.add(Conv3D(filters=16, kernel_size=(3, 3, 3), input_shape=input_shape, padding='same'))
    model.add(MaxPooling3D(pool_size=(2, 2, 2)))

    model.add(Conv3D(filters=32, kernel_size=(3, 3, 3), input_shape=input_shape, padding='same'))
    
    model.add(Flatten())
    model.add(Dense(64, activation='relu'))
    model.add(Dense(target_shape[0] * target_shape[1], activation='linear'))  # Output flattened to match target shape
    model.add(Reshape(target_shape))  # Reshape output to match target shape

    return model

# Define input shape and target shape and create the model
input_shape = X_train.shape[1:] + (1,)
target_shape = y_train.shape[1:]
model = create_simple_model(input_shape, target_shape)
model.summary()

# Compile the model
model.compile(optimizer='adam', loss='mse', metrics=['mse'])

# Evalute model

In [None]:
%%time

# TODO: Try other optimizers (ensemble of Adagrad and others)

model.compile(optimizer='adam', loss='mse', metrics=['mse'])

early_stop = EarlyStopping(patience=5, restore_best_weights=True)

train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))

train_dataset = train_dataset.shuffle(buffer_size=1024).batch(32)

val_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val))

val_dataset = val_dataset.batch(32)

history = model.fit(train_dataset, epochs=20, validation_data=val_dataset, callbacks=[early_stop])

In [None]:
def preprocess_vis_input_data(day_data):
    day_data = np.squeeze(day_data)
    mean_val = np.nanmean(day_data)
    processed_data = day_data - mean_val
    # Replace NaNs with 0.0
    processed_data = np.where(np.isnan(processed_data), 0.0, processed_data)
    return processed_data

def postprocess_prediction(prediction, input_data):
    # Find positions where the last day of input_data is 0
    # land_mask = np.load('land_mask.npy')
    
    # Set those positions in the prediction to NaN
    # Due to resolution mismatch between ERA5 and MUR data, this land_mask cannot be used. Ignore for now
    # prediction[land_mask] = np.nan
    
    # Add back the historical mean
    mean_val = np.nanmean(input_data)
    prediction = np.where(np.isnan(prediction), np.nan, prediction + mean_val)
    
    return prediction

def predict_and_plot(date_to_predict, window_size, model, dataset, plot=True):
    # Step 1: Select the time window
    time_index = np.where(dataset['time'].values.astype('datetime64[D]') == np.datetime64(date_to_predict))[0][0]
    input_data_raw = dataset['sst'][time_index-window_size:time_index].values
    true_output_raw = dataset['sst'][time_index].values
    print(input_data_raw.shape)
    print(true_output_raw.shape)
    # Preprocess the input data
    input_data = np.array([preprocess_vis_input_data(day) for day in input_data_raw])
    
    # Step 2: Make prediction
    prediction = model.predict(input_data[np.newaxis, ...])[0]
    
    # Postprocess the prediction
    prediction_postprocessed = postprocess_prediction(prediction, input_data_raw)
    print(prediction_postprocessed.shape)
    # Step 3: Visualize
    if plot:
        # Determine common scale for all plots
        input_data_raw = input_data_raw[..., np.newaxis]
        true_output_raw = true_output_raw[np.newaxis, ..., np.newaxis]
        prediction_postprocessed = prediction_postprocessed[np.newaxis, ..., np.newaxis]
        
        all_data = np.concatenate([input_data_raw, prediction_postprocessed, true_output_raw])
        vmin = np.nanmin(all_data)
        vmax = np.nanmax(all_data)
        
        def plot_sample(sample, title=''):
            sample_2d = np.squeeze(sample)
            plt.imshow(sample_2d, cmap='viridis', vmin=vmin, vmax=vmax)
            plt.title(title)
            plt.colorbar()
            
            # I am getting a plot mirrored along the y-axis
            # Should probably should invert actual data rather than plot
            
            plt.gca().invert_yaxis()
            plt.show()

        # show input frames
        for i, frame in enumerate(input_data_raw):
            plot_sample(frame, title=f'Input Frame {i+1} ({dataset["time"].values[time_index-window_size+i]})')
        
        # show predicted output
        plot_sample(prediction_postprocessed, title=f'Predicted Output ({date_to_predict})')
        
        # show true output
        plot_sample(true_output_raw, title=f'True Output ({date_to_predict})')

    return input_data_raw, prediction_postprocessed, true_output_raw

In [None]:
def compute_mae(y_true, y_pred):
    mask = ~np.isnan(y_true) & ~np.isnan(y_pred)
    return np.mean(np.abs(y_true[mask] - y_pred[mask]))

In [None]:
date_to_predict = '2009-01-07'
window_size = 5
input_data, predicted_output, true_output = predict_and_plot(date_to_predict, window_size, model, ds_sample)

predicted_mae = compute_mae(true_output, predicted_output)
print(f"MAE between Predicted Output and True Output: {predicted_mae}")

last_input_frame = input_data[-1]
last_input_frame_2d = np.squeeze(last_input_frame)
true_output_2d = np.squeeze(true_output)
last_frame_mae = compute_mae(true_output_2d, last_input_frame_2d)
print(f"MAE between Last Input Frame and True Output: {last_frame_mae}")