# Time-Series Forecasting Transformer (TSFT) model

In [1]:
import os
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass

In [2]:
device= 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
# count how many trainable weights the model has
def count_parameters(model) -> None:
    total_params= sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'Number of parameters: {total_params}')

# The Transformer Architecture

In [4]:
class MultiHeadSelfAttention(nn.Module):
    """
    The Multi-Headed Self-Attention Layer.
    """

    def __init__(self, n_embed, n_head, dropout=0.1, flash_attn=True, bias=True) -> None:
        super(MultiHeadSelfAttention, self).__init__()
        assert n_embed % n_head == 0, "n_embed must be divisible by n_head"
        self.n_embed= n_embed
        self.n_head = n_head
        self.d_head = n_embed // n_head
        self.flash_attn= flash_attn
        # query, key, value projections in a single batch
        self.c_attn= nn.Linear(n_embed, 3 * n_embed, bias=bias)
        # output projection
        self.o_proj= nn.Linear(n_embed, n_embed, bias=bias)
        # regularization
        self.dropout= nn.Dropout(p=dropout)


    def forward(self, x, causal_mask=None):
        B, T, C= x.size()  # x(batch_size, sequence length, n_embed)
        assert C== self.n_embed, "Input embedding dimension must match model embedding dimension"
        # calculate query, key, values for all heads
        qkv= self.c_attn(x)
        q, k, v= qkv.split(self.n_embed, dim=2) # q,k,v -> (B, T, C)
        # reshape for Multi-Head Attention
        q= q.view(B, -1, self.n_head, self.d_head).transpose(1, 2) # q,k,v view   -> (B, T, nh, dh)
        k= k.view(B, -1, self.n_head, self.d_head).transpose(1, 2) # q,k,v transp -> (B, nh, T, dh)
        v= v.view(B, -1, self.n_head, self.d_head).transpose(1, 2)
        # Attention - the 'scaled dot product'
        if self.flash_attn:
            is_causal= True if causal_mask is not None else False
            # implements FlashAttention
            y= F.scaled_dot_product_attention(
                q, k, v, dropout_p=self.dropout.p, is_causal=is_causal
            )
        else:  # the original implementation of Attention
            attn= (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.d_head))
            # apply causal mask (when the mask is not None)
            if causal_mask is not None:
                attn= attn.masked_fill(causal_mask[:,:,:T,:T]== 0, float('-inf'))
            # normalize Attention scores
            attn= F.softmax(attn, dim=-1)
            attn= self.dropout(attn)
            # compute Attention output
            y= attn @ v # (B, nh, T, dh)
        # concatenate multi-head outputs -- re-assembly all head outputs side by side
        y= y.transpose(1, 2).contiguous().view(B, T, C)
        # output projection
        return self.o_proj(y)


In [5]:
class FeedForward(nn.Module):
    """
    The Feed Forward Network (FFN) as a Gated Linear Unit (GLU) architecture.
    The use of a gated mechanism enhances the expressivity of the FFN by introducing gating.
    This is more flexible than traditional MLP layers and is proven effective in many Transformer
    variants like GPT-NeoX or PaLM.
    """

    def __init__(self, n_embed, d_ff, dropout=0.1, bias=True) -> None:
        super(FeedForward, self).__init__()
        self.gate_proj= nn.Linear(n_embed, d_ff, bias=bias)
        self.up_proj  = nn.Linear(n_embed, d_ff, bias=bias)
        self.down_proj= nn.Linear(d_ff, n_embed, bias=bias)
        self.act_fn= nn.SiLU()
        self.dropout= nn.Dropout(p=dropout)


    def forward(self, x):
        x= self.act_fn(self.gate_proj(x)) * self.up_proj(x)
        x= self.dropout(x)
        x= self.down_proj(x)

        return x


In [6]:
class RMSNorm(nn.Module):
    """
    Root Mean Square normalization (RMSNorm).
    """

    def __init__(self, dim, eps=1e-5) -> None:
        super(RMSNorm, self).__init__()
        # scaling parameter gamma initialized with ones and the amount of parameters equal to dim
        self.gamma= nn.Parameter(torch.ones(dim))
        self.eps= eps


    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)


    def forward(self, x):
        x_norm= self._norm(x.float()).type_as(x)

        return x_norm * self.gamma


In [7]:
class TransformerBlock(nn.Module):
    """
    The Transformer Block (Encoder/Decoder, pre-normalization version).
    """

    def __init__(self, n_embed, n_head, d_ff, dropout=0.1, norm_type='layer', flash_attn=True,
                 bias=True) -> None:
        super(TransformerBlock, self).__init__()
        self.norm1= RMSNorm(n_embed) if norm_type=='rms' else nn.LayerNorm(n_embed)
        self.attn= MultiHeadSelfAttention(n_embed, n_head, dropout, flash_attn, bias)
        self.dropout1= nn.Dropout(p=dropout)
        self.norm2= RMSNorm(n_embed) if norm_type=='rms' else nn.LayerNorm(n_embed)
        self.ffn= FeedForward(n_embed, d_ff, dropout, bias)
        self.dropout2= nn.Dropout(p=dropout)


    def forward(self, x, causal_mask=None):
        x_norm= self.norm1(x)
        x= x + self.dropout1(self.attn(x_norm, causal_mask))
        x_norm= self.norm2(x)
        x= x + self.dropout2(self.ffn(x_norm))

        return x


In [8]:
class TransformerModel(nn.Module):
    """
    A Transformer model is essentially a stack of N Encoder/Decoder Blocks.
    If is_causal=True, we have a Decoder Transformer, otherwise, an Encoder Transformer.
    """

    def __init__(self, is_causal=True, n_layer=6, n_embed=512, block_size=768, n_head=8, d_ff=1024,
                 dropout=0.1, norm_type='layer', flash_attn=True, bias=True) -> None:
        super(TransformerModel, self).__init__()
        self.block_size= block_size
        # define the transformer
        self.transformer= nn.ModuleList([
            TransformerBlock(n_embed, n_head, d_ff, dropout, norm_type, flash_attn, bias)
            for _ in range(n_layer)
        ])
        self.norm_final= RMSNorm(n_embed) if norm_type=='rms' else nn.LayerNorm(n_embed)
        # masked attention on the outputs when the TransformerModel is a Decoder
        # positions depend on the past only -- create a lower triangular matrix (2-D tensor)
        if is_causal and not flash_attn:
            self.register_buffer('causal_mask',
                torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size)
            )
        elif is_causal and flash_attn: self.causal_mask= True
        else: self.causal_mask= None


    def forward(self, x):
        B, T, C= x.size()  # x(batch_size, sequence length, n_embed)
        assert T <= self.block_size, \
            f'Cannot forward sequence of length {T}, block size is only {self.block_size}'
        # forward the embedding through the transformer
        for block in self.transformer:
            x= block(x, self.causal_mask)

        return self.norm_final(x)


In [9]:
model= TransformerModel(bias=False).to(device)

count_parameters(model)
model

Number of parameters: 15741952


TransformerModel(
  (transformer): ModuleList(
    (0-5): 6 x TransformerBlock(
      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (attn): MultiHeadSelfAttention(
        (c_attn): Linear(in_features=512, out_features=1536, bias=False)
        (o_proj): Linear(in_features=512, out_features=512, bias=False)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout1): Dropout(p=0.1, inplace=False)
      (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (ffn): FeedForward(
        (gate_proj): Linear(in_features=512, out_features=1024, bias=False)
        (up_proj): Linear(in_features=512, out_features=1024, bias=False)
        (down_proj): Linear(in_features=1024, out_features=512, bias=False)
        (act_fn): SiLU()
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout2): Dropout(p=0.1, inplace=False)
    )
  )
  (norm_final): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)

In [10]:
model_config= {
    'base':   dict(
        n_layer=6, n_embed=512, block_size=768, n_head=8, d_ff=1024, dropout=0.1, bias=False),
    'medium': dict(
        n_layer=12, n_embed=768, block_size=1024, n_head=12, d_ff=1536, dropout=0.1, bias=False),
    'large':  dict(
        n_layer=16, n_embed=1024, block_size=1280, n_head=16, d_ff=2048, dropout=0.1, bias=False),
    'xlarge':  dict(
        n_layer=24, n_embed=1280, block_size=2048, n_head=20, d_ff=2560, dropout=0.1, bias=False)
}

In [11]:
# ----- Medium config -----
model= TransformerModel(**model_config['medium']).to(device)

count_parameters(model)
del model

Number of parameters: 70817280


In [12]:
# ----- Large config -----
model= TransformerModel(**model_config['large']).to(device)

count_parameters(model)
del model

Number of parameters: 167839744


In [13]:
# ----- XLarge config -----
model= TransformerModel(**model_config['xlarge']).to(device)

count_parameters(model)
del model

Number of parameters: 393341440


# Memory usage

In [20]:
import torch.cuda as cuda

data= torch.randn(32, 768, 512).to(device)
model= TransformerModel(flash_attn=False, bias=False).to(device)

cuda.reset_peak_memory_stats()
# Run the model with original Attention
model.eval()
dec= model(data)
# Measure peak memory usage
peak_dense_memory= cuda.max_memory_allocated()

del dec, model

print(f"Original Attention Memory: {np.round(peak_dense_memory / 10**6, decimals=2)} MB")

Original Attention Memory: 8751.33 MB


In [22]:
if device== 'cuda':
    torch.set_float32_matmul_precision('high')
    # Enable flash attention
    torch.backends.cuda.enable_flash_sdp(True)

model= TransformerModel(flash_attn=True, bias=False).to(device)

cuda.reset_peak_memory_stats()
# Run the model with FlashAttention
model.eval()
dec= model(data)
# Measure peak memory usage
peak_flash_memory= cuda.max_memory_allocated()

del dec, model

print(f"Flash Attention Memory: {np.round(peak_flash_memory / 10**6, decimals=2)} MB")

Flash Attention Memory: 5076.33 MB


# The Time-Series Forecasting Transformer (TSFT) model

In [None]:
class Embedding(nn.Module):
    """
    Initializes the Embedding module.
    """

    def __init__(self, ) -> None:
        super(Embedding, self).__init__()
        # define the patch and positional embeddings... TODO


    def forward(self, ts):

        pass


In [None]:
class OutputBlock(nn.Module):
    """
    The MLP (classification) head.
    """

    def __init__(self, n_embed, d_ff, num_classes, dropout=0.1, bias=True,
                 fine_tune=False) -> None:
        super(OutputBlock, self).__init__()
        if fine_tune:
            self.c_head= nn.Linear(n_embed, num_classes, bias=bias)
        else:
            self.c_head= nn.Sequential(
                nn.Linear(n_embed, d_ff, bias=bias),
                nn.GELU(),
                nn.Dropout(p=dropout),
                nn.Linear(d_ff, num_classes, bias=bias),
            )


    def forward(self, x):
        x= self.c_head(x)

        return x


In [None]:
class TSFTransformer(nn.Module):
    """
    Initializes a Time-Series Forecasting Transformer (TSFT) model.
    If is_causal=True, we have a Decoder Transformer, otherwise, an Encoder Transformer.
    """

    def __init__(self, vocab_size, is_causal=True,
                 n_layer=6, n_embed=512, block_size=768, n_head=8, d_ff=1024, dropout=0.1,
                 norm_type='layer', flash_attn=True, bias=True, fine_tune=False) -> None:
        super(TSFTransformer, self).__init__()
        # initial considerations ... TODO


        # define the patch and positional embeddings
        self.embedding= Embedding()
        # define the transformer decoder
        self.decoder= TransformerModel(
            is_causal, n_layer, n_embed, block_size, n_head, d_ff, dropout, norm_type,
            flash_attn, bias
        )
        # identity layer (no change to the tensor)
        self.latent_space= nn.Identity()
        # classification head
        self.lm_head= OutputBlock(n_embed, d_ff, vocab_size, dropout, bias, fine_tune)

        # initialize Linear modules with Glorot / fan_avg
        # let Normalization and Embedding modules use default initializations
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None: nn.init.zeros_(m.bias)


    def forward(self, ts):

        pass


# Training the model

In [19]:
if device== 'cuda':
    torch.set_float32_matmul_precision('high')
    # Enable flash attention
    torch.backends.cuda.enable_flash_sdp(True)