## A copy of notebook [Keras Baseline Seq2Seq](https://www.kaggle.com/code/enzosebiane/keras-baseline-seq2seq?scriptVersionId=183823137) from the author [Enzo Sebiane](https://www.kaggle.com/enzosebiane)

Changes made:
- SEED = 42 --> 2024
- epochs = 50 --> 100
- Added an additional build_cnn11 layer to the model.

In [None]:
import os
os.environ["KERAS_BACKEND"] = "jax"

import gc
import numpy as np
import pandas as pd
import polars as pl
import matplotlib.pyplot as plt

import tensorflow as tf
import jax
import keras

from sklearn import metrics

from tqdm.notebook import tqdm

print(tf.__version__)
print(jax.__version__)

In [None]:
def is_interactive():
    return 'runtime' in get_ipython().config.IPKernelApp.connection_file

print('Interactive?', is_interactive())

In [None]:
SEED = 42  # 42
keras.utils.set_random_seed(SEED)
tf.random.set_seed(SEED)
tf.config.experimental.enable_op_determinism()

In [None]:
TRAIN = False
#DATA_MODEL = "/kaggle/input/leap-climsim-keras-baseline-seq2seq-v01/01-keras-baseline-seq2seq.keras"
DATA_MODEL = "/kaggle/input/01-keras-baseline-seq2seq/04-keras-baseline-seq2seq.keras"
DATA = "/kaggle/input/leap-atmospheric-physics-ai-climsim"
DATA_TFREC = "/kaggle/input/leap-train-tfrecords"

In [None]:
sample = pl.read_csv(os.path.join(DATA, "sample_submission.csv"), n_rows=1)
TARGETS = sample.select(pl.exclude('sample_id')).columns
print(len(TARGETS))

In [None]:
def _parse_function(example_proto):
    feature_description = {
        'x': tf.io.FixedLenFeature([556], tf.float32),
        'targets': tf.io.FixedLenFeature([368], tf.float32)
    }
    e = tf.io.parse_single_example(example_proto, feature_description)
    return e['x'], e['targets']

In [None]:
train_files = [os.path.join(DATA_TFREC, "train_%.3d.tfrec" % i) for i in range(100)]
valid_files = [os.path.join(DATA_TFREC, "train_%.3d.tfrec" % i) for i in range(100, 101)]

In [None]:
BATCH_SIZE = 4096

train_options = tf.data.Options()
train_options.deterministic = True

ds_train = (
    tf.data.Dataset.from_tensor_slices(train_files)
    .with_options(train_options)
    .shuffle(100)
    .interleave(
        lambda file: tf.data.TFRecordDataset(file).map(_parse_function, num_parallel_calls=tf.data.AUTOTUNE),
        num_parallel_calls=tf.data.AUTOTUNE,
        cycle_length=10,
        block_length=1000,
        deterministic=True
    )
    .shuffle(4 * BATCH_SIZE)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)

ds_valid = (
    tf.data.TFRecordDataset(valid_files)
    .map(_parse_function)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)

In [None]:
norm_x = keras.layers.Normalization()
norm_x.adapt(ds_train.map(lambda x, y: x).take(20 if is_interactive() else 1000))

plt.scatter(
    norm_x.mean.squeeze(),
    norm_x.variance.squeeze() ** 0.5,
    marker=".",
    alpha=0.5
)
plt.xscale('log')
plt.yscale('log')

In [None]:
norm_y = keras.layers.Normalization()
norm_y.adapt(ds_train.map(lambda x, y: y).take(20 if is_interactive() else 1000))

mean_y = norm_y.mean
stdd_y = keras.ops.maximum(1e-10, norm_y.variance ** 0.5)

plt.scatter(
    mean_y.squeeze(),
    stdd_y.squeeze(),
    marker=".",
    alpha=0.5
)
plt.xscale('log')
plt.yscale('log')


In [None]:
min_y = np.min(np.stack([np.min(yb, 0) for _, yb in ds_train.take(20 if is_interactive() else 1000)], 0), 0, keepdims=True)
max_y = np.max(np.stack([np.max(yb, 0) for _, yb in ds_train.take(20 if is_interactive() else 1000)], 0), 0, keepdims=True)

### Model definition & Training

In [None]:
@keras.saving.register_keras_serializable(package="MyMetrics", name="ClippedR2Score")
class ClippedR2Score(keras.metrics.Metric):
    def __init__(self, name='r2_score', **kwargs):
        super().__init__(name=name, **kwargs)
        self.base_metric = keras.metrics.R2Score(class_aggregation=None)
        
    def update_state(self, y_true, y_pred, sample_weight=None):
        self.base_metric.update_state(y_true, y_pred, sample_weight=None)
        
    def result(self):
        return keras.ops.mean(keras.ops.clip(self.base_metric.result(), 0.0, 1.0))
        
    def reset_states(self):
        self.base_metric.reset_states()

In [None]:
epochs = 50  # 50 # 25  # 15  # 12
learning_rate = 1e-3

epochs_warmup = 1
epochs_ending = 2
steps_per_epoch = int(np.ceil(len(train_files) * 100_000 / BATCH_SIZE))

lr_scheduler = keras.optimizers.schedules.CosineDecay(
    1e-4, 
    (epochs - epochs_warmup - epochs_ending) * steps_per_epoch, 
    warmup_target=learning_rate,
    warmup_steps=steps_per_epoch * epochs_warmup,
    alpha=0.1
)

plt.plot([lr_scheduler(it) for it in range(0, epochs * steps_per_epoch, steps_per_epoch)]);

In [None]:
keras.utils.clear_session()


def x_to_seq(x):
    x_seq0 = keras.ops.transpose(keras.ops.reshape(x[:, 0:60 * 6], (-1, 6, 60)), (0, 2, 1))
    x_seq1 = keras.ops.transpose(keras.ops.reshape(x[:, 60 * 6 + 16:60 * 9 + 16], (-1, 3, 60)), (0, 2, 1))
    x_flat = keras.ops.reshape(x[:, 60 * 6:60 * 6 + 16], (-1, 1, 16))
    x_flat = keras.ops.repeat(x_flat, 60, axis=1)
    return keras.ops.concatenate([x_seq0, x_seq1, x_flat], axis=-1)


def build_cnn(activation='relu'):
    return keras.Sequential([
        # Première couche Conv1D
        keras.layers.Conv1D(256, 3, padding='same', activation=activation),
        keras.layers.BatchNormalization(),
        # Deuxième couche Conv1D
        keras.layers.Conv1D(128, 3, padding='same', activation=activation),
        keras.layers.BatchNormalization(),
        # Troisième couche Conv1D
        keras.layers.Conv1D(64, 3, padding='same', activation=activation),
        keras.layers.BatchNormalization(),
        # Couche Dropout pour la régularisation
        keras.layers.Dropout(0.3),
        # Couche LSTM pour capturer les dépendances temporelles
        keras.layers.LSTM(64, return_sequences=True),
        keras.layers.BatchNormalization(),
        # Couche GRU pour capturer les dépendances temporelles
        keras.layers.GRU(64, return_sequences=True),
        keras.layers.BatchNormalization(),
        # Couche Dense finale pour ajuster la sortie à la taille 64
        keras.layers.Dense(64, activation=activation),
    ])

def build_cnn11(activation='relu'):
    return keras.Sequential([
        # Première couche Conv1D
        keras.layers.Conv1D(256, 11, padding='same', activation=activation),
        keras.layers.BatchNormalization(),
        # Deuxième couche Conv1D
        keras.layers.Conv1D(128, 11, padding='same', activation=activation),
        keras.layers.BatchNormalization(),
        # Troisième couche Conv1D
        keras.layers.Conv1D(64, 11, padding='same', activation=activation),
        keras.layers.BatchNormalization(),
        # Couche Dropout pour la régularisation
        keras.layers.Dropout(0.3),
        # Couche LSTM pour capturer les dépendances temporelles
        keras.layers.LSTM(64, return_sequences=True),
        keras.layers.BatchNormalization(),
        # Couche GRU pour capturer les dépendances temporelles
        keras.layers.GRU(64, return_sequences=True),
        keras.layers.BatchNormalization(),
        # Couche Dense finale pour ajuster la sortie à la taille 64
        keras.layers.Dense(64, activation=activation),
    ])

X_input = x = keras.layers.Input(ds_train.element_spec[0].shape[1:])
x = keras.layers.Normalization(mean=norm_x.mean, variance=norm_x.variance)(x)
x = x_to_seq(x)

e11 = e = e0 = keras.layers.Conv1D(64, 1, padding='same')(x)
e = build_cnn()(e)
e11 = build_cnn11()(e11)
# add global average to allow some comunication between all levels even in a small CNN
e = e0 + e + e11 + keras.layers.GlobalAveragePooling1D(keepdims=True)(e)
e = keras.layers.BatchNormalization()(e)
e = e + build_cnn()(e)

p_all = keras.layers.Conv1D(14, 1, padding='same')(e)

p_seq = p_all[:, :, :6]
p_seq = keras.ops.transpose(p_seq, (0, 2, 1))
p_seq = keras.layers.Flatten()(p_seq)
assert p_seq.shape[-1] == 360

p_flat = p_all[:, :, 6:6 + 8]
p_flat = keras.ops.mean(p_flat, axis=1)
assert p_flat.shape[-1] == 8

P = keras.ops.concatenate([p_seq, p_flat], axis=1)

# build & compile
model = keras.Model(X_input, P)
model.compile(
    loss='mse', 
    optimizer=keras.optimizers.Adam(lr_scheduler),
    metrics=[ClippedR2Score()]
)
model.build(tuple(ds_train.element_spec[0].shape))
model.summary()

In [None]:
if TRAIN:
    ds_train_target_normalized = ds_train.map(lambda x, y: (x, (y - mean_y) / stdd_y))
    ds_valid_target_normalized = ds_valid.map(lambda x, y: (x, (y - mean_y) / stdd_y))

    history = model.fit(
        ds_train_target_normalized,
        validation_data=ds_valid_target_normalized,
        epochs=epochs,
        verbose=1 if is_interactive() else 2,
        callbacks=[
            keras.callbacks.ModelCheckpoint(filepath='model.keras')
        ]
    )

else:
    model = keras.models.load_model(DATA_MODEL)

In [None]:
'''
Epoch 1/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 389s 78ms/step - loss: 0.5905 - r2_score: 0.2234 - val_loss: 0.3442 - val_r2_score: 0.3814
Epoch 2/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 374s 76ms/step - loss: 0.3988 - r2_score: 0.3838 - val_loss: 0.3066 - val_r2_score: 0.4247
Epoch 3/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.3694 - r2_score: 0.4038 - val_loss: 0.2923 - val_r2_score: 0.4401
Epoch 4/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.3866 - r2_score: 0.4117 - val_loss: 0.2835 - val_r2_score: 0.4508
Epoch 5/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.3485 - r2_score: 0.4260 - val_loss: 0.2887 - val_r2_score: 0.4450
Epoch 6/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.3579 - r2_score: 0.4276 - val_loss: 0.2865 - val_r2_score: 0.4490
Epoch 7/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.3357 - r2_score: 0.4374 - val_loss: 0.2782 - val_r2_score: 0.4578
Epoch 8/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.3228 - r2_score: 0.4458 - val_loss: 0.2865 - val_r2_score: 0.4473
Epoch 9/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.3289 - r2_score: 0.4444 - val_loss: 0.2946 - val_r2_score: 0.4382
Epoch 10/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.3393 - r2_score: 0.4454 - val_loss: 0.2650 - val_r2_score: 0.4728
Epoch 11/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.3197 - r2_score: 0.4537 - val_loss: 0.2644 - val_r2_score: 0.4739
Epoch 12/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.3339 - r2_score: 0.4472 - val_loss: 0.2693 - val_r2_score: 0.4700
Epoch 13/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.3142 - r2_score: 0.4545 - val_loss: 0.2679 - val_r2_score: 0.4688
Epoch 14/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.3372 - r2_score: 0.4520 - val_loss: 0.2623 - val_r2_score: 0.4758
Epoch 15/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.3208 - r2_score: 0.4557 - val_loss: 0.2628 - val_r2_score: 0.4748
Epoch 16/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.3082 - r2_score: 0.4627 - val_loss: 0.2675 - val_r2_score: 0.4689
Epoch 17/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.3185 - r2_score: 0.4597 - val_loss: 0.2701 - val_r2_score: 0.4689
Epoch 18/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.3033 - r2_score: 0.4654 - val_loss: 0.2651 - val_r2_score: 0.4727
Epoch 19/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.3021 - r2_score: 0.4673 - val_loss: 0.2629 - val_r2_score: 0.4752
Epoch 20/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 374s 76ms/step - loss: 0.3064 - r2_score: 0.4639 - val_loss: 0.2560 - val_r2_score: 0.4821
Epoch 21/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.3020 - r2_score: 0.4685 - val_loss: 0.2580 - val_r2_score: 0.4814
Epoch 22/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.3014 - r2_score: 0.4679 - val_loss: 0.2551 - val_r2_score: 0.4838
Epoch 23/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.3008 - r2_score: 0.4691 - val_loss: 0.2526 - val_r2_score: 0.4865
Epoch 24/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2975 - r2_score: 0.4732 - val_loss: 0.2608 - val_r2_score: 0.4775
Epoch 25/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2907 - r2_score: 0.4750 - val_loss: 0.2518 - val_r2_score: 0.4875
Epoch 26/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2969 - r2_score: 0.4762 - val_loss: 0.2520 - val_r2_score: 0.4873
Epoch 27/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.3044 - r2_score: 0.4738 - val_loss: 0.2717 - val_r2_score: 0.4649
Epoch 28/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2918 - r2_score: 0.4775 - val_loss: 0.2511 - val_r2_score: 0.4891
Epoch 29/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2829 - r2_score: 0.4781 - val_loss: 0.2797 - val_r2_score: 0.4581
Epoch 30/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2902 - r2_score: 0.4787 - val_loss: 0.2548 - val_r2_score: 0.4839
Epoch 31/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2887 - r2_score: 0.4793 - val_loss: 0.2520 - val_r2_score: 0.4855
Epoch 32/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 374s 77ms/step - loss: 0.2960 - r2_score: 0.4782 - val_loss: 0.2475 - val_r2_score: 0.4921
Epoch 33/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2865 - r2_score: 0.4780 - val_loss: 0.2595 - val_r2_score: 0.4792
Epoch 34/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2783 - r2_score: 0.4846 - val_loss: 0.2539 - val_r2_score: 0.4850
Epoch 35/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2732 - r2_score: 0.4846 - val_loss: 0.2504 - val_r2_score: 0.4897
Epoch 36/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2779 - r2_score: 0.4846 - val_loss: 0.2537 - val_r2_score: 0.4848
Epoch 37/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2812 - r2_score: 0.4865 - val_loss: 0.2470 - val_r2_score: 0.4935
Epoch 38/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2882 - r2_score: 0.4848 - val_loss: 0.2450 - val_r2_score: 0.4951
Epoch 39/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2817 - r2_score: 0.4881 - val_loss: 0.2475 - val_r2_score: 0.4937
Epoch 40/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2778 - r2_score: 0.4895 - val_loss: 0.2450 - val_r2_score: 0.4957
Epoch 41/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2760 - r2_score: 0.4895 - val_loss: 0.2517 - val_r2_score: 0.4898
Epoch 42/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2713 - r2_score: 0.4936 - val_loss: 0.2416 - val_r2_score: 0.5003
Epoch 43/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2726 - r2_score: 0.4943 - val_loss: 0.2599 - val_r2_score: 0.4777
Epoch 44/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2701 - r2_score: 0.4937 - val_loss: 0.2406 - val_r2_score: 0.5009
Epoch 45/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2736 - r2_score: 0.4946 - val_loss: 0.2476 - val_r2_score: 0.4934
Epoch 46/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2715 - r2_score: 0.4945 - val_loss: 0.2385 - val_r2_score: 0.5033
Epoch 47/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2763 - r2_score: 0.4951 - val_loss: 0.2401 - val_r2_score: 0.5019
Epoch 48/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2671 - r2_score: 0.4957 - val_loss: 0.2415 - val_r2_score: 0.5004
Epoch 49/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2730 - r2_score: 0.4952 - val_loss: 0.2382 - val_r2_score: 0.5038
Epoch 50/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2717 - r2_score: 0.4963 - val_loss: 0.2352 - val_r2_score: 0.5060
Epoch 51/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2726 - r2_score: 0.4974 - val_loss: 0.2363 - val_r2_score: 0.5057
Epoch 52/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2635 - r2_score: 0.4991 - val_loss: 0.2409 - val_r2_score: 0.5004
Epoch 53/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2603 - r2_score: 0.4991 - val_loss: 0.2426 - val_r2_score: 0.4999
Epoch 54/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2644 - r2_score: 0.5010 - val_loss: 0.2359 - val_r2_score: 0.5067
Epoch 55/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2646 - r2_score: 0.5008 - val_loss: 0.2347 - val_r2_score: 0.5079
Epoch 56/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2645 - r2_score: 0.5003 - val_loss: 0.2369 - val_r2_score: 0.5060
Epoch 57/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2641 - r2_score: 0.5022 - val_loss: 0.2363 - val_r2_score: 0.5064
Epoch 58/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2556 - r2_score: 0.5035 - val_loss: 0.2332 - val_r2_score: 0.5095
Epoch 59/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2619 - r2_score: 0.5032 - val_loss: 0.2359 - val_r2_score: 0.5068
Epoch 60/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2634 - r2_score: 0.5038 - val_loss: 0.2331 - val_r2_score: 0.5096
Epoch 61/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2611 - r2_score: 0.5030 - val_loss: 0.2329 - val_r2_score: 0.5100
Epoch 62/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2541 - r2_score: 0.5052 - val_loss: 0.2316 - val_r2_score: 0.5114
Epoch 63/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2616 - r2_score: 0.5050 - val_loss: 0.2327 - val_r2_score: 0.5100
Epoch 64/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2606 - r2_score: 0.5060 - val_loss: 0.2314 - val_r2_score: 0.5116
Epoch 65/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2530 - r2_score: 0.5068 - val_loss: 0.2317 - val_r2_score: 0.5117
Epoch 66/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2559 - r2_score: 0.5087 - val_loss: 0.2310 - val_r2_score: 0.5121
Epoch 67/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2632 - r2_score: 0.5067 - val_loss: 0.2312 - val_r2_score: 0.5121
Epoch 68/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2540 - r2_score: 0.5084 - val_loss: 0.2316 - val_r2_score: 0.5102
Epoch 69/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2610 - r2_score: 0.5079 - val_loss: 0.2313 - val_r2_score: 0.5118
Epoch 70/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2543 - r2_score: 0.5087 - val_loss: 0.2317 - val_r2_score: 0.5115
Epoch 71/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2597 - r2_score: 0.5083 - val_loss: 0.2318 - val_r2_score: 0.5115
Epoch 72/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2548 - r2_score: 0.5090 - val_loss: 0.2312 - val_r2_score: 0.5122
Epoch 73/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2557 - r2_score: 0.5092 - val_loss: 0.2301 - val_r2_score: 0.5134
Epoch 74/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2508 - r2_score: 0.5104 - val_loss: 0.2294 - val_r2_score: 0.5140
Epoch 75/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2524 - r2_score: 0.5098 - val_loss: 0.2289 - val_r2_score: 0.5143
Epoch 76/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2544 - r2_score: 0.5108 - val_loss: 0.2285 - val_r2_score: 0.5149
Epoch 77/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 374s 76ms/step - loss: 0.2551 - r2_score: 0.5129 - val_loss: 0.2290 - val_r2_score: 0.5146
Epoch 78/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2533 - r2_score: 0.5110 - val_loss: 0.2286 - val_r2_score: 0.5149
Epoch 79/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2521 - r2_score: 0.5127 - val_loss: 0.2288 - val_r2_score: 0.5148
Epoch 80/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2571 - r2_score: 0.5111 - val_loss: 0.2287 - val_r2_score: 0.5150
Epoch 81/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2541 - r2_score: 0.5129 - val_loss: 0.2288 - val_r2_score: 0.5147
Epoch 82/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2581 - r2_score: 0.5107 - val_loss: 0.2275 - val_r2_score: 0.5163
Epoch 83/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2489 - r2_score: 0.5151 - val_loss: 0.2280 - val_r2_score: 0.5156
Epoch 84/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2518 - r2_score: 0.5141 - val_loss: 0.2279 - val_r2_score: 0.5156
Epoch 85/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2523 - r2_score: 0.5139 - val_loss: 0.2291 - val_r2_score: 0.5147
Epoch 86/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2496 - r2_score: 0.5155 - val_loss: 0.2283 - val_r2_score: 0.5154
Epoch 87/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2519 - r2_score: 0.5132 - val_loss: 0.2281 - val_r2_score: 0.5154
Epoch 88/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2490 - r2_score: 0.5145 - val_loss: 0.2277 - val_r2_score: 0.5162
Epoch 89/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2555 - r2_score: 0.5140 - val_loss: 0.2291 - val_r2_score: 0.5147
Epoch 90/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2477 - r2_score: 0.5149 - val_loss: 0.2271 - val_r2_score: 0.5167
Epoch 91/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2541 - r2_score: 0.5152 - val_loss: 0.2278 - val_r2_score: 0.5160
Epoch 92/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2542 - r2_score: 0.5143 - val_loss: 0.2272 - val_r2_score: 0.5166
Epoch 93/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2494 - r2_score: 0.5141 - val_loss: 0.2267 - val_r2_score: 0.5171
Epoch 94/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 372s 76ms/step - loss: 0.2530 - r2_score: 0.5133 - val_loss: 0.2270 - val_r2_score: 0.5168
Epoch 95/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2519 - r2_score: 0.5155 - val_loss: 0.2273 - val_r2_score: 0.5165
Epoch 96/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2486 - r2_score: 0.5148 - val_loss: 0.2272 - val_r2_score: 0.5166
Epoch 97/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2526 - r2_score: 0.5141 - val_loss: 0.2279 - val_r2_score: 0.5161
Epoch 98/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2514 - r2_score: 0.5162 - val_loss: 0.2269 - val_r2_score: 0.5172
Epoch 99/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 373s 76ms/step - loss: 0.2534 - r2_score: 0.5155 - val_loss: 0.2268 - val_r2_score: 0.5169
Epoch 100/100
4883/4883 ━━━━━━━━━━━━━━━━━━━━ 372s 76ms/step - loss: 0.2465 - r2_score: 0.5149 - val_loss: 0.2265 - val_r2_score: 0.5174
'''

In [None]:
if TRAIN:
    plt.plot(history.history['loss'], color='tab:blue')
    plt.plot(history.history['val_loss'], color='tab:red')
    plt.yscale('log');

In [None]:
y_valid = np.concatenate([yb for _, yb in ds_valid])
p_valid = model.predict(ds_valid, batch_size=BATCH_SIZE) * stdd_y + mean_y

In [None]:
scores_valid = np.array([metrics.r2_score(y_valid[:, i], p_valid[:, i]) for i in range(len(TARGETS))])
plt.plot(scores_valid.clip(-1, 1))

In [None]:
mask = scores_valid <= 1e-3
f"Number of under-performing targets: {sum(mask)}"

In [None]:
f"Clipped score: {scores_valid.clip(0, 1).mean()}"

In [None]:
del y_valid, p_valid
gc.collect();

# Submission

In [None]:
sample = pl.read_csv("/kaggle/input/leap-atmospheric-physics-ai-climsim/sample_submission.csv")

In [None]:
df_test = (
    pl.scan_csv("/kaggle/input/leap-atmospheric-physics-ai-climsim/test.csv")
    .select(pl.exclude("sample_id"))
    .cast(pl.Float32)
    .collect()
)

In [None]:
p_test = model.predict(df_test.to_numpy(), batch_size=4 * BATCH_SIZE) * stdd_y + mean_y
p_test = np.array(p_test)
p_test[:, mask] = mean_y[:, mask]

In [None]:
# correction of ptend_q0002 targets (from 12 to 29)
df_p_test = pd.DataFrame(p_test, columns=TARGETS)

for idx in range(12, 30):
    df_p_test[f"ptend_q0002_{idx}"] = -df_test[f"state_q0002_{idx}"].to_numpy() / 1200
    
p_test = df_p_test.values

In [None]:
submission = sample.to_pandas()
submission[TARGETS] = submission[TARGETS] * p_test
pl.from_pandas(submission[["sample_id"] + TARGETS]).write_csv("submission.csv")