In [1]:
from model import fox_model, losses, dot_prod_attention
from data import data_generation, batch_creator, gp_kernels
from keras.callbacks import ModelCheckpoint
from helpers import helpers, masks
from inference import infer
import matplotlib.pyplot as plt
import tensorflow_addons as tfa
import tensorflow as tf
import numpy as np
import matplotlib 
import time
import keras

Using TensorFlow backend.


In [2]:
save_dir = '/Users/omernivron/Downloads/GPT_fox'

In [3]:
df = np.load('/Users/omernivron/Downloads/fnr.npy')

In [4]:
# 0 = foxes
# 1 = rabbits
# 2 = time
# 3 = foxes_tokens
# 4 = rabbits_tokens (token id 3 means rabbit_ode)

In [5]:
t = df[2::5]
f = df[0::5]; r = df[1::5]
f_token = df[3::5]; r_token = df[4::5] 

In [6]:
pad_pos_tr = t[:800000]; f_tr = f[:800000]; r_tr = r[:800000]; f_token_tr = f_token[:800000]; r_token_tr = r_token[:800000]
pad_pos_te = t[800000:]; f_te = f[800000:]; r_te = r[800000:]; f_token_te = f_token[800000:]; r_token_te = r_token[800000:]

In [7]:
pad_pos_tr = np.repeat(pad_pos_tr, 2, axis = 0)
pad_pos_te = np.repeat(pad_pos_te, 2, axis = 0)

In [8]:
tar_tr = np.concatenate((f_tr, r_tr), axis = 0); tar_te = np.concatenate((f_te, r_te), axis = 0)
token_tr = np.concatenate((f_token_tr, r_token_tr), axis = 0); token_te = np.concatenate((f_token_te, r_token_te), axis = 0)


In [9]:
pp = masks.position_mask(pad_pos_tr)
pp_te = masks.position_mask(pad_pos_te)

In [27]:
loss_object = tf.keras.losses.MeanSquaredError()
train_loss = tf.keras.metrics.Mean(name='train_loss')
test_loss = tf.keras.metrics.Mean(name='test_loss')
m_tr = tf.keras.metrics.Mean()
m_te = tf.keras.metrics.Mean()

In [28]:
@tf.function
def train_step(token_pos, time_pos, tar, pos_mask):
    '''
    A typical train step function for TF2. Elements which we wish to track their gradient
    has to be inside the GradientTape() clause. see (1) https://www.tensorflow.org/guide/migrate 
    (2) https://www.tensorflow.org/tutorials/quickstart/advanced
    ------------------
    Parameters:
    pos (np array): array of positions (x values) - the 1st/2nd output from data_generator_for_gp_mimick_gpt
    tar (np array): array of targets. Notice that if dealing with sequnces, we typically want to have the targets go from 0 to n-1. The 3rd/4th output from data_generator_for_gp_mimick_gpt  
    pos_mask (np array): see description in position_mask function
    ------------------    
    '''
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]
    combined_mask_tar = masks.create_masks(tar_inp)
    with tf.GradientTape(persistent=True) as tape:
        pred, pred_sig = decoder(token_pos, time_pos, tar_inp, True, pos_mask, combined_mask_tar)
#         print('pred: ')
#         tf.print(pred_sig)

        loss, mse, mask = losses.loss_function(tar_real, pred, pred_sig)


    gradients = tape.gradient(loss, decoder.trainable_variables)
#     tf.print(gradients)
# Ask the optimizer to apply the processed gradients.
    optimizer_c.apply_gradients(zip(gradients, decoder.trainable_variables))
    train_loss(loss)
    m_tr.update_state(mse, mask)
#     b = decoder.trainable_weights[0]
#     tf.print(tf.reduce_mean(b))
    return tar_inp, tar_real, pred, pred_sig, combined_mask_tar

In [22]:
@tf.function
def test_step(token_pos_te, time_pos_te, tar_te, pos_mask_te):
    '''
    
    ---------------
    Parameters:
    pos (np array): array of positions (x values) - the 1st/2nd output from data_generator_for_gp_mimick_gpt
    tar (np array): array of targets. Notice that if dealing with sequnces, we typically want to have the targets go from 0 to n-1. The 3rd/4th output from data_generator_for_gp_mimick_gpt  
    pos_mask_te (np array): see description in position_mask function
    ---------------
    
    '''
    tar_inp_te = tar_te[:, :-1]
    tar_real_te = tar_te[:, 1:]
    combined_mask_tar_te = masks.create_masks(tar_inp_te)
  # training=False is only needed if there are layers with different
  # behavior during training versus inference (e.g. Dropout).
    pred, pred_sig = decoder(token_pos_te, time_pos_te, tar_inp_te, False, pos_mask_te, combined_mask_tar_te)
#     tf.print(tf.math.reduce_max(pred_sig))
    t_loss, t_mse, t_mask = losses.loss_function(tar_real_te, pred, pred_sig)
    tf.print(t_loss)
    test_loss(t_loss)
    m_te.update_state(t_mse, t_mask)
    return tar_real_te, pred, pred_sig

In [23]:
tf.keras.backend.set_floatx('float64')

In [24]:
if __name__ == '__main__':
    writer = tf.summary.create_file_writer(save_dir + '/logs/')
    optimizer_c = tf.keras.optimizers.Adam()
    decoder = fox_model.Decoder(16)
    EPOCHS = 1500
    batch_s  = 15
    run = 0; step = 0
    num_batches = int(tar_tr.shape[0] / batch_s)
    tf.random.set_seed(1)    
    checkpoint = tf.train.Checkpoint(optimizer = optimizer_c, model = decoder)
    main_folder = "/Users/omernivron/Downloads/GPT_fox/ckpt/check_"
    folder = main_folder + str(run); helpers.mkdir(folder)

    with writer.as_default():
        for epoch in range(EPOCHS):
            start = time.time()

            for batch_n in range(num_batches):
                batch_tok_pos_tr, batch_tim_pos_tr, batch_tar_tr , batch_pos_mask, _ = batch_creator.create_batch_foxes(token_tr, pad_pos_tr, tar_tr, pp)
                # batch_tar_tr shape := 128 X 59 = (batch_size, max_seq_len)
                # batch_pos_tr shape := 128 X 59 = (batch_size, max_seq_len)
                tar_inp, tar_real, pred, pred_sig, combined_mask_tar = train_step(batch_tok_pos_tr, batch_tim_pos_tr, batch_tar_tr, batch_pos_mask)

                if batch_n % 50 == 0:
#                     batch_tok_pos_te, batch_tim_pos_te, batch_tar_te , batch_pos_mask_te, _ = batch_creator.create_batch_foxes(token_te, pad_pos_te, tar_te, pp_te)
#                     tar_real_te, pred, pred_sig = test_step(batch_tok_pos_te, batch_tim_pos_te, batch_tar_te, batch_pos_mask_te)
                    helpers.print_progress(epoch, batch_n, train_loss.result(), test_loss.result(), m_tr.result())
#                     helpers.tf_summaries(run, step, train_loss.result(), test_loss.result(), m_tr.result(), m_te.result())
#                     checkpoint.save(folder + '/')
                step += 1

            print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))

Already exists
Epoch 0 batch 0 train Loss 9778.1563 test Loss 0.0000 with MSE metric 9730.3418
Time taken for 1 epoch: 3.241412878036499 secs

Epoch 1 batch 0 train Loss 7238.7994 test Loss 0.0000 with MSE metric 7495.8036
Time taken for 1 epoch: 2.1934361457824707 secs

Epoch 2 batch 0 train Loss 6139.8840 test Loss 0.0000 with MSE metric 6620.6549
Time taken for 1 epoch: 2.17579984664917 secs

Epoch 3 batch 0 train Loss 5615.7915 test Loss 0.0000 with MSE metric 6848.3023
Time taken for 1 epoch: 2.213737964630127 secs

Epoch 4 batch 0 train Loss 5035.6222 test Loss 0.0000 with MSE metric 7232.3076
Time taken for 1 epoch: 2.173903226852417 secs

Epoch 5 batch 0 train Loss 4400.4060 test Loss 0.0000 with MSE metric 7260.6828
Time taken for 1 epoch: 2.1973352432250977 secs

Epoch 6 batch 0 train Loss 3818.8771 test Loss 0.0000 with MSE metric 7139.3505
Time taken for 1 epoch: 2.173475980758667 secs

Epoch 7 batch 0 train Loss 3377.3323 test Loss 0.0000 with MSE metric 7320.3442
Time tak

Epoch 64 batch 0 train Loss 466.4996 test Loss 0.0000 with MSE metric 4694.6869
Time taken for 1 epoch: 2.1731910705566406 secs

Epoch 65 batch 0 train Loss 459.9390 test Loss 0.0000 with MSE metric 4719.5069
Time taken for 1 epoch: 2.1990132331848145 secs

Epoch 66 batch 0 train Loss 453.5481 test Loss 0.0000 with MSE metric 4721.3711
Time taken for 1 epoch: 2.1838648319244385 secs

Epoch 67 batch 0 train Loss 447.2627 test Loss 0.0000 with MSE metric 4711.5311
Time taken for 1 epoch: 2.1827070713043213 secs

Epoch 68 batch 0 train Loss 441.1226 test Loss 0.0000 with MSE metric 4690.5008
Time taken for 1 epoch: 2.195844888687134 secs

Epoch 69 batch 0 train Loss 435.1559 test Loss 0.0000 with MSE metric 4703.4478
Time taken for 1 epoch: 2.1923351287841797 secs

Epoch 70 batch 0 train Loss 429.3686 test Loss 0.0000 with MSE metric 4700.4374
Time taken for 1 epoch: 2.2219629287719727 secs

Epoch 71 batch 0 train Loss 423.7252 test Loss 0.0000 with MSE metric 4704.2592
Time taken for 1 e

Epoch 128 batch 0 train Loss 241.9888 test Loss 0.0000 with MSE metric 4813.6963
Time taken for 1 epoch: 2.2110350131988525 secs

Epoch 129 batch 0 train Loss 240.2025 test Loss 0.0000 with MSE metric 4812.0421
Time taken for 1 epoch: 2.1772427558898926 secs

Epoch 130 batch 0 train Loss 238.4435 test Loss 0.0000 with MSE metric 4810.4142
Time taken for 1 epoch: 2.187378168106079 secs

Epoch 131 batch 0 train Loss 236.7071 test Loss 0.0000 with MSE metric 4806.8122
Time taken for 1 epoch: 2.1795690059661865 secs

Epoch 132 batch 0 train Loss 234.9964 test Loss 0.0000 with MSE metric 4786.6990
Time taken for 1 epoch: 2.1698381900787354 secs

Epoch 133 batch 0 train Loss 233.3169 test Loss 0.0000 with MSE metric 4798.1238
Time taken for 1 epoch: 2.190953016281128 secs

Epoch 134 batch 0 train Loss 231.6607 test Loss 0.0000 with MSE metric 4815.2159
Time taken for 1 epoch: 2.164705991744995 secs

Epoch 135 batch 0 train Loss 230.0235 test Loss 0.0000 with MSE metric 4800.0179
Time taken f

Epoch 192 batch 0 train Loss 164.7428 test Loss 0.0000 with MSE metric 4742.2826
Time taken for 1 epoch: 2.2510969638824463 secs

Epoch 193 batch 0 train Loss 163.9398 test Loss 0.0000 with MSE metric 4752.6094
Time taken for 1 epoch: 2.34002685546875 secs

Epoch 194 batch 0 train Loss 163.1388 test Loss 0.0000 with MSE metric 4741.7842
Time taken for 1 epoch: 2.359250783920288 secs

Epoch 195 batch 0 train Loss 162.3478 test Loss 0.0000 with MSE metric 4735.2095
Time taken for 1 epoch: 2.54773211479187 secs

Epoch 196 batch 0 train Loss 161.5628 test Loss 0.0000 with MSE metric 4730.7143
Time taken for 1 epoch: 2.5318751335144043 secs

Epoch 197 batch 0 train Loss 160.7925 test Loss 0.0000 with MSE metric 4734.9205
Time taken for 1 epoch: 2.4583370685577393 secs

Epoch 198 batch 0 train Loss 160.0208 test Loss 0.0000 with MSE metric 4722.9138
Time taken for 1 epoch: 2.4237418174743652 secs

Epoch 199 batch 0 train Loss 159.2588 test Loss 0.0000 with MSE metric 4714.9137
Time taken for

Epoch 256 batch 0 train Loss 125.5917 test Loss 0.0000 with MSE metric 4640.9786
Time taken for 1 epoch: 2.380682945251465 secs

Epoch 257 batch 0 train Loss 125.1326 test Loss 0.0000 with MSE metric 4639.7861
Time taken for 1 epoch: 2.351369857788086 secs

Epoch 258 batch 0 train Loss 124.6765 test Loss 0.0000 with MSE metric 4640.1807
Time taken for 1 epoch: 2.2309062480926514 secs

Epoch 259 batch 0 train Loss 124.2224 test Loss 0.0000 with MSE metric 4635.9967
Time taken for 1 epoch: 2.2347140312194824 secs

Epoch 260 batch 0 train Loss 123.7740 test Loss 0.0000 with MSE metric 4637.5249
Time taken for 1 epoch: 2.195651054382324 secs

Epoch 261 batch 0 train Loss 123.3256 test Loss 0.0000 with MSE metric 4631.6629
Time taken for 1 epoch: 2.245823860168457 secs

Epoch 262 batch 0 train Loss 122.8812 test Loss 0.0000 with MSE metric 4627.0031
Time taken for 1 epoch: 2.1795856952667236 secs

Epoch 263 batch 0 train Loss 122.4407 test Loss 0.0000 with MSE metric 4622.5440
Time taken fo

Epoch 320 batch 0 train Loss 101.7816 test Loss 0.0000 with MSE metric 4556.4266
Time taken for 1 epoch: 2.1179027557373047 secs

Epoch 321 batch 0 train Loss 101.4832 test Loss 0.0000 with MSE metric 4557.9267
Time taken for 1 epoch: 2.097100019454956 secs

Epoch 322 batch 0 train Loss 101.1863 test Loss 0.0000 with MSE metric 4558.1357
Time taken for 1 epoch: 2.090707302093506 secs

Epoch 323 batch 0 train Loss 100.8910 test Loss 0.0000 with MSE metric 4554.8647
Time taken for 1 epoch: 2.097384214401245 secs

Epoch 324 batch 0 train Loss 100.5967 test Loss 0.0000 with MSE metric 4549.5707
Time taken for 1 epoch: 2.1591339111328125 secs

Epoch 325 batch 0 train Loss 100.3051 test Loss 0.0000 with MSE metric 4547.9801
Time taken for 1 epoch: 2.1055281162261963 secs

Epoch 326 batch 0 train Loss 100.0156 test Loss 0.0000 with MSE metric 4547.1328
Time taken for 1 epoch: 2.082562208175659 secs

Epoch 327 batch 0 train Loss 99.7271 test Loss 0.0000 with MSE metric 4543.4950
Time taken for

Epoch 384 batch 0 train Loss 85.7445 test Loss 0.0000 with MSE metric 4430.0921
Time taken for 1 epoch: 2.09932017326355 secs

Epoch 385 batch 0 train Loss 85.5367 test Loss 0.0000 with MSE metric 4432.9507
Time taken for 1 epoch: 2.0879459381103516 secs

Epoch 386 batch 0 train Loss 85.3289 test Loss 0.0000 with MSE metric 4430.5888
Time taken for 1 epoch: 2.07283616065979 secs

Epoch 387 batch 0 train Loss 85.1230 test Loss 0.0000 with MSE metric 4432.2023
Time taken for 1 epoch: 2.094473361968994 secs

Epoch 388 batch 0 train Loss 84.9178 test Loss 0.0000 with MSE metric 4432.3165
Time taken for 1 epoch: 2.0779871940612793 secs

Epoch 389 batch 0 train Loss 84.7126 test Loss 0.0000 with MSE metric 4426.2870
Time taken for 1 epoch: 2.0440289974212646 secs

Epoch 390 batch 0 train Loss 84.5084 test Loss 0.0000 with MSE metric 4421.7598
Time taken for 1 epoch: 2.059885025024414 secs

Epoch 391 batch 0 train Loss 84.3085 test Loss 0.0000 with MSE metric 4422.5824
Time taken for 1 epoch:

Epoch 448 batch 0 train Loss 74.4692 test Loss 0.0000 with MSE metric 4366.6950
Time taken for 1 epoch: 2.2143821716308594 secs

Epoch 449 batch 0 train Loss 74.3159 test Loss 0.0000 with MSE metric 4366.8722
Time taken for 1 epoch: 2.3053297996520996 secs

Epoch 450 batch 0 train Loss 74.1651 test Loss 0.0000 with MSE metric 4367.7258
Time taken for 1 epoch: 2.431140184402466 secs

Epoch 451 batch 0 train Loss 74.0138 test Loss 0.0000 with MSE metric 4373.1165
Time taken for 1 epoch: 2.3329808712005615 secs

Epoch 452 batch 0 train Loss 73.8625 test Loss 0.0000 with MSE metric 4373.5946
Time taken for 1 epoch: 2.116121768951416 secs

Epoch 453 batch 0 train Loss 73.7128 test Loss 0.0000 with MSE metric 4372.7648
Time taken for 1 epoch: 2.145453929901123 secs

Epoch 454 batch 0 train Loss 73.5630 test Loss 0.0000 with MSE metric 4366.7246
Time taken for 1 epoch: 2.1397459506988525 secs

Epoch 455 batch 0 train Loss 73.4157 test Loss 0.0000 with MSE metric 4367.6265
Time taken for 1 epo

Epoch 512 batch 0 train Loss 65.8410 test Loss 0.0000 with MSE metric 4322.4747
Time taken for 1 epoch: 2.0990610122680664 secs

Epoch 513 batch 0 train Loss 65.7239 test Loss 0.0000 with MSE metric 4324.0972
Time taken for 1 epoch: 2.1070730686187744 secs

Epoch 514 batch 0 train Loss 65.6060 test Loss 0.0000 with MSE metric 4323.0146
Time taken for 1 epoch: 2.072208881378174 secs

Epoch 515 batch 0 train Loss 65.4887 test Loss 0.0000 with MSE metric 4319.9809
Time taken for 1 epoch: 2.083885908126831 secs

Epoch 516 batch 0 train Loss 65.3727 test Loss 0.0000 with MSE metric 4320.8481
Time taken for 1 epoch: 2.091214179992676 secs

Epoch 517 batch 0 train Loss 65.2578 test Loss 0.0000 with MSE metric 4327.0547
Time taken for 1 epoch: 2.0823071002960205 secs

Epoch 518 batch 0 train Loss 65.1421 test Loss 0.0000 with MSE metric 4324.5169
Time taken for 1 epoch: 2.0886600017547607 secs

Epoch 519 batch 0 train Loss 65.0272 test Loss 0.0000 with MSE metric 4326.6644
Time taken for 1 epo

Epoch 576 batch 0 train Loss 59.0995 test Loss 0.0000 with MSE metric 4273.6963
Time taken for 1 epoch: 2.1393589973449707 secs

Epoch 577 batch 0 train Loss 59.0059 test Loss 0.0000 with MSE metric 4272.6245
Time taken for 1 epoch: 2.1505792140960693 secs

Epoch 578 batch 0 train Loss 58.9130 test Loss 0.0000 with MSE metric 4273.8697
Time taken for 1 epoch: 2.1019508838653564 secs

Epoch 579 batch 0 train Loss 58.8202 test Loss 0.0000 with MSE metric 4274.7011
Time taken for 1 epoch: 2.1025660037994385 secs

Epoch 580 batch 0 train Loss 58.7280 test Loss 0.0000 with MSE metric 4275.7119
Time taken for 1 epoch: 2.093026638031006 secs

Epoch 581 batch 0 train Loss 58.6357 test Loss 0.0000 with MSE metric 4274.0007
Time taken for 1 epoch: 2.2002313137054443 secs

Epoch 582 batch 0 train Loss 58.5437 test Loss 0.0000 with MSE metric 4272.5940
Time taken for 1 epoch: 2.1981348991394043 secs

Epoch 583 batch 0 train Loss 58.4517 test Loss 0.0000 with MSE metric 4269.4436
Time taken for 1 e

Epoch 640 batch 0 train Loss 53.7272 test Loss 0.0000 with MSE metric 4215.7425
Time taken for 1 epoch: 2.089808702468872 secs

Epoch 641 batch 0 train Loss 53.6512 test Loss 0.0000 with MSE metric 4215.0159
Time taken for 1 epoch: 2.079240083694458 secs

Epoch 642 batch 0 train Loss 53.5755 test Loss 0.0000 with MSE metric 4214.1241
Time taken for 1 epoch: 2.087961196899414 secs

Epoch 643 batch 0 train Loss 53.4999 test Loss 0.0000 with MSE metric 4212.1755
Time taken for 1 epoch: 2.0876362323760986 secs

Epoch 644 batch 0 train Loss 53.4252 test Loss 0.0000 with MSE metric 4213.7628
Time taken for 1 epoch: 2.0877230167388916 secs

Epoch 645 batch 0 train Loss 53.3503 test Loss 0.0000 with MSE metric 4212.7380
Time taken for 1 epoch: 2.0604920387268066 secs

Epoch 646 batch 0 train Loss 53.2756 test Loss 0.0000 with MSE metric 4211.3296
Time taken for 1 epoch: 2.086854934692383 secs

Epoch 647 batch 0 train Loss 53.2012 test Loss 0.0000 with MSE metric 4209.3081
Time taken for 1 epoc

Epoch 704 batch 0 train Loss 49.3179 test Loss 0.0000 with MSE metric 4159.9947
Time taken for 1 epoch: 2.1718289852142334 secs

Epoch 705 batch 0 train Loss 49.2552 test Loss 0.0000 with MSE metric 4160.1883
Time taken for 1 epoch: 2.1464309692382812 secs

Epoch 706 batch 0 train Loss 49.1926 test Loss 0.0000 with MSE metric 4158.9464
Time taken for 1 epoch: 2.08197021484375 secs

Epoch 707 batch 0 train Loss 49.1304 test Loss 0.0000 with MSE metric 4158.3161
Time taken for 1 epoch: 2.098510980606079 secs

Epoch 708 batch 0 train Loss 49.0684 test Loss 0.0000 with MSE metric 4158.7097
Time taken for 1 epoch: 2.077913284301758 secs

Epoch 709 batch 0 train Loss 49.0064 test Loss 0.0000 with MSE metric 4158.0071
Time taken for 1 epoch: 2.303518056869507 secs

Epoch 710 batch 0 train Loss 48.9445 test Loss 0.0000 with MSE metric 4155.4872
Time taken for 1 epoch: 2.1376078128814697 secs

Epoch 711 batch 0 train Loss 48.8830 test Loss 0.0000 with MSE metric 4154.5512
Time taken for 1 epoch

In [26]:
tok_pos_tr[6, :]

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 2., 2., 2., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [28]:
e = decoder.e1
weights = e.get_weights()[0]

In [25]:
pred[3, :]

<tf.Tensor: shape=(148,), dtype=float64, numpy=
array([23.84272654, 24.17451531, 24.50822531, 24.84380031, 25.18118373,
       25.52031876, 25.86114843, 26.20361569, 26.5476635 , 26.89323492,
       27.24027319, 27.58872182, 27.93852465, 28.2896259 , 28.64197031,
       28.99550315, 29.3501703 , 29.3338399 , 29.33422543, 29.34318017,
       29.36300396, 29.39202324, 29.42885157, 29.4723305 , 29.52148456,
       29.57548654, 29.63363044, 29.69531019, 29.7600027 , 29.82725433,
       29.89666989, 29.96790375, 30.0406525 , 30.11464894, 30.18965701,
       30.26546766, 30.34189531, 30.41877483, 30.49595909, 30.57331675,
       30.65073046, 30.40816779, 29.93889317, 29.46667307, 29.00188682,
       28.57246277, 28.17545528, 27.781603  , 27.37361636, 26.9439041 ,
       26.70652619, 26.48007823, 26.26382102, 26.05708063, 25.85924133,
       25.66973948, 25.48805809, 25.3137221 , 25.14629422, 24.98537123,
       24.83058072, 24.68157821, 24.53804459, 24.3996838 , 24.26622078,
       24.137399

In [88]:
combined_mask_tar[1, :, 44]

<tf.Tensor: shape=(148,), dtype=float32, numpy=
array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)>

In [26]:
tar_inp[3, :]

<tf.Tensor: shape=(148,), dtype=float64, numpy=
array([53.        , 53.7314    , 54.4575573 , 55.17807577, 55.89257113,
       56.60067168, 57.30201906, 57.99626881, 58.68309099, 59.36217067,
       60.03320836, 60.69592045, 61.35003944, 61.99531426, 62.63151042,
       63.25841015, 63.87581247, 64.48353318, 65.08140481, 65.66927651,
       66.24701391, 66.81449891, 67.37162937, 67.91831891, 68.45449649,
       68.98010607, 69.4951062 , 69.99946958, 70.49318263, 70.97624493,
       71.4486688 , 71.91047871, 72.3617108 , 72.8024123 , 73.232641  ,
       73.65246466, 74.0619605 , 74.46121461, 74.8503214 , 75.22938304,
       75.59850895, 75.95781523, 53.        , 50.        , 48.        ,
       49.        , 50.        , 48.        , 44.        , 42.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.      

In [95]:
tar_real[1, :]

<tf.Tensor: shape=(148,), dtype=float64, numpy=
array([170.5377    , 158.71284099, 147.52172089, 136.95626314,
       127.00456081, 117.65142542, 108.87892606, 100.66690691,
        92.99347403,  85.83544469,  79.16875492,  72.96882289,
        67.21086754,  61.87018312,  56.92237156,  52.34353511,
        48.11043257,  44.2006024 ,  40.59245631,  37.26534692,
       183.        ,  23.        ,   5.        ,   0.        ,
         0.        ,   0.        ,   0.        ,   0.        ,
         0.        ,   0.        ,   0.        ,   0.        ,
         0.        ,   0.        ,   0.        ,   0.        ,
         0.        ,   0.        ,   0.        ,   0.        ,
         0.        ,   0.        ,   0.        ,   0.        ,
         0.        ,   0.        ,   0.        ,   0.        ,
         0.        ,   0.        ,   0.        ,   0.        ,
         0.        ,   0.        ,   0.        ,   0.        ,
         0.        ,   0.        ,   0.        ,   0.        ,
       

In [90]:
pred[8, :]

<tf.Tensor: shape=(148,), dtype=float64, numpy=
array([-0.03583945, -0.03630679, -0.03658146, -0.03666347, -0.03655284,
       -0.0362496 , -0.03575379, -0.03506542, -0.03418455, -0.0331112 ,
       -0.03184543, -0.03038729, -0.02873682, -0.02689408, -0.02485913,
       -0.02263204, -0.02021287, -0.01760169, -0.01479858, -0.01180361,
       -0.01453558, -0.01682202, -0.01887398, -0.02066269, -0.02222955,
       -0.02360825, -0.02482641, -0.02590686, -0.02686857, -0.02772741,
       -0.02849675, -0.02918788, -0.02981041, -0.03037256, -0.03088137,
       -0.03134291, -0.03176244, -0.03214451, -0.0324931 , -0.0328117 ,
       -0.03310334, -0.03337072, -0.0336162 , -0.03384188, -0.03404962,
       -0.03424107, -0.0344177 , -0.03458083, -0.03473165, -0.03487121,
       -0.03500047, -0.03512027, -0.0352314 , -0.03533456, -0.03543038,
       -0.03551943, -0.03560224, -0.03567929, -0.03575101, -0.03581778,
       -0.03587999, -0.03593796, -0.03589755, -0.03584261, -0.03577452,
       -0.035694

In [75]:
batch_pos_mask.shape

(32, 148, 149, 149)

In [41]:
tar_real_te[30, :]

<tf.Tensor: shape=(189,), dtype=float64, numpy=
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0.])>

In [43]:
pred[1, :]

<tf.Tensor: shape=(189,), dtype=float64, numpy=
array([-0.04683568, -0.04683568, -0.04683568, -0.04683568, -0.04683568,
       -0.04683568, -0.04683568, -0.04683568, -0.04683568, -0.04683568,
       -0.04683568, -0.04683568, -0.04683568, -0.04683568, -0.04683568,
       -0.04683568, -0.04683568, -0.04683568, -0.04683568, -0.04683568,
       -0.04683568, -0.04683568, -0.04683568, -0.04683568, -0.04683568,
       -0.04683568, -0.04683568, -0.04683568, -0.04683568, -0.04683568,
       -0.04683568, -0.04683568, -0.04683568, -0.04683568, -0.04683568,
       -0.04683568, -0.04683568, -0.04683568, -0.04683568, -0.04683568,
       -0.04683568, -0.04683568, -0.04683568, -0.04683568, -0.04683568,
       -0.04683568, -0.04683568, -0.04683568, -0.04683568, -0.04683568,
       -0.04683568, -0.04683568, -0.04683568, -0.04683568, -0.04683568,
       -0.04683568, -0.04683568, -0.04683568, -0.04683568, -0.04683568,
       -0.04683568, -0.04683568, -0.04683568, -0.04683568, -0.04683568,
       -0.046835

In [46]:
pred_sig[2, :]

<tf.Tensor: shape=(189,), dtype=float64, numpy=
array([2.79753801, 2.79753801, 2.79753801, 2.79753801, 2.79753801,
       2.79753801, 2.79753801, 2.79753801, 2.79753801, 2.79753801,
       2.79753801, 2.79753801, 2.79753801, 2.79753801, 2.79753801,
       2.79753801, 2.79753801, 2.79753801, 2.79753801, 2.79753801,
       2.79753801, 2.79753801, 2.79753801, 2.79753801, 2.79753801,
       2.79753801, 2.79753801, 2.79753801, 2.79753801, 2.79753801,
       2.79753801, 2.79753801, 2.79753801, 2.79753801, 2.79753801,
       2.79753801, 2.79753801, 2.79753801, 2.79753801, 2.79753801,
       2.79753801, 2.79753801, 2.79753801, 2.79753801, 2.79753801,
       2.79753801, 2.79753801, 2.79753801, 2.79753801, 2.79753801,
       2.79753801, 2.79753801, 2.79753801, 2.79753801, 2.79753801,
       2.79753801, 2.79753801, 2.79753801, 2.79753801, 2.79753801,
       2.79753801, 2.79753801, 2.79753801, 2.79753801, 2.79753801,
       2.79753801, 2.79753801, 2.79753801, 2.79753801, 2.79753801,
       2.79753