In [1]:
import os
# os.environ["KERAS_BACKEND"] = "jax"

import gc
import numpy as np
import pandas as pd
import polars as pl
import matplotlib.pyplot as plt

import tensorflow as tf
# import jax
import keras

from sklearn import metrics

from tqdm.notebook import tqdm
import random

print(tf.__version__)
# print(jax.__version__)

2.16.1


In [2]:
def is_interactive():
    return 'runtime' in get_ipython().config.IPKernelApp.connection_file

print('Interactive?', is_interactive())

Interactive? True


In [3]:
SEED = 42
keras.utils.set_random_seed(SEED)
tf.random.set_seed(SEED)
tf.config.experimental.enable_op_determinism()

In [4]:
# DATA = "/kaggle/input/leap-atmospheric-physics-ai-climsim"
DATA_NPY = "data"

In [5]:
sample = pl.read_csv('sample_submission.csv', n_rows=1)
TARGETS = sample.select(pl.exclude('sample_id')).columns
print(len(TARGETS))

368


In [6]:
BATCH_SIZE = 4096
DATA_TFREC = "data"
# create list of all tfrec files in folder and sub folders
TFREC_FILES = tf.io.gfile.glob(DATA_TFREC + '/**/*.tfrecord')
print(len(TFREC_FILES))
train_files = TFREC_FILES[1:46]
valid_files = TFREC_FILES[46:48]
train_options = tf.data.Options()
train_options.deterministic = True

def _parse_function(example_proto):
    feature_description = {
        'x': tf.io.FixedLenFeature([556], tf.float32),
        'targets': tf.io.FixedLenFeature([368], tf.float32)
    }
    e = tf.io.parse_single_example(example_proto, feature_description)
    return e['x'], e['targets'][240:360]

ds_train = (
  tf.data.TFRecordDataset(train_files, compression_type="GZIP")
    .with_options(train_options)
    .shuffle(100)
  .map(_parse_function, num_parallel_calls=tf.data.AUTOTUNE)
    .shuffle(4 * BATCH_SIZE)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)

ds_valid = (
    tf.data.TFRecordDataset(valid_files, compression_type="GZIP")
    .map(_parse_function)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
    )

47


In [7]:
norm_x = keras.layers.Normalization()
norm_x.adapt(ds_train.map(lambda x, y: x).take(1000))

2024-06-19 22:06:32.390647: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [8]:
norm_y = keras.layers.Normalization()
norm_y.adapt(ds_train.map(lambda x, y: y).take(1000))

mean_y = norm_y.mean
stdd_y =keras.ops.maximum(1e-10, norm_y.variance ** 0.5)

2024-06-19 22:08:58.066518: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [9]:
# mean_y = tf.cast(mean_y, tf.float64)
# stdd_y = tf.cast(stdd_y, tf.float64)

In [10]:
# min_y = np.min(np.stack([np.min(yb, 0) for _, yb in ds_train.take(1000)], 0), 0, keepdims=True)
# max_y = np.max(np.stack([np.max(yb, 0) for _, yb in ds_train.take(1000)], 0), 0, keepdims=True)

In [11]:
# min_y = tf.cast(min_y, tf.float64)
# max_y = tf.cast(max_y, tf.float64)

### Model definition & Training

In [12]:
@keras.saving.register_keras_serializable(package="MyMetrics", name="ClippedR2Score")
class ClippedR2Score(keras.metrics.Metric):
    def __init__(self, name='r2_score', **kwargs):
        super().__init__(name=name, **kwargs)
        self.base_metric = keras.metrics.R2Score(class_aggregation=None)
        
    def update_state(self, y_true, y_pred, sample_weight=None):
        self.base_metric.update_state(y_true, y_pred, sample_weight=None)
        
    def result(self):
        return keras.ops.mean(keras.ops.clip(self.base_metric.result(), 0.0, 1.0))
        
    def reset_states(self):
        self.base_metric.reset_states()

In [13]:
epochs = 12
learning_rate = 1e-3

epochs_warmup = 1
epochs_ending = 2
steps_per_epoch = int(np.ceil(len(train_files) * 100_000 / BATCH_SIZE))

lr_scheduler = tf.keras.optimizers.schedules.CosineDecay(
    1e-4,
    (epochs - epochs_warmup - epochs_ending) * steps_per_epoch, 
    warmup_target=learning_rate,
    warmup_steps=steps_per_epoch * epochs_warmup,
    alpha=0.1
)

In [14]:
import keras
import tensorflow as tf

keras.utils.clear_session()

def x_to_seq(x):
    x_seq0 = keras.ops.transpose(keras.ops.reshape(x[:, 0:60 * 6], (-1, 6, 60)), (0, 2, 1))
    x_seq1 = keras.ops.transpose(keras.ops.reshape(x[:, 60 * 6 + 16:60 * 9 + 16], (-1, 3, 60)), (0, 2, 1))
    x_flat = keras.ops.reshape(x[:, 60 * 6:60 * 6 + 16], (-1, 1, 16))
    x_flat = keras.ops.repeat(x_flat, 60, axis=1)
    # x_flat2 = keras.ops.reshape(x[:, 556:], (-1, 1, 19))
    # x_flat2 = keras.ops.repeat(x_flat2, 60, axis=1)
    return keras.ops.concatenate([x_seq0, x_seq1, x_flat], axis=-1)

def build_cnn(activation='relu'):    
    return keras.Sequential([
        # keras.layers.Conv1D(1024, 3, padding='same', activation=activation),
        keras.layers.Conv1D(512, 3, padding='same', activation=activation),
        # keras.layers.BatchNormalization(),
        keras.layers.Conv1D(256, 3, padding='same', activation=activation),
        # keras.layers.BatchNormalization(),
        keras.layers.Conv1D(128, 3, padding='same', activation=activation),
        # keras.layers.BatchNormalization(),
        keras.layers.Conv1D(164, 3, padding='same', activation=activation),
        # keras.layers.BatchNormalization(),
    ])

# Current track
X_input = x = keras.layers.Input(shape=(556,))
x = keras.layers.Normalization(mean=norm_x.mean, variance=norm_x.variance)(x)
x_seq = x_to_seq(x)

e = e0 = keras.layers.Conv1D(164, 1, padding='same')(x_seq)
e = build_cnn()(e)
# Add global average to allow some communication between all levels even in a small CNN
e = e0 + e + keras.layers.GlobalAveragePooling1D(keepdims=True)(e)
# e = keras.layers.BatchNormalization()(e)
e = e + build_cnn()(e)

p_all = keras.layers.Conv1D(2, 1, padding='same')(e)

# p_seq = p_all[:, :, :6]
p_seq = keras.ops.transpose(p_all, (0, 2, 1))
p_seq = keras.layers.Flatten()(p_seq)
assert p_seq.shape[-1] == 120

# p_flat = p_all[:, :, 6:6 + 8+19]
# p_flat = keras.ops.mean(p_flat, axis=1)
# print(p_flat.shape)
# assert p_flat.shape[-1] == 27

# combined = keras.layers.concatenate([p_seq, p_flat], axis=1)

# shallow neural network combining cnn and input
# nn = keras.layers.Dense(60, activation='relu')(p_seq)
# nn = keras.layers.Dense(750, activation='relu')(nn)
# nn = keras.layers.Dense(500, activation='relu')(nn)
# output = keras.layers.Dense(350, activation='linear')(nn)
output = keras.layers.Conv1D(1, 1, padding='same')(p_seq)

# Build & compile
model = keras.Model(inputs=X_input, outputs=output)
model.compile(
    loss='mse', 
    optimizer=keras.optimizers.Adam(lr_scheduler),
    metrics=[ClippedR2Score()]
)

model.build(input_shape=(None, 556))
model.summary()

In [15]:
# model_path = '/kaggle/input/leapseq2seq/model.keras'
# model = tf.keras.models.load_model(model_path)

In [16]:
ds_train_target_normalized = ds_train.map(lambda x, y: (x, (y - mean_y) / stdd_y))
ds_valid_target_normalized = ds_valid.map(lambda x, y: (x, (y - mean_y) / stdd_y))

history = model.fit(
    ds_train_target_normalized,
    validation_data=ds_valid_target_normalized,
    epochs=epochs,
    verbose=1 if is_interactive() else 2,
    callbacks=[
        keras.callbacks.ModelCheckpoint(filepath='models/cnn3_extra.keras'),
        # keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=2, verbose=1, monitor='val_r2_score', mode='max', min_lr=1e-6)
    ]
)

Epoch 1/12
   9234/Unknown [1m13028s[0m 1s/step - loss: 0.4239 - r2_score: 0.4047

2024-06-20 01:46:06.631734: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-06-20 01:46:06.631793: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[StatefulPartitionedCall/adam/add_20/_60]]
  self.gen.throw(typ, value, traceback)


[1m9234/9234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13127s[0m 1s/step - loss: 0.4238 - r2_score: 0.4048 - val_loss: 0.3048 - val_r2_score: 0.5212
Epoch 2/12


2024-06-20 01:47:45.719582: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-06-20 01:47:45.719612: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_2]]


[1m9234/9234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - loss: 0.2640 - r2_score: 0.5509

2024-06-20 05:21:18.912836: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-06-20 05:21:18.912867: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_44]]
2024-06-20 05:22:57.389624: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-06-20 05:22:57.389638: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IsInf/_18]]


[1m9234/9234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12912s[0m 1s/step - loss: 0.2640 - r2_score: 0.5509 - val_loss: 0.2759 - val_r2_score: 0.5435
Epoch 3/12
[1m9234/9234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - loss: 0.2503 - r2_score: 0.5638

2024-06-20 08:57:07.502937: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-06-20 08:57:07.502957: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_44]]
2024-06-20 08:58:45.594127: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-06-20 08:58:45.594153: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[Shape/_4]]


[1m9234/9234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12948s[0m 1s/step - loss: 0.2503 - r2_score: 0.5638 - val_loss: 0.2668 - val_r2_score: 0.5520
Epoch 4/12
[1m9234/9234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - loss: 0.2427 - r2_score: 0.5709

2024-06-20 12:45:30.455145: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-06-20 12:45:30.455205: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[StatefulPartitionedCall/adam/add_20/_60]]
2024-06-20 12:47:20.263417: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-06-20 12:47:20.263436: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IsInf/_18]]


[1m9234/9234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13715s[0m 1s/step - loss: 0.2427 - r2_score: 0.5709 - val_loss: 0.2666 - val_r2_score: 0.5538
Epoch 5/12
[1m9234/9234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - loss: 0.2372 - r2_score: 0.5761

2024-06-20 16:37:03.289696: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-06-20 16:37:03.289720: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[StatefulPartitionedCall/adam/add_22/_62]]
2024-06-20 16:38:48.651804: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-06-20 16:38:48.651834: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IsInf/_18]]


[1m9234/9234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13888s[0m 2s/step - loss: 0.2372 - r2_score: 0.5761 - val_loss: 0.2572 - val_r2_score: 0.5610
Epoch 6/12
[1m9234/9234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - loss: 0.2332 - r2_score: 0.5800

2024-06-20 20:29:53.893506: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-06-20 20:29:53.893551: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[StatefulPartitionedCall/adam/add_40/_46]]
2024-06-20 20:31:38.143794: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-06-20 20:31:38.143809: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_2]]


[1m9234/9234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13969s[0m 2s/step - loss: 0.2332 - r2_score: 0.5800 - val_loss: 0.2518 - val_r2_score: 0.5654
Epoch 7/12
[1m9234/9234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - loss: 0.2297 - r2_score: 0.5832

2024-06-21 00:05:04.256301: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-06-21 00:05:04.256320: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[StatefulPartitionedCall/adam/add_22/_62]]
2024-06-21 00:06:36.879511: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-06-21 00:06:36.879530: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_2]]


[1m9234/9234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12899s[0m 1s/step - loss: 0.2297 - r2_score: 0.5832 - val_loss: 0.2484 - val_r2_score: 0.5685
Epoch 8/12
[1m9234/9234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 991ms/step - loss: 0.2269 - r2_score: 0.5859

2024-06-21 02:39:06.229819: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-06-21 02:39:06.229838: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[StatefulPartitionedCall/adam/add_24/_64]]


[1m9234/9234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9215s[0m 998ms/step - loss: 0.2269 - r2_score: 0.5859 - val_loss: 0.2433 - val_r2_score: 0.5720
Epoch 9/12


2024-06-21 02:40:11.817107: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-06-21 02:40:11.817125: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[Size/_6]]


[1m9234/9234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 848ms/step - loss: 0.2245 - r2_score: 0.5881

2024-06-21 04:50:41.002042: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-06-21 04:50:41.002065: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[StatefulPartitionedCall/adam/add_20/_60]]


[1m9234/9234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7894s[0m 855ms/step - loss: 0.2245 - r2_score: 0.5881 - val_loss: 0.2426 - val_r2_score: 0.5734
Epoch 10/12


2024-06-21 04:51:45.626578: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-06-21 04:51:45.626591: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_2]]


[1m9234/9234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 848ms/step - loss: 0.2223 - r2_score: 0.5903

2024-06-21 07:02:20.036892: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-06-21 07:02:20.036915: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[StatefulPartitionedCall/adam/add_22/_62]]


[1m9234/9234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7899s[0m 855ms/step - loss: 0.2223 - r2_score: 0.5903 - val_loss: 0.2411 - val_r2_score: 0.5745
Epoch 11/12


2024-06-21 07:03:25.102055: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-06-21 07:03:25.102086: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_2]]


[1m9234/9234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 848ms/step - loss: 0.2204 - r2_score: 0.5921

2024-06-21 09:13:52.525366: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-06-21 09:13:52.525387: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[StatefulPartitionedCall/adam/add_22/_62]]


[1m9234/9234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7893s[0m 855ms/step - loss: 0.2204 - r2_score: 0.5921 - val_loss: 0.2390 - val_r2_score: 0.5762
Epoch 12/12


2024-06-21 09:14:57.650646: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-06-21 09:14:57.650667: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_2]]


[1m9234/9234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 917ms/step - loss: 0.2188 - r2_score: 0.5936

2024-06-21 11:36:09.839034: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-06-21 11:36:09.839051: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[StatefulPartitionedCall/adam/add_22/_62]]


[1m9234/9234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8546s[0m 925ms/step - loss: 0.2188 - r2_score: 0.5936 - val_loss: 0.2366 - val_r2_score: 0.5784


2024-06-21 11:37:23.716900: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
2024-06-21 11:37:23.716917: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IsInf/_18]]


In [1]:
plt.plot(history.history['loss'], color='tab:blue')
plt.plot(history.history['val_loss'], color='tab:red')
plt.xlim(1, epochs)
plt.yscale('log');

NameError: name 'plt' is not defined

In [18]:
y_valid = np.concatenate([yb for _, yb in ds_valid])
p_valid = model.predict(ds_valid, batch_size=BATCH_SIZE) * stdd_y + mean_y

2024-06-21 11:37:46.181011: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


[1m210/210[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m69s[0m 330ms/step


  self.gen.throw(typ, value, traceback)
2024-06-21 11:38:55.523725: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [19]:
p_valid.numpy()[0]

array([-1.8088508e-12,  1.0556233e-12, -1.9225639e-12,  9.6331688e-13,
       -2.5689341e-13, -2.3366053e-12, -1.2194444e-12, -3.7685324e-12,
       -3.5009552e-12, -1.0893185e-12,  3.1943136e-13, -1.6524829e-13,
        1.0486838e-09, -1.9031741e-09,  4.4999716e-08,  7.2476691e-09,
        6.1228761e-08,  2.9778428e-07,  6.8113218e-07,  2.7952956e-07,
        6.2085235e-09,  6.4804822e-06, -1.4421671e-05,  4.8758061e-06,
       -8.8860577e-07, -4.2299223e-07, -9.6962276e-07, -8.6652017e-07,
       -4.3466406e-07, -5.2529919e-07, -2.3073341e-07,  8.7964395e-08,
        1.2671904e-07,  2.8674899e-07,  3.6867334e-07,  4.3405191e-07,
        3.0297915e-07,  1.6084070e-07,  2.8657325e-07,  3.1572733e-07,
        6.8252575e-08,  1.7916852e-07, -2.3450309e-07, -5.5631369e-07,
       -9.4659777e-07, -1.3160350e-06, -1.3508510e-06,  5.9617150e-08,
        6.9474186e-07, -1.0056681e-06,  1.9624435e-06,  2.4372134e-06,
       -4.6854802e-07,  1.5955166e-06,  2.1511621e-06,  5.6719500e-07,
      

In [20]:
scores_valid = np.array([metrics.r2_score(y_valid[:, i], p_valid[:, i]) for i in range(len(TARGETS))])
plt.plot(scores_valid.clip(-1, 1))

ValueError: Input arrays use different devices: cpu, /job:localhost/replica:0/task:0/device:GPU:0

In [None]:
mask = scores_valid <= 1e-3
f"Number of under-performing targets: {sum(mask)}"

In [None]:
f"Clipped score: {scores_valid.clip(0, 1).mean()}"

In [None]:
del y_valid, p_valid
gc.collect();