### Import libraries

In [228]:
from model import climate_model, losses, dot_prod_attention
from data import data_generation, data_combine, batch_creator, gp_kernels
from keras.callbacks import ModelCheckpoint
from helpers import helpers, masks
from inference import infer
import matplotlib.pyplot as plt
import tensorflow_addons as tfa
import tensorflow as tf
import numpy as np
import matplotlib 
import time
import keras
np.random.seed(1)

In [216]:
save_dir = '/Users/omernivron/Downloads/GPT_climate'

In [187]:
temp, t, token = data_combine.climate_data_to_model_input('./data/t2m_monthly_averaged_ensemble_members_1989_2019.csv')

In [190]:
## create climate train/test split

In [210]:
time_tr = t[:8000]; temp_tr = temp[:8000]; token_tr = token[:8000]
time_te = t[8000:]; temp_te = temp[8000:]; token_te = token[8000:]

In [247]:
loss_object = tf.keras.losses.MeanSquaredError()
train_loss = tf.keras.metrics.Mean(name='train_loss')
test_loss = tf.keras.metrics.Mean(name='test_loss')
m_tr = tf.keras.metrics.Mean()
m_te = tf.keras.metrics.Mean()

In [248]:
@tf.function
def train_step(decoder, optimizer_c, token_pos, time_pos, tar, pos_mask):
    '''
    A typical train step function for TF2. Elements which we wish to track their gradient
    has to be inside the GradientTape() clause. see (1) https://www.tensorflow.org/guide/migrate 
    (2) https://www.tensorflow.org/tutorials/quickstart/advanced
    ------------------
    Parameters:
    pos (np array): array of positions (x values) - the 1st/2nd output from data_generator_for_gp_mimick_gpt
    tar (np array): array of targets. Notice that if dealing with sequnces, we typically want to have the targets go from 0 to n-1. The 3rd/4th output from data_generator_for_gp_mimick_gpt  
    pos_mask (np array): see description in position_mask function
    ------------------    
    '''
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]
    combined_mask_tar = masks.create_masks(tar_inp)
    with tf.GradientTape(persistent=True) as tape:
        pred, pred_sig = decoder(token_pos, time_pos, tar_inp, True, pos_mask, combined_mask_tar)
#         print('pred: ')
#         tf.print(pred_sig)

        loss, mse, mask = losses.loss_function(tar_real, pred, pred_sig)


    gradients = tape.gradient(loss, decoder.trainable_variables)
#     tf.print(gradients)
# Ask the optimizer to apply the processed gradients.
    optimizer_c.apply_gradients(zip(gradients, decoder.trainable_variables))
    train_loss(loss)
    m_tr.update_state(mse, mask)
#     b = decoder.trainable_weights[0]
#     tf.print(tf.reduce_mean(b))
    return tar_inp, tar_real, pred, pred_sig, mask

In [249]:
@tf.function
def test_step(decoder, token_pos_te, time_pos_te, tar_te, pos_mask_te):
    '''
    
    ---------------
    Parameters:
    pos (np array): array of positions (x values) - the 1st/2nd output from data_generator_for_gp_mimick_gpt
    tar (np array): array of targets. Notice that if dealing with sequnces, we typically want to have the targets go from 0 to n-1. The 3rd/4th output from data_generator_for_gp_mimick_gpt  
    pos_mask_te (np array): see description in position_mask function
    ---------------
    
    '''
    tar_inp_te = tar_te[:, :-1]
    tar_real_te = tar_te[:, 1:]
    combined_mask_tar_te = masks.create_masks(tar_inp_te)
  # training=False is only needed if there are layers with different
  # behavior during training versus inference (e.g. Dropout).
    pred_te, pred_sig_te = decoder(token_pos_te, time_pos_te, tar_inp_te, False, pos_mask_te, combined_mask_tar_te)
    t_loss, t_mse, t_mask = losses.loss_function(tar_real_te, pred_te, pred_sig_te)
    test_loss(t_loss)
    m_te.update_state(t_mse, t_mask)
    return tar_real_te, pred_te, pred_sig_te, t_mask

In [250]:
tf.keras.backend.set_floatx('float64')

In [251]:
if __name__ == '__main__':
    writer = tf.summary.create_file_writer(save_dir + '/logs/')
    optimizer_c = tf.keras.optimizers.Adam(0.0004)
    decoder = climate_model.Decoder(16)
    EPOCHS = 500
    batch_s  = 32
    run = 0; step = 0
    num_batches = int(temp_tr.shape[0] / batch_s)
    tf.random.set_seed(1)
    ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer = optimizer_c, net = decoder)
    main_folder = "/Users/omernivron/Downloads/GPT_climate/ckpt/check_"
    folder = main_folder + str(run); helpers.mkdir(folder)
    #https://www.tensorflow.org/guide/checkpoint
    manager = tf.train.CheckpointManager(ckpt, folder, max_to_keep=3)
    ckpt.restore(manager.latest_checkpoint)
    if manager.latest_checkpoint:
        print("Restored from {}".format(manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")

    with writer.as_default():
        for epoch in range(EPOCHS):
            start = time.time()

            for batch_n in range(num_batches):
                m_tr.reset_states(); train_loss.reset_states()
                m_te.reset_states(); test_loss.reset_states()
                batch_tok_pos_tr, batch_tim_pos_tr, batch_tar_tr, _ = batch_creator.create_batch_foxes(token_tr, time_tr, temp_tr, batch_s=32)
                # batch_tar_tr shape := 128 X 59 = (batch_size, max_seq_len)
                # batch_pos_tr shape := 128 X 59 = (batch_size, max_seq_len)
                batch_pos_mask = masks.position_mask(batch_tok_pos_tr)
                tar_inp, tar_real, pred, pred_sig, mask = train_step(decoder, optimizer_c, batch_tok_pos_tr, batch_tim_pos_tr, batch_tar_tr, batch_pos_mask)

                if batch_n % 100 == 0:
                    batch_tok_pos_te, batch_tim_pos_te, batch_tar_te, _ = batch_creator.create_batch_foxes(token_te, time_te, temp_te, batch_s= 32)
                    batch_pos_mask_te = masks.position_mask(batch_tok_pos_te)
                    tar_real_te, pred_te, pred_sig_te, t_mask = test_step(decoder, batch_tok_pos_te, batch_tim_pos_te, batch_tar_te, batch_pos_mask_te)
                    helpers.print_progress(epoch, batch_n, train_loss.result(), test_loss.result(), m_tr.result())
                    helpers.tf_summaries(run, step, train_loss.result(), test_loss.result(), m_tr.result(), m_te.result())
                    manager.save()
                step += 1
                ckpt.step.assign_add(1)

            print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))

Already exists
Restored from /Users/omernivron/Downloads/GPT_climate/ckpt/check_0/ckpt-16573
Epoch 0 batch 0 train Loss 5.2503 test Loss 5.3098 with MSE metric 6651.7617
Epoch 0 batch 100 train Loss 5.3334 test Loss 5.3419 with MSE metric 7829.4185
Epoch 0 batch 200 train Loss 5.3171 test Loss 5.3750 with MSE metric 7637.5947
Time taken for 1 epoch: 25.849642992019653 secs

Epoch 1 batch 0 train Loss 5.3063 test Loss 5.2360 with MSE metric 7469.3574
Epoch 1 batch 100 train Loss 5.3863 test Loss 5.4182 with MSE metric 8681.6406
Epoch 1 batch 200 train Loss 5.3110 test Loss 5.4515 with MSE metric 7478.4180
Time taken for 1 epoch: 23.60657000541687 secs

Epoch 2 batch 0 train Loss 5.3496 test Loss 5.3069 with MSE metric 8074.8354
Epoch 2 batch 100 train Loss 5.1630 test Loss 5.2977 with MSE metric 5534.8193
Epoch 2 batch 200 train Loss 5.1804 test Loss 5.4084 with MSE metric 5231.6543
Time taken for 1 epoch: 24.078569889068604 secs

Epoch 3 batch 0 train Loss 5.3359 test Loss 5.3252 with 

Epoch 28 batch 200 train Loss 5.3002 test Loss 5.3958 with MSE metric 7368.8594
Time taken for 1 epoch: 24.058087825775146 secs

Epoch 29 batch 0 train Loss 5.2617 test Loss 5.4073 with MSE metric 6769.7119
Epoch 29 batch 100 train Loss 5.4562 test Loss 5.3170 with MSE metric 9675.4932
Epoch 29 batch 200 train Loss 5.3578 test Loss 5.3442 with MSE metric 8129.1162
Time taken for 1 epoch: 23.884035110473633 secs

Epoch 30 batch 0 train Loss 5.3557 test Loss 5.3563 with MSE metric 8173.7373
Epoch 30 batch 100 train Loss 5.3098 test Loss 5.3838 with MSE metric 7522.5220
Epoch 30 batch 200 train Loss 5.3745 test Loss 5.3140 with MSE metric 8515.8477
Time taken for 1 epoch: 23.977431058883667 secs

Epoch 31 batch 0 train Loss 5.2341 test Loss 5.4183 with MSE metric 6371.4541
Epoch 31 batch 100 train Loss 5.3232 test Loss 5.3746 with MSE metric 7728.9375
Epoch 31 batch 200 train Loss 5.2468 test Loss 5.3602 with MSE metric 6606.9883
Time taken for 1 epoch: 24.268733024597168 secs

Epoch 32 b

Epoch 57 batch 100 train Loss 5.2768 test Loss 5.2575 with MSE metric 7034.8906
Epoch 57 batch 200 train Loss 5.1675 test Loss 5.2635 with MSE metric 5167.5264
Time taken for 1 epoch: 27.70429277420044 secs

Epoch 58 batch 0 train Loss 5.2523 test Loss 5.2944 with MSE metric 6652.1499
Epoch 58 batch 100 train Loss 5.2737 test Loss 5.4599 with MSE metric 6984.2568
Epoch 58 batch 200 train Loss 5.2744 test Loss 5.2755 with MSE metric 6965.1821
Time taken for 1 epoch: 23.318583011627197 secs

Epoch 59 batch 0 train Loss 5.3634 test Loss 5.4142 with MSE metric 8377.0752
Epoch 59 batch 100 train Loss 5.2843 test Loss 5.3189 with MSE metric 7135.9751
Epoch 59 batch 200 train Loss 5.3167 test Loss 5.4112 with MSE metric 7633.4160
Time taken for 1 epoch: 23.31714367866516 secs

Epoch 60 batch 0 train Loss 5.2706 test Loss 5.3299 with MSE metric 6923.7129
Epoch 60 batch 100 train Loss 5.3463 test Loss 5.4544 with MSE metric 8073.1250
Epoch 60 batch 200 train Loss 5.2807 test Loss 5.3599 with MS

Epoch 86 batch 100 train Loss 5.3635 test Loss 5.2790 with MSE metric 8243.8467
Epoch 86 batch 200 train Loss 5.3255 test Loss 5.4615 with MSE metric 7698.4170
Time taken for 1 epoch: 23.096320152282715 secs

Epoch 87 batch 0 train Loss 5.2727 test Loss 5.3547 with MSE metric 6957.2783
Epoch 87 batch 100 train Loss 5.3596 test Loss 5.3885 with MSE metric 8241.0059
Epoch 87 batch 200 train Loss 5.3039 test Loss 5.2472 with MSE metric 7423.5825
Time taken for 1 epoch: 23.191839933395386 secs

Epoch 88 batch 0 train Loss 5.3660 test Loss 5.3232 with MSE metric 8376.6064
Epoch 88 batch 100 train Loss 5.3406 test Loss 5.4629 with MSE metric 7999.7759
Epoch 88 batch 200 train Loss 5.2865 test Loss 5.4052 with MSE metric 7183.7788
Time taken for 1 epoch: 23.118685960769653 secs

Epoch 89 batch 0 train Loss 5.3343 test Loss 5.2537 with MSE metric 7863.9746
Epoch 89 batch 100 train Loss 5.3742 test Loss 5.3112 with MSE metric 8497.5234
Epoch 89 batch 200 train Loss 5.3451 test Loss 5.4739 with 

Time taken for 1 epoch: 22.76567816734314 secs

Epoch 115 batch 0 train Loss 5.3127 test Loss 5.3904 with MSE metric 7569.4131
Epoch 115 batch 100 train Loss 5.2722 test Loss 5.3495 with MSE metric 6974.2471
Epoch 115 batch 200 train Loss 5.3063 test Loss 5.3424 with MSE metric 7473.8213
Time taken for 1 epoch: 22.785687923431396 secs

Epoch 116 batch 0 train Loss 5.3021 test Loss 5.3540 with MSE metric 7390.6094
Epoch 116 batch 100 train Loss 5.2696 test Loss 5.4576 with MSE metric 6946.2207
Epoch 116 batch 200 train Loss 5.2412 test Loss 5.3245 with MSE metric 6495.5737
Time taken for 1 epoch: 22.696279048919678 secs

Epoch 117 batch 0 train Loss 5.3386 test Loss 5.2392 with MSE metric 7953.4053
Epoch 117 batch 100 train Loss 5.3437 test Loss 5.2938 with MSE metric 8052.5874
Epoch 117 batch 200 train Loss 5.3384 test Loss 5.2812 with MSE metric 7878.6128
Time taken for 1 epoch: 22.70211887359619 secs

Epoch 118 batch 0 train Loss 5.2441 test Loss 5.4138 with MSE metric 6518.5903
Epoc

Epoch 143 batch 100 train Loss 5.3631 test Loss 5.3626 with MSE metric 8276.8018
Epoch 143 batch 200 train Loss 5.2443 test Loss 5.3615 with MSE metric 6457.0249
Time taken for 1 epoch: 25.156306982040405 secs

Epoch 144 batch 0 train Loss 5.3876 test Loss 5.4072 with MSE metric 8667.7598
Epoch 144 batch 100 train Loss 5.3050 test Loss 5.3772 with MSE metric 7450.9160
Epoch 144 batch 200 train Loss 5.2017 test Loss 5.2178 with MSE metric 5984.2568
Time taken for 1 epoch: 25.413201808929443 secs

Epoch 145 batch 0 train Loss 5.2165 test Loss 5.2883 with MSE metric 6166.8740
Epoch 145 batch 100 train Loss 5.2306 test Loss 5.4070 with MSE metric 6365.7524
Epoch 145 batch 200 train Loss 5.2778 test Loss 5.4011 with MSE metric 7061.0869
Time taken for 1 epoch: 25.551324367523193 secs

Epoch 146 batch 0 train Loss 5.2987 test Loss 5.3305 with MSE metric 7350.6748
Epoch 146 batch 100 train Loss 5.4136 test Loss 5.4091 with MSE metric 9053.6875
Epoch 146 batch 200 train Loss 5.2763 test Loss 5

Epoch 171 batch 200 train Loss 5.2752 test Loss 5.3184 with MSE metric 7024.9180
Time taken for 1 epoch: 25.895297050476074 secs

Epoch 172 batch 0 train Loss 5.3632 test Loss 5.2931 with MSE metric 8359.8232
Epoch 172 batch 100 train Loss 5.3189 test Loss 5.3237 with MSE metric 7640.3486
Epoch 172 batch 200 train Loss 5.3363 test Loss 5.3022 with MSE metric 7923.2402
Time taken for 1 epoch: 25.682255029678345 secs

Epoch 173 batch 0 train Loss 5.2615 test Loss 5.3010 with MSE metric 6773.4639
Epoch 173 batch 100 train Loss 5.2832 test Loss 5.3077 with MSE metric 7104.6567
Epoch 173 batch 200 train Loss 5.2452 test Loss 5.3224 with MSE metric 6584.0088
Time taken for 1 epoch: 26.326202869415283 secs

Epoch 174 batch 0 train Loss 5.2333 test Loss 5.3246 with MSE metric 6391.9243
Epoch 174 batch 100 train Loss 5.2138 test Loss 5.4214 with MSE metric 6118.4834
Epoch 174 batch 200 train Loss 5.3338 test Loss 5.3556 with MSE metric 7894.9746
Time taken for 1 epoch: 25.731635093688965 secs



Time taken for 1 epoch: 26.607697010040283 secs

Epoch 200 batch 0 train Loss 5.3449 test Loss 5.4078 with MSE metric 7988.3018
Epoch 200 batch 100 train Loss 5.2565 test Loss 5.4328 with MSE metric 6720.6626
Epoch 200 batch 200 train Loss 5.2096 test Loss 5.3525 with MSE metric 6099.7471
Time taken for 1 epoch: 26.57576608657837 secs

Epoch 201 batch 0 train Loss 5.3040 test Loss 5.2785 with MSE metric 7436.6777
Epoch 201 batch 100 train Loss 5.3422 test Loss 5.3292 with MSE metric 8000.6528
Epoch 201 batch 200 train Loss 5.2788 test Loss 5.3307 with MSE metric 7056.0127
Time taken for 1 epoch: 26.602109909057617 secs

Epoch 202 batch 0 train Loss 5.2766 test Loss 5.3853 with MSE metric 6988.5352
Epoch 202 batch 100 train Loss 5.3545 test Loss 5.3651 with MSE metric 8217.6328
Epoch 202 batch 200 train Loss 5.3293 test Loss 5.3244 with MSE metric 7790.2935
Time taken for 1 epoch: 26.600423097610474 secs

Epoch 203 batch 0 train Loss 5.3108 test Loss 5.4500 with MSE metric 7542.4604
Epo

Epoch 228 batch 100 train Loss 5.2827 test Loss 5.3778 with MSE metric 7129.8027
Epoch 228 batch 200 train Loss 5.3432 test Loss 5.2428 with MSE metric 8041.5356
Time taken for 1 epoch: 26.323616981506348 secs

Epoch 229 batch 0 train Loss 5.2660 test Loss 5.3340 with MSE metric 6868.2856
Epoch 229 batch 100 train Loss 5.3116 test Loss 5.3820 with MSE metric 7555.7383
Epoch 229 batch 200 train Loss 5.3031 test Loss 5.3482 with MSE metric 7423.9746
Time taken for 1 epoch: 26.67631196975708 secs

Epoch 230 batch 0 train Loss 5.2470 test Loss 5.2811 with MSE metric 6577.9238
Epoch 230 batch 100 train Loss 5.3574 test Loss 5.2389 with MSE metric 8255.4170
Epoch 230 batch 200 train Loss 5.3420 test Loss 5.4372 with MSE metric 7999.0879
Time taken for 1 epoch: 26.78191590309143 secs

Epoch 231 batch 0 train Loss 5.3114 test Loss 5.2593 with MSE metric 7502.8765
Epoch 231 batch 100 train Loss 5.2784 test Loss 5.2600 with MSE metric 7011.4531
Epoch 231 batch 200 train Loss 5.2177 test Loss 5.3

Epoch 256 batch 200 train Loss 5.1455 test Loss 5.4848 with MSE metric 5265.2969
Time taken for 1 epoch: 26.530865907669067 secs

Epoch 257 batch 0 train Loss 5.4327 test Loss 5.3424 with MSE metric 9252.1309
Epoch 257 batch 100 train Loss 5.3090 test Loss 5.2705 with MSE metric 7504.7627
Epoch 257 batch 200 train Loss 5.3108 test Loss 5.2630 with MSE metric 7542.9102
Time taken for 1 epoch: 26.848063945770264 secs

Epoch 258 batch 0 train Loss 5.3651 test Loss 5.3498 with MSE metric 8338.7686
Epoch 258 batch 100 train Loss 5.3199 test Loss 5.3622 with MSE metric 7657.1255
Epoch 258 batch 200 train Loss 5.2952 test Loss 5.3230 with MSE metric 7310.7720
Time taken for 1 epoch: 26.51411199569702 secs

Epoch 259 batch 0 train Loss 5.2157 test Loss 5.3151 with MSE metric 6088.8916
Epoch 259 batch 100 train Loss 5.1974 test Loss 5.3478 with MSE metric 5967.9463
Epoch 259 batch 200 train Loss 5.3181 test Loss 5.2993 with MSE metric 7653.1401
Time taken for 1 epoch: 26.479319095611572 secs

E

Epoch 285 batch 100 train Loss 5.3868 test Loss 5.3579 with MSE metric 8731.7686
Epoch 285 batch 200 train Loss 5.4217 test Loss 5.2743 with MSE metric 9029.5938
Time taken for 1 epoch: 23.031870126724243 secs

Epoch 286 batch 0 train Loss 5.2674 test Loss 5.4392 with MSE metric 6855.3926
Epoch 286 batch 100 train Loss 5.3344 test Loss 5.2938 with MSE metric 7867.7266
Epoch 286 batch 200 train Loss 5.2657 test Loss 5.3238 with MSE metric 6818.1133
Time taken for 1 epoch: 23.02661967277527 secs

Epoch 287 batch 0 train Loss 5.3316 test Loss 5.3596 with MSE metric 7850.6460
Epoch 287 batch 100 train Loss 5.3401 test Loss 5.3464 with MSE metric 7998.5684
Epoch 287 batch 200 train Loss 5.2437 test Loss 5.2596 with MSE metric 6566.4326
Time taken for 1 epoch: 23.06029987335205 secs

Epoch 288 batch 0 train Loss 5.3233 test Loss 5.4849 with MSE metric 7653.8105
Epoch 288 batch 100 train Loss 5.3475 test Loss 5.2819 with MSE metric 8117.6426
Epoch 288 batch 200 train Loss 5.2052 test Loss 5.3

Time taken for 1 epoch: 24.494733095169067 secs

Epoch 314 batch 0 train Loss 5.2390 test Loss 5.4291 with MSE metric 6530.1152
Epoch 314 batch 100 train Loss 5.3744 test Loss 5.4189 with MSE metric 8412.9902
Epoch 314 batch 200 train Loss 5.3735 test Loss 5.4526 with MSE metric 8464.1230
Time taken for 1 epoch: 23.71581196784973 secs

Epoch 315 batch 0 train Loss 5.2913 test Loss 5.4041 with MSE metric 7253.0059
Epoch 315 batch 100 train Loss 5.3261 test Loss 5.3346 with MSE metric 7742.5107
Epoch 315 batch 200 train Loss 5.3984 test Loss 5.4105 with MSE metric 8832.5391
Time taken for 1 epoch: 24.06606101989746 secs

Epoch 316 batch 0 train Loss 5.3407 test Loss 5.4171 with MSE metric 7867.4609
Epoch 316 batch 100 train Loss 5.2997 test Loss 5.3442 with MSE metric 7377.8086
Epoch 316 batch 200 train Loss 5.3054 test Loss 5.3280 with MSE metric 7436.1182
Time taken for 1 epoch: 23.755082845687866 secs

Epoch 317 batch 0 train Loss 5.2895 test Loss 5.3110 with MSE metric 7223.5830
Epoc

Epoch 342 batch 100 train Loss 5.3616 test Loss 5.3792 with MSE metric 8321.7988
Epoch 342 batch 200 train Loss 5.2658 test Loss 5.2171 with MSE metric 6831.0264
Time taken for 1 epoch: 23.395739793777466 secs

Epoch 343 batch 0 train Loss 5.3072 test Loss 5.3200 with MSE metric 7484.1992
Epoch 343 batch 100 train Loss 5.3500 test Loss 5.2942 with MSE metric 8158.5601
Epoch 343 batch 200 train Loss 5.3112 test Loss 5.3119 with MSE metric 7549.9312
Time taken for 1 epoch: 23.60182762145996 secs

Epoch 344 batch 0 train Loss 5.2042 test Loss 5.3921 with MSE metric 5969.0522
Epoch 344 batch 100 train Loss 5.3262 test Loss 5.4200 with MSE metric 7753.6377
Epoch 344 batch 200 train Loss 5.2643 test Loss 5.3070 with MSE metric 6872.8809
Time taken for 1 epoch: 23.63816499710083 secs

Epoch 345 batch 0 train Loss 5.3065 test Loss 5.3583 with MSE metric 7474.7173
Epoch 345 batch 100 train Loss 5.2342 test Loss 5.3397 with MSE metric 6446.3955
Epoch 345 batch 200 train Loss 5.2785 test Loss 5.3

Time taken for 1 epoch: 24.312705993652344 secs

Epoch 371 batch 0 train Loss 5.2393 test Loss 5.3640 with MSE metric 6491.9272
Epoch 371 batch 100 train Loss 5.3382 test Loss 5.3831 with MSE metric 7960.5757
Epoch 371 batch 200 train Loss 5.2739 test Loss 5.4045 with MSE metric 6985.0107
Time taken for 1 epoch: 23.379619121551514 secs

Epoch 372 batch 0 train Loss 5.2601 test Loss 5.2996 with MSE metric 6664.7207
Epoch 372 batch 100 train Loss 5.2934 test Loss 5.3316 with MSE metric 7284.2847
Epoch 372 batch 200 train Loss 5.3184 test Loss 5.3281 with MSE metric 7652.9473
Time taken for 1 epoch: 25.820667266845703 secs

Epoch 373 batch 0 train Loss 5.3073 test Loss 5.2161 with MSE metric 7489.0615
Epoch 373 batch 100 train Loss 5.2794 test Loss 5.3763 with MSE metric 7083.6108
Epoch 373 batch 200 train Loss 5.2902 test Loss 5.4384 with MSE metric 7233.6099
Time taken for 1 epoch: 23.016315937042236 secs

Epoch 374 batch 0 train Loss 5.3406 test Loss 5.2704 with MSE metric 8002.8730
Ep

Epoch 399 batch 100 train Loss 5.3569 test Loss 5.3377 with MSE metric 8271.0723
Epoch 399 batch 200 train Loss 5.2587 test Loss 5.3148 with MSE metric 6751.7808
Time taken for 1 epoch: 22.78200912475586 secs

Epoch 400 batch 0 train Loss 5.2319 test Loss 5.2675 with MSE metric 6429.4756
Epoch 400 batch 100 train Loss 5.4003 test Loss 5.3196 with MSE metric 8916.1562
Epoch 400 batch 200 train Loss 5.2827 test Loss 5.3708 with MSE metric 7130.4609
Time taken for 1 epoch: 22.81319284439087 secs

Epoch 401 batch 0 train Loss 5.3000 test Loss 5.3546 with MSE metric 7380.0967
Epoch 401 batch 100 train Loss 5.2042 test Loss 5.4125 with MSE metric 5983.3613
Epoch 401 batch 200 train Loss 5.2826 test Loss 5.4023 with MSE metric 7104.2920
Time taken for 1 epoch: 22.747803926467896 secs

Epoch 402 batch 0 train Loss 5.3460 test Loss 5.4022 with MSE metric 8093.0020
Epoch 402 batch 100 train Loss 5.2858 test Loss 5.3099 with MSE metric 7171.9082
Epoch 402 batch 200 train Loss 5.2822 test Loss 5.3

Time taken for 1 epoch: 22.586896896362305 secs

Epoch 428 batch 0 train Loss 5.3534 test Loss 5.2841 with MSE metric 8167.2373
Epoch 428 batch 100 train Loss 5.3101 test Loss 5.4304 with MSE metric 7533.4023
Epoch 428 batch 200 train Loss 5.2828 test Loss 5.3584 with MSE metric 7103.8730
Time taken for 1 epoch: 22.27204418182373 secs

Epoch 429 batch 0 train Loss 5.2722 test Loss 5.3948 with MSE metric 6980.4922
Epoch 429 batch 100 train Loss 5.2714 test Loss 5.3702 with MSE metric 6963.9951
Epoch 429 batch 200 train Loss 5.3332 test Loss 5.3722 with MSE metric 7869.1992
Time taken for 1 epoch: 22.03881287574768 secs

Epoch 430 batch 0 train Loss 5.2826 test Loss 5.2925 with MSE metric 7129.0176
Epoch 430 batch 100 train Loss 5.3509 test Loss 5.4483 with MSE metric 8161.0586
Epoch 430 batch 200 train Loss 5.3127 test Loss 5.3917 with MSE metric 7552.7280
Time taken for 1 epoch: 22.216126918792725 secs

Epoch 431 batch 0 train Loss 5.3011 test Loss 5.2509 with MSE metric 7389.9873
Epoc

Epoch 456 batch 100 train Loss 5.3067 test Loss 5.4482 with MSE metric 7481.0298
Epoch 456 batch 200 train Loss 5.2987 test Loss 5.3478 with MSE metric 7362.8682
Time taken for 1 epoch: 22.934611082077026 secs

Epoch 457 batch 0 train Loss 5.4416 test Loss 5.5021 with MSE metric 9599.8525
Epoch 457 batch 100 train Loss 5.2699 test Loss 5.4269 with MSE metric 6949.8408
Epoch 457 batch 200 train Loss 5.4079 test Loss 5.2897 with MSE metric 8940.8799
Time taken for 1 epoch: 23.09203791618347 secs

Epoch 458 batch 0 train Loss 5.2831 test Loss 5.2672 with MSE metric 7084.2646
Epoch 458 batch 100 train Loss 5.3507 test Loss 5.3522 with MSE metric 8158.6836
Epoch 458 batch 200 train Loss 5.3422 test Loss 5.3558 with MSE metric 8027.7637
Time taken for 1 epoch: 23.04433798789978 secs

Epoch 459 batch 0 train Loss 5.3444 test Loss 5.4092 with MSE metric 8039.5596
Epoch 459 batch 100 train Loss 5.3554 test Loss 5.3728 with MSE metric 8232.0977
Epoch 459 batch 200 train Loss 5.2736 test Loss 5.3

Epoch 484 batch 200 train Loss 5.2420 test Loss 5.3028 with MSE metric 6458.3608
Time taken for 1 epoch: 23.660197019577026 secs

Epoch 485 batch 0 train Loss 5.3051 test Loss 5.3489 with MSE metric 7457.2070
Epoch 485 batch 100 train Loss 5.2350 test Loss 5.3045 with MSE metric 6478.1973
Epoch 485 batch 200 train Loss 5.3132 test Loss 5.3144 with MSE metric 7576.9502
Time taken for 1 epoch: 24.92916989326477 secs

Epoch 486 batch 0 train Loss 5.2404 test Loss 5.4326 with MSE metric 6519.9209
Epoch 486 batch 100 train Loss 5.2721 test Loss 5.3489 with MSE metric 6980.4683
Epoch 486 batch 200 train Loss 5.2665 test Loss 5.3194 with MSE metric 6896.2842
Time taken for 1 epoch: 23.154803037643433 secs

Epoch 487 batch 0 train Loss 5.3395 test Loss 5.3868 with MSE metric 7960.2202
Epoch 487 batch 100 train Loss 5.2903 test Loss 5.3609 with MSE metric 7230.6240
Epoch 487 batch 200 train Loss 5.3537 test Loss 5.2279 with MSE metric 8185.1792
Time taken for 1 epoch: 23.24729585647583 secs

Ep

In [None]:
1 - (0.0165 / sum((tar[:, 5] - np.mean(tar[:, 5]))**2) / len(tar[:, 5]))

In [None]:
tar - np.mean(tar, 0)

In [None]:
tar.shape

In [None]:
np.mean(tar[:, 0])

In [None]:
sum((tar[:, 0] - np.mean(tar[:, 0]))**2 )/ 10000

In [None]:
sum(sum((tar - np.mean(tar))**2)) / (tar.shape[0] * tar.shape[1])

In [None]:
pos = df_te[560, :].reshape(1, -1)

In [None]:
tar = df_te[561, :39].reshape(1, -1)

In [None]:
df_te[561, :]

In [None]:
a = inference(pos, tar, 20)

In [None]:
with matplotlib.rc_context({'figure.figsize': [10,2.5]}):
    plt.scatter(pos[:, :39], tar[:, :39], c='black')
    plt.scatter(pos[:, 39:58], a[39:])
    plt.scatter(pos[:, 39:58], df_te[561, 39:58], c='red')

In [None]:
# tf.data.Dataset(tf.Tensor(pad_pos_tr, value_index = 0 , dtype = tf.float32))