### Import libraries

In [1]:
from model import climate_model, losses, dot_prod_attention
from data import data_generation, data_combine, batch_creator, gp_kernels
from keras.callbacks import ModelCheckpoint
from helpers import helpers, masks
from inference import infer
import matplotlib.pyplot as plt
import tensorflow_addons as tfa
import tensorflow as tf
import numpy as np
import matplotlib 
import time
import keras
np.random.seed(1)

Using TensorFlow backend.


In [216]:
save_dir = '/Users/omernivron/Downloads/GPT_climate'

In [187]:
temp, t, token = data_combine.climate_data_to_model_input('./data/t2m_monthly_averaged_ensemble_members_1989_2019.csv')

In [190]:
## create climate train/test split

In [210]:
time_tr = t[:8000]; temp_tr = temp[:8000]; token_tr = token[:8000]
time_te = t[8000:]; temp_te = temp[8000:]; token_te = token[8000:]

In [281]:
loss_object = tf.keras.losses.MeanSquaredError()
train_loss = tf.keras.metrics.Mean(name='train_loss')
test_loss = tf.keras.metrics.Mean(name='test_loss')
m_tr = tf.keras.metrics.Mean()
m_te = tf.keras.metrics.Mean()

In [282]:
@tf.function
def train_step(decoder, optimizer_c, token_pos, time_pos, tar, pos_mask):
    '''
    A typical train step function for TF2. Elements which we wish to track their gradient
    has to be inside the GradientTape() clause. see (1) https://www.tensorflow.org/guide/migrate 
    (2) https://www.tensorflow.org/tutorials/quickstart/advanced
    ------------------
    Parameters:
    pos (np array): array of positions (x values) - the 1st/2nd output from data_generator_for_gp_mimick_gpt
    tar (np array): array of targets. Notice that if dealing with sequnces, we typically want to have the targets go from 0 to n-1. The 3rd/4th output from data_generator_for_gp_mimick_gpt  
    pos_mask (np array): see description in position_mask function
    ------------------    
    '''
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]
    combined_mask_tar = masks.create_masks(tar_inp)
    with tf.GradientTape(persistent=True) as tape:
        pred, pred_log_sig = decoder(token_pos, time_pos, tar_inp, True, pos_mask, combined_mask_tar)
#         print('pred: ')
#         tf.print(pred_sig)

        loss, mse, mask = losses.loss_function(tar_real, pred, pred_log_sig)


    gradients = tape.gradient(loss, decoder.trainable_variables)
#     tf.print(gradients)
# Ask the optimizer to apply the processed gradients.
    optimizer_c.apply_gradients(zip(gradients, decoder.trainable_variables))
    train_loss(loss)
    m_tr.update_state(mse, mask)
#     b = decoder.trainable_weights[0]
#     tf.print(tf.reduce_mean(b))
    return tar_inp, tar_real, pred, pred_log_sig, mask

In [283]:
@tf.function
def test_step(decoder, token_pos_te, time_pos_te, tar_te, pos_mask_te):
    '''
    
    ---------------
    Parameters:
    pos (np array): array of positions (x values) - the 1st/2nd output from data_generator_for_gp_mimick_gpt
    tar (np array): array of targets. Notice that if dealing with sequnces, we typically want to have the targets go from 0 to n-1. The 3rd/4th output from data_generator_for_gp_mimick_gpt  
    pos_mask_te (np array): see description in position_mask function
    ---------------
    
    '''
    tar_inp_te = tar_te[:, :-1]
    tar_real_te = tar_te[:, 1:]
    combined_mask_tar_te = masks.create_masks(tar_inp_te)
  # training=False is only needed if there are layers with different
  # behavior during training versus inference (e.g. Dropout).
    pred_te, pred_log_sig_te = decoder(token_pos_te, time_pos_te, tar_inp_te, False, pos_mask_te, combined_mask_tar_te)
    t_loss, t_mse, t_mask = losses.loss_function(tar_real_te, pred_te, pred_log_sig_te)
    test_loss(t_loss)
    m_te.update_state(t_mse, t_mask)
    return tar_real_te, pred_te, pred_log_sig_te, t_mask

In [284]:
tf.keras.backend.set_floatx('float64')

In [285]:
if __name__ == '__main__':
    writer = tf.summary.create_file_writer(save_dir + '/logs/')
    optimizer_c = tf.keras.optimizers.Adam(0.0003)
    decoder = climate_model.Decoder(16)
    EPOCHS = 500
    batch_s  = 128
    run = 0; step = 0
    num_batches = int(temp_tr.shape[0] / batch_s)
    tf.random.set_seed(1)
    ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer = optimizer_c, net = decoder)
    main_folder = "/Users/omernivron/Downloads/GPT_climate/ckpt/check_"
    folder = main_folder + str(run); helpers.mkdir(folder)
    #https://www.tensorflow.org/guide/checkpoint
    manager = tf.train.CheckpointManager(ckpt, folder, max_to_keep=3)
    ckpt.restore(manager.latest_checkpoint)
    if manager.latest_checkpoint:
        print("Restored from {}".format(manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")

    with writer.as_default():
        for epoch in range(EPOCHS):
            start = time.time()

            for batch_n in range(num_batches):
                m_tr.reset_states(); train_loss.reset_states()
                m_te.reset_states(); test_loss.reset_states()
                batch_tok_pos_tr, batch_tim_pos_tr, batch_tar_tr, _ = batch_creator.create_batch_foxes(token_tr, time_tr, temp_tr, batch_s=128)
                # batch_tar_tr shape := 128 X 59 = (batch_size, max_seq_len)
                # batch_pos_tr shape := 128 X 59 = (batch_size, max_seq_len)
                batch_pos_mask = masks.position_mask(batch_tok_pos_tr)
                tar_inp, tar_real, pred, pred_log_sig, mask = train_step(decoder, optimizer_c, batch_tok_pos_tr, batch_tim_pos_tr, batch_tar_tr, batch_pos_mask)

                if batch_n % 100 == 0:
                    batch_tok_pos_te, batch_tim_pos_te, batch_tar_te, _ = batch_creator.create_batch_foxes(token_te, time_te, temp_te, batch_s= 128)
                    batch_pos_mask_te = masks.position_mask(batch_tok_pos_te)
                    tar_real_te, pred_te, pred_log_sig_te, t_mask = test_step(decoder, batch_tok_pos_te, batch_tim_pos_te, batch_tar_te, batch_pos_mask_te)
                    helpers.print_progress(epoch, batch_n, train_loss.result(), test_loss.result(), m_tr.result())
                    helpers.tf_summaries(run, step, train_loss.result(), test_loss.result(), m_tr.result(), m_te.result())
                    manager.save()
                step += 1
                ckpt.step.assign_add(1)

            print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))

Already exists
Restored from /Users/omernivron/Downloads/GPT_climate/ckpt/check_0/ckpt-21073
Epoch 0 batch 0 train Loss 5.2659 test Loss 5.3509 with MSE metric 6894.6094
Time taken for 1 epoch: 31.018823385238647 secs

Epoch 1 batch 0 train Loss 5.3307 test Loss 5.3369 with MSE metric 7835.0283
Time taken for 1 epoch: 28.722605228424072 secs

Epoch 2 batch 0 train Loss 5.3091 test Loss 5.3692 with MSE metric 7518.0879
Time taken for 1 epoch: 30.544169902801514 secs

Epoch 3 batch 0 train Loss 5.3132 test Loss 5.3349 with MSE metric 7579.1050
Time taken for 1 epoch: 29.89436411857605 secs

Epoch 4 batch 0 train Loss 5.2974 test Loss 5.3666 with MSE metric 7339.3164
Time taken for 1 epoch: 29.931660652160645 secs

Epoch 5 batch 0 train Loss 5.3256 test Loss 5.3490 with MSE metric 7759.5596
Time taken for 1 epoch: 29.423455953598022 secs

Epoch 6 batch 0 train Loss 5.3310 test Loss 5.3658 with MSE metric 7828.8369
Time taken for 1 epoch: 29.46803903579712 secs

Epoch 7 batch 0 train Loss 

Time taken for 1 epoch: 33.79659104347229 secs

Epoch 65 batch 0 train Loss 5.3253 test Loss 5.3724 with MSE metric 7761.4111
Time taken for 1 epoch: 33.87038278579712 secs

Epoch 66 batch 0 train Loss 5.2509 test Loss 5.3747 with MSE metric 6678.9521
Time taken for 1 epoch: 34.280285120010376 secs

Epoch 67 batch 0 train Loss 5.2947 test Loss 5.3625 with MSE metric 7303.0029
Time taken for 1 epoch: 34.226808071136475 secs

Epoch 68 batch 0 train Loss 5.3557 test Loss 5.3580 with MSE metric 8235.4824
Time taken for 1 epoch: 34.44808387756348 secs

Epoch 69 batch 0 train Loss 5.2454 test Loss 5.3364 with MSE metric 6590.9746
Time taken for 1 epoch: 34.71144676208496 secs

Epoch 70 batch 0 train Loss 5.2392 test Loss 5.3947 with MSE metric 6440.5928
Time taken for 1 epoch: 34.59360694885254 secs

Epoch 71 batch 0 train Loss 5.3005 test Loss 5.3160 with MSE metric 7389.2988
Time taken for 1 epoch: 34.09790802001953 secs

Epoch 72 batch 0 train Loss 5.3500 test Loss 5.3577 with MSE metric 

Time taken for 1 epoch: 33.73696303367615 secs

Epoch 130 batch 0 train Loss 5.3141 test Loss 5.3837 with MSE metric 7569.5996
Time taken for 1 epoch: 33.592840909957886 secs

Epoch 131 batch 0 train Loss 5.3119 test Loss 5.3808 with MSE metric 7560.3696
Time taken for 1 epoch: 33.674437046051025 secs

Epoch 132 batch 0 train Loss 5.2930 test Loss 5.3651 with MSE metric 7276.0117
Time taken for 1 epoch: 33.67853403091431 secs

Epoch 133 batch 0 train Loss 5.3433 test Loss 5.3540 with MSE metric 8027.8589
Time taken for 1 epoch: 33.61728262901306 secs

Epoch 134 batch 0 train Loss 5.3048 test Loss 5.3455 with MSE metric 7453.8911
Time taken for 1 epoch: 33.59117293357849 secs

Epoch 135 batch 0 train Loss 5.2646 test Loss 5.3740 with MSE metric 6857.3970
Time taken for 1 epoch: 33.04614281654358 secs

Epoch 136 batch 0 train Loss 5.2385 test Loss 5.3475 with MSE metric 6471.9102
Time taken for 1 epoch: 33.19486904144287 secs

Epoch 137 batch 0 train Loss 5.2940 test Loss 5.3690 with MSE

Time taken for 1 epoch: 32.93606114387512 secs

Epoch 195 batch 0 train Loss 5.3213 test Loss 5.3381 with MSE metric 7701.3604
Time taken for 1 epoch: 32.99523591995239 secs

Epoch 196 batch 0 train Loss 5.3201 test Loss 5.4093 with MSE metric 7682.8110
Time taken for 1 epoch: 33.00985312461853 secs

Epoch 197 batch 0 train Loss 5.3280 test Loss 5.3448 with MSE metric 7801.4521
Time taken for 1 epoch: 33.20161414146423 secs

Epoch 198 batch 0 train Loss 5.3055 test Loss 5.3698 with MSE metric 7451.8906
Time taken for 1 epoch: 33.01970100402832 secs

Epoch 199 batch 0 train Loss 5.2964 test Loss 5.3434 with MSE metric 7328.7271
Time taken for 1 epoch: 33.19156217575073 secs

Epoch 200 batch 0 train Loss 5.3085 test Loss 5.3917 with MSE metric 7485.5469
Time taken for 1 epoch: 32.95593595504761 secs

Epoch 201 batch 0 train Loss 5.2906 test Loss 5.3503 with MSE metric 7243.3066
Time taken for 1 epoch: 33.09499788284302 secs

Epoch 202 batch 0 train Loss 5.3319 test Loss 5.3685 with MSE m

Time taken for 1 epoch: 33.379859924316406 secs

Epoch 260 batch 0 train Loss 5.3226 test Loss 5.3249 with MSE metric 7722.3721
Time taken for 1 epoch: 33.25590991973877 secs

Epoch 261 batch 0 train Loss 5.3034 test Loss 5.3670 with MSE metric 7414.1768
Time taken for 1 epoch: 33.301154136657715 secs

Epoch 262 batch 0 train Loss 5.2910 test Loss 5.3193 with MSE metric 7246.1348
Time taken for 1 epoch: 33.03551697731018 secs

Epoch 263 batch 0 train Loss 5.3414 test Loss 5.3709 with MSE metric 7983.7114
Time taken for 1 epoch: 33.27055907249451 secs

Epoch 264 batch 0 train Loss 5.3201 test Loss 5.3102 with MSE metric 7657.6895
Time taken for 1 epoch: 32.89598894119263 secs

Epoch 265 batch 0 train Loss 5.2917 test Loss 5.3371 with MSE metric 7260.5352
Time taken for 1 epoch: 32.96501803398132 secs

Epoch 266 batch 0 train Loss 5.3118 test Loss 5.3651 with MSE metric 7553.8955
Time taken for 1 epoch: 33.28825068473816 secs

Epoch 267 batch 0 train Loss 5.3050 test Loss 5.3955 with MSE

Epoch 324 batch 0 train Loss 5.2766 test Loss 5.3067 with MSE metric 7039.6465
Time taken for 1 epoch: 27.46802282333374 secs

Epoch 325 batch 0 train Loss 5.2574 test Loss 5.3660 with MSE metric 6774.1328
Time taken for 1 epoch: 27.36406111717224 secs

Epoch 326 batch 0 train Loss 5.3108 test Loss 5.3470 with MSE metric 7542.7148
Time taken for 1 epoch: 27.33210802078247 secs

Epoch 327 batch 0 train Loss 5.3219 test Loss 5.3807 with MSE metric 7706.2109
Time taken for 1 epoch: 27.385365962982178 secs

Epoch 328 batch 0 train Loss 5.2749 test Loss 5.3798 with MSE metric 7014.1011
Time taken for 1 epoch: 27.40013098716736 secs

Epoch 329 batch 0 train Loss 5.3617 test Loss 5.3536 with MSE metric 8268.8066
Time taken for 1 epoch: 27.404680013656616 secs

Epoch 330 batch 0 train Loss 5.2785 test Loss 5.3822 with MSE metric 7070.4385
Time taken for 1 epoch: 27.43275499343872 secs

Epoch 331 batch 0 train Loss 5.2780 test Loss 5.3742 with MSE metric 7059.3291
Time taken for 1 epoch: 27.395

Time taken for 1 epoch: 27.09631609916687 secs

Epoch 389 batch 0 train Loss 5.3136 test Loss 5.4178 with MSE metric 7583.7886
Time taken for 1 epoch: 27.03241991996765 secs

Epoch 390 batch 0 train Loss 5.2998 test Loss 5.3530 with MSE metric 7377.3188
Time taken for 1 epoch: 26.982159852981567 secs

Epoch 391 batch 0 train Loss 5.2987 test Loss 5.3956 with MSE metric 7361.8965
Time taken for 1 epoch: 26.956141233444214 secs

Epoch 392 batch 0 train Loss 5.3177 test Loss 5.3762 with MSE metric 7646.4727
Time taken for 1 epoch: 27.08856439590454 secs

Epoch 393 batch 0 train Loss 5.2915 test Loss 5.3664 with MSE metric 7234.6675
Time taken for 1 epoch: 26.894054889678955 secs

Epoch 394 batch 0 train Loss 5.2709 test Loss 5.3892 with MSE metric 6951.7754
Time taken for 1 epoch: 27.004430055618286 secs

Epoch 395 batch 0 train Loss 5.2352 test Loss 5.3525 with MSE metric 6370.6357
Time taken for 1 epoch: 27.033483028411865 secs

Epoch 396 batch 0 train Loss 5.3258 test Loss 5.3356 with 

Epoch 453 batch 0 train Loss 5.3219 test Loss 5.3439 with MSE metric 7709.7695
Time taken for 1 epoch: 27.216262817382812 secs

Epoch 454 batch 0 train Loss 5.3025 test Loss 5.3443 with MSE metric 7418.3789
Time taken for 1 epoch: 27.065304040908813 secs

Epoch 455 batch 0 train Loss 5.3019 test Loss 5.3073 with MSE metric 7405.6260
Time taken for 1 epoch: 27.164268016815186 secs

Epoch 456 batch 0 train Loss 5.2873 test Loss 5.3817 with MSE metric 7180.5107
Time taken for 1 epoch: 27.02651882171631 secs

Epoch 457 batch 0 train Loss 5.2866 test Loss 5.3515 with MSE metric 7187.3516
Time taken for 1 epoch: 27.192296981811523 secs

Epoch 458 batch 0 train Loss 5.2974 test Loss 5.3324 with MSE metric 7338.7100
Time taken for 1 epoch: 27.25898504257202 secs

Epoch 459 batch 0 train Loss 5.2738 test Loss 5.3536 with MSE metric 6986.7490
Time taken for 1 epoch: 27.075589895248413 secs

Epoch 460 batch 0 train Loss 5.3162 test Loss 5.3490 with MSE metric 7624.3311
Time taken for 1 epoch: 27.

In [275]:
tar_real_te[10, :]

<tf.Tensor: shape=(39,), dtype=float64, numpy=
array([277.4391 , 277.1047 , 280.83145, 277.54877, 276.43674, 280.73575,
       279.6379 , 278.9441 , 280.72787, 279.9052 , 276.1898 , 278.17764,
       276.66025, 277.8398 , 279.3794 , 277.32608, 279.0153 , 276.279  ,
       277.99994, 276.37442, 280.85577, 279.12204, 280.74197, 277.27045,
       278.04602,   0.     ,   0.     ,   0.     ,   0.     ,   0.     ,
         0.     ,   0.     ,   0.     ,   0.     ,   0.     ,   0.     ,
         0.     ,   0.     ,   0.     ])>

In [266]:
t_mask[1, :]

<tf.Tensor: shape=(39,), dtype=float64, numpy=
array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0.])>

In [274]:
pred_te[10, :]

<tf.Tensor: shape=(39,), dtype=float64, numpy=
array([305.5369064 , 305.53690675, 305.53690678, 305.53690751,
       305.53690744, 305.79655562, 305.79655595, 305.79655609,
       305.84506049, 305.84506066, 305.84506075, 310.90365222,
       310.90365221, 310.90365213, 105.75508275, 105.75508295,
       105.75508276, 105.75508288, 105.75508255, 105.75508251,
       105.75508224, 105.75508261, 105.75508272, 105.75508303,
       105.75508289, 105.75508286, 105.75508286, 105.75508286,
       105.75508286, 105.75508286, 105.75508286, 105.75508286,
       105.75508286, 105.75508286, 105.75508286, 105.75508286,
       105.75508286, 105.75508286, 105.75508286])>

In [273]:
((pred_te[10, :] * t_mask[10, :] ) - tar_real_te[10, :])

<tf.Tensor: shape=(39,), dtype=float64, numpy=
array([  28.0978064 ,   28.43220675,   24.70545678,   27.98813751,
         29.10016744,   25.06080562,   26.15865595,   26.85245609,
         25.11719049,   25.93986066,   29.65526075,   32.72601222,
         34.24340221,   33.06385213, -173.62431725, -171.57099705,
       -173.26021724, -170.52391712, -172.24485745, -170.61933749,
       -175.10068776, -173.36695739, -174.98688728, -171.51536697,
       -172.29093711,    0.        ,    0.        ,    0.        ,
          0.        ,    0.        ,    0.        ,    0.        ,
          0.        ,    0.        ,    0.        ,    0.        ,
          0.        ,    0.        ,    0.        ])>

In [None]:
1 - (0.0165 / sum((tar[:, 5] - np.mean(tar[:, 5]))**2) / len(tar[:, 5]))

In [None]:
tar - np.mean(tar, 0)

In [None]:
tar.shape

In [None]:
np.mean(tar[:, 0])

In [None]:
sum((tar[:, 0] - np.mean(tar[:, 0]))**2 )/ 10000

In [None]:
sum(sum((tar - np.mean(tar))**2)) / (tar.shape[0] * tar.shape[1])

In [None]:
pos = df_te[560, :].reshape(1, -1)

In [None]:
tar = df_te[561, :39].reshape(1, -1)

In [None]:
df_te[561, :]

In [None]:
a = inference(pos, tar, 20)

In [None]:
with matplotlib.rc_context({'figure.figsize': [10,2.5]}):
    plt.scatter(pos[:, :39], tar[:, :39], c='black')
    plt.scatter(pos[:, 39:58], a[39:])
    plt.scatter(pos[:, 39:58], df_te[561, 39:58], c='red')

In [None]:
# tf.data.Dataset(tf.Tensor(pad_pos_tr, value_index = 0 , dtype = tf.float32))