In [2]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
import time

import tensorflow.keras.layers as kl
import tensorflow.keras.models as km
import tensorflow.keras.optimizers as ko
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import ReduceLROnPlateau
import tensorflow.keras as keras
from tensorflow.keras.initializers import HeNormal
from sklearn.metrics import auc, accuracy_score, roc_curve, recall_score
import tensorflow_addons as tfa
import random
import os
import pandas as pd
import math

from tensorflow.keras.callbacks import EarlyStopping


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 

 The versions of TensorFlow you are currently using is 2.14.0 and is not supported. 
Some things might work, some things might not.
If you were to encounter a bug, do not file an issue.
If you want to make sure you're using a tested and supported configuration, either change the TensorFlow version or the TensorFlow Addons's version. 
You can find the compatibility matrix in TensorFlow Addon's readme:
https://github.com/tensorflow/addons


In [3]:
import tensorflow as tf

class MultiViewAggregator(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(MultiViewAggregator, self).__init__(**kwargs)

    def build(self, input_shape):
        # Assuming input_shape is (subjects, node_seed, node_target, number_of_views)
        self.num_views = input_shape[-1]
        feat_dim = input_shape[-2]  # node_target dimension is considered as feature dimension

        # This is the learnable weight vector b̅ mentioned in the formula (6)
        self.weight_vector = self.add_weight(shape=(feat_dim,),
                                             initializer='glorot_uniform',
                                             name='weight_vector')

    def call(self, inputs):
        print(f'inputs -> {inputs.shape}')
        # inputs shape is (subjects, node_seed, node_target, number_of_views)
        # Compute the dot product between weight vector and node_target features for each view
        attention_scores = tf.tensordot(inputs, self.weight_vector, axes=[2, 0])
        
        # attention_scores shape is (subjects, node_seed, number_of_views)
        # Apply softmax over the number_of_views dimension to get the attention coefficients
        attention_coefficients = tf.nn.softmax(attention_scores, axis=-1)

        # Expand dims of attention_coefficients for element-wise multiplication
        attention_coefficients_expanded = tf.expand_dims(attention_coefficients, axis=2)

        # Multiply the inputs with the attention coefficients and sum over the views
        weighted_sum = tf.reduce_sum(inputs * attention_coefficients_expanded, axis=-1)

        return weighted_sum

    def compute_output_shape(self, input_shape):
        # The output shape will be the input shape without the number_of_views dimension
        return input_shape[:-1]

# Test the MultiViewAggregator
# Define the input shape parameters
num_subjects = 10
num_node_seed = 5
num_node_target = 5
num_views = 4

# Generate random data to simulate the inputs
inputs = tf.random.normal((num_subjects, num_node_seed, num_node_target, num_views))

# Create the MultiViewAggregator layer
aggregator = MultiViewAggregator()

# Call the aggregator on the test inputs
aggregated_representation = aggregator(inputs)

# Print the shapes to verify the output
print(f"Shape of inputs: {inputs.shape}")
print(f"Shape of aggregated representation: {aggregated_representation.shape}")
# Output the actual aggregated representation for inspection
print(aggregated_representation)

inputs -> (10, 5, 5, 4)
Shape of inputs: (10, 5, 5, 4)
Shape of aggregated representation: (10, 5, 5)
tf.Tensor(
[[[ 0.14682235  0.29236728  0.19625227  1.1891878   0.2711566 ]
  [-1.3149323   0.12830108 -0.23348776 -0.21816854  0.02447005]
  [-0.6903639   0.11053761  0.2368384   0.1969232  -0.00595765]
  [ 0.5013736   0.3270502   0.6146605   0.11175621  0.09466723]
  [ 0.06489319 -0.08328563  0.40416318  0.37804925  0.7366225 ]]

 [[-0.37081495 -0.21770597 -1.4158242   0.12318621  0.32984906]
  [-0.616259    0.6813698   1.0692375   0.0424664   0.05229074]
  [ 0.07170535  0.27779216 -0.27730748 -0.6043559  -0.06504062]
  [-1.0289661   0.8543185   0.36709952 -0.6010849  -0.37804586]
  [-0.40803087  0.09522653  0.20614402 -0.5239665   0.6885087 ]]

 [[-1.244364   -0.4229477  -0.610304   -0.12081103  1.1612086 ]
  [ 0.24314263  1.4712534  -0.47533157  0.465911   -0.06506082]
  [ 0.11330339 -1.2681531  -0.26203465  0.30726242  0.03871648]
  [ 0.23435465  0.1849189  -0.2935584  -0.23248768 

In [19]:
import tensorflow as tf
import numpy as np
from tensorflow.keras import layers
from tensorflow.keras import initializers


        
class GraphSAGE(tf.keras.Model):
    def __init__(self,
                 feature_dim, embed_dim,
                 activation='relu'):
        super(GraphSAGE, self).__init__()

        if activation == 'relu':
            self.activation = tf.keras.layers.ReLU()
        elif activation == 'sigmoid':
            self.activation = tf.keras.layers.Activation('sigmoid')
        elif activation == 'tanh':
            self.activation = tf.keras.layers.Activation('tanh')
        elif activation == 'prelu':
            self.activation = tf.keras.layers.Activation('prelu')
        else:
            raise ValueError('Provide a valid activation for GNN')
        
        
        self.embed_dim = embed_dim
        self.feat_dim = feature_dim
        self.weight = self.add_weight(shape=(embed_dim, self.feat_dim),
                                       initializer=initializers.GlorotUniform(),
                                       trainable=True)
    
    """
    This normalization exists some problem to generate NaN value. Please amend. 23 Nov 2023 by JY
    """
    def normalize_adjacency(self, adj):
        d = tf.reduce_sum(adj, axis=-1)
        d_sqrt_inv = tf.pow(d, -0.5)
        d_sqrt_inv = tf.where(tf.math.is_inf(d_sqrt_inv), 0., d_sqrt_inv)
        d_mat_inv_sqrt = tf.linalg.diag(d_sqrt_inv)
        return tf.matmul(tf.matmul(d_mat_inv_sqrt, adj), d_mat_inv_sqrt)


    def call(self, inputs, adj):
        # adj_normalized = self.normalize_adjacency(adj)
        # print(adj_normalized)
        combined = tf.linalg.matmul(adj, inputs)
        outputs = tf.tensordot(combined, self.weight, axes=[[2], [1]])
        """
        combined = tf.random.normal((40,52,125))
        weight = tf.random.normal((64,125))
        combined_transformed = tf.tensordot(combined, weight, axes=[[2], [1]])
        print(combined_transformed.shape) # (40, 52, 64)
        """
        return self.activation(outputs)
    
    
    
# Test the MultiViewAggregator
# Define the input shape parameters
num_subjects = 10
num_node_seed = 5
num_node_target = 5
num_views = 4




feature_dim = 125
channel = 52
embed_dim = 64 

input = tf.random.normal((num_subjects, channel, feature_dim))
# Generate random data to simulate the inputs
adj = tf.random.normal((num_subjects, channel, channel))

# Create the MultiViewAggregator layer
aggregator = GraphSAGE(feature_dim=feature_dim, embed_dim=embed_dim)

# Call the aggregator on the test inputs
aggregated_representation = aggregator(input, adj)

# Print the shapes to verify the output
print(f"Shape of inputs: {input.shape}")
print(f"Shape of aggregated representation: {aggregated_representation.shape}")
# Output the actual aggregated representation for inspection
print(aggregated_representation)


Shape of inputs: (10, 52, 125)
Shape of aggregated representation: (10, 52, 64)
tf.Tensor(
[[[ 0.          0.          6.3000364  ...  0.          0.
    0.        ]
  [ 6.777787    0.          0.         ...  8.4236      0.
    0.        ]
  [17.504349    3.576948    0.         ...  0.         12.221742
    7.5417676 ]
  ...
  [ 0.          0.          0.         ... 24.122377    0.
    0.        ]
  [11.16812     0.          1.3118283  ...  0.          2.1922848
    0.        ]
  [ 0.         12.110241    0.         ...  0.          4.3504896
    0.        ]]

 [[ 0.55670196  4.906233    0.         ... 10.509783    7.5247827
    0.        ]
  [ 0.          0.          2.6273222  ...  0.          0.
    6.6759257 ]
  [ 0.          0.          6.5667315  ...  2.2529225  13.30647
    0.03883433]
  ...
  [15.294917    1.71078     7.650405   ...  0.          0.
   11.589546  ]
  [ 0.          5.3014593   0.         ...  0.18327388  4.9906635
    7.1529317 ]
  [ 4.3943405   1.6992395   1.1

In [47]:
        
class GraphSAGE(tf.keras.Model):
    def __init__(self,
                 feature_dim, embed_dim,
                 activation='relu'):
        super(GraphSAGE, self).__init__()

        if activation == 'relu':
            self.activation = tf.keras.layers.ReLU()
        elif activation == 'sigmoid':
            self.activation = tf.keras.layers.Activation('sigmoid')
        elif activation == 'tanh':
            self.activation = tf.keras.layers.Activation('tanh')
        elif activation == 'prelu':
            self.activation = tf.keras.layers.Activation('prelu')
        else:
            raise ValueError('Provide a valid activation for GNN')
        
        
        self.embed_dim = embed_dim
        self.feat_dim = feature_dim
        self.weight = self.add_weight(shape=(embed_dim, self.feat_dim),
                                       initializer=initializers.GlorotUniform(),
                                       trainable=True)
    
    """
    This normalization exists some problem to generate NaN value. Please amend. 23 Nov 2023 by JY
    """
    def normalize_adjacency(self, adj):
        d = tf.reduce_sum(adj, axis=-1)
        d_sqrt_inv = tf.pow(d, -0.5)
        d_sqrt_inv = tf.where(tf.math.is_inf(d_sqrt_inv), 0., d_sqrt_inv)
        d_mat_inv_sqrt = tf.linalg.diag(d_sqrt_inv)
        return tf.matmul(tf.matmul(d_mat_inv_sqrt, adj), d_mat_inv_sqrt)


    def call(self, inputs, adj):
        # adj_normalized = self.normalize_adjacency(adj)
        # print(adj_normalized)
        combined = tf.linalg.matmul(adj, inputs)
        outputs = tf.tensordot(combined, self.weight, axes=[[2], [1]])
        """
        combined = tf.random.normal((40,52,125))
        weight = tf.random.normal((64,125))
        combined_transformed = tf.tensordot(combined, weight, axes=[[2], [1]])
        print(combined_transformed.shape) # (40, 52, 64)
        """
        return self.activation(outputs)
    
    
    

class Classifier_GraphSAGE_Transformer():
    def __init__(self, output_directory, callbacks, input_shape, epochs, sweep_config, info):
        # input_shape = (200, 52, 128, 1)

        self.epochs = epochs

        # 随机给定超参数进行训练
        self.info = info
        self.batch_size = sweep_config['batch_size'] if sweep_config else 128
        d_model = 64  # 125# # random.choice([64, 128, 256])
        kernel_size_1 = (4, 5)  # 2, 3, 4
        stride_size_1 = (1, 2)
        kernel_size_2 = (1, 5)  # 2: random.randint(2,8)  (2,5 are the best)
        stride_size_2 = (1, 2)
        kernel_size = [kernel_size_1, kernel_size_2]
        stride_size = [stride_size_1, stride_size_2]
        adam_beta_1, adam_beta_2 = 0.9, 0.999
        num_class = 2  # 2
        learning_rate = 0.01
        num_of_view = 5 
        depth = 3
        feature_dim = 250
        channel = 52
        num_of_last_dense=3
        FFN_units = 256
        optimizer = tf.keras.optimizers.AdamW(learning_rate,
                                              beta_1=adam_beta_1,
                                              beta_2=adam_beta_2,
                                              epsilon=1e-9)

        # If you change these two hyperparameters, remember to change the  self.hyperparameters
        inputs = tf.keras.Input(shape=input_shape[1:])
        input_adj = tf.keras.Input(shape=(input_shape[1], input_shape[1]))
        
        conc_outputs = []
        outputs = GraphSAGE(feature_dim=feature_dim, embed_dim=d_model)(inputs, input_adj)
        conc_outputs.append(tf.expand_dims(outputs, axis=0))

        if depth > 1: 
            for k in range(depth-1):
                outputs = GraphSAGE(feature_dim=d_model, embed_dim=d_model)(outputs, input_adj)
                # outputs = ClsPositionEncodingLayer(
                #     input_channel=input_shape[1], kenerl_size=kernel_size[0], strides=stride_size[0], d_model=d_model, dropout_rate=0.5, name=f'CLS_pos_encoding_{depth}')(outputs)
                outputs = Transformer(0.01,
                                4,
                                FFN_units,
                                4,
                                'relu')(outputs)
                conc_outputs.append(tf.expand_dims(outputs, axis=0))
        outputs = tf.concat(conc_outputs, axis=0)
        print(outputs.shape)
        outputs = tf.reshape(outputs, (depth, -1, d_model * channel))
        outputs = layers.LSTM(units=d_model, return_sequences=True, return_state=False)(outputs)
        outputs = tf.math.reduce_mean(outputs, axis=0)
        outputs = layers.LayerNormalization(epsilon=1e-6)(outputs)

        "Doing this in here is to get the layer[-2] feature"
        for i in range(num_of_last_dense):
            outputs = layers.Dense(FFN_units/(2**i),
                                   activation='relu',
                                   kernel_regularizer=tf.keras.regularizers.l2(0.0001))(outputs)
        outputs = layers.Dense(num_class, activation='softmax')(outputs)
                
        # if depth > 1: 
        #     for k in range(depth-1):
        #         outputs = GraphSAGE(feature_dim=d_model, embed_dim=d_model)(outputs, input_adj)
        #         conc_outputs.append(tf.expand_dims(outputs, axis=0))
        # outputs = tf.concat(conc_outputs, axis=0)
        # print(outputs.shape)
        # outputs = tf.reshape(outputs, (depth, -1, embed_dim * channel))
        # outputs = layers.LSTM(units=d_model, return_sequences=True, return_state=False)(outputs)
        # outputs = tf.math.reduce_mean(outputs, axis=0)
        # print(outputs.shape)
        # # outputs = layers.Dense(units=d_model)(outputs)
        # # outputs = layers.GlobalAveragePooling1D(
        # #     data_format='channels_first', keepdims=False)(outputs)
        # outputs = layers.Dense(num_class, activation='softmax')(outputs)
        model = tf.keras.Model(inputs=[inputs, input_adj], outputs=outputs)
        model.summary()
        model.compile(optimizer=optimizer,
                      loss='categorical_crossentropy',
                      metrics=['accuracy'])
        self.model = model

    def fit(self, X_train, Y_train, X_val, Y_val, X_test, Y_test, adj_train, adj_val, adj_test):
        hist = self.model.fit(
            x=[X_train, adj_train],
            y=Y_train,
            validation_data=([X_val, adj_val], Y_val),
            batch_size=self.batch_size,
            epochs=self.epochs,
            verbose=True,
            shuffle=True  # Set shuffle to True
        )



# hbo_data = np.load('/Users/shanxiafeng/Documents/Code/python/fnirs_DL/JinyuanWang_pythonCode/allData/Output_npy/twoDoctor/HbO-All-HC-MDD/correct_channel_data.npy')
hbo_data = np.load('/Users/shanxiafeng/Documents/Code/python/fnirs_DL/JinyuanWang_pythonCode/allData/Output_npy/twoDoctor/HbO-All-HC-MDD/correct_channel_data.npy')

hbo_data = np.transpose(hbo_data, (0,2,1))
# hbr_data = np.load('/Users/shanxiafeng/Documents/Code/python/fnirs_DL/JinyuanWang_pythonCode/allData/Output_npy/twoDoctor/HbR-All-HC-MDD/correct_channel_data.npy')
hbr_data = np.load('/Users/shanxiafeng/Documents/Code/python/fnirs_DL/JinyuanWang_pythonCode/allData/Output_npy/twoDoctor/HbR-All-HC-MDD/correct_channel_data.npy')

hbr_data = np.transpose(hbr_data, (0,2,1))
labels = np.load('/Users/shanxiafeng/Documents/Code/python/fnirs_DL/JinyuanWang_pythonCode/allData/Output_npy/twoDoctor/HbO-All-HC-MDD/label.npy')

def normalize(data):
    # Iterate over each subject
    normalized_data = np.empty_like(data)
    # Calculate the mean and standard deviation for the current subject
    mean = np.mean(data)
    std = np.std(data)

    # Perform z-normalization for the current subject
    normalized_data= (data - mean) / std
    return normalized_data

epochs = 100

# Concatenate the HbO: (, , 125); HbR: (, , 125) -> Con_HbO_HbR: (, , 250)
hb_input = normalize(np.concatenate((hbo_data,hbr_data),axis=2))
connectivity = np.load('/Users/shanxiafeng/Documents/Code/python/fnirs_DL/JinyuanWang_pythonCode/allData/Output_npy/twoDoctor/Hb-All-HC-MDD/adj_matrix.npy') 
adj = connectivity[...,0]

classifer = Classifier_GraphSAGE_Transformer(None, None, hb_input.shape, epochs=epochs, sweep_config=None, info=None)

def onehotEncode(x):
    t = np.zeros((x.size, x.max()+1))
    t[np.arange(x.size), x] = 1
    return t.astype(int)
onehot_labels = onehotEncode(labels)

X_train, Y_train = hb_input[:350], onehot_labels[:350]
X_val, Y_val = hb_input[350:400], onehot_labels[350:400]
X_test, Y_test = hb_input[400:], onehot_labels[400:]
adj_train, adj_val, adj_test = adj[:350], adj[350:400], adj[400:]

classifer.fit(X_train, Y_train, X_val, Y_val, X_test, Y_test, adj_train, adj_val, adj_test)



(3, None, 52, 64)
Model: "model_9"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_49 (InputLayer)       [(None, 52, 250)]            0         []                            
                                                                                                  
 input_50 (InputLayer)       [(None, 52, 52)]             0         []                            
                                                                                                  
 graph_sage_80 (GraphSAGE)   (None, 52, 64)               16000     ['input_49[0][0]',            
                                                                     'input_50[0][0]']            
                                                                                                  
 graph_sage_81 (GraphSAGE)   (None, 52, 64)               4096      ['grap

In [48]:
X_train.shape

(350, 52, 250)

In [44]:

def scaled_fot_product_attention(queries, keys, values):

    product = tf.matmul(queries, keys, transpose_b=True)
    key_dim = tf.cast(tf.shape(keys)[-1], tf.float32)
    scaled_product = product / tf.math.sqrt(key_dim)

    attention = tf.matmul(tf.nn.softmax(scaled_product, axis=-1), values)
    return attention


class MultiHeadAttention(layers.Layer):

    def __init__(self, n_heads, name='multi_head_attention'):
        super(MultiHeadAttention, self).__init__(name=name)
        self.n_heads = n_heads

    def build(self, input_shape):
        self.d_model = input_shape[-1]
        assert self.d_model % self.n_heads == 0
        self.d_head = self.d_model // self.n_heads

        self.query_lin = layers.Dense(units=self.d_model)
        self.key_lin = layers.Dense(units=self.d_model)
        self.value_lin = layers.Dense(units=self.d_model)

        self.final_lin = layers.Dense(units=self.d_model)

    def split_proj(self, inputs, batch_size):  # inputs: (batch_size, seq_length, d_model)
        shape = (batch_size,
                 -1,
                 self.n_heads,
                 self.d_head)

        # outputs: (batch_size, seq_length, nb_proj, d_proj)
        splited_inputs = tf.reshape(inputs, shape=shape)
        # outputs: (batch_size, nb_proj, seq_length,  d_proj)
        return tf.transpose(splited_inputs, perm=[0, 2, 1, 3])

    def call(self, queries, keys, values):

        batch_size = tf.shape(queries)[0]
        queries = self.query_lin(queries)
        keys = self.key_lin(keys)
        values = self.value_lin(values)

        queries = self.split_proj(queries, batch_size)
        keys = self.split_proj(keys, batch_size)
        values = self.split_proj(values, batch_size)

        attention = scaled_fot_product_attention(queries, keys, values)

        attention = tf.transpose(attention, perm=[0, 2, 1, 3])
        concat_attention = tf.reshape(attention,
                                      shape=(batch_size, -1, self.d_model))
        outputs = self.final_lin(concat_attention)
        return outputs


class PositionalEncoding(layers.Layer):

    def __init__(self):
        super(PositionalEncoding, self).__init__()

    def get_angles(self, pos, i, d_model):  # pos: (seq_length, 1) i: (1, d_model)
        # 2*(i//2) => if i = 5 -> ans = 4
        angles = 1 / np.power(10000., (2*(i//2)) / np.float32(d_model))
        return pos * angles  # (seq_length, d_model)

    def call(self, inputs):
        # input shape batch_size, seq_length, d_model
        seq_length = inputs.shape.as_list()[-2]
        d_model = inputs.shape.as_list()[-1]
        # Calculate the angles given the input
        angles = self.get_angles(np.arange(seq_length)[:, np.newaxis],
                                 np.arange(d_model)[np.newaxis, :],
                                 d_model)
        # Calculate the positional encodings
        angles[:, 0::2] = np.sin(angles[:, 0::2])
        angles[:, 1::2] = np.cos(angles[:, 1::2])
        # Expand the encodings with a new dimension
        pos_encoding = angles[np.newaxis, ...]

        return inputs + tf.cast(pos_encoding, tf.float32)


class EncoderLayer(layers.Layer):

    def __init__(self, FFN_units, n_heads, dropout_rate, activation, name='encoder_layer'):
        super(EncoderLayer, self).__init__(name=name)

        self.FFN_units = FFN_units
        self.n_heads = n_heads
        self.dropout_rate = dropout_rate
        self.activation = activation

    def build(self, input_shape):
        self.d_model = input_shape[-1]
        self.multi_head_attention = MultiHeadAttention(self.n_heads)
        self.dropout_1 = layers.Dropout(rate=self.dropout_rate)
        self.norm_1 = layers.LayerNormalization(epsilon=1e-6)

        self.ffn1_relu_gelu = layers.Dense(
            units=self.FFN_units, activation=self.activation)
        self.ffn2 = layers.Dense(units=self.d_model)
        self.dropout_2 = layers.Dropout(rate=self.dropout_rate)
        self.norm_2 = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs):
        attention = self.multi_head_attention(inputs, inputs, inputs)
        attention = self.dropout_1(attention)
        attention = self.norm_1(attention+inputs)

        outputs = self.ffn1_relu_gelu(attention)
        outputs = self.ffn2(outputs)
        outputs = self.dropout_2(outputs)

        outputs = self.norm_2(outputs + attention)
        return outputs


class EmbeddingLayer(layers.Layer):
    def __init__(self, d_model, filters, kernel_size, strides, l2_rate, name="EmbeddingLayer"):
        super(EmbeddingLayer, self).__init__(name=name)
        self.filters = filters
        self.kernel_size = kernel_size
        self.stride_size = strides
        self.d_model = d_model
        self.l2_rate = l2_rate
    # Why we must have a input_shape but we can not use it, or it will have issues.

    def build(self, input_shape):  # input_shape : batch, channel_dimension, sample_points, HbO/HbR(1,2)
        self.cnn_1 = layers.Conv2D(filters=self.filters,
                                   kernel_size=self.kernel_size,
                                   strides=self.stride_size)

        # self.size_1 = (input_shape[1] - self.kernel_size[0]) // self.stride_size[0] + 1
        # print(f'here size_1 = {self.size_1}')
        self.out_dimension = (
            input_shape[2] - self.kernel_size[1]) // self.stride_size[1] + 1  # {(𝑛 + 2𝑝 − 𝑓 + 1) / 𝑠} + 1 |n=len, p=padding, f=kernel, s=stride ;

        # check_shape = (None, 52, x * y) # using the einsum can be more elegant
        # equal to layers.Reshape((-1, self.out_dimension * self.filters)) , batch_size is ignored
        self.flatten = layers.Reshape((-1, self.out_dimension * self.filters))
        self.lin = layers.Dense(
            self.d_model, kernel_regularizer=tf.keras.regularizers.l2(self.l2_rate))
        self.norm = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs):
        outputs = self.cnn_1(inputs)

        # # check here updated at 14 July 2023 by adding the transpose operation if you can not have good result from this time,
        # outputs = tf.transpose(outputs, perm=[0, 2, 1, 3]) this step change the dimension of channel and sample point, which is not a good choice because you will get (None, 128, channel * output_channel_of_CNN), you lose the comparison of different channel.
        outputs = self.flatten(outputs)
        outputs = self.lin(outputs)
        outputs = self.norm(outputs)
        return outputs


class Encoder(layers.Layer):
    def __init__(self,
                 n_layers,
                 FFN_units,
                 n_heads,
                 dropout_rate,
                 activation,
                 name="encoder"):
        super(Encoder, self).__init__(name=name)
        self.n_layers = n_layers
        self.enc_layers = [EncoderLayer(
            FFN_units, n_heads, dropout_rate, activation) for _ in range(n_layers)]

    def call(self, inputs):
        outputs = inputs
        for i in range(self.n_layers):
            outputs = self.enc_layers[i](outputs)
        return outputs


class ClsPositionEncodingLayer(layers.Layer):
    def __init__(self, input_channel, kenerl_size, strides, d_model, dropout_rate, name="ClsPositionEncodingLayer"):
        super(ClsPositionEncodingLayer, self).__init__(name=name)

        patch = (input_channel - kenerl_size[0]) // strides[0] + 1
        self.cls_token_patch = tf.Variable(tf.random.normal((1, 1, d_model)))
        self.pos_embedding = PositionalEncoding()
        self.dropout_patch = layers.Dropout(dropout_rate)

    def call(self, inputs):
        cls_token_patch_tiled = tf.tile(
            self.cls_token_patch, [tf.shape(inputs)[0], 1, 1])

        outputs = tf.concat([cls_token_patch_tiled, inputs], axis=1)
        outputs = self.pos_embedding(outputs)
        outputs = self.dropout_patch(outputs)
        return outputs
    

        
class GraphSAGE(tf.keras.Model):
    def __init__(self,
                 feature_dim, embed_dim,
                 activation='relu'):
        super(GraphSAGE, self).__init__()

        if activation == 'relu':
            self.activation = tf.keras.layers.ReLU()
        elif activation == 'sigmoid':
            self.activation = tf.keras.layers.Activation('sigmoid')
        elif activation == 'tanh':
            self.activation = tf.keras.layers.Activation('tanh')
        elif activation == 'prelu':
            self.activation = tf.keras.layers.Activation('prelu')
        else:
            raise ValueError('Provide a valid activation for GNN')
        
        
        self.embed_dim = embed_dim
        self.feat_dim = feature_dim
        self.weight = self.add_weight(shape=(embed_dim, self.feat_dim),
                                       initializer=initializers.GlorotUniform(),
                                       trainable=True)
    
    """
    This normalization exists some problem to generate NaN value. Please amend. 23 Nov 2023 by JY
    """
    def normalize_adjacency(self, adj):
        d = tf.reduce_sum(adj, axis=-1)
        d_sqrt_inv = tf.pow(d, -0.5)
        d_sqrt_inv = tf.where(tf.math.is_inf(d_sqrt_inv), 0., d_sqrt_inv)
        d_mat_inv_sqrt = tf.linalg.diag(d_sqrt_inv)
        return tf.matmul(tf.matmul(d_mat_inv_sqrt, adj), d_mat_inv_sqrt)


    def call(self, inputs, adj):
        # adj_normalized = self.normalize_adjacency(adj)
        # print(adj_normalized)
        combined = tf.linalg.matmul(adj, inputs)
        outputs = tf.tensordot(combined, self.weight, axes=[[2], [1]])
        """
        combined = tf.random.normal((40,52,125))
        weight = tf.random.normal((64,125))
        combined_transformed = tf.tensordot(combined, weight, axes=[[2], [1]])
        print(combined_transformed.shape) # (40, 52, 64)
        """
        return self.activation(outputs)
    
    

class Transformer(tf.keras.Model):
    # input_shape = (None, channel_size, sample_point, datapoint)
    def __init__(self,
                 dropout_rate,
                 n_layers,
                 FFN_units,
                 n_heads,
                 activation):
        super(Transformer, self).__init__()

        self.encoder = Encoder(n_layers,
                               FFN_units,
                               n_heads,
                               dropout_rate,
                               activation,
                               name="encoder_1")

        self.global_average_pooling = layers.GlobalAveragePooling1D(
            data_format='channels_first', keepdims=False)

        self.norm = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs):

        output_1 = self.encoder(inputs)  # self.encoder_1(output_1)

        # output_1 = self.global_average_pooling(output_1)

        return output_1