In [9]:
from model import dot_prod
import numpy as np

In [4]:
import tensorflow as tf
from tensorflow import keras

In [105]:

####### need to import dot_prod.py from ATP/model/ #######

class FFN(tf.keras.layers.Layer):
    def __init__(self, output_shape, dropout_rate=0.1):
        super().__init__()

        self.dense_b = tf.keras.layers.Dense(output_shape)
        self.dense_c = tf.keras.layers.Dense(output_shape)
        self.layernorm = [tf.keras.layers.LayerNormalization() for _ in range(2)]        
        self.dropout = tf.keras.layers.Dropout(dropout_rate)

    def call(self, x, query):

      ## query is the output of previous MHA_X layer
      ## x is query input to MHA_X_o 

        x += query
        x = self.layernorm[0](x)
        x_skip = tf.identity(x)
        x = self.dense_b(x)
        x = tf.nn.gelu(x)
        x = self.dropout(x)
        x = self.dense_c(x)
        x += x_skip
        return self.layernorm[1](x)

class MHA_XY(tf.keras.layers.Layer):
    def __init__(self,
                 num_heads,
                  projection_shape,
                  output_shape,
                  dropout_rate=0.1):
        super().__init__()
        self.mha = dot_prod.MultiHeadAttention(num_heads, output_shape, projection_shape)
        self.ffn = FFN(output_shape, dropout_rate)

    def call(self, query, key, value, mask):
        x = self.mha(query, key, value, mask)
        x = self.ffn(x, query)  # Shape `(batch_size, seq_len, output_shape)`.
        return x
    

class embed_layers(tf.keras.layers.Layer):
    def __init__(self,output_shape,num_layers_embed=4):
        super().__init__()
        self.num_layers = num_layers_embed
        self.embed = [tf.keras.layers.Dense(output_shape,activation="relu") for _ in range(num_layers_embed-1)]
        self.embed.append(tf.keras.layers.Dense(output_shape))

    def call(self,inputs):
        x = inputs
        for i in range(self.num_layers):
            x = self.embed[i](x)
        return x

class TNP_Decoder(tf.keras.models.Model):
    def __init__(self,output_shape=64,num_layers=6,projection_shape=16,
                 num_heads=4,dropout_rate=0.0,target_y_dim=1,bound_std=False):
        super().__init__()

        self.num_layers = num_layers

        self.mha_xy = [MHA_XY(num_heads,projection_shape,
                              output_shape,dropout_rate) for _ in range(num_layers)]

        self.embed = embed_layers(output_shape,num_layers_embed=4)

        self.dense = tf.keras.layers.Dense(output_shape,activation="relu")
        self.linear = tf.keras.layers.Dense(2*target_y_dim)
        self.target_y_dim = target_y_dim
        self.bound_std = bound_std
        
    def call(self,inputs,training=True):

        ####### check that using training flag like this does prevent dropout
        ### when it is set to false
        
        context_target_pairs,target_masked_pairs,mask = inputs
        input_for_mha = tf.concat([context_target_pairs,target_masked_pairs],axis=1)

        embed = self.embed(input_for_mha)
        
        v = embed
        k = tf.identity(v)
        q = tf.identity(v)

        for i in range(self.num_layers):
            x = self.mha_xy[i](q,k,v,mask)
            q = tf.identity(x)
            k = tf.identity(x)
            v = tf.identity(x)
      
        L = self.dense(x)
        L = self.linear(L)

        mean,log_sigma = L[:,:,:self.target_y_dim],L[:,:,self.target_y_dim:]

        if self.bound_std:
            sigma = 0.05 + 0.95 * tf.math.softplus(log_sigma)
        else:
            sigma = tf.exp(log_sigma)
        
        log_sigma = tf.math.log(sigma)
        return mean,log_sigma      

In [106]:
tnp_model = TNP_Decoder(output_shape=4,num_layers=2,projection_shape=4*3,
                 num_heads=4,dropout_rate=0.1,target_y_dim=1,bound_std=False)

In [107]:
x = np.random.normal(size=(2,20,1))
y = np.random.normal(size=(2,20,1))
n_C = 3
n_T = 7

In [108]:

        x = x[:,:n_C+n_T,:]
        y = y[:,:n_C+n_T,:]

In [109]:

        context_part = tf.concat([tf.ones((n_C,n_C),tf.bool),tf.zeros((n_C,2*n_T),tf.bool)],
                         axis=-1)
        first_part = tf.linalg.band_part(tf.ones((n_T,n_C+2*n_T),tf.bool),-1,n_C)
        second_part = tf.linalg.band_part(tf.ones((n_T,n_C+2*n_T),tf.bool),-1,n_C-1)
        mask = tf.concat([context_part,first_part,second_part],axis=0)
        

In [110]:
        batch_s = tf.shape(x)[0]

        context_target_pairs = tf.concat([x,y],axis=2)
        
        y_masked = tf.zeros((batch_s,n_T,y.shape[-1]))
        target_masked_pairs = tf.concat([x[:,n_C:],y_masked],axis=2)


In [115]:
tnp_model([context_target_pairs,target_masked_pairs,mask],training=False)

(<tf.Tensor: shape=(2, 17, 1), dtype=float32, numpy=
 array([[[0.44684225],
         [0.7466157 ],
         [0.80075735],
         [0.90359604],
         [0.48757184],
         [0.4199815 ],
         [0.48619607],
         [0.427507  ],
         [0.41764975],
         [0.86133313],
         [0.44897962],
         [0.44548956],
         [0.42017996],
         [0.47749928],
         [0.4582056 ],
         [0.49849993],
         [0.42031136]],
 
        [[0.4165408 ],
         [0.733211  ],
         [0.52880526],
         [0.9691214 ],
         [0.4716235 ],
         [0.92942685],
         [0.38826334],
         [0.916807  ],
         [0.35382468],
         [0.9291175 ],
         [0.40406883],
         [0.35289133],
         [0.35821268],
         [0.32225105],
         [0.3457569 ],
         [0.3155364 ],
         [0.3388884 ]]], dtype=float32)>,
 <tf.Tensor: shape=(2, 17, 1), dtype=float32, numpy=
 array([[[-0.20082581],
         [-0.33149752],
         [-0.3544551 ],
         [-0.39711