# Setup virtual environment and import modules

- Run next cell once, if you don't have all of the modules in the cell after installed
- First line converts the cell into bash commands
- Second line creates virtual environment
- Third line links existing modules to your newly created environment
- `pip install` any missing modules. Just tensorflow in my case
- Activate virtual environment
- Enable virtual environment into a Jupyter kernel

- Restart kernel and select

In [1]:
# %%sh
# python -m venv ~/venvs/myenv
# realpath /env/lib/python3.10/site-packages > ~/venvs/myenv/lib/python3.10/site-packages/base_venv.pth
# pip install tensorflow
# source ~/venvs/myenv/bin/activate
# python -m ipykernel install --user --name=myenv --display-name "myenv"

In [1]:
import gc

import s3fs
import xarray as xr
import numpy as np
import pandas as pd
import dask.array as da
import tensorflow as tf

from tensorflow.keras import Input, Model
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import BatchNormalization, Conv2D, Conv3D, Flatten, Dense, MaxPool3D, GlobalAveragePooling3D, Dropout
from tensorflow.keras.callbacks import EarlyStopping

from dask.distributed import Client, LocalCluster
from dask.delayed import delayed
from sklearn.model_selection import train_test_split

2023-08-10 07:16:15.780155: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-08-10 07:16:17.031019: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


# Initialise Dask LocalCluster

In [2]:
cluster = LocalCluster(n_workers=4)
client = Client(cluster)

In [3]:
# Hao's dashboard link if in CSIRO environment
# https://hub.csiro.easi-eo.solutions/user/csiro-csiro-aad_tan196@csiro.au/proxy/8787/status
# Otherwise use provided dashboard link
client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 4
Total threads: 32,Total memory: 124.00 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:36151,Workers: 4
Dashboard: http://127.0.0.1:8787/status,Total threads: 32
Started: Just now,Total memory: 124.00 GiB

0,1
Comm: tcp://127.0.0.1:41843,Total threads: 8
Dashboard: http://127.0.0.1:41151/status,Memory: 31.00 GiB
Nanny: tcp://127.0.0.1:36077,
Local directory: /tmp/dask-worker-space/worker-tjs8ouno,Local directory: /tmp/dask-worker-space/worker-tjs8ouno

0,1
Comm: tcp://127.0.0.1:41091,Total threads: 8
Dashboard: http://127.0.0.1:44275/status,Memory: 31.00 GiB
Nanny: tcp://127.0.0.1:33851,
Local directory: /tmp/dask-worker-space/worker-6cthcnx8,Local directory: /tmp/dask-worker-space/worker-6cthcnx8

0,1
Comm: tcp://127.0.0.1:43015,Total threads: 8
Dashboard: http://127.0.0.1:33063/status,Memory: 31.00 GiB
Nanny: tcp://127.0.0.1:43309,
Local directory: /tmp/dask-worker-space/worker-2wor0qbq,Local directory: /tmp/dask-worker-space/worker-2wor0qbq

0,1
Comm: tcp://127.0.0.1:34085,Total threads: 8
Dashboard: http://127.0.0.1:39823/status,Memory: 31.00 GiB
Nanny: tcp://127.0.0.1:45317,
Local directory: /tmp/dask-worker-space/worker-z4w_ei0s,Local directory: /tmp/dask-worker-space/worker-z4w_ei0s


# Load MUR Satellite Data (Skip if using ERA5)

In [4]:
# Bypass AWS tokens, keys etc.
s3 = s3fs.S3FileSystem(anon=True)

# Verify that we're in the right place
sst_files = s3.ls("mur-sst/zarr-v1/")
sst_files

['mur-sst/zarr-v1/',
 'mur-sst/zarr-v1/.zattrs',
 'mur-sst/zarr-v1/.zgroup',
 'mur-sst/zarr-v1/.zmetadata',
 'mur-sst/zarr-v1/analysed_sst',
 'mur-sst/zarr-v1/analysis_error',
 'mur-sst/zarr-v1/lat',
 'mur-sst/zarr-v1/lon',
 'mur-sst/zarr-v1/mask',
 'mur-sst/zarr-v1/sea_ice_fraction',
 'mur-sst/zarr-v1/time']

In [5]:
# Load
ds = xr.open_zarr(
        store=s3fs.S3Map(
            root=f"s3://{sst_files[0]}", s3=s3, check=False
        )
     )

# Subsetting dataset so JupyterHub can handle it

In [6]:
# Check that data is daily only and no missing days
assert pd.infer_freq(ds.time.values) == 'D'

# Good job NASA :) 

In [7]:
def dates_to_mur_indices(start, end):
    
    start_of_dataset = np.datetime64('2002-06-01')
    end_of_dataset = np.datetime64('2020-01-20')
    
    if (type(start) == str) and (type(end) == str):
        try:
            start = np.datetime64(start)
            end = np.datetime64(end)
        except ValueError as e:
            print(e)
            raise TypeError(f'Date(s) not in the format YYYY-MM-DD')
    else:
        raise TypeError(f'Please enter dates as np.datetime64 or strings in the format YYYY-MM-DD')
        
    assert (start >= start_of_dataset) and (start < end_of_dataset), f'{start} out of dataset range {start_of_dataset} - {end_of_dataset}'
    assert (end <= end_of_dataset) and (end > start_of_dataset), f'{end} out of dataset range {start_of_dataset} - {end_of_dataset}'
    assert start <= end, f'start date {start} after end date {end}'
    
    start_index = (start - start_of_dataset).astype(int)
    end_index = (end - start_of_dataset).astype(int) + 1
    
    return range(start_index, end_index)

In [8]:
# Taking 2 weeks' worth of data
ds_sample = ds.isel(time=dates_to_mur_indices('2009-01-01', '2009-01-15'))
# ds_sample = ds.isel(time=0)

In [9]:
ds_sample

Unnamed: 0,Array,Chunk
Bytes,36.21 GiB,123.53 MiB
Shape,"(15, 17999, 36000)","(5, 1799, 3600)"
Dask graph,440 chunks in 3 graph layers,440 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 36.21 GiB 123.53 MiB Shape (15, 17999, 36000) (5, 1799, 3600) Dask graph 440 chunks in 3 graph layers Data type float32 numpy.ndarray",36000  17999  15,

Unnamed: 0,Array,Chunk
Bytes,36.21 GiB,123.53 MiB
Shape,"(15, 17999, 36000)","(5, 1799, 3600)"
Dask graph,440 chunks in 3 graph layers,440 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,36.21 GiB,123.53 MiB
Shape,"(15, 17999, 36000)","(5, 1799, 3600)"
Dask graph,440 chunks in 3 graph layers,440 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 36.21 GiB 123.53 MiB Shape (15, 17999, 36000) (5, 1799, 3600) Dask graph 440 chunks in 3 graph layers Data type float32 numpy.ndarray",36000  17999  15,

Unnamed: 0,Array,Chunk
Bytes,36.21 GiB,123.53 MiB
Shape,"(15, 17999, 36000)","(5, 1799, 3600)"
Dask graph,440 chunks in 3 graph layers,440 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,36.21 GiB,123.53 MiB
Shape,"(15, 17999, 36000)","(5, 1799, 3600)"
Dask graph,440 chunks in 3 graph layers,440 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 36.21 GiB 123.53 MiB Shape (15, 17999, 36000) (5, 1799, 3600) Dask graph 440 chunks in 3 graph layers Data type float32 numpy.ndarray",36000  17999  15,

Unnamed: 0,Array,Chunk
Bytes,36.21 GiB,123.53 MiB
Shape,"(15, 17999, 36000)","(5, 1799, 3600)"
Dask graph,440 chunks in 3 graph layers,440 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,36.21 GiB,123.53 MiB
Shape,"(15, 17999, 36000)","(5, 1799, 3600)"
Dask graph,440 chunks in 3 graph layers,440 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 36.21 GiB 123.53 MiB Shape (15, 17999, 36000) (5, 1799, 3600) Dask graph 440 chunks in 3 graph layers Data type float32 numpy.ndarray",36000  17999  15,

Unnamed: 0,Array,Chunk
Bytes,36.21 GiB,123.53 MiB
Shape,"(15, 17999, 36000)","(5, 1799, 3600)"
Dask graph,440 chunks in 3 graph layers,440 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [10]:
# Taking a smaller sample area
# ds_sample = ds_sample.sel(lat=slice(-5, 35), lon=slice(45,90))
ds_sample = ds.sel(lat=slice(-5, 35), lon=slice(45,90))

In [11]:
# Downsamplce by factor of 10
ds_sample = ds_sample.isel(lat=slice(0, None, 10), lon=slice(0, None, 10))

In [12]:
# # It seems that there are NaN values when I inspected my sample in memory later on - I'm not sure if I did this cell properly

# # Get dates with NaN values
# all_nan_dates = np.isnan(ds_sample["analysed_sst"]).all(dim=["lon", "lat"]).compute()

# # Were there any?
# if not all_nan_dates.any():
#     print('No NaN values')
# else:
#     print('NaN values exist')

In [13]:
# Rename SST variable so that I don't need to make changes to Jiarui and team's functions below
ds_sample = ds_sample.rename({'analysed_sst': 'sst'})

In [14]:
ds_sample

Unnamed: 0,Array,Chunk
Bytes,4.34 GiB,0.93 MiB
Shape,"(6443, 401, 451)","(5, 180, 271)"
Dask graph,7734 chunks in 4 graph layers,7734 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 4.34 GiB 0.93 MiB Shape (6443, 401, 451) (5, 180, 271) Dask graph 7734 chunks in 4 graph layers Data type float32 numpy.ndarray",451  401  6443,

Unnamed: 0,Array,Chunk
Bytes,4.34 GiB,0.93 MiB
Shape,"(6443, 401, 451)","(5, 180, 271)"
Dask graph,7734 chunks in 4 graph layers,7734 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.34 GiB,0.93 MiB
Shape,"(6443, 401, 451)","(5, 180, 271)"
Dask graph,7734 chunks in 4 graph layers,7734 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 4.34 GiB 0.93 MiB Shape (6443, 401, 451) (5, 180, 271) Dask graph 7734 chunks in 4 graph layers Data type float32 numpy.ndarray",451  401  6443,

Unnamed: 0,Array,Chunk
Bytes,4.34 GiB,0.93 MiB
Shape,"(6443, 401, 451)","(5, 180, 271)"
Dask graph,7734 chunks in 4 graph layers,7734 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.34 GiB,0.93 MiB
Shape,"(6443, 401, 451)","(5, 180, 271)"
Dask graph,7734 chunks in 4 graph layers,7734 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 4.34 GiB 0.93 MiB Shape (6443, 401, 451) (5, 180, 271) Dask graph 7734 chunks in 4 graph layers Data type float32 numpy.ndarray",451  401  6443,

Unnamed: 0,Array,Chunk
Bytes,4.34 GiB,0.93 MiB
Shape,"(6443, 401, 451)","(5, 180, 271)"
Dask graph,7734 chunks in 4 graph layers,7734 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.34 GiB,0.93 MiB
Shape,"(6443, 401, 451)","(5, 180, 271)"
Dask graph,7734 chunks in 4 graph layers,7734 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 4.34 GiB 0.93 MiB Shape (6443, 401, 451) (5, 180, 271) Dask graph 7734 chunks in 4 graph layers Data type float32 numpy.ndarray",451  401  6443,

Unnamed: 0,Array,Chunk
Bytes,4.34 GiB,0.93 MiB
Shape,"(6443, 401, 451)","(5, 180, 271)"
Dask graph,7734 chunks in 4 graph layers,7734 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


# Preprocess and split data (thanks for the functions Jiarui and team)

In [15]:
def preprocess_day_data(day_data):
    day_data = da.squeeze(day_data)
    mean_val = da.nanmean(day_data).compute()  # compute here to get scalar value
    return day_data - mean_val

#TODO: Utilise multiprocessing to parallelise if possible?
def preprocess_data(zarr_ds, chunk_size=200):
    total_len = zarr_ds['sst'].shape[0]
    chunk_shape = (chunk_size,) + zarr_ds['sst'].shape[1:]  # Adjusted chunking
    chunks = []

    for start_idx in range(0, total_len, chunk_size):
        end_idx = min(start_idx + chunk_size, total_len)
        
        # Directly slice the dask array without wrapping it with da.from_array again
        chunk = zarr_ds['sst'][start_idx:end_idx]
        
        processed_chunk = chunk.map_blocks(preprocess_day_data)
        
        # Use da.where to replace NaNs with 0.0
        processed_chunk = da.where(da.isnan(processed_chunk), 0.0, processed_chunk)
        
        chunks.append(processed_chunk)

    return da.concatenate(chunks, axis=0)


def prepare_data_from_processed(processed_data, window_size=5): 
    length = processed_data.shape[0]
    X, y = [], []

    for i in range(length - window_size):
        X.append(processed_data[i:i+window_size])
        y.append(processed_data[i+window_size])

    X, y = da.array(X), da.array(y)
    return X, y


def time_series_split(X, y, train_ratio=0.7, val_ratio=0.2):
    total_length = X.shape[0]
    
    # Compute end indices for each split
    train_end = int(total_length * train_ratio)
    val_end = int(total_length * (train_ratio + val_ratio))
    
    X_train = X[:train_end]
    y_train = y[:train_end]
    
    X_val = X[train_end:val_end]
    y_val = y[train_end:val_end]
    
    X_test = X[val_end:]
    y_test = y[val_end:]
    
    return X_train, y_train, X_val, y_val, X_test, y_test


In [16]:
%%time

processed_data = preprocess_data(ds_sample)
processed_data

CPU times: user 14.7 s, sys: 785 ms, total: 15.5 s
Wall time: 3min 8s


Unnamed: 0,Array,Chunk
Bytes,4.34 GiB,0.93 MiB
Shape,"(6443, 401, 451)","(5, 180, 271)"
Dask graph,7734 chunks in 170 graph layers,7734 chunks in 170 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 4.34 GiB 0.93 MiB Shape (6443, 401, 451) (5, 180, 271) Dask graph 7734 chunks in 170 graph layers Data type float32 numpy.ndarray",451  401  6443,

Unnamed: 0,Array,Chunk
Bytes,4.34 GiB,0.93 MiB
Shape,"(6443, 401, 451)","(5, 180, 271)"
Dask graph,7734 chunks in 170 graph layers,7734 chunks in 170 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [17]:
# Split data into training and test
X, y = prepare_data_from_processed(processed_data)
X_train, y_train, X_val, y_val, X_test, y_test = time_series_split(X, y)

In [18]:
training_dims = np.shape(X_train)
training_dims

(4506, 5, 401, 451)

In [19]:
# ValueError: Dimensions must be equal, but are 451 and 401 for '{{node mean_squared_error/SquaredDifference}} = SquaredDifference[T=DT_FLOAT](sequential_1/batch_normalization/FusedBatchNormV3, IteratorGetNext:1)' with input shapes: [?,5,401,451,32], [?,401,451].

# Some more processing - standardising SST values to fit into model

In [21]:
np.max(ds_sample.sst.values)

nan

In [23]:
ds_sample

Unnamed: 0,Array,Chunk
Bytes,4.34 GiB,0.93 MiB
Shape,"(6443, 401, 451)","(5, 180, 271)"
Dask graph,7734 chunks in 4 graph layers,7734 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 4.34 GiB 0.93 MiB Shape (6443, 401, 451) (5, 180, 271) Dask graph 7734 chunks in 4 graph layers Data type float32 numpy.ndarray",451  401  6443,

Unnamed: 0,Array,Chunk
Bytes,4.34 GiB,0.93 MiB
Shape,"(6443, 401, 451)","(5, 180, 271)"
Dask graph,7734 chunks in 4 graph layers,7734 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.34 GiB,0.93 MiB
Shape,"(6443, 401, 451)","(5, 180, 271)"
Dask graph,7734 chunks in 4 graph layers,7734 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 4.34 GiB 0.93 MiB Shape (6443, 401, 451) (5, 180, 271) Dask graph 7734 chunks in 4 graph layers Data type float32 numpy.ndarray",451  401  6443,

Unnamed: 0,Array,Chunk
Bytes,4.34 GiB,0.93 MiB
Shape,"(6443, 401, 451)","(5, 180, 271)"
Dask graph,7734 chunks in 4 graph layers,7734 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.34 GiB,0.93 MiB
Shape,"(6443, 401, 451)","(5, 180, 271)"
Dask graph,7734 chunks in 4 graph layers,7734 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 4.34 GiB 0.93 MiB Shape (6443, 401, 451) (5, 180, 271) Dask graph 7734 chunks in 4 graph layers Data type float32 numpy.ndarray",451  401  6443,

Unnamed: 0,Array,Chunk
Bytes,4.34 GiB,0.93 MiB
Shape,"(6443, 401, 451)","(5, 180, 271)"
Dask graph,7734 chunks in 4 graph layers,7734 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.34 GiB,0.93 MiB
Shape,"(6443, 401, 451)","(5, 180, 271)"
Dask graph,7734 chunks in 4 graph layers,7734 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 4.34 GiB 0.93 MiB Shape (6443, 401, 451) (5, 180, 271) Dask graph 7734 chunks in 4 graph layers Data type float32 numpy.ndarray",451  401  6443,

Unnamed: 0,Array,Chunk
Bytes,4.34 GiB,0.93 MiB
Shape,"(6443, 401, 451)","(5, 180, 271)"
Dask graph,7734 chunks in 4 graph layers,7734 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [22]:
ds_sample.sst.values

array([[[300.401  , 300.52   , 300.348  , ..., 302.787  , 302.568  ,
         302.527  ],
        [300.483  , 300.501  , 300.45398, ..., 303.034  , 302.744  ,
         302.716  ],
        [300.46   , 300.43   , 300.481  , ..., 302.288  , 302.483  ,
         302.69998],
        ...,
        [      nan,       nan,       nan, ...,       nan,       nan,
               nan],
        [      nan,       nan,       nan, ...,       nan,       nan,
               nan],
        [      nan,       nan,       nan, ...,       nan,       nan,
               nan]],

       [[300.399  , 300.54   , 300.373  , ..., 302.867  , 302.432  ,
         301.947  ],
        [300.52798, 300.522  , 300.503  , ..., 303.044  , 302.682  ,
         302.612  ],
        [300.469  , 300.43298, 300.5    , ..., 302.175  , 302.304  ,
         302.62   ],
        ...,
        [      nan,       nan,       nan, ...,       nan,       nan,
               nan],
        [      nan,       nan,       nan, ...,       nan,       nan,
   

# Build model

In [26]:
def create_simple_model(input_shape):
#     model = Sequential()

#     # 3D-CNN code goes here 

#     model.add(Conv3D(filters=32, kernel_size=(3, 3, 3),
#                          input_shape=input_shape,
#                          padding='same'
#                     )
#              )
#     model.add(BatchNormalization())
    
    # Let's start from scratch
    width = input_shape[0]
    height = input_shape[1]
    depth = input_shape[2]
    
    inputs = Input((width, height, depth, 1))

    x = Conv3D(filters=32, kernel_size=(3,3,3), activation="relu")(inputs)
    x = MaxPool3D(pool_size=2)(x)
    x = BatchNormalization()(x)
    
    print(x)

#     x = Conv3D(filters=64, kernel_size=(3,3,3), activation="relu")(x)
#     x = MaxPool3D(pool_size=2)(x)
#     x = BatchNormalization()(x)

#     x = Conv3D(filters=128, kernel_size=3, activation="relu")(x)
#     x = MaxPool3D(pool_size=2)(x)
#     x = BatchNormalization()(x)

#     x = Conv3D(filters=256, kernel_size=3, activation="relu")(x)
#     x = MaxPool3D(pool_size=2)(x)
#     x = BatchNormalization()(x)

    x = GlobalAveragePooling3D()(x)
    x = Dense(units=512, activation="relu")(x)
    x = Dropout(0.3)(x)

    outputs = Dense(units=1, activation="sigmoid")(x)

    # Define the model.
    model = Model(inputs, outputs, name="3dcnn")

    return model


In [27]:
%%time

# how to determine first dimension/argument?
model = create_simple_model((training_dims[1], training_dims[2], training_dims[3], 1))
# model.build((5, 401, 451, 1))
model.summary()

KerasTensor(type_spec=TensorSpec(shape=(None, 1, 199, 224, 32), dtype=tf.float32, name=None), name='batch_normalization_1/FusedBatchNormV3:0', description="created by layer 'batch_normalization_1'")
Model: "3dcnn"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 5, 401, 451, 1)   0         
                             ]                                   
                                                                 
 conv3d_2 (Conv3D)           (None, 3, 399, 449, 32)   896       
                                                                 
 max_pooling3d_1 (MaxPoolin  (None, 1, 199, 224, 32)   0         
 g3D)                                                            
                                                                 
 batch_normalization_1 (Bat  (None, 1, 199, 224, 32)   128       
 chNormalization)                                           

# Evalute model

In [28]:
%%time

# Try other optimizers (ensemble of Adagrad and others)
model.compile(optimizer='adam', loss='mse', metrics=['mse'])

early_stop = EarlyStopping(patience=5, restore_best_weights=True)

train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(32)

val_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val))
val_dataset = val_dataset.batch(32)

history = model.fit(train_dataset, epochs=20, validation_data=val_dataset, callbacks=[early_stop])



Epoch 1/20


InvalidArgumentError: Graph execution error:

Detected at node 'gradient_tape/mean_squared_error/BroadcastGradientArgs' defined at (most recent call last):
    File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
      exec(code, run_globals)
    File "/env/lib/python3.10/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/env/lib/python3.10/site-packages/traitlets/config/application.py", line 1043, in launch_instance
      app.start()
    File "/env/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 725, in start
      self.io_loop.start()
    File "/env/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/usr/lib/python3.10/asyncio/base_events.py", line 600, in run_forever
      self._run_once()
    File "/usr/lib/python3.10/asyncio/base_events.py", line 1896, in _run_once
      handle._run()
    File "/usr/lib/python3.10/asyncio/events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "/env/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 513, in dispatch_queue
      await self.process_one()
    File "/env/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 502, in process_one
      await dispatch(*args)
    File "/env/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 409, in dispatch_shell
      await result
    File "/env/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 729, in execute_request
      reply_content = await reply_content
    File "/env/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 422, in do_execute
      res = shell.run_cell(
    File "/env/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 540, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/env/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 2961, in run_cell
      result = self._run_cell(
    File "/env/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3016, in _run_cell
      result = runner(coro)
    File "/env/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/env/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3221, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/env/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3400, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/env/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3460, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/tmp/ipykernel_118/879337435.py", line 1, in <module>
      get_ipython().run_cell_magic('time', '', "\n# Try other optimizers (ensemble of Adagrad and others)\nmodel.compile(optimizer='adam', loss='mse', metrics=['mse'])\n\nearly_stop = EarlyStopping(patience=5, restore_best_weights=True)\n\ntrain_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))\ntrain_dataset = train_dataset.shuffle(buffer_size=1024).batch(32)\n\nval_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val))\nval_dataset = val_dataset.batch(32)\n\nhistory = model.fit(train_dataset, epochs=20, validation_data=val_dataset, callbacks=[early_stop])\n")
    File "/env/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 2430, in run_cell_magic
      result = fn(*args, **kwargs)
    File "/env/lib/python3.10/site-packages/IPython/core/magics/execution.py", line 1319, in time
      exec(code, glob, local_ns)
    File "<timed exec>", line 12, in <module>
    File "/home/jovyan/venvs/SST/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/jovyan/venvs/SST/lib/python3.10/site-packages/keras/src/engine/training.py", line 1742, in fit
      tmp_logs = self.train_function(iterator)
    File "/home/jovyan/venvs/SST/lib/python3.10/site-packages/keras/src/engine/training.py", line 1338, in train_function
      return step_function(self, iterator)
    File "/home/jovyan/venvs/SST/lib/python3.10/site-packages/keras/src/engine/training.py", line 1322, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/home/jovyan/venvs/SST/lib/python3.10/site-packages/keras/src/engine/training.py", line 1303, in run_step
      outputs = model.train_step(data)
    File "/home/jovyan/venvs/SST/lib/python3.10/site-packages/keras/src/engine/training.py", line 1084, in train_step
      self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
    File "/home/jovyan/venvs/SST/lib/python3.10/site-packages/keras/src/optimizers/optimizer.py", line 543, in minimize
      grads_and_vars = self.compute_gradients(loss, var_list, tape)
    File "/home/jovyan/venvs/SST/lib/python3.10/site-packages/keras/src/optimizers/optimizer.py", line 276, in compute_gradients
      grads = tape.gradient(loss, var_list)
Node: 'gradient_tape/mean_squared_error/BroadcastGradientArgs'
Incompatible shapes: [32,1] vs. [32,401,451]
	 [[{{node gradient_tape/mean_squared_error/BroadcastGradientArgs}}]] [Op:__inference_train_function_1417]