In [2]:
import os
os.environ["CHECKPOINTS_PATH"] = "../checkpoints"

import subprocess as sp
from src.config import Sam2Checkpoints
import dataclasses
from src.api.services import sam2_service
import torch
import time
import gc

# Measure VRAM requirements of models

In [3]:
def get_gpu_memory():
    command = "nvidia-smi --query-gpu=memory.free --format=csv"
    memory_free_info = sp.check_output(command.split()).decode('ascii').split('\n')[:-1][1:]
    memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
    return memory_free_values

In [7]:
checkpoints = Sam2Checkpoints()
requirements = {}

for name, path in dataclasses.asdict(checkpoints).items():
    print(f"Loading {name} checkpoint...")

    # Clear any previous garbage
    gc.collect()
    torch.cuda.empty_cache()

    # Reset stats
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()

    # Load the model
    predictor = sam2_service.load_predictor(path)

    # Sync before measuring
    torch.cuda.synchronize()

    # Get peak memory used (in bytes)
    peak_memory = torch.cuda.max_memory_allocated()
    requirements[name] = round(peak_memory / (1024 ** 2), 2)  # convert to MB

    # Cleanup
    del predictor
    gc.collect()
    torch.cuda.empty_cache()

print("\nVRAM requirements per model:")
for name, mem in requirements.items():
    print(f"{name}: {mem} MB")

Loading BASE_PLUS checkpoint...
Loading LARGE checkpoint...
Loading SMALL checkpoint...
Loading TINY checkpoint...

VRAM requirements per model:
BASE_PLUS: 450.47 MB
LARGE: 1000.91 MB
SMALL: 315.29 MB
TINY: 288.71 MB
