# Imports

In [1]:
# System imports
import os
import gc
import sys
import warnings
from itertools import chain
warnings.filterwarnings(action='ignore')

# Data handling
import numpy as np
import xarray as xr

# Plotting
import matplotlib.pyplot as plt
from IPython.display import Video, Image, HTML, display

np.set_printoptions(suppress = True, formatter = {'float_kind':'{:2f}'.format})

# Read in data

In [2]:
MITGCM_filename = "../../data/raw/cat_tave.nc"
ds = xr.open_dataset(MITGCM_filename)
ds

In [3]:
# Set other variables
subsample_rate = 200
start = 0
data_end_index = 7200
trainval_split_ratio = 0.7
valtest_split_ratio = 0.9
trainval_split = int(data_end_index * trainval_split_ratio)
valtest_split = int(data_end_index * valtest_split_ratio)
print(f"start: {start}")
print(f"trainval_split: {trainval_split}")
print(f"valtest_split: {valtest_split}")
print(f"test_split: {data_end_index}")
print(f"subsample_rate: {subsample_rate}")

start: 0
trainval_split: 5040
valtest_split: 6480
test_split: 7200
subsample_rate: 200


In [4]:
# Subsample the Dataset
trainval_range = range(start, trainval_split, subsample_rate)
valtest_range = range(trainval_split, valtest_split, subsample_rate)
test_range = range(valtest_split, data_end_index, subsample_rate)
sample_times = [(t, t + 1) for t in chain(trainval_range, valtest_range)]
ds_reduced = ds.isel(T=np.array(sample_times).flatten())
ds_reduced

In [7]:
# Free memory
ds = None
del ds
gc.collect()

469

In [9]:
da_T = ds_reduced["Ttave"].values
da_S = ds_reduced["Stave"].values
da_U_tmp = ds_reduced["uVeltave"].values
da_V_tmp = ds_reduced["vVeltave"].values
da_Kwx = ds_reduced["Kwx"].values
da_Kwy = ds_reduced["Kwy"].values
da_Kwz = ds_reduced["Kwz"].values
da_Eta = ds_reduced["ETAtave"].values
da_lat = ds_reduced["Y"].values
da_lon = ds_reduced["X"].values
da_depth = ds_reduced["Z"].values
# Calc U and V by averaging surrounding points, to get on same grid as other variables
da_U = (da_U_tmp[:, :, :, :-1] + da_U_tmp[:, :, :, 1:]) / 2.0
da_V = (da_V_tmp[:, :, :-1, :] + da_V_tmp[:, :, 1:, :]) / 2.0

In [12]:
density_file = "../../data/raw/DensityData.npy"
density = np.load(density_file, mmap_mode="r")
density.shape

(18000, 42, 78, 11)

In [15]:
clim_filename = "../../data/raw/ncra_cat_tave.nc"
ds_clim = xr.open_dataset(clim_filename)
da_clim_T = ds_clim["Ttave"].values

In [19]:
x_size = ds_reduced.dims["X"]
y_size = ds_reduced.dims["Y"]
z_size = ds_reduced.dims["Z"]
print(f"X: {x_size}\nY: {y_size}\nZ: {z_size}")

X: 11
Y: 78
Z: 42


In [21]:
# Region 1
x_lw_1 = 1
x_up_1 = (x_size - 2)
y_lw_1 = 1
y_up_1 = (y_size - 3)
z_lw_1 = 1
z_up_1 = (z_size - 1)

In [22]:
# Region 2
da_T2 = np.concatenate((da_T[:, :, :, -1:], da_T[:, :, :, :-1]), axis=3)
da_S2 = np.concatenate((da_S[:, :, :, -1:], da_S[:, :, :, :-1]), axis=3)
da_U2 = np.concatenate((da_U[:, :, :, -1:], da_U[:, :, :, :-1]), axis=3)
da_V2 = np.concatenate((da_V[:, :, :, -1:], da_V[:, :, :, :-1]), axis=3)
da_Kwx2 = np.concatenate((da_Kwx[:, :, :, -1:], da_Kwx[:, :, :, :-1]), axis=3)
da_Kwy2 = np.concatenate((da_Kwy[:, :, :, -1:], da_Kwy[:, :, :, :-1]), axis=3)
da_Kwz2 = np.concatenate((da_Kwz[:, :, :, -1:], da_Kwz[:, :, :, :-1]), axis=3)
da_Eta2 = np.concatenate((da_Eta[:, :, -1:], da_Eta[:, :, :-1]), axis=2)
da_lon2 = np.concatenate((da_lon[-1:], da_lon[:-1]), axis=0)
density2 = np.concatenate((density[:, :, :, -1:], density[:, :, :, :-1]), axis=3)

x_lw_2 = 1  # Note zero column is now what was at the -1 column!
x_up_2 = 2  # one higher than the point we want to forecast for, i.e. first point we're not forecasting
y_lw_2 = 1
y_up_2 = 15  # one higher than the point we want to forecast for, i.e. first point we're not forecasting
z_lw_2 = 1
z_up_2 = 31  # one higher than the point we want to forecast for, i.e. first point we're not forecasting

In [23]:
# Region 3
da_T3 = np.concatenate((da_T[:, :, :, 1:], da_T[:, :, :, :1]), axis=3)
da_S3 = np.concatenate((da_S[:, :, :, 1:], da_S[:, :, :, :1]), axis=3)
da_U3 = np.concatenate((da_U[:, :, :, 1:], da_U[:, :, :, :1]), axis=3)
da_V3 = np.concatenate((da_V[:, :, :, 1:], da_V[:, :, :, :1]), axis=3)
da_Kwx3 = np.concatenate((da_Kwx[:, :, :, 1:], da_Kwx[:, :, :, :1]), axis=3)
da_Kwy3 = np.concatenate((da_Kwy[:, :, :, 1:], da_Kwy[:, :, :, :1]), axis=3)
da_Kwz3 = np.concatenate((da_Kwz[:, :, :, 1:], da_Kwz[:, :, :, :1]), axis=3)
da_Eta3 = np.concatenate((da_Eta[:, :, 1:], da_Eta[:, :, :1]), axis=2)
da_lon3 = np.concatenate((da_lon[1:], da_lon[:1]), axis=0)
density3 = np.concatenate((density[:, :, :, 1:], density[:, :, :, :1]), axis=3)

x_lw_3 = x_size - 3
x_up_3 = (x_size - 1)
y_lw_3 = 1
y_up_3 = 15
z_lw_3 = 1
z_up_3 = 31

In [32]:
for t in range(len(trainval_range)):
    print(f"Index: {t} \t Time: {int(ds_reduced.T[t].values)} \t ")

Index: 0 	 Time: 1555286400
Index: 1 	 Time: 1555372800
Index: 2 	 Time: 1572566400
Index: 3 	 Time: 1572652800
Index: 4 	 Time: 1589846400
Index: 5 	 Time: 1589932800
Index: 6 	 Time: 1607126400
Index: 7 	 Time: 1607212800
Index: 8 	 Time: 1624406400
Index: 9 	 Time: 1624492800
Index: 10 	 Time: 1641686400
Index: 11 	 Time: 1641772800
Index: 12 	 Time: 1658966400
Index: 13 	 Time: 1659052800
Index: 14 	 Time: 1676246400
Index: 15 	 Time: 1676332800
Index: 16 	 Time: 1693526400
Index: 17 	 Time: 1693612800
Index: 18 	 Time: 1710806400
Index: 19 	 Time: 1710892800
Index: 20 	 Time: 1728086400
Index: 21 	 Time: 1728172800
Index: 22 	 Time: 1745366400
Index: 23 	 Time: 1745452800
Index: 24 	 Time: 1762646400
Index: 25 	 Time: 1762732800


In [41]:
for t in range(len(trainval_range), len(trainval_range) + len(valtest_range)):
    print(f"Index: {t} \t Time: {int(ds_reduced.T[t].values)}")

Index: 26 	 Time: 1779926400
Index: 27 	 Time: 1780012800
Index: 28 	 Time: 1797206400
Index: 29 	 Time: 1797292800
Index: 30 	 Time: 1814486400
Index: 31 	 Time: 1814572800
Index: 32 	 Time: 1831766400
Index: 33 	 Time: 1831852800


In [56]:
for t in range(len(trainval_range) + len(valtest_range), len(trainval_range) + len(valtest_range) + len(test_range)):
    print(f"Index: {t} \t Time: {int(ds_reduced.T[t].values)}")

Index: 34 	 Time: 1849046400
Index: 35 	 Time: 1849132800
Index: 36 	 Time: 1866326400
Index: 37 	 Time: 1866412800
