# Mamba State Space Model with Docker for Local Development and Cloud Deployment

This documentation provides a guide on Mamba State Space Model (SSM) implemented in Python, designed for both local development and cloud deployment using Docker. It covers the following topics:

1. **Introduction to Mamba SSM**: Overview of the Mamba State Space Model and its applications.
2. **Setting Up the Development Environment**: Step-by-step instructions for setting up a local development environment using Docker.
3. **Building and Running the Docker Container**: Instructions for building the Docker image and running the container.
4. **Deploying to the Cloud**: Guidelines for deploying the Mamba SSM to a cloud platform using Docker.
5. **Best Practices**: Tips and best practices for working with Mamba SSM and Docker.


## Prerequisites

Before you begin, ensure you have the following installed on your local machine:

- Docker: [Install Docker](https://docs.docker.com/get-docker/)
- A compatible GPU (for Mamba SSM)
- NVIDIA drivers (if using GPU)


## Sections
- [Introduction to Mamba SSM](#introduction-to-mamba-ssm)
- [Building and Running the Docker Container](#building-and-running-the-docker-container)
- [Using Mamba SSM](#using-mamba-ssm)
- [Deploying to the Cloud](#deploying-to-the-cloud)


## Introduction to Mamba SSM

At the heart of modern AI systems like ChatGPT and AlphaFold lies a critical challenge: understanding sequences. Whether you're parsing a sentence, listening to speech, decoding a genome, or analyzing financial time series, the task boils down to learning patterns across time or space. For the past few years, Transformers—powered by the self-attention mechanism—have dominated this space due to their remarkable ability to model complex relationships between every pair of inputs. But this power comes at a cost: Transformers scale poorly with longer sequences and become expensive or infeasible beyond a few thousand steps.

Enter Mamba, a new kind of sequence model inspired by state space systems—an elegant mathematical framework traditionally used in physics and control theory. Think of it like replacing a massive all-to-all attention grid with a sleek, memory-efficient signal processor that knows when to pay attention and when to ignore noise.

What makes Mamba different isn’t just that it’s faster (though it is—often 3× faster than its peers on modern GPUs), or that it can handle million-token contexts (which it can, with ease). It’s that Mamba is selective. Unlike older state space models that process information uniformly over time, Mamba can decide what matters at each step. It brings a kind of intelligent filtering—like memory with a spotlight—selectively storing important details and discarding the rest.

Here’s the kicker: despite being fully recurrent and operating in linear time, Mamba matches or exceeds the accuracy of Transformers in domains ranging from text and audio to genomics. It doesn’t need attention layers or even separate MLP blocks. Its design is minimal, clean, and hardware-aware—making it not only smart, but fast.

By blending the long-term memory of RNNs, the locality of CNNs, and the expressive power of Transformers—all within a scalable, streamlined architecture—Mamba represents a profound shift in how we think about modeling sequences at scale.


## Building and Running the Docker Container

To build and run the Docker container for Mamba SSM, follow these steps:

1. **Clone the Repository**: Clone the Mamba SSM repository to your local machine.

   ```bash
   git clone https://github.com/gabenavarro/MLContainerLab.git
   cd MLContainerLab
   ```

2. **Build the Docker Image**: Use the provided Dockerfile to build the Docker image.

   ```bash
   # You can choose any tag you want for the image
   # Feel free to play around with the base image, just make sure the host has the same or higher CUDA version
   docker build -f ./assets/build/Dockerfile.mamba.cu128py26cp312 -t ssm-mamba:128-26-312 .
   ```
3. **Run the Docker Container**: Run the Docker container with the necessary configurations. In the first example, we will run the container locally with GPU support. This is the recommended way to run a container while in development mode. For scaling up, we will use the second example which runs the container in the cloud.

   ```bash
    # Run the container with GPU support
    docker run -dt \
        --gpus all \
        -v "$(pwd):/workspace" \
        --name ssm-mamba \
        --env NVIDIA_VISIBLE_DEVICES=all \
        --env GOOGLE_APPLICATION_CREDENTIALS=/workspace/assets/secrets/gcp-key.json \
        ssm-mamba:128-26-312
    ```
> Note: The `-v "$(pwd):/workspace"` option mounts the current directory to `/workspace` in the container, allowing you to access your local files from within the container. The `--env` options set environment variables for GPU visibility and Google Cloud credentials.<br>
> Note: The `--gpus all` option allows the container to use all available GPUs. <br>

4. **Access the Container with IDE**: In this example, we will use Visual Studio Code to access the container. You can use any IDE of your choice.

   ```bash
   # In a scriptable manner
   CONTAINER_NAME=ssm-mamba
   FOLDER=/workspace
   HEX_CONFIG=$(printf {\"containerName\":\"/$CONTAINER_NAME\"} | od -A n -t x1 | tr -d '[\n\t ]')
   code --folder-uri "vscode-remote://attached-container+$HEX_CONFIG$FOLDER"
   ```

> Note: The `code` command is used to open Visual Studio Code. Make sure you have the Remote - Containers extension installed in VS Code to access the container directly. <br>
> Note: Make sure you have installed Remote - Containers extension in VS Code.<br>

## Using Mamba SSM

We will now train a simple Mamba SSM model below using the installed dockerized mamba package above. We will be using a training routine that is similar to training language models with a dataset of bitcoin prices. It will have severe overfitting, but it will be enough to show how to use the package. The training routine is similar to the one used in the [Mamba SSM repository](https://github.com/gabenavarro/MLContainerLab/tree/main/assets/examples/mamba).


### Data Preparation

First, lets go ahead and download the data. We will use a limited dataset from Kaggle to start, however in more advanced scenarios its highly suggested to use an API with access to more datasets such as the [CoinGecko API](https://www.coingecko.com/en/api). 

In [None]:
%%capture
# Download and unzip the Bitcoin historical data dataset from Kaggle
!curl -L -o /workspace/datasets/bitcoin-historical-data.zip \
  https://www.kaggle.com/api/v1/datasets/download/mczielinski/bitcoin-historical-data \
    && unzip -o /workspace/datasets/bitcoin-historical-data.zip -d /workspace/datasets/ \
    && rm /workspace/datasets/bitcoin-historical-data.zip

In [1]:
TIME_SERIES_CSV = "/workspace/datasets/btcusd_1-min_data.csv"
PROCESSED_DATA_DIR = "/workspace/datasets/auto_regressive_processed_timeseries"
CKPT_DIR = "/workspace/datasets/checkpoints"

In [2]:
import pandas as pd
pd.set_option('display.max_columns', None)
pd.read_csv(TIME_SERIES_CSV, low_memory=False, nrows=5).head()

Unnamed: 0,Timestamp,Open,High,Low,Close,Volume,datetime
0,1325412000.0,4.58,4.58,4.58,4.58,0.0,2012-01-01 10:01:00+00:00
1,1325412000.0,4.58,4.58,4.58,4.58,0.0,2012-01-01 10:02:00+00:00
2,1325412000.0,4.58,4.58,4.58,4.58,0.0,2012-01-01 10:03:00+00:00
3,1325412000.0,4.58,4.58,4.58,4.58,0.0,2012-01-01 10:04:00+00:00
4,1325412000.0,4.58,4.58,4.58,4.58,0.0,2012-01-01 10:05:00+00:00


Now that we have the dataset, lets create a dataset and dataloader for the model using litdata. Litdata is a lightweight data loading library that is designed to work with PyTorch and other deep learning frameworks. It provides a simple and efficient way to load and preprocess data for training and evaluation.

In [None]:
import pandas as pd
import litdata as ld
import numpy as np
from typing import Dict, Any

def process_timeseries(file_path: str, sequence_length: int = 2048) -> Dict[str, Any]:
    """Process a timeseries CSV file into a format suitable for autoregressive modeling."""
    df = pd.read_csv(file_path, low_memory=False)
    if "datetime" in df.columns:
        df = df.drop(columns=["datetime"])
    df = df.sort_values('Timestamp')
    numerical_features = ['Open', 'High', 'Low', 'Close', 'Volume']
    
    # —— NEW: log-transform price values to handle exponential growth ——
    for price_col in ['Open', 'High', 'Low', 'Close']:
        # Ensure all values are positive before log transform
        df[price_col] = np.log(df[price_col].replace([np.inf, -np.inf, 0, np.nan], 0.01).fillna(0.01))
    
    # —— NEW: log1p-transform volume to reduce extreme skew ——
    df['Volume'] = np.log1p(df['Volume'].replace([np.inf, -np.inf], np.nan).fillna(0.0))
    
    stats = {}
    for feature in numerical_features:
        vals = df[feature].replace([np.inf, -np.inf], np.nan).dropna()
        mean, std = (vals.mean(), vals.std() or 1.0)
        stats[feature] = {'mean': mean, 'std': std}
        df[feature] = (df[feature] - mean) / std

    if len(df) <= sequence_length:
        print(f"Warning: only {len(df)} rows < sequence_length={sequence_length}")
        return None

    def create_timeseries_sample(index: int) -> Dict[str, Any]:
        if index < sequence_length or index >= len(df):
            return {"index": index,
                    "inputs": np.zeros((sequence_length,5),dtype=np.float32),
                    "mask": np.zeros(sequence_length, dtype=bool),
                    "stats": stats}
        seq = df.iloc[index-sequence_length:index][numerical_features].values
        arr = np.nan_to_num(seq.astype(np.float32), nan=0.0, posinf=0.0, neginf=0.0)
        mask = ~np.isnan(seq).any(axis=1)
        return {"index": index, "inputs": arr, "mask": mask, "stats": stats}

    # Store original stats for inverse transformation during inference
    stats['transform_type'] = 'log_then_zscore'
    
    return create_timeseries_sample


# Set the sequence length for your model
sequence_length = 2048

# Get the processing function configured for your specific file
process_function = process_timeseries(TIME_SERIES_CSV, sequence_length)

# Filter indices to exclude those with NaN values
def get_valid_indices():
    df = pd.read_csv(TIME_SERIES_CSV, low_memory=False)
    if "datetime" in df.columns:
        df = df.drop(columns=["datetime"])
    
    df = df.sort_values('Timestamp')
    numerical_features = ['Open', 'High', 'Low', 'Close', 'Volume']
    
    # Replace infinity values with NaN
    for feature in numerical_features:
        df[feature] = df[feature].replace([np.inf, -np.inf], np.nan)
    
    valid_indices = []
    for idx in range(sequence_length, len(df), int(sequence_length * 0.25)):
        # Check if the entire sequence has no NaN values
        sequence = df.iloc[idx-sequence_length:idx][numerical_features].values
        if not np.isnan(sequence).any() and not np.isinf(sequence).any():
            valid_indices.append(idx)
        else:
            print(f"Skipping index {idx} due to NaN or inf values in the sequence.")
    
    if not valid_indices:
        raise ValueError("No valid sequences found! All sequences contain NaN or inf values.")
    
    return valid_indices

valid_indices = get_valid_indices()

# The optimize function writes data in an optimized format
ld.optimize(
    fn=process_function,              # the function that processes each sample
    inputs=valid_indices,             # the indices of valid samples
    output_dir=PROCESSED_DATA_DIR,    # optimized data is stored here
    num_workers=4,                    # The number of workers on the same machine
    chunk_bytes="64MB"                # size of each chunk
)
# Takes about 30 seconds to run

In [3]:
from litdata import StreamingDataset, StreamingDataLoader, train_test_split
import torch
streaming_dataset = StreamingDataset(PROCESSED_DATA_DIR) # data are stored in the cloud

def custom_collate(batch):
    # Filter out None values
    batch = [item for item in batch if item is not None]
    
    if not batch:
        # Return empty tensors if the batch is empty
        return {
            "index": torch.tensor([], dtype=torch.long),
            "inputs": torch.tensor([], dtype=torch.float32),
            "mask": torch.tensor([], dtype=torch.bool),
            "stats": {}
        }
    
    # Process each key separately
    indices = torch.tensor([item["index"] for item in batch], dtype=torch.long)
    
    # Make sure arrays are writable by copying and convert to tensor
    inputs = torch.stack([torch.tensor(np.nan_to_num(item["inputs"].copy(), nan=0.0), dtype=torch.float32) for item in batch])
    masks = torch.stack([torch.tensor(item["mask"].copy(), dtype=torch.bool) for item in batch])
    
    # Get stats (use first non-empty item's stats)
    stats = next((item["stats"] for item in batch if "stats" in item), {})
    
    return {
        "index": indices,
        "inputs": inputs,
        "mask": masks,
        "stats": stats
    }

print(len(streaming_dataset)) # display the length of your data
# out: 100,000

train_dataset, val_dataset, test_dataset = train_test_split(streaming_dataset, splits=[0.8, 0.1, 0.1])

print("Train ", len(train_dataset))
train_dataloader = StreamingDataLoader(train_dataset, num_workers=4, batch_size=32, shuffle=True, collate_fn=custom_collate)  # Create DataLoader for training
# out: 80,000

print("Validation ", len(val_dataset))
val_dataloader = StreamingDataLoader(val_dataset, num_workers=4, batch_size=32, shuffle=False, collate_fn=custom_collate)  # Create DataLoader for validation

test_dataloader = StreamingDataLoader(test_dataset, num_workers=4, batch_size=32, shuffle=False, collate_fn=custom_collate)
# out: 10,000
print("Test ", len(test_dataset))


  exec(code_obj, self.user_global_ns, self.user_ns)


13690
Train  10952
Validation  1369
Test  1369


### Define the Model and Training Loop

Now that we have the data, we can define the model and training loop. We will use the Mamba SSM model from the mamba package. The model is a simple recurrent neural network (RNN) that is designed to work with sequences of data. It uses a state space model to learn the underlying patterns in the data and make predictions.

In [8]:
import torch
from mamba_ssm import Mamba2
from torch import nn, optim, Tensor
import torch.nn.functional as F
import lightning as pl
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import matplotlib.pyplot as plt
import numpy as np

batch, length, dim = 2, 64, 16

class Mamba2Layer(nn.Module):
    def __init__(
            self, d_model:int, d_state:int, d_conv:int, expand:int, headdim:int,
            layer_idx:int
        ):
        super(Mamba2Layer, self).__init__()
        self.mamba = Mamba2(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand,
            headdim=headdim,
            layer_idx=layer_idx,
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.mamba(x)


In [14]:
class Mamba2Model(pl.LightningModule):
    def __init__(
            self, input_dim:int, d_model:int, d_state:int, d_conv:int, 
            expand:int, headdim:int, n_layers:int, 
            lr:float=1e-4, weight_decay:float=0.01, **kwargs
        ):
        super(Mamba2Model, self).__init__()

        self.save_hyperparameters()

        # Linear layer for input projection
        self.input_projection = nn.Linear(input_dim, d_model)

        # Mamba2 layers
        self.layers = nn.ModuleList([
            Mamba2Layer(d_model, d_state, d_conv, expand, headdim, i) for i in range(n_layers)
        ])

        # Linear layer for output projection
        self.linear = nn.Linear(d_model, input_dim)

        # Provide a small weight for the volume feature
        fw = torch.tensor([1.,1.,1.,1.,0.25])
        self.register_buffer('feature_weights', fw)

        # For tracking metrics
        self.train_loss = 0.0
        self.val_loss = 0.0
        
        # Learning rate
        self.lr = lr
        self.weight_decay = weight_decay
        
        # For storing test predictions
        self.test_predictions = []
        self.test_targets = []
        self.test_masks = []
        self.feature_names = ['Open', 'High', 'Low', 'Close', 'Volume']


    def forward(self, x: Tensor) -> Tensor:
        x = self.input_projection(x)
        for layer in self.layers:
            x = layer(x)
        return self.linear(x)
    
    def _calculate_autoregressive_loss(self, preds, targets, mask):
        # shift for next-step prediction
        p, t = preds[:, :-1], targets[:, 1:]
        m = mask[:, 1:].unsqueeze(-1)
        # mse = F.mse_loss(p, t, reduction='none')
        mse = F.smooth_l1_loss(p, t, reduction='none', beta=0.01)
        # apply per-feature weights
        mse = mse * self.feature_weights.view(1,1,-1)
        # mask and reduce
        mse = mse * m
        denom = m.sum().clamp_min(1.0)
        return mse.sum() / denom

    
    def training_step(self, batch, batch_idx):
        """
        Training step for autoregressive prediction
        """
        # Get inputs and mask from batch
        inputs = batch['inputs']
        mask = batch['mask']
        
        # Safety check for NaN in inputs
        if torch.isnan(inputs).any():
            inputs = torch.nan_to_num(inputs, nan=0.0)
        
        # Forward pass to get predictions
        predictions = self(inputs)
        
        # Calculate autoregressive loss
        loss = self._calculate_autoregressive_loss(predictions, inputs, mask)        
        
        # Log the loss for monitoring (don't try to compute gradient norm here)
        self.train_loss = loss
        self.log('train_loss', loss, prog_bar=True)  
        
        return loss
    
    def on_after_backward(self):
        """
        Called after .backward() and before optimizers do anything.
        This is the right place to check gradients.
        """
        # Safely compute gradient norm - after backward pass when gradients exist
        if any(p.grad is not None for p in self.parameters()):
            grad_list = [p.grad.detach().norm(2) for p in self.parameters() if p.grad is not None]
            if grad_list:  # Make sure list is not empty
                grad_norm = torch.stack(grad_list).norm(2)
                self.log('grad_norm', grad_norm, prog_bar=True)
            else:
                self.log('grad_norm', torch.tensor(0.0), prog_bar=True)
        else:
            self.log('grad_norm', torch.tensor(0.0), prog_bar=True)

    def validation_step(self, batch, batch_idx):
        """
        Validation step for autoregressive prediction
        """
        # Get inputs and mask from batch
        inputs = batch['inputs']
        mask = batch['mask']
        
        # Safety check for NaN in inputs
        if torch.isnan(inputs).any():
            inputs = torch.nan_to_num(inputs, nan=0.0)
        
        # Forward pass to get predictions
        predictions = self(inputs)
        
        # Calculate autoregressive loss
        loss = self._calculate_autoregressive_loss(predictions, inputs, mask)
                
        # Log the loss for monitoring
        self.val_loss = loss
        self.log('val_loss', loss, prog_bar=True)
        
        return loss
    
    def test_step(self, batch, batch_idx):
        """
        Test step for evaluating the model after training
        """
        # Get inputs and mask from batch
        inputs = batch['inputs']
        mask = batch['mask']
        
        # Safety check for NaN in inputs
        if torch.isnan(inputs).any():
            inputs = torch.nan_to_num(inputs, nan=0.0)
        
        # Forward pass to get predictions
        predictions = self(inputs)
        
        # Calculate autoregressive loss
        loss = self._calculate_autoregressive_loss(predictions, inputs, mask)
        
        # Store predictions and targets for later analysis
        # Use detach() and cpu() to avoid memory leaks
        self.test_predictions.append(predictions.detach().cpu())
        self.test_targets.append(inputs.detach().cpu())
        self.test_masks.append(mask.detach().cpu())
        
        # Log the test loss
        self.log('test_loss', loss, prog_bar=True)
        
        return loss
    
    # Configuring the optimizer
    def configure_optimizers(self):
        opt = torch.optim.AdamW(
            self.parameters(),
            lr=self.lr,
            weight_decay=self.weight_decay
        )
        sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
            opt,
            mode='min',
            factor=0.5,
            patience=2
        )
        return {
            "optimizer": opt,
            "lr_scheduler_config": {
                "scheduler": sch,
                "monitor": "val_loss",
                "interval": "epoch",
                "frequency": 1
            }
        }
    
    def on_test_epoch_end(self):
        """
        Calculate and log test metrics at the end of the test epoch
        """
        # Concatenate all batches
        all_preds = torch.cat(self.test_predictions, dim=0)
        all_targets = torch.cat(self.test_targets, dim=0)
        all_masks = torch.cat(self.test_masks, dim=0)
        
        # For autoregressive prediction, shift by one step
        pred_shifted = all_preds[:, :-1, :]
        target_shifted = all_targets[:, 1:, :]
        mask_shifted = all_masks[:, 1:]
        
        # Convert to numpy for sklearn metrics
        pred_np   = pred_shifted.detach().cpu().to(torch.float32).numpy()
        target_np = target_shifted.detach().cpu().to(torch.float32).numpy()
        mask_np   = mask_shifted.detach().cpu().to(torch.float32).numpy()
        
        # Calculate metrics for each feature
        metrics = {}
        for i, feature_name in enumerate(self.feature_names):
            # Extract predictions and targets for this feature
            feature_preds = pred_np[:, :, i]
            feature_targets = target_np[:, :, i]
            
            # Apply mask to consider only valid positions
            valid_preds = []
            valid_targets = []
            
            # Flatten and filter by mask
            for batch_idx in range(feature_preds.shape[0]):
                for seq_idx in range(feature_preds.shape[1]):
                    if mask_np[batch_idx, seq_idx]:
                        valid_preds.append(feature_preds[batch_idx, seq_idx])
                        valid_targets.append(feature_targets[batch_idx, seq_idx])
            
            # Convert to numpy arrays
            valid_preds = np.array(valid_preds)
            valid_targets = np.array(valid_targets)
            
            if len(valid_preds) > 0:
                # Calculate metrics
                mse = mean_squared_error(valid_targets, valid_preds)
                rmse = np.sqrt(mse)
                mae = mean_absolute_error(valid_targets, valid_preds)
                
                # Calculate MAPE (Mean Absolute Percentage Error) with handling for zeros
                with np.errstate(divide='ignore', invalid='ignore'):
                    mape = np.mean(np.abs((valid_targets - valid_preds) / np.maximum(np.abs(valid_targets), 1e-8))) * 100
                    mape = np.nan_to_num(mape, nan=0.0, posinf=0.0, neginf=0.0)
                
                # R-squared
                r2 = r2_score(valid_targets, valid_preds)
                
                # Store metrics
                metrics[f"{feature_name}_mse"] = mse
                metrics[f"{feature_name}_rmse"] = rmse
                metrics[f"{feature_name}_mae"] = mae
                metrics[f"{feature_name}_mape"] = mape
                metrics[f"{feature_name}_r2"] = r2
                
                # Log each metric
                self.log(f"test_{feature_name}_mse", mse)
                self.log(f"test_{feature_name}_rmse", rmse)
                self.log(f"test_{feature_name}_mae", mae)
                self.log(f"test_{feature_name}_mape", mape)
                self.log(f"test_{feature_name}_r2", r2)
        
        # Calculate average metrics across all features
        avg_mse = np.mean([metrics[f"{feature}_mse"] for feature in self.feature_names if f"{feature}_mse" in metrics])
        avg_rmse = np.mean([metrics[f"{feature}_rmse"] for feature in self.feature_names if f"{feature}_rmse" in metrics])
        avg_mae = np.mean([metrics[f"{feature}_mae"] for feature in self.feature_names if f"{feature}_mae" in metrics])
        avg_mape = np.mean([metrics[f"{feature}_mape"] for feature in self.feature_names if f"{feature}_mape" in metrics])
        avg_r2 = np.mean([metrics[f"{feature}_r2"] for feature in self.feature_names if f"{feature}_r2" in metrics])
        
        # Log average metrics
        self.log("test_avg_mse", avg_mse)
        self.log("test_avg_rmse", avg_rmse)
        self.log("test_avg_mae", avg_mae)
        self.log("test_avg_mape", avg_mape)
        self.log("test_avg_r2", avg_r2)
        
        # Print summary of test metrics
        print("\n===== TEST METRICS =====")
        print(f"Average MSE: {avg_mse:.6f}")
        print(f"Average RMSE: {avg_rmse:.6f}")
        print(f"Average MAE: {avg_mae:.6f}")
        print(f"Average MAPE: {avg_mape:.6f}%")
        print(f"Average R²: {avg_r2:.6f}")
        print("=======================\n")
        
        # Create and save visualizations
        self._create_prediction_visualizations(pred_shifted, target_shifted, mask_shifted)
        
        # Clear stored predictions to free memory
        self.test_predictions = []
        self.test_targets = []
        self.test_masks = []

    def _create_prediction_visualizations(self, predictions, targets, masks, num_samples=3):
        """
        Create and save visualization of predictions vs targets
        
        Args:
            predictions: Model predictions [batch_size, seq_len, feature_dim]
            targets: Target values [batch_size, seq_len, feature_dim]
            masks: Boolean masks [batch_size, seq_len]
            num_samples: Number of samples to visualize
        """
        # Convert to numpy for plotting
        preds_np = predictions.detach().cpu().to(torch.float32).numpy()
        targets_np = targets.detach().cpu().to(torch.float32).numpy()
        masks_np = masks.detach().cpu().to(torch.float32).numpy()
        
        # Create directory for plots if it doesn't exist
        import os
        os.makedirs(os.path.join(CKPT_DIR, "plots"), exist_ok=True)
        
        # Plot for each feature
        for feature_idx, feature_name in enumerate(self.feature_names):
            plt.figure(figsize=(15, 10))
            
            # Plot for a few random samples
            for sample_idx in range(min(num_samples, preds_np.shape[0])):
                # Get predictions and targets for this sample and feature
                sample_preds = preds_np[sample_idx, :, feature_idx]
                sample_targets = targets_np[sample_idx, :, feature_idx]
                sample_mask = masks_np[sample_idx, :]
                
                # Create time index for x-axis
                time_idx = np.arange(len(sample_preds))
                
                # Plot targets
                plt.subplot(num_samples, 1, sample_idx + 1)
                plt.plot(time_idx, sample_targets, 'b-', label='Actual', alpha=0.7)
                
                # Plot predictions (only where mask is True)
                masked_preds = np.where(sample_mask, sample_preds, np.nan)
                plt.plot(time_idx, masked_preds, 'r-', label='Predicted', alpha=0.7)
                
                # Add title and legend
                plt.title(f"Sample {sample_idx+1}: {feature_name}")
                plt.legend()
                plt.grid(True, alpha=0.3)
                
                # Calculate metrics for this sample
                valid_indices = np.where(sample_mask)[0]
                if len(valid_indices) > 0:
                    valid_preds = sample_preds[valid_indices]
                    valid_targets = sample_targets[valid_indices]
                    
                    mse = mean_squared_error(valid_targets, valid_preds)
                    rmse = np.sqrt(mse)
                    r2 = r2_score(valid_targets, valid_preds)
                    
                    plt.figtext(0.01, 0.5 - 0.15 * sample_idx, 
                                f"MSE: {mse:.4f}, RMSE: {rmse:.4f}, R²: {r2:.4f}", 
                                fontsize=9)
            
            plt.tight_layout()
            plt.savefig(os.path.join(CKPT_DIR, "plots", f"{feature_name}_predictions.svg"))
            plt.close()
        
        # Create a summary plot with all features for the first sample
        plt.figure(figsize=(15, 12))
        
        for feature_idx, feature_name in enumerate(self.feature_names):
            plt.subplot(len(self.feature_names), 1, feature_idx + 1)
            
            # Get predictions and targets for first sample and this feature
            sample_preds = preds_np[0, :, feature_idx]
            sample_targets = targets_np[0, :, feature_idx]
            sample_mask = masks_np[0, :]
            
            # Create time index for x-axis
            time_idx = np.arange(len(sample_preds))
            
            # Plot targets
            plt.plot(time_idx, sample_targets, 'b-', label='Actual', alpha=0.7)
            
            # Plot predictions (only where mask is True)
            masked_preds = np.where(sample_mask, sample_preds, np.nan)
            plt.plot(time_idx, masked_preds, 'r-', label='Predicted', alpha=0.7)
            
            # Add title and legend
            plt.title(f"{feature_name}")
            plt.legend()
            plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(os.path.join(CKPT_DIR, "plots", "all_features_predictions.svg"))
        plt.close()

In [16]:
parameters = dict(
    input_dim=5,
    d_model=64, 
    d_state=64, 
    d_conv=4, 
    expand=4, 
    headdim=16, 
    n_layers=4,
    lr=1e-4,
    weight_decay=0.01,
    epochs=100,
    data_dir = PROCESSED_DATA_DIR,  # Directory for processed data
    ckpt_dir = CKPT_DIR,            # Checkpoint directory
    num_workers = 4,                # Number of workers for data loading
    batch_size = 32,                # Batch size
    sequence_length = 2048,         # Sequence length
    limit_val_batches = 100,        # Use a fraction of validation data for faster training
    accumulate_grad_batches = 8,    # Gradient accumulation
    gradient_clip_val = 0.5,        # Gradient clipping for stability
    precision = "bf16-mixed",       # Mixed precision training
    log_every_n_steps = 10,         # Log every n steps
    accelerator = "gpu",            # Use GPU for training
    nodes = 1,                      # Number of nodes
    devices = 1,                    # Number of devices (GPUs)
    strategy = "auto",              # Distributed Data Parallel

)
model = Mamba2Model(**parameters)



# Lightning callbacks for early stopping and model checkpointing
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint

early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=10,
    mode='min'
)

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath=parameters.get('ckpt_dir', CKPT_DIR),
    filename='mamba-best-model-{epoch:02d}-{val_loss:.4f}',
    save_top_k=1,
    mode='min'
)

trainer = pl.Trainer(
    limit_val_batches=parameters['limit_val_batches'],
    max_epochs=parameters['epochs'], 
    accumulate_grad_batches=parameters['accumulate_grad_batches'], 
    gradient_clip_val=parameters['gradient_clip_val'],  
    default_root_dir=parameters['ckpt_dir'],
    precision=parameters['precision'],
    log_every_n_steps=parameters['log_every_n_steps'],
    accelerator=parameters['accelerator'],
    devices=parameters['devices'],
    strategy=parameters['strategy'],
    num_nodes=parameters['nodes'],
    callbacks=[early_stop_callback, checkpoint_callback]
) 

# Train the model
trainer.fit(model, train_dataloader, val_dataloader)

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type       | Params | Mode 
--------------------------------------------------------
0 | input_projection | Linear     | 384    | train
1 | layers           | ModuleList | 242 K  | train
2 | linear           | Linear     | 325    | train
--------------------------------------------------------
243 K     Trainable params
0         Non-trainable params
243 K     Total params
0.972     Total estimated model params size (MB)
31        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.12/dist-packages/lightning/pytorch/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 32. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/usr/local/lib/python3.12/dist-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.12/dist-packages/lightning/pytorch/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 25. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [-1:59:59<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [19]:
from tensorboard.backend.event_processing import event_accumulator
import matplotlib.pyplot as plt
import os

def plot_tensorboard_scalars(ckpt_dir:str, tag=None, tags=None):
    """
    Plot scalar curves from TensorBoard logs in a given directory.

    Args:
        tag (str): Specific tag to plot.
        tags (list): List of tags to plot.
    """

    import glob
    
    # Find the most recent version
    version_dirs = glob.glob(f"{ckpt_dir}/lightning_logs/version_*")
    if not version_dirs:
        print("No training logs found.")
        return
    

    latest_version = max(version_dirs, key=lambda x: int(x.split('_')[-1]))

    # Find any event files in the directory
    files = glob.glob(os.path.join(latest_version, 'events.out.tfevents.*'))
    if not files:
        raise FileNotFoundError(f"No TensorBoard event files found in {latest_version}")

    # Initialize the EventAccumulator to read scalars
    ea = event_accumulator.EventAccumulator(
        latest_version,
        size_guidance={  # Load all scalar data
            event_accumulator.SCALARS: 0,
        }
    )
    ea.Reload()  # Load the event data

    # Determine which tags to plot
    available_tags = ea.Tags().get('scalars', [])
    if tags:
        plot_tags = [t for t in tags if t in available_tags]
    elif tag:
        if tag not in available_tags:
            raise ValueError(f"Tag '{tag}' not found. Available tags: {available_tags}")
        plot_tags = [tag]
    else:
        # If no tag(s) specified, plot all scalar tags
        plot_tags = available_tags

    if not plot_tags:
        raise ValueError("No valid scalar tags to plot.")

    # Plot each tag's values over steps
    plt.figure()
    for t in plot_tags:
        events = ea.Scalars(t)
        steps = [e.step for e in events]
        values = [e.value for e in events]
        plt.plot(steps, values, label=t)

    plt.xlabel('Step')
    plt.ylabel('Value')
    plt.title(f"TensorBoard Scalars")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(ckpt_dir, "plots", "mamba_tensorBoard_scalars.svg"))
    plt.close()


# Plot the training and validation loss curves
plot_tensorboard_scalars(ckpt_dir=parameters['ckpt_dir'], tags=['train_loss', 'val_loss'])

Below is the scalar plot of the training and validation loss over the epochs. The training loss is shown in blue and the validation loss is shown in orange. As we can see, the training loss decreases over time, indicating that the model is learning. The validation loss also decreases, but at a slower rate, indicating that the model is not overfitting to the training data.

![Mamba Attention Predictions](../assets/images/mamba_tensorBoard_scalars.svg)


## Model Evaluation 

### Test Dataset
Now that we have trained the model, we can evaluate it on the test dataset. We will use the `evaluate` method from the `litdata` library to evaluate the model on the test dataset. Below is the code to evaluate the model on the test dataset.

In [24]:
# Load the best checkpoint
best_model_path = checkpoint_callback.best_model_path
if best_model_path:
    print(f"Loading best model from {best_model_path}")
    model = Mamba2Model.load_from_checkpoint(best_model_path)

# Test the model and calculate metrics
print("Evaluating model on test set...")
trainer.test(model, test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Loading best model from /workspace/datasets/checkpoints/mamba-best-model-epoch=21-val_loss=0.1543.ckpt
Evaluating model on test set...


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.


Testing: |          | 0/? [00:00<?, ?it/s]


===== TEST METRICS =====
Average MSE: 0.252570
Average RMSE: 0.234710
Average MAE: 0.161060
Average MAPE: 67.145914%
Average R²: 0.832838



[{'test_loss': 0.21124550700187683,
  'test_Open_mse': 0.00015961345343384892,
  'test_Open_rmse': 0.012633821927011013,
  'test_Open_mae': 0.009417413733899593,
  'test_Open_mape': 1.3162744045257568,
  'test_Open_r2': 0.9975064992904663,
  'test_High_mse': 0.000122522353194654,
  'test_High_rmse': 0.011068981140851974,
  'test_High_mae': 0.0075338599272072315,
  'test_High_mape': 1.0993212461471558,
  'test_High_r2': 0.9980871081352234,
  'test_Low_mse': 0.00015940971206873655,
  'test_Low_rmse': 0.012625755742192268,
  'test_Low_mae': 0.008318182080984116,
  'test_Low_mape': 1.2755264043807983,
  'test_Low_r2': 0.9975081086158752,
  'test_Close_mse': 0.0001886643876787275,
  'test_Close_rmse': 0.013735515996813774,
  'test_Close_mae': 0.009779187850654125,
  'test_Close_mape': 1.4570692777633667,
  'test_Close_r2': 0.9970526695251465,
  'test_Volume_mse': 1.2622184753417969,
  'test_Volume_rmse': 1.1234849691390991,
  'test_Volume_mae': 0.7702534794807434,
  'test_Volume_mape': 330.

In [26]:
import os
import glob


# Create a performance report with all the metrics
def create_performance_report(ckpt_dir: str = CKPT_DIR, model_name: str = "mamba"):
    # Get metrics from the lightning logs    
    # Find the most recent version
    version_dirs = glob.glob(f"{ckpt_dir}/lightning_logs/version_*")
    if not version_dirs:
        print("No training logs found.")
        return
    
    latest_version = max(version_dirs, key=lambda x: int(x.split('_')[-1]))
    metrics_path = f"{latest_version}/metrics.csv"
    
    if not os.path.exists(metrics_path):
        print(f"No metrics file found at {metrics_path}")
        return
    
    # Load metrics
    metrics_df = pd.read_csv(metrics_path)
    
    # Filter for test metrics only
    test_metrics = metrics_df[metrics_df['step'].isna()]
    
    # Create a report
    with open(f"{ckpt_dir}/test_performance_report.md", 'w') as f:
        f.write("# Model Performance Report\n\n")
        f.write("## Overall Metrics\n\n")
        
        # Add overall metrics
        if 'test_avg_mse' in test_metrics.columns:
            f.write(f"* **Average MSE:** {test_metrics['test_avg_mse'].iloc[-1]:.6f}\n")
        if 'test_avg_rmse' in test_metrics.columns:
            f.write(f"* **Average RMSE:** {test_metrics['test_avg_rmse'].iloc[-1]:.6f}\n")
        if 'test_avg_mae' in test_metrics.columns:
            f.write(f"* **Average MAE:** {test_metrics['test_avg_mae'].iloc[-1]:.6f}\n")
        if 'test_avg_mape' in test_metrics.columns:
            f.write(f"* **Average MAPE:** {test_metrics['test_avg_mape'].iloc[-1]:.6f}%\n")
        if 'test_avg_r2' in test_metrics.columns:
            f.write(f"* **Average R²:** {test_metrics['test_avg_r2'].iloc[-1]:.6f}\n")
        
        f.write("\n## Feature-Specific Metrics\n\n")
        
        # Add feature-specific metrics
        for feature in ['Open', 'High', 'Low', 'Close', 'Volume']:
            f.write(f"### {feature}\n\n")
            
            if f'test_{feature}_mse' in test_metrics.columns:
                f.write(f"* **MSE:** {test_metrics[f'test_{feature}_mse'].iloc[-1]:.6f}\n")
            if f'test_{feature}_rmse' in test_metrics.columns:
                f.write(f"* **RMSE:** {test_metrics[f'test_{feature}_rmse'].iloc[-1]:.6f}\n")
            if f'test_{feature}_mae' in test_metrics.columns:
                f.write(f"* **MAE:** {test_metrics[f'test_{feature}_mae'].iloc[-1]:.6f}\n")
            if f'test_{feature}_mape' in test_metrics.columns:
                f.write(f"* **MAPE:** {test_metrics[f'test_{feature}_mape'].iloc[-1]:.6f}%\n")
            if f'test_{feature}_r2' in test_metrics.columns:
                f.write(f"* **R²:** {test_metrics[f'test_{feature}_r2'].iloc[-1]:.6f}\n")
            
            f.write("\n")
        
        # Add prediction plots
        f.write("## Prediction Visualizations\n\n")
        for feature in ['Open', 'High', 'Low', 'Close', 'Volume']:
            f.write(f"### {feature} Predictions\n\n")
            f.write(f"![{feature} Predictions](plots/{feature}_predictions.svg)\n\n")
        
        f.write("### All Features (Sample 1)\n\n")
        f.write(f"![All Features](plots/all_features_predictions.svg)\n\n")

# Create the performance report
create_performance_report(ckpt_dir=parameters['ckpt_dir'], model_name="mamba")
print(f"Test performance report saved to {parameters['ckpt_dir']}/test_performance_report.md")

No metrics file found at /workspace/datasets/checkpoints/lightning_logs/version_3/metrics.csv
Test performance report saved to /workspace/datasets/checkpoints/test_performance_report.md


![Mamba Predictions](../assets/images/mamba_all_features_predictions.svg)

Let's take a look at how the model performs differently across features. The metrics table shows:

Close price prediction: R² of 0.99
Open price prediction: R² of 0.99
High price prediction: R² of 0.99
Low price prediction: R² of 0.99
Volume prediction: R² of only 0.17

Volume is clearly the most challenging feature to predict, with an R² value of only 0.17 and extremely high MAPE of 330%. The high MAPE value indicates substantial percentage errors in predictions. This happens when:

- Working with normalized data where values are close to zero
- Dealing with highly volatile series like cryptocurrency prices
- The model struggles with abrupt changes

We can improve the model in many ways, such as:
- Data transformation: Better normalization or scaling
- Feature engineering: Adding more features or using different features such as technical indicators or more complex features such as wavelet transforms.
- Model architecture: Using more complex architectures such as CNNs or LSTMs or alternative attention architectures such as state space models.
- Hyperparameter tuning: Tuning the hyperparameters of the model to improve performance
- Regularization: Adding regularization to the model to prevent overfitting.