# Mamba State Space Model with Docker for Local Development and Cloud Deployment

This documentation provides a guide on Mamba State Space Model (SSM) implemented in Python, designed for both local development and cloud deployment using Docker. It covers the following topics:

1. **Introduction to Mamba SSM**: Overview of the Mamba State Space Model and its applications.
2. **Setting Up the Development Environment**: Step-by-step instructions for setting up a local development environment using Docker.
3. **Building and Running the Docker Container**: Instructions for building the Docker image and running the container.
4. **Deploying to the Cloud**: Guidelines for deploying the Mamba SSM to a cloud platform using Docker.
5. **Best Practices**: Tips and best practices for working with Mamba SSM and Docker.


## Prerequisites

Before you begin, ensure you have the following installed on your local machine:

- Docker: [Install Docker](https://docs.docker.com/get-docker/)
- A compatible GPU (for Mamba SSM)
- NVIDIA drivers (if using GPU)


## Sections
- [Introduction to Mamba SSM](#introduction-to-mamba-ssm)
- [Building and Running the Docker Container](#building-and-running-the-docker-container)
- [Using Mamba SSM](#using-mamba-ssm)
- [Deploying to the Cloud](#deploying-to-the-cloud)


## Introduction to Mamba SSM

At the heart of modern AI systems like ChatGPT and AlphaFold lies a critical challenge: understanding sequences. Whether you're parsing a sentence, listening to speech, decoding a genome, or analyzing financial time series, the task boils down to learning patterns across time or space. For the past few years, Transformers—powered by the self-attention mechanism—have dominated this space due to their remarkable ability to model complex relationships between every pair of inputs. But this power comes at a cost: Transformers scale poorly with longer sequences and become expensive or infeasible beyond a few thousand steps.

Enter Mamba, a new kind of sequence model inspired by state space systems—an elegant mathematical framework traditionally used in physics and control theory. Think of it like replacing a massive all-to-all attention grid with a sleek, memory-efficient signal processor that knows when to pay attention and when to ignore noise.

What makes Mamba different isn’t just that it’s faster (though it is—often 3× faster than its peers on modern GPUs), or that it can handle million-token contexts (which it can, with ease). It’s that Mamba is selective. Unlike older state space models that process information uniformly over time, Mamba can decide what matters at each step. It brings a kind of intelligent filtering—like memory with a spotlight—selectively storing important details and discarding the rest.

Here’s the kicker: despite being fully recurrent and operating in linear time, Mamba matches or exceeds the accuracy of Transformers in domains ranging from text and audio to genomics. It doesn’t need attention layers or even separate MLP blocks. Its design is minimal, clean, and hardware-aware—making it not only smart, but fast.

By blending the long-term memory of RNNs, the locality of CNNs, and the expressive power of Transformers—all within a scalable, streamlined architecture—Mamba represents a profound shift in how we think about modeling sequences at scale.


## Building and Running the Docker Container

To build and run the Docker container for Mamba SSM, follow these steps:

1. **Clone the Repository**: Clone the Mamba SSM repository to your local machine.

   ```bash
   git clone https://github.com/gabenavarro/MLContainerLab.git
   cd MLContainerLab
   ```

2. **Build the Docker Image**: Use the provided Dockerfile to build the Docker image.

   ```bash
   # You can choose any tag you want for the image
   # Feel free to play around with the base image, just make sure the host has the same or higher CUDA version
   docker build -f ./assets/build/Dockerfile.mamba.cu128py26cp312 -t ssm-mamba:128-26-312 .
   ```
3. **Run the Docker Container**: Run the Docker container with the necessary configurations. In the first example, we will run the container locally with GPU support. This is the recommended way to run a container while in development mode. For scaling up, we will use the second example which runs the container in the cloud.

   ```bash
    # Run the container with GPU support
    docker run -dt \
        --gpus all \
        -v "$(pwd):/workspace" \
        --name ssm-mamba \
        --env NVIDIA_VISIBLE_DEVICES=all \
        --env GOOGLE_APPLICATION_CREDENTIALS=/workspace/assets/secrets/gcp-key.json \
        ssm-mamba:128-26-312
    ```
> Note: The `-v "$(pwd):/workspace"` option mounts the current directory to `/workspace` in the container, allowing you to access your local files from within the container. The `--env` options set environment variables for GPU visibility and Google Cloud credentials.<br>
> Note: The `--gpus all` option allows the container to use all available GPUs. <br>

4. **Access the Container with IDE**: In this example, we will use Visual Studio Code to access the container. You can use any IDE of your choice.

   ```bash
   # In a scriptable manner
   CONTAINER_NAME=ssm-mamba
   FOLDER=/workspace
   HEX_CONFIG=$(printf {\"containerName\":\"/$CONTAINER_NAME\"} | od -A n -t x1 | tr -d '[\n\t ]')
   code --folder-uri "vscode-remote://attached-container+$HEX_CONFIG$FOLDER"
   ```

> Note: The `code` command is used to open Visual Studio Code. Make sure you have the Remote - Containers extension installed in VS Code to access the container directly. <br>
> Note: Make sure you have installed Remote - Containers extension in VS Code.<br>

## Using Mamba SSM

We will now train a simple Mamba SSM model below using the installed dockerized mamba package above. We will be using a training routine that is similar to training language models with a dataset of bitcoin prices. It will have severe overfitting, but it will be enough to show how to use the package. The training routine is similar to the one used in the [Mamba SSM repository](https://github.com/gabenavarro/MLContainerLab/tree/main/assets/examples/mamba).


### Data Preparation

First, lets go ahead and download the data. We will use a limited dataset from Kaggle to start, however in more advanced scenarios its highly suggested to use an API with access to more datasets such as the [CoinGecko API](https://www.coingecko.com/en/api). 

In [None]:
%%capture
# Download and unzip the Bitcoin historical data dataset from Kaggle
!curl -L -o /workspace/datasets/bitcoin-historical-data.zip \
  https://www.kaggle.com/api/v1/datasets/download/mczielinski/bitcoin-historical-data \
    && unzip -o /workspace/datasets/bitcoin-historical-data.zip -d /workspace/datasets/ \
    && rm /workspace/datasets/bitcoin-historical-data.zip

In [None]:
TIME_SERIES_CSV = "/workspace/datasets/btcusd_1-min_data.csv"
PROCESSED_DATA_DIR = "/workspace/datasets/auto_regressive_processed_timeseries"
CKPT_DIR = "/workspace/datasets/checkpoints"

In [None]:
# Make directories if they do not exist
import os
os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)
os.makedirs(CKPT_DIR, exist_ok=True)

In [None]:
import pandas as pd
pd.set_option('display.max_columns', None)
pd.read_csv(TIME_SERIES_CSV, low_memory=False, nrows=5).head()

Now that we have the dataset, lets create a dataset and dataloader for the model using litdata. Litdata is a lightweight data loading library that is designed to work with PyTorch and other deep learning frameworks. It provides a simple and efficient way to load and preprocess data for training and evaluation.

In [None]:
import pandas as pd
import litdata as ld
import numpy as np
from typing import Dict, Any

def process_timeseries(file_path: str, sequence_length: int = 2048) -> Dict[str, Any]:
    """Process a timeseries CSV file into a format suitable for autoregressive modeling."""
    df = pd.read_csv(file_path, low_memory=False)
    if "datetime" in df.columns:
        df = df.drop(columns=["datetime"])
    df = df.sort_values('Timestamp')
    numerical_features = ['Open', 'High', 'Low', 'Close', 'Volume']
    
    # —— NEW: log-transform price values to handle exponential growth ——
    for price_col in ['Open', 'High', 'Low', 'Close']:
        # Ensure all values are positive before log transform
        df[price_col] = np.log(df[price_col].replace([np.inf, -np.inf, 0, np.nan], 0.01).fillna(0.01))
    
    # —— NEW: log1p-transform volume to reduce extreme skew ——
    df['Volume'] = np.log1p(df['Volume'].replace([np.inf, -np.inf], np.nan).fillna(0.0))
    
    stats = {}
    for feature in numerical_features:
        vals = df[feature].replace([np.inf, -np.inf], np.nan).dropna()
        mean, std = (vals.mean(), vals.std() or 1.0)
        stats[feature] = {'mean': mean, 'std': std}
        df[feature] = (df[feature] - mean) / std

    if len(df) <= sequence_length:
        print(f"Warning: only {len(df)} rows < sequence_length={sequence_length}")
        return None

    def create_timeseries_sample(index: int) -> Dict[str, Any]:
        if index < sequence_length or index >= len(df):
            return {"index": index,
                    "inputs": np.zeros((sequence_length,5),dtype=np.float32),
                    "mask": np.zeros(sequence_length, dtype=bool),
                    "stats": stats}
        seq = df.iloc[index-sequence_length:index][numerical_features].values
        arr = np.nan_to_num(seq.astype(np.float32), nan=0.0, posinf=0.0, neginf=0.0)
        mask = ~np.isnan(seq).any(axis=1)
        return {"index": index, "inputs": arr, "mask": mask, "stats": stats}

    # Store original stats for inverse transformation during inference
    stats['transform_type'] = 'log_then_zscore'
    
    return create_timeseries_sample


# Set the sequence length for your model
sequence_length = 2048

# Get the processing function configured for your specific file
process_function = process_timeseries(TIME_SERIES_CSV, sequence_length)

# Filter indices to exclude those with NaN values
def get_valid_indices():
    df = pd.read_csv(TIME_SERIES_CSV, low_memory=False)
    if "datetime" in df.columns:
        df = df.drop(columns=["datetime"])
    
    df = df.sort_values('Timestamp')
    numerical_features = ['Open', 'High', 'Low', 'Close', 'Volume']
    
    # Replace infinity values with NaN
    for feature in numerical_features:
        df[feature] = df[feature].replace([np.inf, -np.inf], np.nan)
    
    valid_indices = []
    for idx in range(sequence_length, len(df), int(sequence_length * 0.25)):
        # Check if the entire sequence has no NaN values
        sequence = df.iloc[idx-sequence_length:idx][numerical_features].values
        if not np.isnan(sequence).any() and not np.isinf(sequence).any():
            valid_indices.append(idx)
        else:
            print(f"Skipping index {idx} due to NaN or inf values in the sequence.")
    
    if not valid_indices:
        raise ValueError("No valid sequences found! All sequences contain NaN or inf values.")
    
    return valid_indices

valid_indices = get_valid_indices()

# The optimize function writes data in an optimized format
ld.optimize(
    fn=process_function,              # the function that processes each sample
    inputs=valid_indices,             # the indices of valid samples
    output_dir=PROCESSED_DATA_DIR,    # optimized data is stored here
    num_workers=4,                    # The number of workers on the same machine
    chunk_bytes="64MB"                # size of each chunk
)
# Takes about 30 seconds to run

In [None]:
from litdata import StreamingDataset, StreamingDataLoader, train_test_split
import torch
streaming_dataset = StreamingDataset(PROCESSED_DATA_DIR) # data are stored in the cloud

def custom_collate(batch):
    # Filter out None values
    batch = [item for item in batch if item is not None]
    
    if not batch:
        # Return empty tensors if the batch is empty
        return {
            "index": torch.tensor([], dtype=torch.long),
            "inputs": torch.tensor([], dtype=torch.float32),
            "mask": torch.tensor([], dtype=torch.bool),
            "stats": {}
        }
    
    # Process each key separately
    indices = torch.tensor([item["index"] for item in batch], dtype=torch.long)
    
    # Make sure arrays are writable by copying and convert to tensor
    inputs = torch.stack([torch.tensor(np.nan_to_num(item["inputs"].copy(), nan=0.0), dtype=torch.float32) for item in batch])
    masks = torch.stack([torch.tensor(item["mask"].copy(), dtype=torch.bool) for item in batch])
    
    # Get stats (use first non-empty item's stats)
    stats = next((item["stats"] for item in batch if "stats" in item), {})
    
    return {
        "index": indices,
        "inputs": inputs,
        "mask": masks,
        "stats": stats
    }

print(len(streaming_dataset)) # display the length of your data
# out: 100,000

train_dataset, val_dataset, test_dataset = train_test_split(streaming_dataset, splits=[0.8, 0.1, 0.1])

print("Train ", len(train_dataset))
train_dataloader = StreamingDataLoader(train_dataset, num_workers=4, batch_size=32, shuffle=True, collate_fn=custom_collate)  # Create DataLoader for training
# out: 80,000

print("Validation ", len(val_dataset))
val_dataloader = StreamingDataLoader(val_dataset, num_workers=4, batch_size=32, shuffle=False, collate_fn=custom_collate)  # Create DataLoader for validation

test_dataloader = StreamingDataLoader(test_dataset, num_workers=4, batch_size=32, shuffle=False, collate_fn=custom_collate)
# out: 10,000
print("Test ", len(test_dataset))


### Define the Model and Training Loop

Now that we have the data, we can define the model and training loop. We will use the Mamba SSM model from the mamba package. The model is a simple recurrent neural network (RNN) that is designed to work with sequences of data. It uses a state space model to learn the underlying patterns in the data and make predictions.

In [1]:
import torch
from mamba_ssm import Mamba2

batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim, device="cuda")

model = Mamba2(
    d_model=dim,
    d_state=16,
    d_conv=4,
    expand=2,
    headdim=16,
).to("cuda")

y = model(x)
assert y.shape == x.shape

In [3]:
import torch
from mamba_ssm import Mamba

batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape