# Time-Series Forecasting Transformer (TSFT) model

In [None]:
import os
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass

In [None]:
device= 'cuda' if torch.cuda.is_available() else 'cpu'

# The Transformer Decoder

In [None]:
class MultiHeadSelfAttention(nn.Module):
    """
    The Causal Attention Layer.
    """

    def __init__(self, n_embed, block_size, n_head, flash_attn=True, dropout=0.1) -> None:
        super(MultiHeadSelfAttention, self).__init__()
        assert n_embed % n_head == 0, "n_embed must be divisible by n_head"
        self.n_embed= n_embed
        self.n_head = n_head
        self.d_head = n_embed // n_head
        self.flash_attn= flash_attn
        # query, key, value projections in a single batch
        self.c_attn= nn.Linear(n_embed, 3 * n_embed)
        # output projection
        self.o_proj= nn.Linear(n_embed, n_embed)
        # regularization
        self.dropout= nn.Dropout(p=dropout)
        # masked attention on the outputs
        if not self.flash_attn:
            self.register_buffer('causal_mask',
                torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size)
            )


    def forward(self, x):
        B, T, C= x.size() # batch_size, sequence length, embedding dim (d_model)
        assert C == self.n_embed, "Input embedding dimension must match model embedding dimension"
        # calculate query, key, values for all heads
        qkv= self.c_attn(x)
        q, k, v= qkv.split(self.n_embed, dim=2) # q,k,v -> (B, T, C)
        # reshape for Multi-Head Attention
        q= q.view(B, -1, self.n_head, self.d_head).transpose(1, 2) # q,k,v view   -> (B, T, nh, dh)
        k= k.view(B, -1, self.n_head, self.d_head).transpose(1, 2) # q,k,v transp -> (B, nh, T, dh)
        v= v.view(B, -1, self.n_head, self.d_head).transpose(1, 2)
        # Attention - the 'scaled dot product'
        if self.flash_attn:
            y= F.scaled_dot_product_attention(  # implements FlashAttention
                q, k, v, dropout_p=self.dropout.p, is_causal=True
            )
        else:  # the original implementation of Attention
            attn= (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.d_head))
            # apply causal mask and normalize Attention scores
            attn= attn.masked_fill(self.causal_mask[:,:,:T,:T]== 0, float('-inf'))
            attn= F.softmax(attn, dim=-1)
            attn= self.dropout(attn)
            # compute Attention output
            y= attn @ v # (B, nh, T, dh)
        # concatenate multi-head outputs -- re-assembly all head outputs side by side
        y= y.transpose(1, 2).contiguous().view(B, T, C)
        # output projection
        return self.o_proj(y)


In [None]:
class FeedForward(nn.Module):
    """
    The Feed Forward Network (FFN) as a Gated Linear Unit (GLU) architecture.
    """

    def __init__(self, n_embed, d_ffn, dropout=0.1) -> None:
        super(FeedForward, self).__init__()
        self.gate_proj= nn.Linear(n_embed, d_ffn)
        self.up_proj  = nn.Linear(n_embed, d_ffn)
        self.down_proj= nn.Linear(d_ffn, n_embed)
        self.act_fn= nn.SiLU()
        self.dropout= nn.Dropout(p=dropout)


    def forward(self, x):
        x= self.act_fn(self.gate_proj(x)) * self.up_proj(x)
        x= self.dropout(x)
        x= self.down_proj(x)

        return x


In [None]:
class RMSNorm(nn.Module):
    """
    Root Mean Square normalization (RMSNorm).
    """

    def __init__(self, dim, eps=1e-5) -> None:
        super(RMSNorm, self).__init__()
        # scaling parameter gamma initialized with ones and the amount of parameters equal to dim
        self.gamma= nn.Parameter(torch.ones(dim))
        self.eps= eps


    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)


    def forward(self, x):
        x_norm= self._norm(x.float()).type_as(x)

        return x_norm * self.gamma


In [None]:
class DecoderBlock(nn.Module):
    """
    The Decoder Block (pre-normalization version).
    """

    def __init__(self, n_embed, block_size, n_head, d_ff, norm_type='layer', flash_attn=True,
                 dropout=0.1) -> None:
        super(DecoderBlock, self).__init__()
        self.norm_1= RMSNorm(n_embed) if norm_type=='rms' else nn.LayerNorm(n_embed)
        self.attn= MultiHeadSelfAttention(n_embed, block_size, n_head, flash_attn, dropout)
        self.dropout1= nn.Dropout(p=dropout)
        self.norm_2= RMSNorm(n_embed) if norm_type=='rms' else nn.LayerNorm(n_embed)
        self.ffn = FeedForward(n_embed, d_ff, dropout)
        self.dropout2= nn.Dropout(p=dropout)


    def forward(self, x):
        x_norm= self.norm_1(x)
        x= x + self.dropout1(self.attn(x_norm))
        x_norm= self.norm_2(x)
        x= x + self.dropout2(self.ffn(x_norm))

        return x


In [None]:
class TransformerDecoder(nn.Module):
    """
    The Transformer Decoder is essentially a stack of N Encoder Blocks.
    """

    def __init__(self, n_embed=512, block_size=512, n_layer=6, n_head=8, d_ff=1024,
                 norm_type='layer', flash_attn=True, dropout=0.1) -> None:
        super(TransformerDecoder, self).__init__()
        self.transformer= nn.ModuleList([
            DecoderBlock(n_embed, block_size, n_head, d_ff, norm_type, flash_attn, dropout)
            for _ in range(n_layer)
        ])
        self.norm_final= RMSNorm(n_embed) if norm_type=='rms' else nn.LayerNorm(n_embed)


    def forward(self, x):
        for layer in self.transformer:
            x= layer(x)

        return self.norm_final(x)


In [None]:
model= TransformerDecoder().to(device)

total_params= sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Number of parameters: {total_params}\n')

model

Number of parameters: 15769600



TransformerDecoder(
  (transformer): ModuleList(
    (0-5): 6 x DecoderBlock(
      (norm_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (attn): MultiHeadSelfAttention(
        (c_attn): Linear(in_features=512, out_features=1536, bias=True)
        (o_proj): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout1): Dropout(p=0.1, inplace=False)
      (norm_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (ffn): FeedForward(
        (gate_proj): Linear(in_features=512, out_features=1024, bias=True)
        (up_proj): Linear(in_features=512, out_features=1024, bias=True)
        (down_proj): Linear(in_features=1024, out_features=512, bias=True)
        (act_fn): SiLU()
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout2): Dropout(p=0.1, inplace=False)
    )
  )
  (norm_final): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)

In [None]:
model_config= {
    'base':   dict(n_embed=512, block_size=512, n_layer=6, n_head=8, d_ff=1024,
                   flash_attn=True, dropout=0.1),
    'medium': dict(n_embed=1024, block_size=512, n_layer=8, n_head=16, d_ff=2048,
                   flash_attn=True, dropout=0.1),
    'large':  dict(n_embed=1280, block_size=512, n_layer=16, n_head=20, d_ff=2560,
                   flash_attn=True, dropout=0.1)
}

In [None]:
# ----- Medium config -----
model= TransformerDecoder(**model_config['medium']).to(device)

total_params= sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Number of parameters: {total_params}\n')

Number of parameters: 83994624



In [None]:
# ----- Large config -----
model= TransformerDecoder(**model_config['large']).to(device)

total_params= sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Number of parameters: {total_params}\n')

Number of parameters: 262412800



In [None]:
import torch.cuda as cuda

data= torch.randn(16, 128, 512).to(device)
model= TransformerDecoder(flash_attn=False).to(device)

cuda.reset_peak_memory_stats()
# Run the model with dense attention
model.eval()
dec= model(data)
# Measure peak memory usage
peak_flash_memory= cuda.max_memory_allocated()

del dec

print(f"Flash Attention Memory: {np.round(peak_flash_memory / 10**6, decimals=2)} MB")

Flash Attention Memory: 676.35 MB


In [None]:
model= TransformerDecoder(flash_attn=False).to(device)

cuda.reset_peak_memory_stats()
# Run the model with dense attention
model.eval()
dec= model(data)
# Measure peak memory usage
peak_dense_memory= cuda.max_memory_allocated()

del dec

print(f"Dense Attention Memory: {np.round(peak_dense_memory / 10**6, decimals=2)} MB")

Dense Attention Memory: 682.38 MB


# The Time-Series Forecasting Transformer (TSFT)

In [None]:
class Embedding(nn.Module):
    """
    Initializes the Embedding module.
    """

    def __init__(self, ) -> None:
        super(Embedding, self).__init__()
        # define the patch and positional embeddings... TODO


    def forward(self, ts):

        pass


In [None]:
class TSFTransformer(nn.Module):
    """
    Initializes a Time-Series Forecasting Transformer (TSFT) model.
    """

    def __init__(self, vocab_size,
                 n_embed=512, block_size=512, n_layer=6, n_head=8, d_ff=1024, norm_type='layer',
                 flash_attn=True, dropout=0.1) -> None:
        super(TSFTransformer, self).__init__()
        # initial considerations ... TODO


        # define the patch and positional embeddings
        self.embedding= Embedding()
        # define the transformer decoder
        self.decoder= TransformerDecoder(
            n_embed, block_size, n_layer, n_head, d_ff, norm_type, flash_attn, dropout
        )
        # identity layer (no change to the tensor)
        self.latent_space= nn.Identity()
        # classification head
        self.lm_head= nn.Linear(n_embed, vocab_size, bias=False)

        # initialize parameters with Glorot / fan_avg
        for p in self.parameters():
            if p.dim()> 1:
                nn.init.xavier_normal_(p)


    def forward(self, ts):

        pass


# Training the model

In [None]:
if device== 'cuda':
    torch.set_float32_matmul_precision('high')
    # Enable flash attention
    torch.backends.cuda.enable_flash_sdp(True)