In [3]:
# import tensorflow_datasets as tfds
import tensorflow as tf

from tensorflow import keras
# from tensorflow.keras import layers, Model
from tensorflow.keras.models import *
from tensorflow.keras.layers import *

from dataclasses import dataclass
from einops import rearrange, repeat
from typing import Union

from transformers import AutoTokenizer

import datasets
import math
import numpy as np

In [4]:
class Time2Vector(Layer):
    def __init__(self, seq_len, **kwargs):
        super(Time2Vector, self).__init__()
        self.seq_len = seq_len

    def build(self, input_shape):
        '''Initialize weights and biases with shape (batch, seq_len)'''
        self.weights_linear = self.add_weight(name='weight_linear',
                                              shape=(int(self.seq_len),),
                                              initializer='uniform',
                                              trainable=True)

        self.bias_linear = self.add_weight(name='bias_linear',
                                           shape=(int(self.seq_len),),
                                           initializer='uniform',
                                           trainable=True)

        self.weights_periodic = self.add_weight(name='weight_periodic',
                                                shape=(int(self.seq_len),),
                                                initializer='uniform',
                                                trainable=True)

        self.bias_periodic = self.add_weight(name='bias_periodic',
                                             shape=(int(self.seq_len),),
                                             initializer='uniform',
                                             trainable=True)

    def call(self, x):
        '''Calculate linear and periodic time features'''
        x = tf.math.reduce_mean(x[:, :, :4], axis=-1)
        time_linear = self.weights_linear * x + self.bias_linear  # Linear time feature
        time_linear = tf.expand_dims(time_linear, axis=-1)  # Add dimension (batch, seq_len, 1)

        time_periodic = tf.math.sin(tf.multiply(x, self.weights_periodic) + self.bias_periodic)
        time_periodic = tf.expand_dims(time_periodic, axis=-1)  # Add dimension (batch, seq_len, 1)
        return tf.concat([time_linear, time_periodic], axis=-1)  # shape = (batch, seq_len, 2)

    def get_config(self):  # Needed for saving and loading model with custom layer
        config = super().get_config().copy()
        config.update({'seq_len': self.seq_len})
        return config


In [5]:
@dataclass
class ModelArgs:
    model_input_dims: int = 64
    model_states: int = 64
    projection_expand_factor: int = 2
    conv_kernel_size: int = 4
    delta_t_min: float = 0.001
    delta_t_max: float = 0.1
    delta_t_scale: float = 0.1
    delta_t_init_floor: float = 1e-4
    conv_use_bias: bool = True
    dense_use_bias: bool = False
    layer_id: int = -1
    seq_length: int = 30  # 30 days stock price data
    num_layers: int = 5
    dropout_rate: float = 0.2
#     use_lm_head: float = False
#     num_classes: int = None
#     vocab_size: int = None
#     final_activation = None
#     loss:Union[str, keras.losses.Loss] = None
    loss: str = 'mse'
#     optimizer: Union[str, keras.optimizers.Optimizer] = keras.optimizers.AdamW()
    optimizer: str = 'adam'
    metrics = ['mae', 'mape']

    def __post_init__(self):
        self.model_internal_dim: int = int(self.projection_expand_factor * self.model_input_dims)

        self.delta_t_rank = math.ceil(self.model_input_dims/16)
        if self.layer_id == -1:
            self.layer_id = np.round(np.random.randint(0, 1000), 4)

#         if self.vocab_size == None:
#             raise ValueError("vocab size cannot be none")

#         if self.use_lm_head:
#             self.num_classes=self.vocab_size
#         else:
#             if self.num_classes == None:
#                 raise ValueError(f'num classes cannot be {self.num_classes}')

#             if self.num_classes == 1:
#                 self.final_activation = 'sigmoid'
#             else:
#                 self.final_activation = 'softmax'

#         if self.loss == None:
#             raise ValueError(f"loss cannot be {self.loss}")

In [6]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
vocab_size = tokenizer.vocab_size

In [7]:
def selective_scan(u, delta, A, B, C, D):
    # first step of A_bar = exp(ΔA), i.e., ΔA
    dA = tf.einsum('bld,dn->bldn', delta, A) 
    dB_u = tf.einsum('bld,bld,bln->bldn', delta, u, B)
    
    dA_cumsum = tf.pad(
        dA[:, 1:], [[0, 0], [1, 1], [0, 0], [0, 0]])[:, 1:, :, :]
    
    dA_cumsum = tf.reverse(dA_cumsum, axis=[1])  # Flip along axis 1
    
    # Cumulative sum along all the input tokens, parallel prefix sum, 
    # calculates dA for all the input tokens parallely
    dA_cumsum = tf.math.cumsum(dA_cumsum, axis=1)  

    # second step of A_bar = exp(ΔA), i.e., exp(ΔA)
    dA_cumsum = tf.exp(dA_cumsum)  
    dA_cumsum = tf.reverse(dA_cumsum, axis=[1])  # Flip back along axis 1

    x = dB_u * dA_cumsum
    # 1e-12 to avoid division by 0
    x = tf.math.cumsum(x, axis=1)/(dA_cumsum + 1e-12) 

    y = tf.einsum('bldn,bln->bld', x, C)
    
    return y + u * D 

In [9]:
class MambaBlock(Layer):  # layers.
    def __init__(self, modelargs: ModelArgs, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.args = modelargs
        args = modelargs
        self.layer_id = modelargs.layer_id

#         self.in_projection = layers.Dense(
#             args.model_internal_dim * 2, 
#             input_shape=(args.model_input_dims,), use_bias=False)
        self.in_projection = Dense(
            args.model_internal_dim * 2, 
            input_shape=(args.model_input_dims,), use_bias=False)

#         self.conv1d = layers.Conv1D(
#             filters=args.model_internal_dim,
#             use_bias=args.conv_use_bias,
#             kernel_size=args.conv_kernel_size,
#             groups=args.model_internal_dim,
#             data_format='channels_first',
#             padding='causal'
#         )
        self.conv1d = Conv1D(
            filters=args.model_internal_dim,
            use_bias=args.conv_use_bias,
            kernel_size=args.conv_kernel_size,
            groups=args.model_internal_dim,
            data_format='channels_first',
            padding='causal'
        )

        # this layer takes in current token 'x' 
        # and outputs the input-specific Δ, B, C (according to S6)
#         self.x_projection = layers.Dense(args.delta_t_rank + args.model_states * 2, use_bias=False)
        self.x_projection = Dense(args.delta_t_rank + args.model_states * 2, use_bias=False)

        # this layer projects Δ from delta_t_rank to the mamba internal 
        # dimension
#         self.delta_t_projection = layers.Dense(args.model_internal_dim, 
#                                                input_shape=(args.delta_t_rank,), use_bias=True)
        self.delta_t_projection = Dense(args.model_internal_dim, 
                                               input_shape=(args.delta_t_rank,), use_bias=True)

        self.A = repeat(
                tf.range(1, args.model_states+1, dtype=tf.float32), 
                'n -> d n', d=args.model_internal_dim)

        self.A_log = tf.Variable(
                tf.math.log(self.A), 
                trainable=True, dtype=tf.float32, 
                name=f"SSM_A_log_{args.layer_id}")

        self.D = tf.Variable(
                np.ones(args.model_internal_dim), 
                trainable=True, dtype=tf.float32, 
                name=f"SSM_D_{args.layer_id}")

#         self.out_projection = layers.Dense(
#                 args.model_input_dims, 
#                 input_shape=(args.model_internal_dim,), 
#                 use_bias=args.dense_use_bias)
        self.out_projection = Dense(
                args.model_input_dims, 
                input_shape=(args.model_internal_dim,), 
                use_bias=args.dense_use_bias)

    def call(self, x):
        """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba pape.
        Official Implementation:
            class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
        """

        (batch_size, seq_len, dimension) = x.shape

        x_and_res = self.in_projection(x) # shape = (batch, seq_len, 2 * model_internal_dimension)
        (x, res) = tf.split(x_and_res, 
                            [self.args.model_internal_dim, 
                             self.args.model_internal_dim], axis=-1)
        
        x = rearrange(x, 'b l d_in -> b d_in l')
        x = self.conv1d(x)[:, :, :seq_len]
        x = rearrange(x, 'b d_in l -> b l d_in')
        
        x = tf.nn.swish(x)
        y = self.ssm(x)
        y = y * tf.nn.swish(res)
        return self.out_projection(y)
    
    def ssm(self, x):
        """Runs the SSM. See:
            - Algorithm 2 in Section 3.2 in the Mamba paper
            - run_SSM(A, B, C, u) in The Annotated S4
            Official Implementation:
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
        """
        (d_in, n) = self.A_log.shape

        # Compute ∆ A B C D, the state space parameters.
        #     A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
        #     ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
        #                                  and is why Mamba is called **selective** state spaces)

        A = -tf.exp(tf.cast(self.A_log, tf.float32)) # shape -> (d_in, n)
        D = tf.cast(self.D, tf.float32)

        x_dbl = self.x_projection(x) # shape -> (batch, seq_len, delta_t_rank + 2*n)

        (delta, B, C) = tf.split(
                x_dbl, 
                num_or_size_splits=[self.args.delta_t_rank, n, n], 
                axis=-1) # delta.shape -> (batch, seq_len) & B, C shape -> (batch, seq_len, n)

        delta = tf.nn.softplus(self.delta_t_projection(delta)) # shape -> (batch, seq_len, model_input_dim)

        return selective_scan(x, delta, A, B, C, D)

In [10]:
class ResidualBlock(Layer):  # layers.
    def __init__(self, modelargs: ModelArgs, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.args = modelargs
        self.mixer = MambaBlock(modelargs)
        self.norm = LayerNormalization(epsilon=1e-5)  # layers.  deleted

    def call(self, x):
        """
        Official Implementation:
            Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297
            
            Note: the official repo chains residual blocks that look like
                [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...
            where the first Add is a no-op. This is purely for performance reasons as this
            allows them to fuse the Add->Norm.

            We instead implement our blocks as the more familiar, simpler, and numerically equivalent
                [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....
            
        """
        return self.mixer(self.norm(x)) + x

In [11]:
def init_model(args: ModelArgs):
    time_embedding = Time2Vector(args.seq_length)
    
    input_layer = Input(shape=(args.seq_length,1), name='input_ids') # layers.
#     x = layers.Embedding(
#                 args.vocab_size, 
#                 args.model_input_dims, 
#                 input_length=args.seq_length)(input_layer)
    x = time_embedding(input_layer)
    x = Concatenate(axis=-1)([input_layer, x]) # layers.

    for i in range(args.num_layers):
        x = ResidualBlock(args, name=f"Residual_{i}")(x)
        x = Dropout(args.dropout_rate)(x) # for regularization  # layers.

    x = LayerNormalization(epsilon=1e-5)(x) # normalization layer  # layers.
    
    # use flatten only if we are not using the model as an LM
    # if not args.use_lm_head: 
    x = Flatten()(x) # layers.
    x = Dense(64, activation=tf.nn.gelu)(x)  # layers.
    x = Dropout(0.1)(x)                     # consider to use args.dropout  # layers.
    output_layer = Dense(1, activation='linear')(x)  # layers.
#     output_layer = layers.Dense(
#                 args.num_classes, 
#                 activation=args.final_activation)(x)

    model = Model(
                inputs=[input_layer], 
                outputs=[output_layer], name='Mamba_ka_Mamba')
    model.compile(
        loss=args.loss,
        optimizer=args.optimizer,
        metrics=args.metrics
    )
#     model.compile(loss='mse', optimizer='adam', metrics=['mae', 'mape'])

    return model

In [13]:
args = ModelArgs(
    model_input_dims=3,
    model_states=32,
#     num_layers=12,
    dropout_rate=0.2,
#     vocab_size=vocab_size,
#     num_classes=1,
#     loss='binary_crossentropy',
)
model = init_model(args)
model.summary()

In [15]:
from datetime import datetime
import pandas as pd

tsmc_data = pd.read_csv('./tsmc_stock_prices_INT_close_only.csv')
tsmc_data.index = tsmc_data["date"]
tsmc_data = tsmc_data.drop(columns=["date"])
print(tsmc_data.head())

          close
date           
20100104     64
20100105     64
20100106     64
20100107     64
20100108     64


In [16]:
from sklearn.preprocessing import MinMaxScaler
input_length = 30
output_length = 1
test_percentage = 0.2
dataset = tsmc_data['close'].to_numpy()

scaler = MinMaxScaler()
dataset_norm = scaler.fit_transform(dataset.reshape(-1, 1)).flatten()
dataset_list = []
for i in range(len(dataset) - input_length - output_length):
    dataset_list.append(dataset_norm[i:i + input_length + output_length])
dataset_list = np.array(dataset_list)
trainset = dataset_list[:int(len(dataset_list) * (1 - test_percentage))]
testset = dataset_list[int(len(dataset_list) * (1 - test_percentage)):]

x_train = trainset[:, :-1]
y_train = trainset[:, -1:]
x_test = testset[:, :-1]
y_test = testset[:, -1:]

print('x_train.shape:' + str(x_train.shape))
print('y_train.shape:' + str(y_train.shape))
print('x_test.shape:' + str(x_test.shape))
print('y_test.shape' + str(y_test.shape))

x_train.shape:(2718, 30)
y_train.shape:(2718, 1)
x_test.shape:(680, 30)
y_test.shape(680, 1)
