In [None]:
'''
Developed by Juirai, Boris, Hao, combined by Paula Birocchi 

date: 09-08-2023

Ocean Hack Week 2023

This script was developed to run a machine learning model to predict SST surface distribution.

'''
# getting the libraries necessary to run this script:
import pandas as pd
from pathlib import Path
import xarray as xr
import numpy as np
import calendar
import os.path
import dask.array as da
from dask.delayed import delayed
from sklearn.model_selection import train_test_split
import gc
from tensorflow.keras import layers, regularizers, optimizers
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import Input, Dropout, Dense, Add, LayerNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping
# Core libraries for this tutorial
from eosdis_store import EosdisStore # Available via `pip install zarr zarr-eosdis-store`
import requests
from pqdm.threads import pqdm
from matplotlib import animation, pyplot as plt
from IPython.core.display import display, HTML
from pprint import pprint
# importing library for s3 buckets
import s3fs

# Bypass AWS tokens, keys etc.
s3 = s3fs.S3FileSystem(anon=True)

# Verify that we're in the right place
# getting the sattelite data from S3 buckets: 
sst_files = s3.ls("mur-sst/zarr-v1/")
sst_files

ds = xr.open_zarr(
        store=s3fs.S3Map(
            root=f"s3://{sst_files[0]}", s3=s3, check=False
        )
)
# We reduced the size of matrix to test the model, as we were having issues:
# slicing our data to be able to make the model run:
dscut = ds.sel(time=slice("2002-06-01", "2002-06-30"),lat=slice(5,7),lon=slice(50,52))



dscut['time'] = dscut['time'].dt.floor('D')


In [None]:
import dask.array as da
from dask.delayed import delayed
from sklearn.model_selection import train_test_split
import gc

def preprocess_day_data(day_data):
    day_data = da.squeeze(day_data)
    mean_val = da.nanmean(day_data).compute()  # compute here to get scalar value
    return day_data - mean_val

def preprocess_data(zarr_ds, chunk_size=200):
    total_len = zarr_ds['analysed_sst'].shape[0]
    chunk_shape = (chunk_size,) + zarr_ds['analysed_sst'].shape[1:]  # Adjusted chunking
    chunks = []

    for start_idx in range(0, total_len, chunk_size):
        end_idx = min(start_idx + chunk_size, total_len)
        
        # Directly slice the dask array without wrapping it with da.from_array again
        chunk = zarr_ds['analysed_sst'][start_idx:end_idx]
        
        processed_chunk = chunk.map_blocks(preprocess_day_data)
        
        # Use da.where to replace NaNs with 0.0
        processed_chunk = da.where(da.isnan(processed_chunk), 0.0, processed_chunk)
        
        chunks.append(processed_chunk)

    return da.concatenate(chunks, axis=0)

# processed_data = preprocess_data(zarr_ds).compute()
processed_data = preprocess_data(dscut)

def prepare_data_from_processed(processed_data, window_size=5): 
    length = processed_data.shape[0]
    X, y = [], []

    for i in range(length - window_size):
        X.append(processed_data[i:i+window_size])
        y.append(processed_data[i+window_size])

    X, y = da.array(X), da.array(y)
    return X, y

X, y = prepare_data_from_processed(processed_data)

In [None]:
def time_series_split(X, y, train_ratio=0.7, val_ratio=0.2):
    total_length = X.shape[0]
    
    # Compute end indices for each split
    train_end = int(total_length * train_ratio)
    val_end = int(total_length * (train_ratio + val_ratio))
    
    X_train = X[:train_end]
    y_train = y[:train_end]
    
    X_val = X[train_end:val_end]
    y_val = y[train_end:val_end]
    
    X_test = X[val_end:]
    y_test = y[val_end:]
    
    return X_train, y_train, X_val, y_val, X_test, y_test

X_train, y_train, X_val, y_val, X_test, y_test = time_series_split(X, y)

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import ConvLSTM2D, BatchNormalization, Conv2D

# please change the number of points in your matrix before running the following code:
def create_simple_model(input_shape=(5, 201,201, 1)):
    model = Sequential()

    # ConvLSTM layer
    model.add(ConvLSTM2D(filters=32, kernel_size=(3, 3),
                         input_shape=input_shape,
                         padding='same', return_sequences=False))
    model.add(BatchNormalization())
    
    # Conv2D layer for output
    model.add(Conv2D(filters=1, kernel_size=(3, 3), padding='same', activation='linear'))

    return model

model = create_simple_model()
model.summary()

model.compile(optimizer='adam', loss='mse', metrics=['mse'])

early_stop = EarlyStopping(patience=5, restore_best_weights=True)

train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(32)

val_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val))
val_dataset = val_dataset.batch(32)

In [None]:
history = model.fit(train_dataset, epochs=20, validation_data=val_dataset, callbacks=[early_stop])


In [None]:
def create_land_mask(data): 
    land_mask = np.isnan(data)
    return np.flipud(land_mask)

land_mask_resized = create_land_mask(X[0][0].compute())

np.save('land_mask_resized.npy', land_mask_resized)
def preprocess_vis_input_data(day_data):
    day_data = np.squeeze(day_data)
    mean_val = np.nanmean(day_data)
    processed_data = day_data - mean_val
    # Replace NaNs with 0.0
    processed_data = np.where(np.isnan(processed_data), 0.0, processed_data)
    return processed_data

def postprocess_prediction(prediction, input_data,land_mask_resized):
    # Find positions where the last day of input_data is 0
    
    # Set those positions in the prediction to NaN
    prediction[land_mask_resized] = np.nan
    
    # Add back the historical mean
    mean_val = np.nanmean(input_data)
    prediction = np.where(np.isnan(prediction), np.nan, prediction + mean_val)
    
    return prediction

def predict_and_plot(date_to_predict, window_size, model, dataset, plot=True):
    # Step 1: Select the time window
    time_index = np.where(dataset['time'].values == np.datetime64(date_to_predict))[0][0]
    input_data_raw = dataset['analysed_sst'][time_index-window_size:time_index].values
    true_output_raw = dataset['analysed_sst'][time_index].values
    print(input_data_raw.shape)
    print(true_output_raw.shape)
    # Preprocess the input data
    input_data = np.array([preprocess_vis_input_data(day) for day in input_data_raw])
    
    # Step 2: Make prediction
    prediction = model.predict(input_data[np.newaxis, ...])[0]
    
    # Postprocess the prediction
    prediction_postprocessed = postprocess_prediction(prediction, input_data_raw,land_mask_resized)
    print(prediction_postprocessed.shape)
    # Step 3: Visualize
    if plot:
        # Determine common scale for all plots
        input_data_raw = input_data_raw[..., np.newaxis]
        true_output_raw = true_output_raw[np.newaxis, ..., np.newaxis]
        prediction_postprocessed = prediction_postprocessed[np.newaxis, ...]
        
        all_data = np.concatenate([input_data_raw, prediction_postprocessed, true_output_raw])
        vmin = np.nanmin(all_data)
        vmax = np.nanmax(all_data)
        
        def plot_sample(sample,i, title=''):
            sample_2d = np.squeeze(sample)
            plt.imshow(sample_2d, cmap='viridis', vmin=vmin, vmax=vmax)
            plt.title(title)
            plt.colorbar()
            plt.savefig('figure1_test'+str(i)+'.png', bbox_inches='tight')
            plt.close()

        # show input frames
        for i, frame in enumerate(input_data_raw):
            plot_sample(frame, i,title=f'Input Frame {i+1} ({dataset["time"].values[time_index-window_size+i]})')
        
        # show predicted output
        plot_sample(prediction_postprocessed,i+1, title=f'Predicted Output ({date_to_predict})')
        
        # show true output
        plot_sample(true_output_raw, i+2,title=f'True Output ({date_to_predict})')

    return input_data_raw, prediction_postprocessed, true_output_raw
                


def compute_mae(y_true, y_pred):
    mask = ~np.isnan(y_true) & ~np.isnan(y_pred)
    return np.mean(np.abs(y_true[mask] - y_pred[mask]))

date_to_predict = '2002-06-30'
window_size = 5
input_data, predicted_output, true_output = predict_and_plot(date_to_predict, window_size, model, dscut)

predicted_mae = compute_mae(true_output, predicted_output)
print(f"MAE between Predicted Output and True Output: {predicted_mae}")

last_input_frame = input_data[-1]
last_input_frame_2d = np.squeeze(last_input_frame)
true_output_2d = np.squeeze(true_output)
last_frame_mae = compute_mae(true_output_2d, last_input_frame_2d)
print(f"MAE between Last Input Frame and True Output: {last_frame_mae}")

#model.save('ConvLSTM_nc_2002-.keras')
                
# just plotting land mask
import numpy as np
import matplotlib.pyplot as plt
data = np.load('land_mask_resized.npy')
plt.imshow(data, cmap='gray')
plt.colorbar()
plt.title('Land Mask')
plt.savefig('land_mask_version2.png')