# ML flow:

## Setting up

In [None]:
# run script that installs missing libraries
#! chmod 755 ../../scripts.sh
#! ../../scripts.sh

# run script that installs missing libraries
#! chmod 755 ../../ML_scripts.sh
#! ../../ML_scripts.sh

### Imports:

In [None]:
# Basics
from matplotlib import pyplot as plt
import matplotlib.path as mpath
import os
import sys
from os import listdir
from os.path import isfile, join
import numpy as np
import pandas as pd
from datetime import datetime
from datetime import timedelta
from tqdm import tqdm
from re import search
from math import cos,sin,pi
import random as rn

# xarray and cartopy plots
import xarray as xr
import cartopy
import cartopy.crs as ccrs
import pyproj
from pyproj import Transformer
#import cf_units
#import rasterio

# Google file system
#import gcsfs
#from google.cloud import storage

# ML
from scipy import ndimage
import torch
import keras
import tensorflow as tf 
from keras import backend as K
from tensorflow.python.keras.backend import set_session

from keras.models import load_model, Model
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, UpSampling2D, Conv2DTranspose, Reshape, concatenate, BatchNormalization, Activation
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
from keras.models import Sequential
from netCDF4 import Dataset
from sklearn.model_selection import train_test_split
import setGPU

# Import custom scripts
sys.path.append('../')
#from process_pangeo import *
#from GC_scripts import *
#from processRCM import *
#from reprojectionFunctions import *
#from MakeInputFunctions import *
from model import *

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
print(f"Is CUDA supported by this system? {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
  
# Storing ID of current CUDA device
cuda_id = torch.cuda.current_device()
print(f"ID of current CUDA device:{torch.cuda.current_device()}")
        
print(f"Name of current CUDA device:{torch.cuda.get_device_name(cuda_id)}")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
if len(tf.config.list_physical_devices('GPU'))>0:
    physical_devices = tf.config.experimental.list_physical_devices('GPU')
    tf.config.experimental.set_memory_growth(physical_devices[0], True)

### Functions:

In [None]:
def plotAllVar(GCM_xy, m=3, n=3, name="GCM", time=0):
    vars_ = list(GCM_xy.data_vars)
    coords = list(GCM_xy.coords)
    f = plt.figure(figsize=(20, 10))
    map_proj = ccrs.SouthPolarStereo(central_longitude=0.0, globe=None)
    for i in range(len(vars_)):
        var = vars_[i]
        ax = plt.subplot(m, n, i + 1, projection=ccrs.SouthPolarStereo())
        GCM_xy[var].isel(time=time).plot(
            ax=ax, x="x", y="y", transform=ccrs.SouthPolarStereo(), add_colorbar=True
        )
        ax.coastlines("10m", color="black")
        ax.gridlines()
        ax.set_title(f"{GCM_xy[var].long_name} ({var})")
    plt.suptitle(f"First time step {GCM_xy.time[0].values} of {name}")

In [None]:
def resize(df, sizex, sizey, print_=True):
    # resize to input domain size:
    if print_:
        print("Shape before resizing:", df.shape)
    image = tf.constant(df)
    image_resized = tf.image.resize(image, (sizex, sizey))
    df_resized = image_resized.numpy()
    if print_:
        print("Shape after resizing:", df_resized.shape)

    return df_resized

In [None]:
def cutBoundaries(df, max_x, max_y, lowerHalf = False):
    df = df.where(df.x < max_x, drop=True)
    df = df.where(-max_x <= df.x, drop=True)
    if lowerHalf:
        df = df.where(df.y < 0, drop=True)
    else:
        df = df.where(df.y < max_y, drop=True)
    df = df.where(-max_y <= df.y, drop=True)
    return df

In [None]:
def takeRandomSamples(full_input, full_target, pred=False, full_prediction=None):
    randTime = rn.randint(0, len(full_input[0]-1))
    sample2dtrain = full_input[0][randTime]
    sample1dtrain = full_input[1][randTime]

    sampletarget = full_target[randTime]

    if pred:
        samplepred = full_prediction[randTime]

        return sample2dtrain, sample1dtrain, sampletarget, samplepred, randTime
    else:
        return sample2dtrain, sample1dtrain, sampletarget, randTime


def plotTrain(GCMLike, sample2dtrain, numVar, ax, time, list_var, region="Whole Antarctica"):
    if region != "Whole Antarctica":
        ds = createLowerInput(GCMLike, region = region, Nx=48, Ny=25, print_=False)
    else:
        ds = GCMLike

    VAR = list_var[numVar]
    coords = {"y": ds.coords["y"], "x": ds.coords["x"]}
    dftrain = xr.Dataset(coords=coords, attrs=ds.attrs)
    dftrain[VAR] = xr.Variable(
        dims=("y", "x"), data=sample2dtrain[:, :, numVar], attrs=ds[VAR].attrs
    )
    dftrain[VAR].plot(
        ax=ax, x="x", transform=ccrs.SouthPolarStereo(), add_colorbar=True, cmap='RdYlBu_r'
    )

    ax.coastlines("10m", color="black")
    ax.gridlines()

    ax.set_title(f"{time} Input: {VAR}")


def plotTarget(target_dataset, sampletarget, ax, vmin, vmax, region="Whole Antarctica"):
    if region != "Whole Antarctica":
        ds = createLowerTarget(target_dataset, region = region, Nx=64, Ny = 64, print_=False)
    else:
        ds = target_dataset

    coords = {"y": ds.coords["y"], "x": ds.coords["x"]}
    dftrain = xr.Dataset(coords=coords, attrs=ds.attrs)
    dftrain["SMB"] = xr.Variable(
        dims=("y", "x"), data=sampletarget[:, :, 0], attrs=ds["SMB"].attrs
    )
    pl = dftrain.SMB.plot(ax=ax, x="x", transform=ccrs.SouthPolarStereo(), add_colorbar=True, 
                          cmap='RdYlBu_r', vmin = vmin, vmax = vmax)
    ax.coastlines("10m", color="black")
    ax.gridlines()
    ax.set_title(f"Target: SMB")


def plotPred(target_dataset, samplepred, ax, vmin, vmax, region="Whole Antarctica"):
    if region != "Whole Antarctica":
        ds = createLowerTarget(target_dataset, region = region, Nx=64, Ny = 64, print_=False)
    else:
        ds = target_dataset
    coords = {"y": ds.coords["y"], "x": ds.coords["x"]}
    dftrain = xr.Dataset(coords=coords, attrs=ds.attrs)
    dftrain["SMB"] = xr.Variable(
        dims=("y", "x"), data=samplepred[:, :, 0], attrs=ds["SMB"].attrs
    )
    dftrain.SMB.plot(ax=ax, x="x", transform=ccrs.SouthPolarStereo(), add_colorbar=True, 
                          cmap='RdYlBu_r', vmin = vmin, vmax = vmax)
    ax.coastlines("10m", color="black")
    ax.gridlines()
    ax.set_title(f"Prediction: SMB")

## Load data:

In [None]:
# Path on GC with data:
pathGC = f'Chris_data/RawData/MAR-ACCESS1.3/monthly/RCM/'
pathCluster = '../../../../../../mlodata1/marvande/data/'

### Input: GCM-like RCM
I.e., to create X, Z

In [None]:
downloadFromGC = False
fileGCMLike = 'MAR(ACCESS1-3)-stereographic_monthly_GCM_like.nc'
if downloadFromGC:
    downloadFileFromGC(pathGC, '', fileGCMLike)
    GCMLike = xr.open_dataset(fileGCMLike)
    os.remove(fileGCMLike)
else:
    GCMLike = xr.open_dataset(pathCluster+fileGCMLike)
print(GCMLike.dims)
GCMLike

In [None]:
fig = plt.figure(figsize=(20, 10))
ax = plt.subplot(2, 1, 1, projection=ccrs.SouthPolarStereo())
GCMLike.SMB.isel(time=0).plot(
    x="x", ax=ax, transform=ccrs.SouthPolarStereo(), add_colorbar=False
)

ax.coastlines("10m", color="black")
ax.gridlines(draw_labels=True)


ax = plt.subplot(2, 1, 2, projection=ccrs.PlateCarree())
GCMLike.SMB.isel(time=0).plot(
    x="x", ax=ax, transform=ccrs.SouthPolarStereo(), add_colorbar=False
)


ax.coastlines("10m", color="black")
ax.gridlines(draw_labels=True)

### Target: RCM

In [None]:
# Open target file
downloadFromGC = False
fileTarget = 'MAR(ACCESS1-3)_monthly.nc'
if downloadFromGC:
    downloadFileFromGC(pathGC, '', fileTarget)
    target_dataset = xr.open_dataset(fileTarget)
    os.remove(fileTarget)
else:
    target_dataset = xr.open_dataset(pathCluster+fileTarget)
print(target_dataset.dims)

# Cut a small part of on the right that is not too important
N = 160
max_x = (N / 2) * 35 * 1000
max_y = (N / 2) * 35 * 1000

target_dataset = cutBoundaries(target_dataset, max_x, max_y)
print("New target dimensions:", target_dataset.dims)

In [None]:
# plot
fig = plt.figure(figsize=(5, 5))
ax = plt.subplot(1, 1, 1, projection=ccrs.SouthPolarStereo())
target_dataset.SMB.isel(time=0).plot(
    x="x", ax=ax, transform=ccrs.SouthPolarStereo(), add_colorbar=True
)

ax.coastlines("10m", color="black")
ax.gridlines(draw_labels=True)

## Cut to specific region:

### Target:
Cut to lower peninsula to get a 64x64 grid like in the paper.

Regions: 
- Larsen C ice shelf (Antarctic Peninsula)
- Pine Island & Thwaites ice shelves (Amundsen Sea Embayment)
- Roi Baudouin ice shelf (Dronning Maud Land)
- Shackleton ice shelf (Wilkes Land)

In [None]:
def cutBoundaries(df, max_x, max_y, lowerHalf = False):
    df = df.where(df.x < max_x, drop=True)
    df = df.where(-max_x <= df.x, drop=True)
    if lowerHalf:
        df = df.where(df.y < 0, drop=True)
    else:
        df = df.where(df.y < max_y, drop=True)
    df = df.where(-max_y <= df.y, drop=True)
    return df

In [None]:
def createLowerTarget(target_dataset, region, Nx=64, Ny = 64, print_=True):
    if region == 'Larsen':
        max_y = Nx * 35 * 1000
        max_x = Ny * 35 * 1000
        min_x = target_dataset.x.min().values

        df = target_dataset.where(target_dataset.y>=0, drop = True)
        df = df.where(df.y < max_y, drop = True)
        df = df.where(df.x < min_x+max_x, drop=True)
    
    if region == 'Lower Peninsula':
        max_x = (Nx / 2) * 35 * 1000
        max_y = Ny * 35 * 1000
        df = cutBoundaries(target_dataset, max_x, max_y, lowerHalf=True)
    
    if print_:
        print("New target dimensions:", df.dims)
    return df

In [None]:
LarsenTarget = createLowerTarget(target_dataset, region='Larsen', Nx=64, Ny = 64, print_=True)

# plot
fig = plt.figure(figsize=(10, 10))
ax = plt.subplot(1, 1, 1, projection=ccrs.SouthPolarStereo())
LarsenTarget.SMB.isel(time=0).plot(
    x="x", ax=ax, transform=ccrs.SouthPolarStereo(), add_colorbar=True
)

ax.coastlines("10m", color="black")
ax.gridlines(draw_labels=True)

In [None]:
lowerTarget = createLowerTarget(target_dataset, region='Lower Peninsula', Nx=64, Ny = 64, print_=True)

# plot
fig = plt.figure(figsize=(10, 10))
ax = plt.subplot(1, 1, 1, projection=ccrs.SouthPolarStereo())
lowerTarget.SMB.isel(time=0).plot(
    x="x", ax=ax, transform=ccrs.SouthPolarStereo(), add_colorbar=True
)

ax.coastlines("10m", color="black")
ax.gridlines(draw_labels=True)

### Input:
Cut to grid around target domain defined above, grid of 25x48, will resize later to 32x32 or 16x16

In [None]:
def createLowerInput(GCMLike, region, Nx=48, Ny=25, print_=True):
    
    if region == 'Lower Peninsula':
        max_x = (Nx/2) * 68 * 1000
        max_y = (Ny) * 206 * 1000
        
        df = GCMLike.where(GCMLike.x < max_x, drop=True)
        df = df.where(-max_x <= df.x, drop=True)
        
    if region == 'Larsen':
        min_x = GCMLike.x.min().values
        df = GCMLike.where(GCMLike.x < min_x+ (Nx * 68 * 1000), drop=True)
    if print_:
        print("New dimensions:", df.dims)

    return df

In [None]:
LarsenInput = createLowerInput(GCMLike,region = 'Larsen', Nx=48, Ny=25)

# plot
fig = plt.figure(figsize=(10, 10))
ax = plt.subplot(1, 1, 1, projection=ccrs.SouthPolarStereo())
LarsenInput.SMB.isel(time=0).plot(
    x="x", ax=ax, transform=ccrs.SouthPolarStereo(), add_colorbar=True
)

ax.coastlines("10m", color="black")
ax.gridlines(draw_labels=True)

## Create input:

**Z**:
- (ignore for now) External forcing also given to RCM → total concentration of greenhouse gases and solar and ozone forcings
- Cosinus, sinus vector to encode information about day of year
- Daily spatial means and standard deviations time series for each $X_{i,j,m}$ (because normalising 2D variables removes temporal information)

**X**: 
SHAPE `[nbmonths, x, y, nb_vars]`

In [None]:
# Download GMC like RCM input data:
downloadFromGC = False
fileGCMLike = "MAR(ACCESS1-3)-stereographic_monthly_GCM_like.nc"
if downloadFromGC:
    downloadFileFromGC(pathGC, "", fileGCMLike)
    GCMLike = xr.open_dataset(fileGCMLike)
    os.remove(fileGCMLike)
else:
    GCMLike = xr.open_dataset(pathCluster+fileGCMLike)

region = 'Lower Peninsula'
size_input_domain = 32
DATASET = GCMLike.drop(["SMB"])

DATASET = createLowerInput(DATASET,region =region, Nx=48, Ny=25, print_=False)
plotAllVar(DATASET)

In [None]:
def input_maker(
    GCMLike,
    size_input_domain=16,  # size of domain, format: 8,16,32, must be defined in advance
    stand=True,  # standardization
    seas=True,  # put a cos, sin vector to control the season, format : bool
    means=True,  # add the mean of the variables raw or stdz, format : bool
    stds=True,  # add the std of the variables raw or stdz, format : bool
    resize_input=True,  # resize input to size_input_domain
    region="Larsen",  # region of interest
    print_=True
):

    if region != "Whole Antarctica":
        DATASET = createLowerInput(GCMLike, region = region, Nx=48, Ny=25, print_=False)
    else:
        DATASET = GCMLike
    """
    MAKE THE 2D INPUT ARRAY
    SHAPE [nbmonths, x, y, nb_vars]
    """

    # Remove target variable from DATASET:
    DATASET = DATASET.drop(["SMB"])

    nbmonths = DATASET.dims["time"]
    x = DATASET.dims["x"]
    y = DATASET.dims["y"]
    nb_vars = len(list(DATASET.data_vars))
    VAR_LIST = list(DATASET.data_vars)

    INPUT_2D_bf = np.transpose(
        np.asarray([DATASET[i].values for i in VAR_LIST]), [1, 2, 3, 0]
    )

    # if no size is given, take smallest power of 2
    if size_input_domain == None:
        size_input_domain = np.max(
            [
                highestPowerof2(INPUT_2D_bf.shape[1]),
                highestPowerof2(INPUT_2D_bf.shape[2]),
            ]
        )

    if resize_input:
        # resize to size_input_domain
        INPUT_2D = resize(INPUT_2D_bf, size_input_domain, size_input_domain, print_)
    else:
        INPUT_2D = INPUT_2D_bf

    if stand:
        # Standardize:
        INPUT_2D_SDTZ = standardize(INPUT_2D)
        # in their code with aerosols extra stuff but ignore
        INPUT_2D_ARRAY = INPUT_2D_SDTZ
    else:
        INPUT_2D_ARRAY = INPUT_2D

    if print_:
        print("Parameters:\n -------------------")
        print("Size of input domain:", size_input_domain)
        print("Region:", region)
        print("\nCreating 2D input X:\n -------------------")
        print(f"Number of variables: {nb_vars}")
        print(f"Variables: {VAR_LIST}")
        print(f"INPUT_2D shape: {INPUT_2D_ARRAY.shape}")
        print("\nCreating 1D input Z:\n -------------------")
        
    """
    MAKE THE 1D INPUT ARRAY
    CONTAINS MEANS, STD SEASON IF ASKED
    """
    
    INPUT_1D = []
    if means and stds:
        vect_std = INPUT_2D.std(axis=(1, 2))
        vect_means = INPUT_2D.mean(axis=(1, 2))
        SpatialMean = vect_means.reshape(INPUT_2D.shape[0], 1, 1, INPUT_2D.shape[3])
        SpatialSTD = vect_std.reshape(INPUT_2D.shape[0], 1, 1, INPUT_2D.shape[3])

        INPUT_1D.append(SpatialMean)
        INPUT_1D.append(SpatialSTD)
        if print_:
            print(f"SpatialMean/std shape: {SpatialMean.shape}")

    if seas:
        months = 12
        cosvect = np.tile(
            [cos(2 * i * pi / months) for i in range(months)],
            int(INPUT_2D.shape[0] / months),
        )
        sinvect = np.tile(
            [sin(2 * i * pi / months) for i in range(months)],
            int(INPUT_2D.shape[0] / months),
        )
        cosvect = cosvect.reshape(INPUT_2D.shape[0], 1, 1, 1)
        sinvect = sinvect.reshape(INPUT_2D.shape[0], 1, 1, 1)

        INPUT_1D.append(cosvect)
        INPUT_1D.append(sinvect)
        if print_:
            print(f"Cos/sin encoding shape: {cosvect.shape}")

    INPUT_1D_ARRAY = np.concatenate(INPUT_1D, axis=3)
    if print_:
        print(f"INPUT_1D shape: {INPUT_1D_ARRAY.shape}")
        
    DATASET.close()
    return INPUT_2D_ARRAY, INPUT_1D_ARRAY, VAR_LIST

In [None]:
# Download GMC like RCM input data:
downloadFromGC = False
fileGCMLike = "MAR(ACCESS1-3)-stereographic_monthly_GCM_like.nc"
if downloadFromGC:
    downloadFileFromGC(pathGC, "", fileGCMLike)
    GCMLike = xr.open_dataset(fileGCMLike)
    os.remove(fileGCMLike)
else:
    GCMLike = xr.open_dataset(pathCluster+fileGCMLike)

region = 'Lower Peninsula'
size_input_domain = 32

# Make input
i2D, i1D, VAR_LIST = input_maker(
    GCMLike,
    size_input_domain,
    stand=True,  # standardization
    seas=True,  # put a cos,sin vector to control the season, format : bool
    means=True,  # add the mean of the variables raw or stdz, format : r,s,n
    stds=True,
    resize_input=True,
    region=region,
)

inputs_2D = []
inputs_1D = []
inputs_1D.append(i1D)
inputs_2D.append(i2D)

# Make a non standardised version for plots:
i2D_ns, i1D_ns, var_list = input_maker(
    GCMLike,
    size_input_domain,
    stand=False,  # standardization
    seas=True,  # put a cos,sin vector to control the season, format : bool
    means=True,  # add the mean of the variables raw or stdz, format : r,s,n
    stds=True,
    resize_input=False,
    region=region,
    print_=False,
)

inputs_2D_ns = []
inputs_1D_ns = []
inputs_1D_ns.append(i1D_ns)
inputs_2D_ns.append(i2D_ns)

## Create target:

In [None]:
def target_maker(
    target_dataset, 
    region="Larsen", # region of interest
    resize=True, # resize to target_size
    target_size=None # if none provided and resize true, set to min highest power of 2
):
    target_times = []
    targets = []

    if region != "Whole Antarctica":
        lowerTarget = createLowerTarget(target_dataset, region = region, Nx=64, Ny=64, print_=False)
        targetArray = lowerTarget.SMB.values
    else:
        targetArray = target_dataset.SMB.values

    targetArray = targetArray.reshape(
        targetArray.shape[0], targetArray.shape[1], targetArray.shape[2], 1
    )

    if target_size == None:
        # resize to highest power of 2:
        target_size = np.min(
            [
                highestPowerof2(targetArray.shape[1]),
                highestPowerof2(targetArray.shape[2]),
            ]
        )

    if resize:
        target_SMB = resize(targetArray, target_size, target_size)
    else:
        target_SMB = targetArray

    targets.append(target_SMB)
    target_times.append(target_dataset.time.values)

    full_target = np.concatenate(targets, axis=0)

    return full_target, target_times

In [None]:
full_target, target_times = target_maker(
    target_dataset, region=region, resize=False
)
# Full target to model
target_time = np.concatenate(target_times, axis=0)
target_lon = target_dataset["x"]
target_lat = target_dataset["y"]

## U-Net:

In [None]:
SCENARIO=['HIST' , 'RCP85']
var_list = ['RF', 'SP', 'LWD', 'SWD', 'TT', 'VVP', 'UUP']
var_pred = ['SMB']

In [None]:
# Full input to model
full_input=[np.concatenate(inputs_2D,axis=0),np.concatenate(inputs_1D,axis=0)]
full_input_ns=[np.concatenate(inputs_2D_ns,axis=0),np.concatenate(inputs_1D_ns,axis=0)]

print("Shapes of targets and inputs:\n---------------------------")
print("Target:", full_target.shape)
print("Input 2D:", full_input[0].shape)
print("Input 1D:", full_input[1].shape)

In [None]:
randTime = 0
dt = pd.to_datetime([GCMLike.time.isel(time=randTime).values])
time = str(dt.date[0])

sample = full_input[0][randTime, :, :, :]
sample = resize(sample, 25, 48, print_=False)
ax = plt.subplot(1, 2, 2, projection=ccrs.SouthPolarStereo())

plotTrain(GCMLike, sample, 1, ax, time, VAR_LIST, region=region)
ax.set_title("Standardised")
sample = full_input_ns[0][randTime, :, :, :]
sample = resize(sample, 25, 48, print_=False)
ax = plt.subplot(1, 2, 1, projection=ccrs.SouthPolarStereo())
plotTrain(GCMLike, sample, 1, ax, time, VAR_LIST, region=region)
ax.set_title("Unstandardised")

### Build U-Net:

In [None]:
## In Doury et al (2022) we chose to mask the over seas values as we are not interested in them, we did so by setting them always to 0.
# Ignore for now....
path_model = 'Results/'

#conv = 32
filter_size = 64
# set seed
seed = 123
rn.seed(seed)

# We use here the unet_maker function defined in make_unet
unet=unet_maker(nb_inputs=len(full_input),
                size_target_domain=full_target.shape[1],
                shape_inputs=[A.shape[1:] for A in full_input],
                filters = filter_size, seed = seed)

LR = 0.005
unet.compile(optimizer=Adam(learning_rate=LR), loss="mse", metrics=[rmse_k])
#unet.to(device)

In [None]:
#unet.summary()

### Training, validation and test separation:

In [None]:
def train_test(full_target, full_input, full_input_ns, perc=0.8):
    idx_train = rn.sample(range(full_target.shape[0]), int(perc * full_target.shape[0]))

    # Standardised inputs
    full_input_train = [
        full_input[k][idx_train, :, :, :] for k in range(len(full_input))
    ]
    full_input_test = [
        np.delete(full_input[k], idx_train, axis=0) for k in range(len(full_input))
    ]

    # Not standardised inputs for plots
    full_input_train_ns = [
        full_input_ns[k][idx_train, :, :, :] for k in range(len(full_input_ns))
    ]
    full_input_test_ns = [
        np.delete(full_input_ns[k], idx_train, axis=0)
        for k in range(len(full_input_ns))
    ]

    # Target
    full_target_train = full_target[idx_train, :, :]
    full_target_test = np.delete(full_target, idx_train, axis=0)

    train = {
        "input": full_input_train,
        "input_ns": full_input_train_ns,
        "target": full_target_train,
    }
    test = {
        "input": full_input_test,
        "input_ns": full_input_test_ns,
        "target": full_target_test,
    }

    return train, test

In [None]:
# Separation into test and training:
train_, test = train_test(full_target, full_input, full_input_ns, perc=0.9)

# Separation into validation and training:
train, validation = train_test(train_['target'], train_['input'], train_['input_ns'], perc=0.8)

# train
full_target_train = train['target']
full_input_train  = train['input']

# validation
full_target_val = validation['target']
full_input_val  = validation['input']

# test
full_target_test = test['target']
full_input_test = test['input']

# not standardised: 
full_input_train_ns = train['input_ns']
full_input_test_ns = test['input_ns']
full_input_val_ns = validation['input_ns']

print("Shapes of targets and inputs split into test and train:")
print("Target test:", full_target_test.shape)
print("Input 2D test:", full_input_test[0].shape)

print("--")
print("Target train:", full_target_train.shape)
print("Input 2D train:", full_input_train[0].shape)

print("--")
print("Target val:", full_target_val.shape)
print("Input 2D val:", full_input_val[0].shape)

### Example of training data:

In [None]:
f = plt.figure(figsize=(20, 30))
map_proj = ccrs.SouthPolarStereo(central_longitude=0.0, globe=None)

N = 5
for i in range(N):
    sample2dtrain_, sample1dtrain_, sampletarget_, randTime = takeRandomSamples(
        full_input_train_ns, full_target_train, pred=False
    )
    dt = pd.to_datetime([GCMLike.time.isel(time=randTime).values])
    time = str(dt.date[0])

    if region != "Whole Antarctica":
        sample2dtrain_ = resize(sample2dtrain_, 25, 48, print_=False)
    else:
        sample2dtrain_ = resize(sample2dtrain_, 25, 90, print_=False)

    sampletarget_ = sampletarget_.reshape(
        sampletarget_.shape[1], sampletarget_.shape[0], 1
    )
    M = 2
    vmin = np.min(sampletarget_)
    vmax = np.max(sampletarget_)
    
    for m in range(M):
        if m == 0:
            ax = plt.subplot(N, M, (i * M) + m + 1, projection=ccrs.SouthPolarStereo())
            plotTrain(GCMLike, sample2dtrain_, 4, ax, time, VAR_LIST, region=region)
        if m == 1:
            ax = plt.subplot(N, M, (i * M) + m + 1, projection=ccrs.SouthPolarStereo())
            plotTarget(target_dataset, sampletarget_, ax, vmin, vmax, region=region)

### Fit model:

In [None]:
batch_size, epochs = 32, 100
callbacks = [
    ReduceLROnPlateau(
        monitor="val_loss", factor=0.7, patience=4, verbose=1
    ),  ## callbacks to reduce the lr during the training
    EarlyStopping(
        monitor="val_loss", patience=15, verbose=1
    ),  ## Stops the fitting if val_loss does not improve after 15 iterations
    ModelCheckpoint(path_model, monitor="val_loss", verbose=1, save_best_only=True),
]  ## Save only best model

history = unet.fit(
    full_input_train,
    full_target_train[:, :, :, None],
    batch_size=batch_size,
    epochs=epochs,
    validation_data=(full_input_val, full_target_val[:, :, :, None]),
    callbacks=callbacks,
)

In [None]:
# list all data in history
print(history.history.keys())

In [None]:
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()

In [None]:
# summarize history for loss
plt.plot(history.history['rmse_k'])
plt.plot(history.history['val_rmse_k'])
plt.title('model rmse')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()

## Predictions:

In [None]:
# Load best model:
unet = keras.models.load_model(path_model, 
                                custom_objects = {"rmse_k": rmse_k})

In [None]:
# plot predictions:
training_pred = unet.predict(full_input_train)
val_pred = unet.predict(full_input_val)
testing_pred = unet.predict(full_input_test)

#### Training:

In [None]:
from datetime import date
today = str(date.today())

In [None]:
GCMLike.shape

In [None]:
f = plt.figure(figsize=(20, 30))
map_proj = ccrs.SouthPolarStereo(central_longitude=0.0, globe=None)

N = 5
for i in range(N):
    (
        sample2dtrain_,
        sample1dtrain_,
        sampletarget_,
        samplepred_,
        randTime,
    ) = takeRandomSamples(
        full_input_train_ns, full_target_train, pred=True, full_prediction=training_pred
    )
    dt = pd.to_datetime([GCMLike.time.isel(time=randTime).values])
    time = str(dt.date[0])

    if region != "Whole Antarctica":
        sample2dtrain = resize(sample2dtrain_, 25, 48, print_=False)
    else:
        sample2dtrain = resize(sample2dtrain_, 25, 90, print_=False)

    sampletarget = sampletarget_.reshape(
        sampletarget_.shape[0], sampletarget_.shape[1], 1
    )

    samplepred = samplepred_.reshape(samplepred_.shape[0], samplepred_.shape[1], 1)
    
    vmin = np.min([sampletarget, samplepred])
    vmax = np.max([sampletarget, samplepred])

    M = 3
    for m in range(M):
        if m == 0:
            ax = plt.subplot(N, M, (i * M) + m + 1, projection=ccrs.SouthPolarStereo())
            plotTrain(GCMLike, sample2dtrain, 4, ax, time,VAR_LIST, region)
        if m == 1:
            ax = plt.subplot(N, M, (i * M) + m + 1, projection=ccrs.SouthPolarStereo())
            plotTarget(target_dataset, sampletarget, ax, vmin, vmax, region)
        if m == 2:
            ax = plt.subplot(N, M, (i * M) + m + 1, projection=ccrs.SouthPolarStereo())
            plotPred(target_dataset, samplepred, ax, vmin, vmax, region)
            
plt.savefig(f'Results_plots/{today}_train_numepoch_{epochs}_bs_{batch_size}.png')

#### Test:

In [None]:
f = plt.figure(figsize=(20, 30))
map_proj = ccrs.SouthPolarStereo(central_longitude=0.0, globe=None)

N = 5
for i in range(N):
    (
        sample2dtrain_,
        sample1dtrain_,
        sampletarget_,
        samplepred_,
        randTime,
    ) = takeRandomSamples(
        full_input_test_ns, full_target_test, pred=True, full_prediction=testing_pred
    )
    dt = pd.to_datetime([GCMLike.time.isel(time=randTime).values])
    time = str(dt.date[0])

    if region != "Whole Antarctica":
        sample2dtrain = resize(sample2dtrain_, 25, 48, print_=False)
    else:
        sample2dtrain = resize(sample2dtrain_, 25, 90, print_=False)

    sampletarget = sampletarget_.reshape(
        sampletarget_.shape[0], sampletarget_.shape[1], 1
    )

    samplepred = samplepred_.reshape(samplepred_.shape[0], samplepred_.shape[1], 1)
    
    vmin = np.min([sampletarget, samplepred])
    vmax = np.max([sampletarget, samplepred])
    
    M = 3
    for m in range(M):
        if m == 0:
            ax = plt.subplot(N, M, (i * M) + m + 1, projection=ccrs.SouthPolarStereo())
            plotTrain(GCMLike, sample2dtrain, 4, ax, time,VAR_LIST, region)
        if m == 1:
            ax = plt.subplot(N, M, (i * M) + m + 1, projection=ccrs.SouthPolarStereo())
            plotTarget(target_dataset, sampletarget, ax, vmin, vmax, region)
        if m == 2:
            ax = plt.subplot(N, M, (i * M) + m + 1, projection=ccrs.SouthPolarStereo())
            plotPred(target_dataset, samplepred, ax, vmin, vmax, region)
            
plt.savefig(f'Results_plots/{today}_test_numepoch_{epochs}_bs_{batch_size}.png')

In [None]:
plotAllVar(GCMLike, time = 0)

In [None]:
plotAllVar(target_dataset, time = 0)

In [None]:
plotAllVar(lowerInput, time = 0)

## Evaluation:

Numerical metrics for whole period:
- **RMSE:** The Root Mean Squared Error measures the standard deviation
of the prediction error (for SMB)
- **Temporal Anomalies Correlation**: Pearson correlation coefficient after removing seasonal cycle
- **Ratio of Variance:** indicates performance of emulator in reproducing local daily variability
- **Wasserstein distance:** measures distance between two probability density functions (P,Q)


See for other periods
- **Climatology:** compare climatology maps over present (2006-2025) and future (2081-2100) climate
- **Number of days over 30°C**
- **99th Percentile**
- **Climate Change**

In [None]:
# rmse (from model.py):
rmse = rmse_k(y_true, y_pred)

# temporal anomalies correlation

# variance ratio

# wasserstein distance