In [20]:
import numpy as np
import xarray as xr
import pandas as pd
import copy
from datetime import datetime, timedelta
from keras.utils import to_categorical
# import visualkeras
import tensorflow as tf
from model_builders import *
from sklearn.metrics import balanced_accuracy_score
import optuna
from optuna.samplers import TPESampler
import keras
from keras.callbacks import ModelCheckpoint
from sklearn.utils.class_weight import compute_class_weight
import os
import warnings
import joblib
warnings.filterwarnings("ignore", category=UserWarning)

## GLOBAL SEED ##    
np.random.seed(42)
tf.random.set_seed(42)

In [2]:
def create_tf_datasets(input_data, output_data):
    # Convert xarray dataset to numpy array for TensorFlow Dataset
    input_images = input_data.transpose('time', 'lat', 'lon','channel').values
    output_one_hot = output_data.values

    # Create TensorFlow Datasets
    input_dataset = tf.data.Dataset.from_tensor_slices(input_images)
    output_dataset = tf.data.Dataset.from_tensor_slices(output_one_hot)

    # Combine input and output datasets into a joint dataset
    joint_dataset = tf.data.Dataset.zip((input_dataset, output_dataset))

    return joint_dataset

def create_datasets(input_anoms, var_name, df_shifts, week_out):
# Assuming you have the xarray.Dataset 'input_data' and the pandas.Series 'output_data'
    input_data = copy.deepcopy(input_anoms[var_name])

    array_temp = input_data.data
    array_temp[np.isfinite(array_temp)==False]=0
    input_data.data = array_temp

    input_data = (input_data - input_data.mean('time')) / (input_data.std('time'))
    # Reshape the data to add a new dimension
    values_reshaped = input_data.values.reshape(input_data.shape[0], input_data.shape[1], input_data.shape[2], 1)

    # Create a new xarray.DataArray with the reshaped data and the original coordinates
    input_data = xr.DataArray(values_reshaped, coords=input_data.coords, dims=('time', 'lat', 'lon', 'channel'))
    output_data = copy.deepcopy(df_shifts[f'week{week_out}']).dropna()

    # Step 1: Create a common date index that includes all dates in both the input and output data
    common_dates = np.intersect1d(input_data['time'].values, output_data.index)

    # Step 2: Reindex the input xarray dataset and the output DataFrame to the common date index
    input_data = input_data.sel(time=common_dates)
    output_data = output_data.loc[common_dates]

    # Step 3: One-hot encode the output DataFrame using to_categorical
    num_classes = len(output_data.unique())  # Number of classes (number of weeks in this case)
    output_data_encoded = to_categorical(output_data, num_classes=num_classes)
    output_data_encoded = pd.DataFrame(output_data_encoded,index=output_data.index)

    # Step 4: Create masks for training, validation, and testing periods
    train_mask = (output_data.index >= '1980-01-01') & (output_data.index <= '2010-12-31')
    val_mask = (output_data.index >= '2011-01-01') & (output_data.index <= '2015-12-31')
    test_mask = (output_data.index >= '2016-01-01') & (output_data.index <= '2020-12-31')

    # Step 5: Split the input xarray dataset and the output DataFrame into subsets
    input_train = input_data.sel(time=train_mask)
    input_val = input_data.sel(time=val_mask)
    input_test = input_data.sel(time=test_mask)

    output_train = output_data_encoded.loc[train_mask]
    output_val = output_data_encoded.loc[val_mask]
    output_test = output_data_encoded.loc[test_mask]

    train_joint_dataset = create_tf_datasets(input_train, output_train)
    val_joint_dataset = create_tf_datasets(input_val, output_val)
    test_joint_dataset = create_tf_datasets(input_test, output_test)

    buffer_size = train_joint_dataset.cardinality()
    train_joint_dataset = train_joint_dataset.shuffle(buffer_size)
    return train_joint_dataset, val_joint_dataset, test_joint_dataset

def get_output_from_dataset(dataset):
    output_array = []
    for input_data, output_data in dataset.as_numpy_iterator():
        output_array.append(output_data)

    # Convert the list of NumPy arrays into a single NumPy array
    output_array = np.array(output_array)
    return output_array

def balanced_accuracy(y_true, y_pred):
    y_true = tf.argmax(y_true, axis=1)
    y_pred = tf.argmax(y_pred, axis=1)
    return tf.py_function(balanced_accuracy_score, (y_true, y_pred), tf.float32)

def logging_callback(study, frozen_trial):
    previous_best_value = study.user_attrs.get("previous_best_value", None)
    if previous_best_value != study.best_value:
        study.set_user_attr("previous_best_value", study.best_value)
        print(
            "Trial {} finished with best value: {} and parameters: {}. ".format(
            frozen_trial.number,
            frozen_trial.value,
            frozen_trial.params,
            )
        )

In [3]:
# def check_dataset_integrity(dataset, name):
#     input_shapes = set()
#     output_shapes = set()
#     num_samples = 0

#     for input_data, output_data in dataset:
#         num_samples += 1
#         input_shapes.add(tuple(input_data.shape.as_list()))
#         output_shapes.add(tuple(output_data.shape.as_list()))

#     print(f"Dataset: {name}")
#     print(f"Input Shapes: {input_shapes}")
#     print(f"Output Shapes: {output_shapes}")
#     print(f"Number of Samples: {num_samples}")


# # Check shapes and number of samples in the training dataset
# check_dataset_integrity(train_joint_dataset, "Training Dataset")

# # Check shapes and number of samples in the validation dataset
# check_dataset_integrity(val_joint_dataset, "Validation Dataset")

# # Check shapes and number of samples in the test dataset
# check_dataset_integrity(test_joint_dataset, "Test Dataset")


In [28]:
class Objective(object):
    def __init__(self, train_joint_dataset, val_joint_dataset, test_joint_dataset,
                 path_models, variable, week):
        self.train_joint_dataset = train_joint_dataset
        self.val_joint_dataset = val_joint_dataset
        self.test_joint_dataset = test_joint_dataset
        self.path_models = path_models
        self.variable = variable
        self.week = week
 
    def __call__(self, trial):    
        keras.backend.clear_session()
        
        model_base = trial.suggest_categorical('model_base',['vanilla','resnet50','resnet101',\
                                                             'inception','xception','densenet'])
        ks = trial.suggest_categorical('ks',[3,5,7,9,11])
        ps = trial.suggest_categorical('ps',[2,4,6,8])
        type_pooling = trial.suggest_categorical('type_pooling',[None, 'avg','max'])
        stc = trial.suggest_categorical('stc',[1,2,3,4])
        stp = trial.suggest_categorical('stp',[1,2,3,4])
        do = trial.suggest_categorical('do',[0.3,0.4,0.5])
        md = trial.suggest_categorical('md',[2,4,8,16])
        nfilters = trial.suggest_categorical('nfilters',[4,8,16,32])
        activation = trial.suggest_categorical('activation',['LeakyReLU','ReLU'])
        weighted_loss = trial.suggest_categorical('weighted_loss',[True,False])
        
        dict_params = {'model_base':model_base,
                       'ks':ks,
                       'ps':ps,
                       'type_pooling':type_pooling,
                       'stc':stc,
                       'stp':stp,
                       'do':do,
                       'md':md,
                       'nfilters':nfilters,
                       'activation':activation,
                       'weighted_loss':weighted_loss}
        print(dict_params)                                      
        # instantiate and compile model
        if dict_params['model_base']=='vanilla':
            model = build_vanilla_cnn(dict_params['ks'],
                                      dict_params['ps'],
                                      dict_params['type_pooling'],
                                      dict_params['stc'],
                                      dict_params['stp'],
                                      dict_params['do'],
                                      dict_params['md'],
                                      dict_params['nfilters'],
                                      dict_params['activation'])
        elif dict_params['model_base']=='resnet50':
            model = build_resnet50_model(dict_params['type_pooling'],
                                         dict_params['do'],
                                         dict_params['md'],
                                         dict_params['activation'])
        elif dict_params['model_base']=='resnet101':
            model = build_resnet101_model(dict_params['type_pooling'],
                                         dict_params['do'],
                                         dict_params['md'],
                                         dict_params['activation'])
        elif dict_params['model_base']=='inception':
            model = build_inception_model(dict_params['type_pooling'],
                                         dict_params['do'],
                                         dict_params['md'],
                                         dict_params['activation'])
        elif dict_params['model_base']=='xception':
            model = build_xception_model(dict_params['type_pooling'],
                                         dict_params['do'],
                                         dict_params['md'],
                                         dict_params['activation'])
        elif dict_params['model_base']=='densenet':
            model = build_densenet_model(dict_params['type_pooling'],
                                         dict_params['do'],
                                         dict_params['md'],
                                         dict_params['activation'])
            
        model.compile(loss=keras.losses.categorical_crossentropy, 
                optimizer=keras.optimizers.Adam(lr=0.0001),metrics=[balanced_accuracy,'accuracy'])
        
        epochs = 100
        early_stopping_patience = 5

        # Create the EarlyStopping callback
        early_stopping_callback = tf.keras.callbacks.EarlyStopping(
            monitor='val_balanced_accuracy',  # Metric to monitor
            patience=early_stopping_patience,  # Number of epochs with no improvement
            restore_best_weights=True  # Restore the weights of the best model
        )

        # Train the model with early stopping
        try:
            os.mkdir(f'{self.path_models}{self.variable}')
        except: pass
    
        filepath = f'{self.path_models}{self.variable}/model_{self.week}_v9.h5'
        checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=True, 
                                     mode='auto',save_weights_only=False)
        
        if dict_params['weighted_loss']==True:
            
            y_train = get_output_from_dataset(self.train_joint_dataset)
            y_train_integers = np.argmax(y_train, axis=1)
            class_weights = compute_class_weight(class_weight='balanced',classes=np.unique(y_train_integers),
                                                 y = y_train_integers)
            d_class_weights = dict(enumerate(class_weights))
            
            history = model.fit(
                self.train_joint_dataset.batch(32),
                validation_data=self.val_joint_dataset.batch(32),
                class_weight = d_class_weights,
                epochs=epochs,
                callbacks=[checkpoint,early_stopping_callback]
            )
        else:
            history = model.fit(
                self.train_joint_dataset.batch(32),
                validation_data=self.val_joint_dataset.batch(32),
                epochs=epochs,
                callbacks=[checkpoint,early_stopping_callback],
                verbose=0
            )
        
        test_loss, test_balanced_accuracy, test_accuracy = model.evaluate(self.test_joint_dataset.batch(32))
        val_balanced_accuracy = np.max(history.history['val_balanced_accuracy'])
        val_accuracy = np.max(history.history['val_accuracy'])
        
        trial.set_user_attr('test_balanced_accuracy',test_balanced_accuracy)
        trial.set_user_attr('test_accuracy',test_accuracy)
        trial.set_user_attr('val_balanced_accuracy',val_balanced_accuracy)
        trial.set_user_attr('val_accuracy',val_accuracy)
        
        return val_balanced_accuracy

In [29]:
name_var = 'Z500_ERA5'
path_weekly_anoms = '/glade/scratch/jhayron/Data4Predictability/WeeklyAnoms/'
input_anoms = xr.open_dataset(f'{path_weekly_anoms}{name_var}.nc')
var_name = list(input_anoms.data_vars.keys())[0]
week_out = 3
week_out_str = f'week{week_out}'

wr_series = pd.read_csv('/glade/work/jhayron/Data4Predictability/WR_Series.csv',\
                index_col=0,names=['week0'],skiprows=1,parse_dates=True)
for wk in range(2,10):
    series_temp = copy.deepcopy(wr_series["week0"])
    series_temp.index = series_temp.index - timedelta(weeks = wk-1)
    series_temp.name = f'week{wk-1}'
    if wk==2:
        df_shifts = pd.concat([pd.DataFrame(wr_series["week0"]),pd.DataFrame(series_temp)],axis=1)  
    else:
        df_shifts = pd.concat([df_shifts,pd.DataFrame(series_temp)],axis=1)

In [30]:
train_joint_dataset, val_joint_dataset, test_joint_dataset = \
    create_datasets(input_anoms, var_name, df_shifts, week_out)

In [31]:
path_models = '/glade/work/jhayron/Data4Predictability/models/CNN/v0/'

In [32]:
optimizer_direction = 'maximize'
number_of_random_points = 30  # random searches to start opt process
maximum_time = 0.12*60*60  # seconds
objective = Objective(train_joint_dataset,val_joint_dataset,test_joint_dataset,
                      path_models,name_var,week_out_str)
    
results_directory = f'/glade/work/jhayron/Data4Predictability/models/CNN/results_optuna/{week_out_str}/'

try:
    os.mkdir(results_directory)
except:
    pass

study_name = f'study_{name_var}_{week_out_str}_v0'
storage_name = f'sqlite:///{study_name}.db'

optuna.logging.set_verbosity(optuna.logging.WARNING)
study = optuna.create_study(direction=optimizer_direction,
        sampler=TPESampler(n_startup_trials=number_of_random_points),
        study_name=study_name, storage=storage_name,load_if_exists=True)

study.optimize(objective, timeout=maximum_time, gc_after_trial=True,callbacks=[logging_callback],)

# save results
df_results = study.trials_dataframe()
df_results.to_pickle(results_directory + f'df_optuna_results_{name_var}_v0.pkl')
df_results.to_csv(results_directory + f'df_optuna_results_{name_var}_v0.csv')
#save study
joblib.dump(study, results_directory + f'optuna_study_{name_var}_v0.pkl')

{'model_base': 'densenet', 'ks': 11, 'ps': 2, 'type_pooling': None, 'stc': 4, 'stp': 2, 'do': 0.3, 'md': 2, 'nfilters': 4, 'activation': 'ReLU', 'weighted_loss': True}
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Trial 0 finished with best value: 0.2604166865348816 and parameters: {'model_base': 'densenet', 'ks': 11, 'ps': 2, 'type_pooling': None, 'stc': 4, 'stp': 2, 'do': 0.3, 'md': 2, 'nfilters': 4, 'activation': 'ReLU', 'weighted_loss': True}. 
{'model_base': 'xception', 'ks': 5, 'ps': 4, 'type_pooling': 'avg', 'stc': 4, 'stp': 1, 'do': 0.5, 'md': 4, 'nfilters': 4, 'activation': 'ReLU', 'weighted_loss': True}
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100


['/glade/work/jhayron/Data4Predictability/models/CNN/results_optuna/week3/optuna_study_Z500_ERA5_v0.pkl']

In [33]:
df_results

Unnamed: 0,number,value,datetime_start,datetime_complete,duration,params_activation,params_do,params_ks,params_md,params_model_base,...,params_ps,params_stc,params_stp,params_type_pooling,params_weighted_loss,user_attrs_test_accuracy,user_attrs_test_balanced_accuracy,user_attrs_val_accuracy,user_attrs_val_balanced_accuracy,state
0,0,0.260417,2023-07-30 20:59:35.542110,2023-07-30 21:04:57.669291,0 days 00:05:22.127181,ReLU,0.3,11,2,densenet,...,2,4,2,,True,0.186047,0.25,0.272727,0.260417,COMPLETE
1,1,0.21875,2023-07-30 21:04:58.743773,2023-07-30 21:09:58.680051,0 days 00:04:59.936278,ReLU,0.5,5,4,xception,...,4,4,1,avg,True,0.293023,0.25,0.212121,0.21875,COMPLETE


In [27]:
study.trials_dataframe()

Unnamed: 0,number,value,datetime_start,datetime_complete,duration,params_activation,params_do,params_ks,params_md,params_model_base,params_nfilters,params_ps,params_stc,params_stp,params_type_pooling,params_weighted_loss,state
0,0,0.29059,2023-07-30 18:55:52.954667,2023-07-30 18:59:48.385126,0 days 00:03:55.430459,LeakyReLU,0.5,7,4,resnet50,4,6,4,3,max,False,COMPLETE
1,1,0.260417,2023-07-30 18:59:48.788530,2023-07-30 19:01:40.926363,0 days 00:01:52.137833,ReLU,0.3,5,16,resnet50,8,8,2,4,max,True,COMPLETE
2,2,0.21875,2023-07-30 19:01:41.432905,2023-07-30 19:06:35.628478,0 days 00:04:54.195573,ReLU,0.4,9,2,xception,16,6,4,2,avg,True,COMPLETE
3,3,0.290136,2023-07-30 19:06:35.987198,2023-07-30 19:12:56.481747,0 days 00:06:20.494549,LeakyReLU,0.4,11,2,resnet101,4,4,1,4,avg,True,COMPLETE
4,4,0.306908,2023-07-30 19:12:56.992562,2023-07-30 19:18:16.798079,0 days 00:05:19.805517,ReLU,0.3,7,16,inception,4,4,3,4,max,True,COMPLETE
5,5,0.272874,2023-07-30 19:18:17.562916,2023-07-30 19:25:21.812051,0 days 00:07:04.249135,ReLU,0.4,3,2,densenet,4,4,3,1,max,True,COMPLETE
6,6,0.268188,2023-07-30 19:25:22.620584,2023-07-30 19:27:00.393401,0 days 00:01:37.772817,LeakyReLU,0.4,5,4,inception,32,6,3,3,max,False,COMPLETE
7,7,0.303073,2023-07-30 19:27:00.808516,2023-07-30 19:30:19.996218,0 days 00:03:19.187702,LeakyReLU,0.4,5,8,inception,4,8,2,4,max,False,COMPLETE
8,8,0.260417,2023-07-30 19:30:20.417347,2023-07-30 19:35:11.910293,0 days 00:04:51.492946,LeakyReLU,0.5,9,4,xception,16,8,1,3,avg,True,COMPLETE
9,9,0.260417,2023-07-30 19:35:12.261241,2023-07-30 19:38:00.927037,0 days 00:02:48.665796,ReLU,0.5,9,2,resnet101,32,8,3,4,avg,True,COMPLETE


In [2]:
# Parameters to test

#### 1. type_pooling
#### 2. 

In [None]:
#### Vanilla types

In [3]:
variables = ['z500','olr', 'sst', 'u10', 'sm_region', 'st_region']
name_var = ['z500','olr', 'sst', 'u10', 'sm', 'st']
units = ['m2/s2','J/m2','K','m/s','m3/m3','K']

In [24]:
for var_short, variable,unit in zip(name_var,variables,units):
    # for week in ['week1','week2','week3','week4','week5','week6']:
    for week in ['week3']:
        results_directory = f'/glade/scratch/jhayron/Weather_Regimes/models/CNN/results_optuna/{week}/'
        # study_optuna = joblib.load(results_directory + f'optuna_study_{var_short}_v4_acc.pkl')
        aaaa

NameError: name 'aaaa' is not defined

In [42]:
df_results = pd.read_csv(results_directory + f'df_optuna_results_olr_v4_acc.csv',index_col=1)

print(df_results[df_results.value==df_results.value.min()]\
    [['params_ks','params_nfilters','params_ps','params_stc','params_stp']])

        params_ks  params_nfilters  params_ps  params_stc  params_stp
number                                                               
296             9               16          4           3           1


In [43]:
df_results[df_results.value==df_results.value.min()]

Unnamed: 0_level_0,Unnamed: 0,value,datetime_start,datetime_complete,duration,params_activation,params_bn,params_bs,params_do,params_ks,params_md,params_nfilters,params_ps,params_stc,params_stp,params_type_pooling,params_wl,state
number,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
296,296,0.200855,2023-04-16 21:47:44.863815,2023-04-16 21:47:52.692154,0 days 00:00:07.828339,LeakyReLU,False,128,0.3,9,32,16,4,3,1,Max,False,COMPLETE


In [35]:
df_results.keys()

Index(['Unnamed: 0', 'value', 'datetime_start', 'datetime_complete',
       'duration', 'params_activation', 'params_bn', 'params_bs', 'params_do',
       'params_ks', 'params_md', 'params_nfilters', 'params_ps', 'params_stc',
       'params_stp', 'params_type_pooling', 'params_wl', 'state'],
      dtype='object')

In [36]:
12*24*9

2592

In [None]:
df_results[df_results.value==df_results.value.min()].params_ks

In [24]:
model = build_resnet101_model('avg',0.5,8,'LeakyReLU')
model.compile(loss=keras.losses.categorical_crossentropy, 
                optimizer=keras.optimizers.Adam(lr=0.0001),metrics=[balanced_accuracy,'accuracy'])

In [27]:
# Create the EarlyStopping callback
epochs = 100
early_stopping_patience = 5
early_stopping_callback = tf.keras.callbacks.EarlyStopping(
    monitor='val_balanced_accuracy',  # Metric to monitor
    patience=early_stopping_patience,  # Number of epochs with no improvement
    restore_best_weights=True  # Restore the weights of the best model
)


In [29]:

y_train = get_output_from_dataset(train_joint_dataset)
y_train_integers = np.argmax(y_train, axis=1)
class_weights = compute_class_weight(class_weight='balanced',classes=np.unique(y_train_integers),
                                     y = y_train_integers)
d_class_weights = dict(enumerate(class_weights))

history = model.fit(
    train_joint_dataset.batch(16),
    validation_data=val_joint_dataset.batch(16),
    class_weight = d_class_weights,
    epochs=epochs,
    callbacks=[early_stopping_callback]
)

2023-07-30 16:39:27.350915: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2023-07-30 16:39:27.353584: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2300000000 Hz


Epoch 1/100


2023-07-30 16:39:40.213232: W tensorflow/core/common_runtime/bfc_allocator.cc:433] Allocator (GPU_0_bfc) ran out of memory trying to allocate 9.00MiB (rounded to 9437184)requested by op Fill
Current allocation summary follows.
2023-07-30 16:39:40.213490: I tensorflow/core/common_runtime/bfc_allocator.cc:972] BFCAllocator dump for GPU_0_bfc
2023-07-30 16:39:40.213506: I tensorflow/core/common_runtime/bfc_allocator.cc:979] Bin (256): 	Total Chunks: 329, Chunks in use: 329. 82.2KiB allocated for chunks. 82.2KiB in use in bin. 34.0KiB client-requested in use in bin.
2023-07-30 16:39:40.213515: I tensorflow/core/common_runtime/bfc_allocator.cc:979] Bin (512): 	Total Chunks: 147, Chunks in use: 147. 73.5KiB allocated for chunks. 73.5KiB in use in bin. 73.5KiB client-requested in use in bin.
2023-07-30 16:39:40.213523: I tensorflow/core/common_runtime/bfc_allocator.cc:979] Bin (1024): 	Total Chunks: 906, Chunks in use: 906. 908.0KiB allocated for chunks. 908.0KiB in use in bin. 906.0KiB clien

ResourceExhaustedError: in user code:

    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/keras/engine/training.py:805 train_function  *
        return step_function(self, iterator)
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/keras/engine/training.py:795 step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:1259 run
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:2730 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:3417 _call_for_each_replica
        return fn(*args, **kwargs)
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/keras/engine/training.py:788 run_step  **
        outputs = model.train_step(data)
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/keras/engine/training.py:757 train_step
        self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:498 minimize
        return self.apply_gradients(grads_and_vars, name=name)
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:604 apply_gradients
        self._create_all_weights(var_list)
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:783 _create_all_weights
        self._create_slots(var_list)
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/keras/optimizer_v2/adam.py:127 _create_slots
        self.add_slot(var, 'm')
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:847 add_slot
        weight = tf_variables.Variable(
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/ops/variables.py:262 __call__
        return cls._variable_v2_call(*args, **kwargs)
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/ops/variables.py:244 _variable_v2_call
        return previous_getter(
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/ops/variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:3332 creator
        return next_creator(**kwargs)
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/ops/variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:3332 creator
        return next_creator(**kwargs)
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/ops/variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:3332 creator
        return next_creator(**kwargs)
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/ops/variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py:712 variable_capturing_scope
        v = UnliftedInitializerVariable(
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/ops/variables.py:264 __call__
        return super(VariableMetaclass, cls).__call__(*args, **kwargs)
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py:227 __init__
        initial_value = initial_value()
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/keras/initializers/initializers_v2.py:139 __call__
        return super(Zeros, self).__call__(shape, dtype=_get_dtype(dtype), **kwargs)
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/ops/init_ops_v2.py:154 __call__
        return array_ops.zeros(shape, dtype)
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py:201 wrapper
        return target(*args, **kwargs)
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/ops/array_ops.py:2819 wrapped
        tensor = fun(*args, **kwargs)
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/ops/array_ops.py:2880 zeros
        output = fill(shape, constant(zero, dtype=dtype), name=name)
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py:201 wrapper
        return target(*args, **kwargs)
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/ops/array_ops.py:239 fill
        result = gen_array_ops.fill(dims, value, name=name)
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/ops/gen_array_ops.py:3348 fill
        _ops.raise_from_not_ok_status(e, name)
    /glade/work/jhayron/conda-envs/cnn_wr/lib/python3.9/site-packages/tensorflow/python/framework/ops.py:6862 raise_from_not_ok_status
        six.raise_from(core._status_to_exception(e.code, message), None)
    <string>:3 raise_from
        

    ResourceExhaustedError: OOM when allocating tensor with shape[3,3,512,512] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc [Op:Fill]
