### Import libraries

In [228]:
from model import climate_model, losses, dot_prod_attention
from data import data_generation, data_combine, batch_creator, gp_kernels
from keras.callbacks import ModelCheckpoint
from helpers import helpers, masks
from inference import infer
import matplotlib.pyplot as plt
import tensorflow_addons as tfa
import tensorflow as tf
import numpy as np
import matplotlib 
import time
import keras
np.random.seed(1)

In [216]:
save_dir = '/Users/omernivron/Downloads/GPT_climate'

In [187]:
temp, t, token = data_combine.climate_data_to_model_input('./data/t2m_monthly_averaged_ensemble_members_1989_2019.csv')

In [190]:
## create climate train/test split

In [210]:
time_tr = t[:8000]; temp_tr = temp[:8000]; token_tr = token[:8000]
time_te = t[8000:]; temp_te = temp[8000:]; token_te = token[8000:]

In [234]:
loss_object = tf.keras.losses.MeanSquaredError()
train_loss = tf.keras.metrics.Mean(name='train_loss')
test_loss = tf.keras.metrics.Mean(name='test_loss')
m_tr = tf.keras.metrics.Mean()
m_te = tf.keras.metrics.Mean()

In [235]:
@tf.function
def train_step(decoder, optimizer_c, token_pos, time_pos, tar, pos_mask):
    '''
    A typical train step function for TF2. Elements which we wish to track their gradient
    has to be inside the GradientTape() clause. see (1) https://www.tensorflow.org/guide/migrate 
    (2) https://www.tensorflow.org/tutorials/quickstart/advanced
    ------------------
    Parameters:
    pos (np array): array of positions (x values) - the 1st/2nd output from data_generator_for_gp_mimick_gpt
    tar (np array): array of targets. Notice that if dealing with sequnces, we typically want to have the targets go from 0 to n-1. The 3rd/4th output from data_generator_for_gp_mimick_gpt  
    pos_mask (np array): see description in position_mask function
    ------------------    
    '''
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]
    combined_mask_tar = masks.create_masks(tar_inp)
    with tf.GradientTape(persistent=True) as tape:
        pred, pred_sig = decoder(token_pos, time_pos, tar_inp, True, pos_mask, combined_mask_tar)
#         print('pred: ')
#         tf.print(pred_sig)

        loss, mse, mask = losses.loss_function(tar_real, pred, pred_sig)


    gradients = tape.gradient(loss, decoder.trainable_variables)
#     tf.print(gradients)
# Ask the optimizer to apply the processed gradients.
    optimizer_c.apply_gradients(zip(gradients, decoder.trainable_variables))
    train_loss(loss)
    m_tr.update_state(mse, mask)
#     b = decoder.trainable_weights[0]
#     tf.print(tf.reduce_mean(b))
    return tar_inp, tar_real, pred, pred_sig, mask

In [236]:
@tf.function
def test_step(decoder, token_pos_te, time_pos_te, tar_te, pos_mask_te):
    '''
    
    ---------------
    Parameters:
    pos (np array): array of positions (x values) - the 1st/2nd output from data_generator_for_gp_mimick_gpt
    tar (np array): array of targets. Notice that if dealing with sequnces, we typically want to have the targets go from 0 to n-1. The 3rd/4th output from data_generator_for_gp_mimick_gpt  
    pos_mask_te (np array): see description in position_mask function
    ---------------
    
    '''
    tar_inp_te = tar_te[:, :-1]
    tar_real_te = tar_te[:, 1:]
    combined_mask_tar_te = masks.create_masks(tar_inp_te)
  # training=False is only needed if there are layers with different
  # behavior during training versus inference (e.g. Dropout).
    pred_te, pred_sig_te = decoder(token_pos_te, time_pos_te, tar_inp_te, False, pos_mask_te, combined_mask_tar_te)
    t_loss, t_mse, t_mask = losses.loss_function(tar_real_te, pred_te, pred_sig_te)
    test_loss(t_loss)
    m_te.update_state(t_mse, t_mask)
    return tar_real_te, pred_te, pred_sig_te, t_mask

In [237]:
tf.keras.backend.set_floatx('float64')

In [None]:
if __name__ == '__main__':
    writer = tf.summary.create_file_writer(save_dir + '/logs/')
    optimizer_c = tf.keras.optimizers.Adam()
    decoder = climate_model.Decoder(16)
    EPOCHS = 500
    batch_s  = 32
    run = 0; step = 0
    num_batches = int(temp_tr.shape[0] / batch_s)
    tf.random.set_seed(1)
    ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer = optimizer_c, net = decoder)
    main_folder = "/Users/omernivron/Downloads/GPT_climate/ckpt/check_"
    folder = main_folder + str(run); helpers.mkdir(folder)
    #https://www.tensorflow.org/guide/checkpoint
    manager = tf.train.CheckpointManager(ckpt, folder, max_to_keep=3)
    ckpt.restore(manager.latest_checkpoint)
    if manager.latest_checkpoint:
        print("Restored from {}".format(manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")

    with writer.as_default():
        for epoch in range(EPOCHS):
            start = time.time()

            for batch_n in range(num_batches):
                m_tr.reset_states(); train_loss.reset_states()
                m_te.reset_states(); test_loss.reset_states()
                batch_tok_pos_tr, batch_tim_pos_tr, batch_tar_tr, _ = batch_creator.create_batch_foxes(token_tr, time_tr, temp_tr, batch_s=32)
                # batch_tar_tr shape := 128 X 59 = (batch_size, max_seq_len)
                # batch_pos_tr shape := 128 X 59 = (batch_size, max_seq_len)
                batch_pos_mask = masks.position_mask(batch_tok_pos_tr)
                tar_inp, tar_real, pred, pred_sig, mask = train_step(decoder, optimizer_c, batch_tok_pos_tr, batch_tim_pos_tr, batch_tar_tr, batch_pos_mask)

                if batch_n % 10 == 0:
                    batch_tok_pos_te, batch_tim_pos_te, batch_tar_te, _ = batch_creator.create_batch_foxes(token_te, time_te, temp_te, batch_s= 32)
                    batch_pos_mask_te = masks.position_mask(batch_tok_pos_te)
                    tar_real_te, pred_te, pred_sig_te, t_mask = test_step(decoder, batch_tok_pos_te, batch_tim_pos_te, batch_tar_te, batch_pos_mask_te)
                    helpers.print_progress(epoch, batch_n, train_loss.result(), test_loss.result(), m_tr.result())
                    helpers.tf_summaries(run, step, train_loss.result(), test_loss.result(), m_tr.result(), m_te.result())
                    manager.save()
                step += 1
                ckpt.step.assign_add(1)

            print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))

Already exists
Restored from /Users/omernivron/Downloads/GPT_climate/ckpt/check_0/ckpt-12500
Epoch 0 batch 0 train Loss 5.2739 test Loss 5.3960 with MSE metric 7004.3525
Epoch 0 batch 10 train Loss 5.3193 test Loss 5.2997 with MSE metric 7671.2256
Epoch 0 batch 20 train Loss 5.2761 test Loss 5.3289 with MSE metric 7031.9185
Epoch 0 batch 30 train Loss 5.2238 test Loss 5.3786 with MSE metric 6280.2842
Epoch 0 batch 40 train Loss 5.3137 test Loss 5.2315 with MSE metric 7581.8154
Epoch 0 batch 50 train Loss 5.3869 test Loss 5.3905 with MSE metric 8709.8418
Epoch 0 batch 60 train Loss 5.2766 test Loss 5.4403 with MSE metric 7043.8154
Epoch 0 batch 70 train Loss 5.2409 test Loss 5.3325 with MSE metric 6418.5020
Epoch 0 batch 80 train Loss 5.3103 test Loss 5.4162 with MSE metric 7533.3545
Epoch 0 batch 90 train Loss 5.2801 test Loss 5.4246 with MSE metric 7092.9111
Epoch 0 batch 100 train Loss 5.3501 test Loss 5.3274 with MSE metric 8150.3296
Epoch 0 batch 110 train Loss 5.3044 test Loss 5.3

Epoch 4 batch 10 train Loss 5.2046 test Loss 5.3782 with MSE metric 5885.1895
Epoch 4 batch 20 train Loss 5.2004 test Loss 5.4595 with MSE metric 5833.0474
Epoch 4 batch 30 train Loss 5.3310 test Loss 5.3161 with MSE metric 7845.3711
Epoch 4 batch 40 train Loss 5.2773 test Loss 5.3506 with MSE metric 7018.6499
Epoch 4 batch 50 train Loss 5.2539 test Loss 5.2978 with MSE metric 6723.7568
Epoch 4 batch 60 train Loss 5.3238 test Loss 5.3999 with MSE metric 7739.4102
Epoch 4 batch 70 train Loss 5.3751 test Loss 5.3683 with MSE metric 8478.8232
Epoch 4 batch 80 train Loss 5.3449 test Loss 5.4134 with MSE metric 8035.2354
Epoch 4 batch 90 train Loss 5.2883 test Loss 5.3809 with MSE metric 7212.1602
Epoch 4 batch 100 train Loss 5.3226 test Loss 5.4721 with MSE metric 7688.2002
Epoch 4 batch 110 train Loss 5.2845 test Loss 5.4208 with MSE metric 7059.3350
Epoch 4 batch 120 train Loss 5.3521 test Loss 5.3900 with MSE metric 8157.3882
Epoch 4 batch 130 train Loss 5.3484 test Loss 5.4910 with MSE

Epoch 8 batch 30 train Loss 5.3523 test Loss 5.3570 with MSE metric 8130.1357
Epoch 8 batch 40 train Loss 5.3379 test Loss 5.4210 with MSE metric 7960.9014
Epoch 8 batch 50 train Loss 5.2990 test Loss 5.3978 with MSE metric 7366.9414
Epoch 8 batch 60 train Loss 5.4245 test Loss 5.4119 with MSE metric 9181.5176
Epoch 8 batch 70 train Loss 5.4154 test Loss 5.3382 with MSE metric 9231.1846
Epoch 8 batch 80 train Loss 5.3191 test Loss 5.3831 with MSE metric 7659.1997
Epoch 8 batch 90 train Loss 5.2329 test Loss 5.4139 with MSE metric 6442.1250
Epoch 8 batch 100 train Loss 5.3064 test Loss 5.2479 with MSE metric 7468.5527
Epoch 8 batch 110 train Loss 5.2248 test Loss 5.3372 with MSE metric 6130.6709
Epoch 8 batch 120 train Loss 5.4192 test Loss 5.3596 with MSE metric 9036.5684
Epoch 8 batch 130 train Loss 5.2147 test Loss 5.3605 with MSE metric 5992.6211
Epoch 8 batch 140 train Loss 5.4226 test Loss 5.2635 with MSE metric 9048.9121
Epoch 8 batch 150 train Loss 5.3399 test Loss 5.3707 with M

Epoch 12 batch 50 train Loss 5.2323 test Loss 5.4770 with MSE metric 6420.9409
Epoch 12 batch 60 train Loss 5.2427 test Loss 5.3771 with MSE metric 6472.5073
Epoch 12 batch 70 train Loss 5.3223 test Loss 5.3763 with MSE metric 7712.1934
Epoch 12 batch 80 train Loss 5.4112 test Loss 5.3833 with MSE metric 9044.3457
Epoch 12 batch 90 train Loss 5.2680 test Loss 5.4529 with MSE metric 6919.2607
Epoch 12 batch 100 train Loss 5.2028 test Loss 5.4810 with MSE metric 5981.0146
Epoch 12 batch 110 train Loss 5.3260 test Loss 5.4718 with MSE metric 7757.2139
Epoch 12 batch 120 train Loss 5.3832 test Loss 5.2961 with MSE metric 8608.3809
Epoch 12 batch 130 train Loss 5.2799 test Loss 5.3087 with MSE metric 7036.4336
Epoch 12 batch 140 train Loss 5.3311 test Loss 5.3596 with MSE metric 7788.6904
Epoch 12 batch 150 train Loss 5.3213 test Loss 5.3949 with MSE metric 7667.3364
Epoch 12 batch 160 train Loss 5.2538 test Loss 5.3747 with MSE metric 6689.4326
Epoch 12 batch 170 train Loss 5.2982 test Los

Epoch 16 batch 60 train Loss 5.3081 test Loss 5.4167 with MSE metric 7501.0283
Epoch 16 batch 70 train Loss 5.3401 test Loss 5.3055 with MSE metric 7996.6289
Epoch 16 batch 80 train Loss 5.2296 test Loss 5.3125 with MSE metric 6339.5635
Epoch 16 batch 90 train Loss 5.3085 test Loss 5.2942 with MSE metric 7493.1260
Epoch 16 batch 100 train Loss 5.2698 test Loss 5.3992 with MSE metric 6927.7588
Epoch 16 batch 110 train Loss 5.2844 test Loss 5.3589 with MSE metric 7150.3516
Epoch 16 batch 120 train Loss 5.3036 test Loss 5.3022 with MSE metric 7426.6602
Epoch 16 batch 130 train Loss 5.2113 test Loss 5.3171 with MSE metric 6098.0010
Epoch 16 batch 140 train Loss 5.3042 test Loss 5.5114 with MSE metric 7410.3408
Epoch 16 batch 150 train Loss 5.2805 test Loss 5.3687 with MSE metric 7014.1953
Epoch 16 batch 160 train Loss 5.2917 test Loss 5.3728 with MSE metric 7257.1143
Epoch 16 batch 170 train Loss 5.2513 test Loss 5.4339 with MSE metric 6672.2949
Epoch 16 batch 180 train Loss 5.3040 test Lo

Epoch 20 batch 70 train Loss 5.3765 test Loss 5.4545 with MSE metric 8501.7148
Epoch 20 batch 80 train Loss 5.3774 test Loss 5.3016 with MSE metric 8579.6816
Epoch 20 batch 90 train Loss 5.2769 test Loss 5.3781 with MSE metric 6971.7314
Epoch 20 batch 100 train Loss 5.2834 test Loss 5.2555 with MSE metric 7140.6270
Epoch 20 batch 110 train Loss 5.3379 test Loss 5.4733 with MSE metric 7902.1211
Epoch 20 batch 120 train Loss 5.3165 test Loss 5.2908 with MSE metric 7629.6182
Epoch 20 batch 130 train Loss 5.3233 test Loss 5.4684 with MSE metric 7729.4155
Epoch 20 batch 140 train Loss 5.3484 test Loss 5.3744 with MSE metric 8063.7632
Epoch 20 batch 150 train Loss 5.2689 test Loss 5.3994 with MSE metric 6908.4932
Epoch 20 batch 160 train Loss 5.4169 test Loss 5.4031 with MSE metric 8961.4062
Epoch 20 batch 170 train Loss 5.2086 test Loss 5.2898 with MSE metric 5927.1987
Epoch 20 batch 180 train Loss 5.3014 test Loss 5.3365 with MSE metric 7355.6338
Epoch 20 batch 190 train Loss 5.3013 test L

Epoch 24 batch 80 train Loss 5.3227 test Loss 5.3723 with MSE metric 7722.0688
Epoch 24 batch 90 train Loss 5.2788 test Loss 5.4071 with MSE metric 7052.8052
Epoch 24 batch 100 train Loss 5.3226 test Loss 5.2790 with MSE metric 7625.8306
Epoch 24 batch 110 train Loss 5.3138 test Loss 5.3249 with MSE metric 7587.9473
Epoch 24 batch 120 train Loss 5.2280 test Loss 5.3580 with MSE metric 6222.6162
Epoch 24 batch 130 train Loss 5.2136 test Loss 5.5158 with MSE metric 6096.2202
Epoch 24 batch 140 train Loss 5.3582 test Loss 5.3409 with MSE metric 8267.5576
Epoch 24 batch 150 train Loss 5.2662 test Loss 5.4692 with MSE metric 6897.3403
Epoch 24 batch 160 train Loss 5.2827 test Loss 5.3919 with MSE metric 7130.6953
Epoch 24 batch 170 train Loss 5.3766 test Loss 5.3910 with MSE metric 8543.7656
Epoch 24 batch 180 train Loss 5.3430 test Loss 5.3930 with MSE metric 8042.4297
Epoch 24 batch 190 train Loss 5.3402 test Loss 5.2976 with MSE metric 7900.7344
Epoch 24 batch 200 train Loss 5.3143 test 

Epoch 28 batch 90 train Loss 5.3966 test Loss 5.2274 with MSE metric 8701.2422
Epoch 28 batch 100 train Loss 5.4307 test Loss 5.2877 with MSE metric 9476.3291
Epoch 28 batch 110 train Loss 5.3141 test Loss 5.4105 with MSE metric 7590.2349
Epoch 28 batch 120 train Loss 5.3634 test Loss 5.4162 with MSE metric 8325.0186
Epoch 28 batch 130 train Loss 5.3004 test Loss 5.4366 with MSE metric 7370.6855
Epoch 28 batch 140 train Loss 5.3311 test Loss 5.4390 with MSE metric 7856.0454
Epoch 28 batch 150 train Loss 5.3819 test Loss 5.3760 with MSE metric 8464.3955
Epoch 28 batch 160 train Loss 5.2506 test Loss 5.4584 with MSE metric 6634.3037
Epoch 28 batch 170 train Loss 5.3849 test Loss 5.3949 with MSE metric 8690.5293
Epoch 28 batch 180 train Loss 5.3343 test Loss 5.3440 with MSE metric 7875.7554
Epoch 28 batch 190 train Loss 5.2767 test Loss 5.3288 with MSE metric 6973.7007
Epoch 28 batch 200 train Loss 5.1749 test Loss 5.2244 with MSE metric 5623.4897
Epoch 28 batch 210 train Loss 5.2180 test

In [None]:
1 - (0.0165 / sum((tar[:, 5] - np.mean(tar[:, 5]))**2) / len(tar[:, 5]))

In [None]:
tar - np.mean(tar, 0)

In [None]:
tar.shape

In [None]:
np.mean(tar[:, 0])

In [None]:
sum((tar[:, 0] - np.mean(tar[:, 0]))**2 )/ 10000

In [None]:
sum(sum((tar - np.mean(tar))**2)) / (tar.shape[0] * tar.shape[1])

In [None]:
pos = df_te[560, :].reshape(1, -1)

In [None]:
tar = df_te[561, :39].reshape(1, -1)

In [None]:
df_te[561, :]

In [None]:
a = inference(pos, tar, 20)

In [None]:
with matplotlib.rc_context({'figure.figsize': [10,2.5]}):
    plt.scatter(pos[:, :39], tar[:, :39], c='black')
    plt.scatter(pos[:, 39:58], a[39:])
    plt.scatter(pos[:, 39:58], df_te[561, 39:58], c='red')

In [None]:
# tf.data.Dataset(tf.Tensor(pad_pos_tr, value_index = 0 , dtype = tf.float32))