In [3]:
import torch
import subprocess

def get_available_gpus(min_vram_gb=20):
    # Execute the nvidia-smi command to get GPU memory usage
    result = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.free', '--format=csv,nounits,noheader'])
    free_memory = [int(x) for x in result.decode('utf-8').strip().split('\n')]
    
    # Check for GPUs with enough free memory
    available_gpus = [i for i, mem in enumerate(free_memory) if mem >= min_vram_gb * 1024]  # convert GB to MB
    return available_gpus

def set_device_with_sufficient_memory(min_vram_gb=20):
    available_gpus = get_available_gpus(min_vram_gb)
    
    if len(available_gpus) > 0:
        # Set the first available GPU with enough memory
        device = torch.device(f'cuda:{available_gpus[0]}')
        torch.cuda.set_device(device)
        print(f'Set GPU {available_gpus[0]} with {min_vram_gb}GB available memory as device.')
    else:
        # Fall back to CPU if no GPU is available
        device = torch.device('cpu')
        print(f'No GPUs with at least {min_vram_gb}GB available memory. Falling back to CPU.')
    
    return device

# Set the device and make it available globally
device = set_device_with_sufficient_memory()

Set GPU 0 with 20GB available memory as device.


In [4]:
device

device(type='cuda', index=0)

In [5]:
!nvidia-smi

Mon Oct 21 15:37:23 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA H100 NVL                On  |   00000000:17:00.0 Off |                    0 |
| N/A   61C    P0            302W /  400W |   31665MiB /  95830MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA H100 NVL                On  |   00