### Import libraries

In [228]:
from model import climate_model, losses, dot_prod_attention
from data import data_generation, data_combine, batch_creator, gp_kernels
from keras.callbacks import ModelCheckpoint
from helpers import helpers, masks
from inference import infer
import matplotlib.pyplot as plt
import tensorflow_addons as tfa
import tensorflow as tf
import numpy as np
import matplotlib 
import time
import keras
np.random.seed(1)

In [216]:
save_dir = '/Users/omernivron/Downloads/GPT_climate'

In [187]:
temp, t, token = data_combine.climate_data_to_model_input('./data/t2m_monthly_averaged_ensemble_members_1989_2019.csv')

In [190]:
## create climate train/test split

In [210]:
time_tr = t[:8000]; temp_tr = temp[:8000]; token_tr = token[:8000]
time_te = t[8000:]; temp_te = temp[8000:]; token_te = token[8000:]

In [229]:
loss_object = tf.keras.losses.MeanSquaredError()
train_loss = tf.keras.metrics.Mean(name='train_loss')
test_loss = tf.keras.metrics.Mean(name='test_loss')
m_tr = tf.keras.metrics.Mean()
m_te = tf.keras.metrics.Mean()

In [230]:
@tf.function
def train_step(decoder, optimizer_c, token_pos, time_pos, tar, pos_mask):
    '''
    A typical train step function for TF2. Elements which we wish to track their gradient
    has to be inside the GradientTape() clause. see (1) https://www.tensorflow.org/guide/migrate 
    (2) https://www.tensorflow.org/tutorials/quickstart/advanced
    ------------------
    Parameters:
    pos (np array): array of positions (x values) - the 1st/2nd output from data_generator_for_gp_mimick_gpt
    tar (np array): array of targets. Notice that if dealing with sequnces, we typically want to have the targets go from 0 to n-1. The 3rd/4th output from data_generator_for_gp_mimick_gpt  
    pos_mask (np array): see description in position_mask function
    ------------------    
    '''
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]
    combined_mask_tar = masks.create_masks(tar_inp)
    with tf.GradientTape(persistent=True) as tape:
        pred, pred_sig = decoder(token_pos, time_pos, tar_inp, True, pos_mask, combined_mask_tar)
#         print('pred: ')
#         tf.print(pred_sig)

        loss, mse, mask = losses.loss_function(tar_real, pred, pred_sig)


    gradients = tape.gradient(loss, decoder.trainable_variables)
#     tf.print(gradients)
# Ask the optimizer to apply the processed gradients.
    optimizer_c.apply_gradients(zip(gradients, decoder.trainable_variables))
    train_loss(loss)
    m_tr.update_state(mse, mask)
#     b = decoder.trainable_weights[0]
#     tf.print(tf.reduce_mean(b))
    return tar_inp, tar_real, pred, pred_sig, mask

In [231]:
@tf.function
def test_step(decoder, token_pos_te, time_pos_te, tar_te, pos_mask_te):
    '''
    
    ---------------
    Parameters:
    pos (np array): array of positions (x values) - the 1st/2nd output from data_generator_for_gp_mimick_gpt
    tar (np array): array of targets. Notice that if dealing with sequnces, we typically want to have the targets go from 0 to n-1. The 3rd/4th output from data_generator_for_gp_mimick_gpt  
    pos_mask_te (np array): see description in position_mask function
    ---------------
    
    '''
    tar_inp_te = tar_te[:, :-1]
    tar_real_te = tar_te[:, 1:]
    combined_mask_tar_te = masks.create_masks(tar_inp_te)
  # training=False is only needed if there are layers with different
  # behavior during training versus inference (e.g. Dropout).
    pred_te, pred_sig_te = decoder(token_pos_te, time_pos_te, tar_inp_te, False, pos_mask_te, combined_mask_tar_te)
    t_loss, t_mse, t_mask = losses.loss_function(tar_real_te, pred_te, pred_sig_te)
    test_loss(t_loss)
    m_te.update_state(t_mse, t_mask)
    return tar_real_te, pred_te, pred_sig_te, t_mask

In [232]:
tf.keras.backend.set_floatx('float64')

In [None]:
if __name__ == '__main__':
    writer = tf.summary.create_file_writer(save_dir + '/logs/')
    optimizer_c = tf.keras.optimizers.Adam()
    decoder = climate_model.Decoder(16)
    EPOCHS = 500
    batch_s  = 32
    run = 0; step = 0
    num_batches = int(temp_tr.shape[0] / batch_s)
    tf.random.set_seed(1)
    ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer = optimizer_c, net = decoder)
    main_folder = "/Users/omernivron/Downloads/GPT_climate/ckpt/check_"
    folder = main_folder + str(run); helpers.mkdir(folder)
    #https://www.tensorflow.org/guide/checkpoint
    manager = tf.train.CheckpointManager(ckpt, folder, max_to_keep=3)
    ckpt.restore(manager.latest_checkpoint)
    if manager.latest_checkpoint:
        print("Restored from {}".format(manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")

    with writer.as_default():
        for epoch in range(EPOCHS):
            start = time.time()

            for batch_n in range(num_batches):
                batch_tok_pos_tr, batch_tim_pos_tr, batch_tar_tr, _ = batch_creator.create_batch_foxes(token_tr, time_tr, temp_tr, batch_s=32)
                # batch_tar_tr shape := 128 X 59 = (batch_size, max_seq_len)
                # batch_pos_tr shape := 128 X 59 = (batch_size, max_seq_len)
                batch_pos_mask = masks.position_mask(batch_tok_pos_tr)
                tar_inp, tar_real, pred, pred_sig, mask = train_step(decoder, optimizer_c, batch_tok_pos_tr, batch_tim_pos_tr, batch_tar_tr, batch_pos_mask)

                if batch_n % 10 == 0:
                    batch_tok_pos_te, batch_tim_pos_te, batch_tar_te, _ = batch_creator.create_batch_foxes(token_te, time_te, temp_te, batch_s= 32)
                    batch_pos_mask_te = masks.position_mask(batch_tok_pos_te)
                    tar_real_te, pred_te, pred_sig_te, t_mask = test_step(decoder, batch_tok_pos_te, batch_tim_pos_te, batch_tar_te, batch_pos_mask_te)
                    helpers.print_progress(epoch, batch_n, train_loss.result(), test_loss.result(), m_tr.result())
                    helpers.tf_summaries(run, step, train_loss.result(), test_loss.result(), m_tr.result(), m_te.result())
                    manager.save()
                step += 1
                ckpt.step.assign_add(1)

            print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))

Already exists
Initializing from scratch.
Epoch 0 batch 0 train Loss 276635.7271 test Loss 92122.0270 with MSE metric 45551.5664
Epoch 0 batch 10 train Loss 147707.3831 test Loss 46122.4647 with MSE metric 46528.5883
Epoch 0 batch 20 train Loss 88084.5578 test Loss 30754.6986 with MSE metric 46206.7365
Epoch 0 batch 30 train Loss 62429.0003 test Loss 23068.2816 with MSE metric 45921.8479
Epoch 0 batch 40 train Loss 48274.3682 test Loss 18456.1269 with MSE metric 46211.5025
Epoch 0 batch 50 train Loss 39442.4337 test Loss 15381.2496 with MSE metric 46455.9027
Epoch 0 batch 60 train Loss 33508.3575 test Loss 13184.8707 with MSE metric 46346.1084
Epoch 0 batch 70 train Loss 29167.5388 test Loss 11537.5780 with MSE metric 46335.6242
Epoch 0 batch 80 train Loss 25797.9345 test Loss 10256.3357 with MSE metric 46273.0657
Epoch 0 batch 90 train Loss 23146.2318 test Loss 9231.3473 with MSE metric 46227.9679
Epoch 0 batch 100 train Loss 21060.0951 test Loss 8392.7229 with MSE metric 46165.1620
E

Epoch 3 batch 190 train Loss 2573.7582 test Loss 978.2907 with MSE metric 45097.1019
Epoch 3 batch 200 train Loss 2547.3261 test Loss 968.1856 with MSE metric 45086.2424
Epoch 3 batch 210 train Loss 2521.5919 test Loss 958.2892 with MSE metric 45083.5910
Epoch 3 batch 220 train Loss 2496.2625 test Loss 948.5948 with MSE metric 45077.8469
Epoch 3 batch 230 train Loss 2471.7906 test Loss 939.0961 with MSE metric 45075.7935
Epoch 3 batch 240 train Loss 2447.3417 test Loss 929.7878 with MSE metric 45074.1080
Time taken for 1 epoch: 30.357666015625 secs

Epoch 4 batch 0 train Loss 2423.4229 test Loss 920.6635 with MSE metric 45070.6918
Epoch 4 batch 10 train Loss 2399.9280 test Loss 911.7182 with MSE metric 45057.1591
Epoch 4 batch 20 train Loss 2376.9605 test Loss 902.9480 with MSE metric 45060.1760
Epoch 4 batch 30 train Loss 2354.4834 test Loss 894.3457 with MSE metric 45053.5382
Epoch 4 batch 40 train Loss 2332.3808 test Loss 885.9081 with MSE metric 45056.1516
Epoch 4 batch 50 train Lo

Epoch 7 batch 140 train Loss 1300.7095 test Loss 493.4961 with MSE metric 44818.2943
Epoch 7 batch 150 train Loss 1294.0040 test Loss 490.9602 with MSE metric 44817.6849
Epoch 7 batch 160 train Loss 1287.3511 test Loss 488.4506 with MSE metric 44813.3637
Epoch 7 batch 170 train Loss 1280.7618 test Loss 485.9669 with MSE metric 44810.6646
Epoch 7 batch 180 train Loss 1274.3010 test Loss 483.5091 with MSE metric 44811.2059
Epoch 7 batch 190 train Loss 1267.8261 test Loss 481.0765 with MSE metric 44810.5494
Epoch 7 batch 200 train Loss 1261.4127 test Loss 478.6689 with MSE metric 44809.8589
Epoch 7 batch 210 train Loss 1255.0965 test Loss 476.2854 with MSE metric 44813.5021
Epoch 7 batch 220 train Loss 1248.8292 test Loss 473.9266 with MSE metric 44812.7825
Epoch 7 batch 230 train Loss 1242.6034 test Loss 471.5913 with MSE metric 44808.9969
Epoch 7 batch 240 train Loss 1236.4422 test Loss 469.2792 with MSE metric 44808.6132
Time taken for 1 epoch: 29.421195030212402 secs

Epoch 8 batch 0 

Epoch 11 batch 90 train Loss 871.8436 test Loss 332.1477 with MSE metric 44707.9802
Epoch 11 batch 100 train Loss 868.8273 test Loss 331.0205 with MSE metric 44706.4711
Epoch 11 batch 110 train Loss 865.8288 test Loss 329.9011 with MSE metric 44702.7393
Epoch 11 batch 120 train Loss 862.8777 test Loss 328.7893 with MSE metric 44698.5043
Epoch 11 batch 130 train Loss 859.9492 test Loss 327.6854 with MSE metric 44698.4619
Epoch 11 batch 140 train Loss 857.0172 test Loss 326.5892 with MSE metric 44696.4532
Epoch 11 batch 150 train Loss 854.1239 test Loss 325.5003 with MSE metric 44692.4952
Epoch 11 batch 160 train Loss 851.2339 test Loss 324.4191 with MSE metric 44690.6223
Epoch 11 batch 170 train Loss 848.3575 test Loss 323.3452 with MSE metric 44689.6701
Epoch 11 batch 180 train Loss 845.5189 test Loss 322.2788 with MSE metric 44688.6152
Epoch 11 batch 190 train Loss 842.7060 test Loss 321.2197 with MSE metric 44688.2442
Epoch 11 batch 200 train Loss 839.8990 test Loss 320.1676 with MSE

Epoch 15 batch 40 train Loss 656.4897 test Loss 251.6091 with MSE metric 44649.6013
Epoch 15 batch 50 train Loss 654.7873 test Loss 250.9754 with MSE metric 44650.0012
Epoch 15 batch 60 train Loss 653.1027 test Loss 250.3452 with MSE metric 44650.0421
Epoch 15 batch 70 train Loss 651.4181 test Loss 249.7183 with MSE metric 44649.1121
Epoch 15 batch 80 train Loss 649.7408 test Loss 249.0947 with MSE metric 44647.7067
Epoch 15 batch 90 train Loss 648.0725 test Loss 248.4744 with MSE metric 44646.1725
Epoch 15 batch 100 train Loss 646.4165 test Loss 247.8572 with MSE metric 44645.2851
Epoch 15 batch 110 train Loss 644.7653 test Loss 247.2429 with MSE metric 44644.8591
Epoch 15 batch 120 train Loss 643.1229 test Loss 246.6321 with MSE metric 44643.1733
Epoch 15 batch 130 train Loss 641.4899 test Loss 246.0245 with MSE metric 44641.2036
Epoch 15 batch 140 train Loss 639.8646 test Loss 245.4198 with MSE metric 44639.2454
Epoch 15 batch 150 train Loss 638.2462 test Loss 244.8183 with MSE metr

Time taken for 1 epoch: 29.84224796295166 secs

Epoch 19 batch 0 train Loss 525.7473 test Loss 202.9620 with MSE metric 44585.1595
Epoch 19 batch 10 train Loss 524.6610 test Loss 202.5586 with MSE metric 44583.3717
Epoch 19 batch 20 train Loss 523.5810 test Loss 202.1569 with MSE metric 44583.4698
Epoch 19 batch 30 train Loss 522.5030 test Loss 201.7569 with MSE metric 44582.7458
Epoch 19 batch 40 train Loss 521.4332 test Loss 201.3587 with MSE metric 44581.1580
Epoch 19 batch 50 train Loss 520.3638 test Loss 200.9622 with MSE metric 44579.6533
Epoch 19 batch 60 train Loss 519.2989 test Loss 200.5672 with MSE metric 44577.5454
Epoch 19 batch 70 train Loss 518.2389 test Loss 200.1739 with MSE metric 44574.7417
Epoch 19 batch 80 train Loss 517.1837 test Loss 199.7822 with MSE metric 44573.0875
Epoch 19 batch 90 train Loss 516.1344 test Loss 199.3921 with MSE metric 44571.1268
Epoch 19 batch 100 train Loss 515.0874 test Loss 199.0036 with MSE metric 44570.0729
Epoch 19 batch 110 train Los

Epoch 22 batch 200 train Loss 439.5845 test Loss 170.9795 with MSE metric 44517.1503
Epoch 22 batch 210 train Loss 438.8287 test Loss 170.6996 with MSE metric 44518.5449
Epoch 22 batch 220 train Loss 438.0748 test Loss 170.4206 with MSE metric 44515.3944
Epoch 22 batch 230 train Loss 437.3248 test Loss 170.1427 with MSE metric 44514.9178
Epoch 22 batch 240 train Loss 436.5766 test Loss 169.8657 with MSE metric 44516.0122
Time taken for 1 epoch: 30.260584354400635 secs

Epoch 23 batch 0 train Loss 435.8310 test Loss 169.5898 with MSE metric 44513.3356
Epoch 23 batch 10 train Loss 435.0883 test Loss 169.3148 with MSE metric 44512.0308
Epoch 23 batch 20 train Loss 434.3486 test Loss 169.0408 with MSE metric 44513.1384
Epoch 23 batch 30 train Loss 433.6202 test Loss 168.7675 with MSE metric 44513.4129
Epoch 23 batch 40 train Loss 432.8849 test Loss 168.4954 with MSE metric 44511.8111
Epoch 23 batch 50 train Loss 432.1534 test Loss 168.2243 with MSE metric 44511.1223
Epoch 23 batch 60 train

Epoch 26 batch 150 train Loss 377.9733 test Loss 148.1580 with MSE metric 44454.7234
Epoch 26 batch 160 train Loss 377.4174 test Loss 147.9525 with MSE metric 44453.3625
Epoch 26 batch 170 train Loss 376.8630 test Loss 147.7476 with MSE metric 44453.6740
Epoch 26 batch 180 train Loss 376.3102 test Loss 147.5433 with MSE metric 44453.7684
Epoch 26 batch 190 train Loss 375.7591 test Loss 147.3397 with MSE metric 44451.9496
Epoch 26 batch 200 train Loss 375.2098 test Loss 147.1366 with MSE metric 44450.7005
Epoch 26 batch 210 train Loss 374.6619 test Loss 146.9341 with MSE metric 44452.3522
Epoch 26 batch 220 train Loss 374.1156 test Loss 146.7323 with MSE metric 44451.1834
Epoch 26 batch 230 train Loss 373.5711 test Loss 146.5309 with MSE metric 44450.5299
Epoch 26 batch 240 train Loss 373.0280 test Loss 146.3302 with MSE metric 44451.0870
Time taken for 1 epoch: 29.952781915664673 secs

Epoch 27 batch 0 train Loss 372.4866 test Loss 146.1301 with MSE metric 44450.3651
Epoch 27 batch 10 

Epoch 30 batch 100 train Loss 331.7017 test Loss 131.0458 with MSE metric 44402.5543
Epoch 30 batch 110 train Loss 331.2757 test Loss 130.8884 with MSE metric 44401.7364
Epoch 30 batch 120 train Loss 330.8509 test Loss 130.7314 with MSE metric 44401.1948
Epoch 30 batch 130 train Loss 330.4272 test Loss 130.5748 with MSE metric 44402.4479
Epoch 30 batch 140 train Loss 330.0047 test Loss 130.4186 with MSE metric 44402.8617
Epoch 30 batch 150 train Loss 329.5861 test Loss 130.2628 with MSE metric 44401.9661
Epoch 30 batch 160 train Loss 329.1658 test Loss 130.1074 with MSE metric 44401.9016
Epoch 30 batch 170 train Loss 328.7467 test Loss 129.9523 with MSE metric 44402.1558
Epoch 30 batch 180 train Loss 328.3287 test Loss 129.7978 with MSE metric 44402.1161
Epoch 30 batch 190 train Loss 327.9115 test Loss 129.6436 with MSE metric 44402.0686
Epoch 30 batch 200 train Loss 327.4956 test Loss 129.4898 with MSE metric 44400.9390
Epoch 30 batch 210 train Loss 327.0813 test Loss 129.3364 with MS

Epoch 34 batch 50 train Loss 295.7075 test Loss 117.7262 with MSE metric 44341.9728
Epoch 34 batch 60 train Loss 295.3707 test Loss 117.6016 with MSE metric 44341.8949
Epoch 34 batch 70 train Loss 295.0348 test Loss 117.4773 with MSE metric 44341.6695
Epoch 34 batch 80 train Loss 294.7001 test Loss 117.3532 with MSE metric 44340.4276
Epoch 34 batch 90 train Loss 294.3657 test Loss 117.2294 with MSE metric 44340.7906
Epoch 34 batch 100 train Loss 294.0321 test Loss 117.1060 with MSE metric 44340.4896
Epoch 34 batch 110 train Loss 293.6995 test Loss 116.9828 with MSE metric 44340.1295
Epoch 34 batch 120 train Loss 293.3678 test Loss 116.8599 with MSE metric 44338.9561
Epoch 34 batch 130 train Loss 293.0365 test Loss 116.7373 with MSE metric 44338.4681
Epoch 34 batch 140 train Loss 292.7062 test Loss 116.6149 with MSE metric 44336.6368
Epoch 34 batch 150 train Loss 292.3764 test Loss 116.4928 with MSE metric 44335.7128
Epoch 34 batch 160 train Loss 292.0485 test Loss 116.3710 with MSE met

Epoch 38 batch 10 train Loss 266.6070 test Loss 106.9376 with MSE metric 44281.3151
Epoch 38 batch 20 train Loss 266.3348 test Loss 106.8364 with MSE metric 44281.0579
Epoch 38 batch 30 train Loss 266.0629 test Loss 106.7355 with MSE metric 44280.2183
Epoch 38 batch 40 train Loss 265.7916 test Loss 106.6347 with MSE metric 44279.0542
Epoch 38 batch 50 train Loss 265.5209 test Loss 106.5341 with MSE metric 44279.6080
Epoch 38 batch 60 train Loss 265.2507 test Loss 106.4338 with MSE metric 44278.0860
Epoch 38 batch 70 train Loss 264.9811 test Loss 106.3336 with MSE metric 44276.4217
Epoch 38 batch 80 train Loss 264.7121 test Loss 106.2337 with MSE metric 44275.2926
Epoch 38 batch 90 train Loss 264.4436 test Loss 106.1339 with MSE metric 44274.6828
Epoch 38 batch 100 train Loss 264.1757 test Loss 106.0344 with MSE metric 44273.8267
Epoch 38 batch 110 train Loss 263.9085 test Loss 105.9350 with MSE metric 44273.1122
Epoch 38 batch 120 train Loss 263.6417 test Loss 105.8359 with MSE metric 

Epoch 41 batch 220 train Loss 242.8285 test Loss 98.0871 with MSE metric 44200.8859
Epoch 41 batch 230 train Loss 242.6036 test Loss 98.0032 with MSE metric 44200.7658
Epoch 41 batch 240 train Loss 242.3791 test Loss 97.9195 with MSE metric 44200.1878
Time taken for 1 epoch: 33.621551752090454 secs

Epoch 42 batch 0 train Loss 242.1551 test Loss 97.8359 with MSE metric 44200.0555
Epoch 42 batch 10 train Loss 241.9316 test Loss 97.7525 with MSE metric 44199.4589
Epoch 42 batch 20 train Loss 241.7084 test Loss 97.6693 with MSE metric 44199.1145
Epoch 42 batch 30 train Loss 241.4857 test Loss 97.5862 with MSE metric 44198.2599
Epoch 42 batch 40 train Loss 241.2634 test Loss 97.5032 with MSE metric 44196.2663
Epoch 42 batch 50 train Loss 241.0415 test Loss 97.4205 with MSE metric 44195.1520
Epoch 42 batch 60 train Loss 240.8200 test Loss 97.3378 with MSE metric 44194.8486
Epoch 42 batch 70 train Loss 240.5990 test Loss 97.2553 with MSE metric 44194.6064
Epoch 42 batch 80 train Loss 240.378

Epoch 45 batch 180 train Loss 223.0282 test Loss 90.6853 with MSE metric 44101.5011
Epoch 45 batch 190 train Loss 222.8394 test Loss 90.6146 with MSE metric 44100.6974
Epoch 45 batch 200 train Loss 222.6510 test Loss 90.5439 with MSE metric 44099.5101
Epoch 45 batch 210 train Loss 222.4628 test Loss 90.4734 with MSE metric 44098.1245
Epoch 45 batch 220 train Loss 222.2749 test Loss 90.4029 with MSE metric 44097.0320
Epoch 45 batch 230 train Loss 222.0874 test Loss 90.3326 with MSE metric 44096.3366
Epoch 45 batch 240 train Loss 221.9002 test Loss 90.2625 with MSE metric 44094.5095
Time taken for 1 epoch: 33.040276288986206 secs

Epoch 46 batch 0 train Loss 221.7133 test Loss 90.1924 with MSE metric 44093.0937
Epoch 46 batch 10 train Loss 221.5267 test Loss 90.1225 with MSE metric 44091.4200
Epoch 46 batch 20 train Loss 221.3404 test Loss 90.0526 with MSE metric 44089.4272
Epoch 46 batch 30 train Loss 221.1545 test Loss 89.9829 with MSE metric 44087.6704
Epoch 46 batch 40 train Loss 220

In [None]:
1 - (0.0165 / sum((tar[:, 5] - np.mean(tar[:, 5]))**2) / len(tar[:, 5]))

In [None]:
tar - np.mean(tar, 0)

In [None]:
tar.shape

In [None]:
np.mean(tar[:, 0])

In [None]:
sum((tar[:, 0] - np.mean(tar[:, 0]))**2 )/ 10000

In [None]:
sum(sum((tar - np.mean(tar))**2)) / (tar.shape[0] * tar.shape[1])

In [None]:
pos = df_te[560, :].reshape(1, -1)

In [None]:
tar = df_te[561, :39].reshape(1, -1)

In [None]:
df_te[561, :]

In [None]:
a = inference(pos, tar, 20)

In [None]:
with matplotlib.rc_context({'figure.figsize': [10,2.5]}):
    plt.scatter(pos[:, :39], tar[:, :39], c='black')
    plt.scatter(pos[:, 39:58], a[39:])
    plt.scatter(pos[:, 39:58], df_te[561, 39:58], c='red')

In [None]:
# tf.data.Dataset(tf.Tensor(pad_pos_tr, value_index = 0 , dtype = tf.float32))