In [1]:
!pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121

Looking in indexes: https://download.pytorch.org/whl/cu121
Collecting torch==2.4.0
  Downloading https://download.pytorch.org/whl/cu121/torch-2.4.0%2Bcu121-cp310-cp310-linux_x86_64.whl (799.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m799.1/799.1 MB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchvision==0.19.0
  Downloading https://download.pytorch.org/whl/cu121/torchvision-0.19.0%2Bcu121-cp310-cp310-linux_x86_64.whl (7.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m21.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchaudio==2.4.0
  Downloading https://download.pytorch.org/whl/cu121/torchaudio-2.4.0%2Bcu121-cp310-cp310-linux_x86_64.whl (3.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m51.0 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch==2.4.0)
  Downloading https://download.pytorch.org/whl/cu121/nvidia_cuda_nvrtc_cu1

In [2]:
!pip install causal-conv1d && pip install mamba-ssm

Collecting causal-conv1d
  Downloading causal_conv1d-1.4.0.tar.gz (9.3 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ninja (from causal-conv1d)
  Downloading ninja-1.11.1.2-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (5.3 kB)
Downloading ninja-1.11.1.2-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (422 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m422.9/422.9 kB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: causal-conv1d
  Building wheel for causal-conv1d (setup.py) ... [?25l[?25hdone
  Created wheel for causal-conv1d: filename=causal_conv1d-1.4.0-cp310-cp310-linux_x86_64.whl size=104867883 sha256=b5e7cf7e964b5e99275d97ba1e1b0ee4e3073f4593743ba1f1c6aa394a3008cc
  Stored in directory: /root/.cache/pip/wheels/e3/dd/4c/205f24e151736bd22f5980738dd10a19af6f093b6f4dcab006
Successfully built causal-conv1d
Installing collected packages: ninja, causal-conv1d
Successfully instal

In [3]:
!pip install memory_profiler

Collecting memory_profiler
  Downloading memory_profiler-0.61.0-py3-none-any.whl.metadata (20 kB)
Downloading memory_profiler-0.61.0-py3-none-any.whl (31 kB)
Installing collected packages: memory_profiler
Successfully installed memory_profiler-0.61.0


In [15]:
import torch
from torch.utils.checkpoint import checkpoint
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from memory_profiler import memory_usage  # For CPU memory profiling, if needed

In [16]:
def measure_memory_usage(model, input_ids, device="cuda", mixed_precision=False):
    """
    Measures GPU memory usage during a forward pass.
    """
    torch.cuda.reset_peak_memory_stats(device)
    torch.cuda.empty_cache()
    model.to(device)
    input_ids = input_ids.to(device)

    with torch.no_grad():  # To exclude gradient computation memory
        if mixed_precision:
            with torch.cuda.amp.autocast():  # Mixed precision enabled
                _ = model(input_ids)
        else:
            _ = model(input_ids)

    memory_used = torch.cuda.memory_allocated(device) / (1024 ** 2)  # Convert to MB
    return memory_used

In [17]:
if __name__ == "__main__":
    # Configuration
    pretrained_path = "state-spaces/mamba-130m"  # Use a smaller model if needed
    device = "cuda"
    dtype = torch.float32  # Default precision for the original model
    seq_length = 32        # Sequence length
    batch_size = 2         # Batch size

    # Generate dummy input
    input_ids = torch.randint(0, 50257, (batch_size, seq_length))

    # Load the original Mamba model
    original_model = MambaLMHeadModel.from_pretrained(pretrained_path, device=device, dtype=dtype)

    # Measure memory usage for the original model
    original_memory = measure_memory_usage(original_model, input_ids, device, mixed_precision=False)
    print(f"Original Model: {original_memory:.2f} MB")

    # Clear GPU memory after the original model measurement
    del original_model
    torch.cuda.empty_cache()

    # Load the model again for mixed precision
    mixed_precision_model = MambaLMHeadModel.from_pretrained(pretrained_path, device=device, dtype=dtype)

    # Measure memory usage for the mixed precision model
    mixed_precision_memory = measure_memory_usage(mixed_precision_model, input_ids, device, mixed_precision=True)
    print(f"Mixed Precision Model: {mixed_precision_memory:.2f} MB")

    # Clear GPU memory after the mixed precision model measurement
    del mixed_precision_model
    torch.cuda.empty_cache()

    # Compare memory usage
    print(f"Memory Saved: {original_memory - mixed_precision_memory:.2f} MB")

Original Model: 1747.43 MB
Mixed Precision Model: 1741.29 MB
Memory Saved: 6.14 MB


  with torch.cuda.amp.autocast():  # Mixed precision enabled
