In [1]:
from model import fox_model, losses, dot_prod_attention
from data import data_generation, batch_creator, gp_kernels
from keras.callbacks import ModelCheckpoint
from helpers import helpers, masks
from inference import infer
import matplotlib.pyplot as plt
import tensorflow_addons as tfa
import tensorflow as tf
import numpy as np
import matplotlib 
import time
import keras

Using TensorFlow backend.


In [2]:
save_dir = '/Users/omernivron/Downloads/GPT_fox'

In [3]:
df = np.load('/Users/omernivron/Downloads/fnr.npy')

In [4]:
# 0 = foxes
# 1 = rabbits
# 2 = time
# 3 = foxes_tokens
# 4 = rabbits_tokens (token id 3 means rabbit_ode)

In [5]:
t = df[2::5]
f = df[0::5]; r = df[1::5]
f_token = df[3::5]; r_token = df[4::5] 

In [6]:
pad_pos_tr = t[:80]; f_tr = f[:80]; r_tr = r[:80]; f_token_tr = f_token[:80]; r_token_tr = r_token[:80]
pad_pos_te = t[80:]; f_te = f[80:]; r_te = r[80:]; f_token_te = f_token[80:]; r_token_te = r_token[80:]

In [7]:
pad_pos_tr = np.repeat(pad_pos_tr, 2, axis = 0)
pad_pos_te = np.repeat(pad_pos_te, 2, axis = 0)

In [8]:
tar_tr = np.concatenate((f_tr, r_tr), axis = 0); tar_te = np.concatenate((f_te, r_te), axis = 0)
token_tr = np.concatenate((f_token_tr, r_token_tr), axis = 0); token_te = np.concatenate((f_token_te, r_token_te), axis = 0)

In [9]:
pp = masks.position_mask(pad_pos_tr)
pp_te = masks.position_mask(pad_pos_te)

In [10]:
loss_object = tf.keras.losses.MeanSquaredError()
train_loss = tf.keras.metrics.Mean(name='train_loss')
test_loss = tf.keras.metrics.Mean(name='test_loss')
m_tr = tf.keras.metrics.Mean()
m_te = tf.keras.metrics.Mean()

In [11]:
@tf.function
def train_step(token_pos, time_pos, tar, pos_mask):
    '''
    A typical train step function for TF2. Elements which we wish to track their gradient
    has to be inside the GradientTape() clause. see (1) https://www.tensorflow.org/guide/migrate 
    (2) https://www.tensorflow.org/tutorials/quickstart/advanced
    ------------------
    Parameters:
    pos (np array): array of positions (x values) - the 1st/2nd output from data_generator_for_gp_mimick_gpt
    tar (np array): array of targets. Notice that if dealing with sequnces, we typically want to have the targets go from 0 to n-1. The 3rd/4th output from data_generator_for_gp_mimick_gpt  
    pos_mask (np array): see description in position_mask function
    ------------------    
    '''
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]
    combined_mask_tar = masks.create_masks(tar_inp)
    with tf.GradientTape(persistent=True) as tape:
        pred, pred_sig = decoder(token_pos, time_pos, tar_inp, True, pos_mask, combined_mask_tar)
#         print('pred: ')
#         tf.print(pred_sig)

        loss, mse, mask = losses.loss_function(tar_real, pred, pred_sig)


    gradients = tape.gradient(loss, decoder.trainable_variables)
#     tf.print(gradients)
# Ask the optimizer to apply the processed gradients.
    optimizer_c.apply_gradients(zip(gradients, decoder.trainable_variables))
    train_loss(loss)
    m_tr.update_state(mse, mask)
#     b = decoder.trainable_weights[0]
#     tf.print(tf.reduce_mean(b))
    return tar_inp, tar_real, pred, pred_sig, combined_mask_tar

In [12]:
@tf.function
def test_step(token_pos_te, time_pos_te, tar_te, pos_mask_te):
    '''
    
    ---------------
    Parameters:
    pos (np array): array of positions (x values) - the 1st/2nd output from data_generator_for_gp_mimick_gpt
    tar (np array): array of targets. Notice that if dealing with sequnces, we typically want to have the targets go from 0 to n-1. The 3rd/4th output from data_generator_for_gp_mimick_gpt  
    pos_mask_te (np array): see description in position_mask function
    ---------------
    
    '''
    tar_inp_te = tar_te[:, :-1]
    tar_real_te = tar_te[:, 1:]
    combined_mask_tar_te = masks.create_masks(tar_inp_te)
  # training=False is only needed if there are layers with different
  # behavior during training versus inference (e.g. Dropout).
    pred, pred_sig = decoder(token_pos_te, time_pos_te, tar_inp_te, False, pos_mask_te, combined_mask_tar_te)
#     tf.print(tf.math.reduce_max(pred_sig))
    t_loss, t_mse, t_mask = losses.loss_function(tar_real_te, pred, pred_sig)
    tf.print(t_loss)
    test_loss(t_loss)
    m_te.update_state(t_mse, t_mask)
    return tar_real_te, pred, pred_sig

In [13]:
tf.keras.backend.set_floatx('float64')

In [None]:
if __name__ == '__main__':
    writer = tf.summary.create_file_writer(save_dir + '/logs/')
    optimizer_c = tf.keras.optimizers.Adam()
    decoder = fox_model.Decoder(16)
    EPOCHS = 1500
    batch_s  = 15
    run = 0; step = 0
    num_batches = int(tar_tr.shape[0] / batch_s)
    tf.random.set_seed(1)    
    checkpoint = tf.train.Checkpoint(optimizer = optimizer_c, model = decoder)
    main_folder = "/Users/omernivron/Downloads/GPT_fox/ckpt/check_"
    folder = main_folder + str(run); helpers.mkdir(folder)

    with writer.as_default():
        for epoch in range(EPOCHS):
            start = time.time()

            for batch_n in range(num_batches):
                batch_tok_pos_tr, batch_tim_pos_tr, batch_tar_tr , batch_pos_mask, _ = batch_creator.create_batch_foxes(token_tr, pad_pos_tr, tar_tr, pp)
                # batch_tar_tr shape := 128 X 59 = (batch_size, max_seq_len)
                # batch_pos_tr shape := 128 X 59 = (batch_size, max_seq_len)
                tar_inp, tar_real, pred, pred_sig, combined_mask_tar = train_step(batch_tok_pos_tr, batch_tim_pos_tr, batch_tar_tr, batch_pos_mask)

                if batch_n % 50 == 0:
#                     batch_tok_pos_te, batch_tim_pos_te, batch_tar_te , batch_pos_mask_te, _ = batch_creator.create_batch_foxes(token_te, pad_pos_te, tar_te, pp_te)
#                     tar_real_te, pred, pred_sig = test_step(batch_tok_pos_te, batch_tim_pos_te, batch_tar_te, batch_pos_mask_te)
                    helpers.print_progress(epoch, batch_n, train_loss.result(), test_loss.result(), m_tr.result())
#                     helpers.tf_summaries(run, step, train_loss.result(), test_loss.result(), m_tr.result(), m_te.result())
#                     checkpoint.save(folder + '/')
                step += 1

            print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))

Already exists
Epoch 0 batch 0 train Loss 601.5988 test Loss 0.0000 with MSE metric 5679.6123
Time taken for 1 epoch: 149.25540828704834 secs

Epoch 1 batch 0 train Loss 189.5434 test Loss 0.0000 with MSE metric 6058.0791
Time taken for 1 epoch: 149.5448489189148 secs

Epoch 2 batch 0 train Loss 105.0678 test Loss 0.0000 with MSE metric 6041.6445
Time taken for 1 epoch: 142.45593404769897 secs

Epoch 3 batch 0 train Loss 73.0988 test Loss 0.0000 with MSE metric 5971.8579
Time taken for 1 epoch: 157.56733798980713 secs

Epoch 4 batch 0 train Loss 56.6541 test Loss 0.0000 with MSE metric 6039.2402
Time taken for 1 epoch: 153.1128089427948 secs

Epoch 5 batch 0 train Loss 46.6552 test Loss 0.0000 with MSE metric 6091.9624
Time taken for 1 epoch: 140.58529090881348 secs

Epoch 6 batch 0 train Loss 39.9267 test Loss 0.0000 with MSE metric 6013.5962
Time taken for 1 epoch: 141.0576012134552 secs

Epoch 7 batch 0 train Loss 35.0920 test Loss 0.0000 with MSE metric 5994.8862
Time taken for 1 e

In [26]:
tok_pos_tr[6, :]

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 2., 2., 2., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [28]:
e = decoder.e1
weights = e.get_weights()[0]

In [25]:
pred[3, :]

<tf.Tensor: shape=(148,), dtype=float64, numpy=
array([23.84272654, 24.17451531, 24.50822531, 24.84380031, 25.18118373,
       25.52031876, 25.86114843, 26.20361569, 26.5476635 , 26.89323492,
       27.24027319, 27.58872182, 27.93852465, 28.2896259 , 28.64197031,
       28.99550315, 29.3501703 , 29.3338399 , 29.33422543, 29.34318017,
       29.36300396, 29.39202324, 29.42885157, 29.4723305 , 29.52148456,
       29.57548654, 29.63363044, 29.69531019, 29.7600027 , 29.82725433,
       29.89666989, 29.96790375, 30.0406525 , 30.11464894, 30.18965701,
       30.26546766, 30.34189531, 30.41877483, 30.49595909, 30.57331675,
       30.65073046, 30.40816779, 29.93889317, 29.46667307, 29.00188682,
       28.57246277, 28.17545528, 27.781603  , 27.37361636, 26.9439041 ,
       26.70652619, 26.48007823, 26.26382102, 26.05708063, 25.85924133,
       25.66973948, 25.48805809, 25.3137221 , 25.14629422, 24.98537123,
       24.83058072, 24.68157821, 24.53804459, 24.3996838 , 24.26622078,
       24.137399

In [88]:
combined_mask_tar[1, :, 44]

<tf.Tensor: shape=(148,), dtype=float32, numpy=
array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)>

In [26]:
tar_inp[3, :]

<tf.Tensor: shape=(148,), dtype=float64, numpy=
array([53.        , 53.7314    , 54.4575573 , 55.17807577, 55.89257113,
       56.60067168, 57.30201906, 57.99626881, 58.68309099, 59.36217067,
       60.03320836, 60.69592045, 61.35003944, 61.99531426, 62.63151042,
       63.25841015, 63.87581247, 64.48353318, 65.08140481, 65.66927651,
       66.24701391, 66.81449891, 67.37162937, 67.91831891, 68.45449649,
       68.98010607, 69.4951062 , 69.99946958, 70.49318263, 70.97624493,
       71.4486688 , 71.91047871, 72.3617108 , 72.8024123 , 73.232641  ,
       73.65246466, 74.0619605 , 74.46121461, 74.8503214 , 75.22938304,
       75.59850895, 75.95781523, 53.        , 50.        , 48.        ,
       49.        , 50.        , 48.        , 44.        , 42.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.      

In [95]:
tar_real[1, :]

<tf.Tensor: shape=(148,), dtype=float64, numpy=
array([170.5377    , 158.71284099, 147.52172089, 136.95626314,
       127.00456081, 117.65142542, 108.87892606, 100.66690691,
        92.99347403,  85.83544469,  79.16875492,  72.96882289,
        67.21086754,  61.87018312,  56.92237156,  52.34353511,
        48.11043257,  44.2006024 ,  40.59245631,  37.26534692,
       183.        ,  23.        ,   5.        ,   0.        ,
         0.        ,   0.        ,   0.        ,   0.        ,
         0.        ,   0.        ,   0.        ,   0.        ,
         0.        ,   0.        ,   0.        ,   0.        ,
         0.        ,   0.        ,   0.        ,   0.        ,
         0.        ,   0.        ,   0.        ,   0.        ,
         0.        ,   0.        ,   0.        ,   0.        ,
         0.        ,   0.        ,   0.        ,   0.        ,
         0.        ,   0.        ,   0.        ,   0.        ,
         0.        ,   0.        ,   0.        ,   0.        ,
       

In [90]:
pred[8, :]

<tf.Tensor: shape=(148,), dtype=float64, numpy=
array([-0.03583945, -0.03630679, -0.03658146, -0.03666347, -0.03655284,
       -0.0362496 , -0.03575379, -0.03506542, -0.03418455, -0.0331112 ,
       -0.03184543, -0.03038729, -0.02873682, -0.02689408, -0.02485913,
       -0.02263204, -0.02021287, -0.01760169, -0.01479858, -0.01180361,
       -0.01453558, -0.01682202, -0.01887398, -0.02066269, -0.02222955,
       -0.02360825, -0.02482641, -0.02590686, -0.02686857, -0.02772741,
       -0.02849675, -0.02918788, -0.02981041, -0.03037256, -0.03088137,
       -0.03134291, -0.03176244, -0.03214451, -0.0324931 , -0.0328117 ,
       -0.03310334, -0.03337072, -0.0336162 , -0.03384188, -0.03404962,
       -0.03424107, -0.0344177 , -0.03458083, -0.03473165, -0.03487121,
       -0.03500047, -0.03512027, -0.0352314 , -0.03533456, -0.03543038,
       -0.03551943, -0.03560224, -0.03567929, -0.03575101, -0.03581778,
       -0.03587999, -0.03593796, -0.03589755, -0.03584261, -0.03577452,
       -0.035694

In [75]:
batch_pos_mask.shape

(32, 148, 149, 149)

In [41]:
tar_real_te[30, :]

<tf.Tensor: shape=(189,), dtype=float64, numpy=
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0.])>

In [43]:
pred[1, :]

<tf.Tensor: shape=(189,), dtype=float64, numpy=
array([-0.04683568, -0.04683568, -0.04683568, -0.04683568, -0.04683568,
       -0.04683568, -0.04683568, -0.04683568, -0.04683568, -0.04683568,
       -0.04683568, -0.04683568, -0.04683568, -0.04683568, -0.04683568,
       -0.04683568, -0.04683568, -0.04683568, -0.04683568, -0.04683568,
       -0.04683568, -0.04683568, -0.04683568, -0.04683568, -0.04683568,
       -0.04683568, -0.04683568, -0.04683568, -0.04683568, -0.04683568,
       -0.04683568, -0.04683568, -0.04683568, -0.04683568, -0.04683568,
       -0.04683568, -0.04683568, -0.04683568, -0.04683568, -0.04683568,
       -0.04683568, -0.04683568, -0.04683568, -0.04683568, -0.04683568,
       -0.04683568, -0.04683568, -0.04683568, -0.04683568, -0.04683568,
       -0.04683568, -0.04683568, -0.04683568, -0.04683568, -0.04683568,
       -0.04683568, -0.04683568, -0.04683568, -0.04683568, -0.04683568,
       -0.04683568, -0.04683568, -0.04683568, -0.04683568, -0.04683568,
       -0.046835

In [46]:
pred_sig[2, :]

<tf.Tensor: shape=(189,), dtype=float64, numpy=
array([2.79753801, 2.79753801, 2.79753801, 2.79753801, 2.79753801,
       2.79753801, 2.79753801, 2.79753801, 2.79753801, 2.79753801,
       2.79753801, 2.79753801, 2.79753801, 2.79753801, 2.79753801,
       2.79753801, 2.79753801, 2.79753801, 2.79753801, 2.79753801,
       2.79753801, 2.79753801, 2.79753801, 2.79753801, 2.79753801,
       2.79753801, 2.79753801, 2.79753801, 2.79753801, 2.79753801,
       2.79753801, 2.79753801, 2.79753801, 2.79753801, 2.79753801,
       2.79753801, 2.79753801, 2.79753801, 2.79753801, 2.79753801,
       2.79753801, 2.79753801, 2.79753801, 2.79753801, 2.79753801,
       2.79753801, 2.79753801, 2.79753801, 2.79753801, 2.79753801,
       2.79753801, 2.79753801, 2.79753801, 2.79753801, 2.79753801,
       2.79753801, 2.79753801, 2.79753801, 2.79753801, 2.79753801,
       2.79753801, 2.79753801, 2.79753801, 2.79753801, 2.79753801,
       2.79753801, 2.79753801, 2.79753801, 2.79753801, 2.79753801,
       2.79753