# PatchTST with data pipelining

PatchTST with functions that streamline data loading and prefetching, through the `tf.data` API.

As preprocessing step, we subtract each signal's median to the signal itself; then this net is used in conjunction with a median predictor.

We use a TensorBoard callback to monitor performance.

This configuration roughly takes 90s per epoch on an nVidia RTX3080.

In [None]:
# %%
input_len = 200
telescope = 18
P = 16
S = 8
N = int((input_len-P)/S)
D = 128
input_shape = (P,N)
attention_modules = 6
attention_heads = 16
key_dim = 8
dense_layers_count = 2
dense_units = D*N
stride = 10
dataset_stride = 20
window = input_len
params_for_patchTST = {
    'input_shape': input_shape,
    'output_shape': telescope,
    'patch_size': P,
    'patch_stride': S,
    'num_patches': N,
    'num_modules': attention_modules,
    'num_heads':attention_heads,
    'key_dim': key_dim,
    'embedding_size':D,
    'ff_layers_count':dense_layers_count,
    'ff_layers_units': dense_units
}

In [None]:
import os
os.environ["TF_GPU_THREAD_MODE"]="gpu_private"

### Import libraries

In [None]:
# Fix randomness and hide warnings
seed = 42

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['PYTHONHASHSEED'] = str(seed)
os.environ['MPLCONFIGDIR'] = os.getcwd()+'/configs/'

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=Warning)

import numpy as np
np.random.seed(seed)

import logging
import gc
import random
random.seed(seed)

In [None]:
# Import tensorflow
import tensorflow as tf
from tensorflow import keras as tfk
from tensorflow.keras import layers as tfkl
tf.autograph.set_verbosity(0)
tf.get_logger().setLevel(logging.ERROR)
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
tf.random.set_seed(seed)
tf.compat.v1.set_random_seed(seed)
print(tf.__version__)

In [None]:
import pandas as pd
import seaborn as sns
from datetime import datetime
import matplotlib.pyplot as plt
plt.rc('font', size=16)
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split

In [None]:
import numpy as np
import tensorflow as tf
import tensorflow.keras as tfk
import tensorflow.keras.layers as tfkl
import matplotlib.pyplot as pyplot
import random

from tensorflow.keras import mixed_precision

In [None]:
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)
#tf.keras.backend.set_floatx('float16')
tf.function(jit_compile=True) 
tf.keras.backend.floatx()

This function takes two arrays `X` and `y` (i.e. the signals and the telescopes), concatenates each pair into an element, batches the elements (meaning that now retrieving one dataset element will retrieve a batch of `batch_size` pairs of signal and window), caches the transformations (actually this instruction was used mainly because a previous version had a `map` call) and prefetches automatically.

In [None]:
def make_dataset(X, y, batch_size=128, prefetch_amt=tf.data.experimental.AUTOTUNE):
    dataset = tf.data.Dataset.from_tensor_slices((X, y))
    dataset = dataset.batch(batch_size, drop_remainder=False, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.cache()
    dataset = dataset.prefetch(prefetch_amt)
    return dataset

## Train/Val data generation

In [None]:
data = np.load("data/training_data.npy")
data = np.float32(data)
categories = np.load("data/categories.npy")
valid_periods = np.load("data/valid_periods.npy")
#valid_periods = np.int16(valid_periods)
data.shape, categories.shape, valid_periods.shape

data.dtype

In [None]:
def build_sequences(target_data, valid_periods, window=200, stride=50, telescope=18):
    # Sanity check to avoid runtime errors
    assert window % stride == 0, "Window is " + str(window) + ", stride is " + str(stride) 
    dataset = []
    labels = []
    for i, signal in enumerate(target_data):
        #remove all initial zeros
        temp_sig = signal[valid_periods[i][0]:]
        padding_check = (len(temp_sig)-telescope)%window
        if(padding_check != 0):
            # Compute padding length
            padding_len = window - len(temp_sig)%window
            padding = np.zeros((padding_len), dtype='float32')
            temp_sig = np.concatenate((padding,temp_sig))
            assert len(temp_sig) % window == 0, "Length is " + str(len(temp_sig)) + ", windows length is " + str(window)    

        for j in np.arange(0,len(temp_sig)-window-telescope,stride):
            dataset.append(temp_sig[j:j+window])
            labels.append(temp_sig[j+window:j+window+telescope])
    return np.array(dataset), np.array(labels)

In [None]:
def build_sequences_filtered(target_data, valid_periods, window=200, stride=50, telescope=18):
    assert window % stride == 0
    outliers = []
    dataset = []
    labels = []
    for i, signal in enumerate(target_data):
        if valid_periods[i][1]-valid_periods[i][0] >= telescope*2:
            for j in np.arange(min(valid_periods[i][0],len(signal)-window-telescope),len(signal)-window-telescope,stride):
                input_sequence = signal[j:j+window]
                output_sequence = signal[j+window:j+window+telescope]
                dataset.append(input_sequence)
                labels.append(output_sequence)
    dataset = np.array(dataset)
    labels = np.array(labels)
    medians = np.median(dataset,axis=1)
    median_labels_telescope = np.median(labels,axis=1)
    print(len(outliers))
    return dataset, labels, medians, median_labels_telescope

In [None]:
def build_patches(target_data, patch_size=16, stride=8):
    # Sanity check to avoid runtime errors
    assert patch_size % stride == 0
    dataset = []
    for signal in target_data:
        patches = []
        temp_sig = signal.copy()
        for idx in np.arange(0,len(temp_sig)-patch_size,stride):
            signal_segment = temp_sig[idx:idx+patch_size]
            patches.append(signal_segment)
        dataset.append(np.transpose(np.array(patches)))
    dataset = np.array(dataset)
    print(dataset.shape)
    return dataset #patches are (N,P)

In [None]:
X, y = build_sequences(data,valid_periods,window=input_len,stride=dataset_stride,telescope=telescope)

X.shape, y.shape

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
del X,y

X_train.shape, X_val.shape, y_train.shape, y_val.shape

Median subtraction step

In [None]:
X, y, median_X, median_y = build_sequences_filtered(data, valid_periods, window, stride, telescope)
norm_X, norm_y = (X-median_X[:,None]), (y-median_y[:,None])#/iqr_y[:,None]
print(X.shape, y.shape)

norm_X = build_patches(norm_X,patch_size=P,stride=S)

X_train, X_val, y_train, y_val, norm_X_train, norm_X_val, norm_y_train, norm_y_val, median_y_train, median_y_val = train_test_split(X, y, norm_X, norm_y, median_y, test_size=0.2, random_state=seed)

gc.collect()
X_train.shape, y_train.shape, X_val.shape, y_val.shape

X_train.dtype, y_train.dtype, X_val.dtype, y_val.dtype

## PatchTST implementation

In [None]:
@tfk.saving.register_keras_serializable()
class ProjectionAndPosition(tfkl.Layer):
    def __init__(self, patch_size = 10, embedding_size = 10, num_patches = 20, **kwargs):
        super(ProjectionAndPosition, self).__init__(**kwargs)
        self.patch_size = patch_size  # Output dimensions
        self.num_patches = num_patches
        self.embedding_size = embedding_size

    def build(self, input_shape):
        super(ProjectionAndPosition, self).build(input_shape)
        batch_size = input_shape[0]
        #input is (N,P), output is (N,D)
        # Create trainable weight matrices
        self.W_p = self.add_weight(shape=(self.embedding_size,self.patch_size),
                                   initializer='random_normal',
                                   trainable=True,
                                   name='W_p')

        self.W_pos = self.add_weight(shape=(self.embedding_size,self.num_patches),
                                     initializer='random_normal',
                                     trainable=True,
                                     name='W_pos')

    def call(self, inputs):
        # Perform the transformation
        y = tf.matmul(self.W_p,inputs) + self.W_pos
        return y

    def get_config(self):
        config = {'embedding_size': self.embedding_size, 'num_patches':self.num_patches, 'patch_size':self.patch_size}
        base_config = super(ProjectionAndPosition, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

@tfk.saving.register_keras_serializable()
class CustomFeedForwardLayer(tfkl.Layer):
    def __init__(self, in_shape, num_units, **kwargs):
        super(CustomFeedForwardLayer, self).__init__(**kwargs)
        self.in_shape = in_shape
        self.num_units = num_units
        self.dense = tfkl.Dense(units = num_units, input_shape=(None,in_shape[0]), activation = 'gelu', dtype='float16')
        
    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        D = tf.shape(inputs)[1]
        N = tf.shape(inputs)[2]
        # Iterate over each sample in the batch
        outputs = tf.transpose(inputs, perm=[0, 2, 1])
        outputs = tf.reshape(outputs, [batch_size*N,D])
        outputs = self.dense(outputs)
        outputs = tf.reshape(outputs,[batch_size,N,self.num_units])
        outputs = tf.transpose(outputs, perm=[0,2,1])
        return outputs #reshaped_out
    
    def get_config(self):
        config = {'in_shape': self.in_shape, 'num_units':self.num_units}
        base_config = super().get_config()
        return dict(list(base_config.items()) + list(config.items()))

In [None]:
#from model_developing.custom_layers_for_patchtst import PatchingLayer, ProjectionAndPosition
def encoder(encoder_input, num_modules, num_heads, key_dim, ff_layers_count, ff_layers_units):
    x = encoder_input
    for _ in range(num_modules):
        # Multi-head self-attention mechanism
        print("output of dropout, is: ",x.shape)
        #y = tf.transpose(x,[0,2,1])
        y = tfkl.MultiHeadAttention(num_heads=num_heads, key_dim=key_dim)(x,x)        
        #y = tf.einsum('ijk->ikj', y)
        #y = tf.transpose(y,[0,2,1])
        print("output of multihead attention, should be (BS,N,D): ",y.shape)
        x = tfkl.BatchNormalization()(tf.math.add(x,y))  # Residual connection and layer normalization
        x = tfkl.Dropout(0.2)(x)
        print("output of batch normalization (after multihead), should be (BS,N,D): ",x.shape)
        #y = tf.transpose(x, perm=[0, 2, 1])
        y = CustomFeedForwardLayer(x.shape[1:], x.shape[1]*2)(x)
        #y = tfkl.Activation('gelu')(y)
        y = CustomFeedForwardLayer(y.shape[1:],x.shape[1])(y)
        #y = tfkl.Activation('gelu')(y)
        #y = tf.transpose(y, perm=[0, 2, 1])
        print("output of dense layer, should be (BS,D,N): ",y.shape)
        x = tfkl.BatchNormalization()(tf.math.add(x,y))
        print("output of batch normalization (afer dense), should be (BS,N,D): ",x.shape)
    return x

def patchTST(input_shape, output_shape, patch_size, patch_stride, num_patches, num_modules, num_heads, key_dim, embedding_size, ff_layers_count, ff_layers_units):
    '''
    input_shape: should be (L,), is the length of a window of a signal. 200 would be optimal since 200 is the length of the signals in the test set
    patch_size: the size of the patches to which the input will be divided
    patch_stride: the stride on the input array used for patching it
    output_shape: should be (T,), where T is the number of points to predict (9 or 18 in our case)
    the model will be trained on 9 in the first phase, then, for the last phase, we could take the trained architecture, replace the last dense layer to an ouput of 18
    and fine tune it, alternatively, we could directly train the model on 18 points and crop the predictions for the first phase
    '''
    input_layer = tfkl.Input(shape=input_shape, name='input_layer')
    #patches = PatchingLayer(patch_size = patch_size, stride = patch_stride)(input_layer)
    encoder_input = ProjectionAndPosition(patch_size = patch_size, num_patches = num_patches, embedding_size = embedding_size)(input_layer)#(patches)
    print("output of projection and position, should be (BS,N,D):",encoder_input.shape)
    encoder_output = encoder(encoder_input, num_modules, num_heads, key_dim, ff_layers_count, ff_layers_units)
    print("output of encoder, should be (BS,N,D):",encoder_output.shape)
    encoder_output = tfkl.Flatten()(encoder_output)
    print("output of flatten, should be (BS,N*D):",encoder_output.shape)
    output = tfkl.Dense(units = output_shape, activation='linear')(encoder_output)
    print("output of dense, should be (BS,T):",output.shape)
    model = tfk.Model(inputs=input_layer, outputs=output, name='PatchTST')
    # Compile the model with Mean Squared Error loss and Adam optimizer
    model.compile(loss=tfk.losses.MeanSquaredError(), optimizer=tfk.optimizers.AdamW(learning_rate=1e-4))
    return model


In [None]:
model = patchTST(**params_for_patchTST)

In [None]:
tf.keras.utils.plot_model(model, expand_nested=True)

In [None]:
model.summary()

## Model training

TensorBoard callback

In [None]:
# Create a TensorBoard callback
logs = "logs/" + datetime.now().strftime("%Y%m%d-%H%M%S")

tboard_callback = tf.keras.callbacks.TensorBoard(log_dir = logs,
                                                 histogram_freq = 1,
                                                 profile_batch = '500,520')

X_train.dtype, y_train.dtype, X_val.dtype, y_val.dtype

In [None]:
# Train the model
history = model.fit(
    make_dataset(norm_X_train, norm_y_train),
    #batch_size = 128,
    epochs = 200,
    validation_data=make_dataset(norm_X_val,norm_y_val),
    use_multiprocessing=True,
    callbacks = [
        tfk.callbacks.EarlyStopping(monitor='val_loss', mode='min', patience=20, restore_best_weights=True),
        tfk.callbacks.ReduceLROnPlateau(monitor='val_loss', mode='min', patience=5, factor=0.1, min_lr=1e-6),
        tboard_callback
    ]
).history

In [None]:
'''model.save_weights('./Models/PatchTST_FINAL')
model.save('./Models/PatchTST_FINAL')
model = tfk.models.load_model(
    "./Models/PatchTST_FINAL",
    custom_objects={"ProjectionAndPosition": ProjectionAndPosition},
)
model.evaluate(X_val,y_val)
#model.save("Models/PatchTST")'''