In [18]:
import torch

In [19]:
# Get the number of devices
num_cuda_devices = torch.cuda.device_count()
print(num_cuda_devices)


2


In [20]:
# Iterate over each device available and print some information
for i in range(torch.cuda.device_count()):
    print(f"Device Name: {torch.cuda.get_device_name(i)}")
    print(f"Device Capability: {torch.cuda.get_device_capability(i)}")

Device Name: NVIDIA GeForce RTX 4090
Device Capability: (8, 9)
Device Name: NVIDIA GeForce RTX 5090
Device Capability: (12, 0)


In [34]:
# Function to get the best CUDA device, based on the
# major cuda capability of the device
# https://docs.pytorch.org/docs/stable/generated/torch.cuda.get_device_capability.html
def get_best_cuda_device():
    # We are assuming there is a CUDA device available
    # As in, this function is called by another function
    # which will first check if cuda is available
    best_device = max(
        range(torch.cuda.device_count()),
        key=lambda i: torch.cuda.get_device_capability(i)[0]
    )
    return torch.device(f"cuda:{best_device}")

In [35]:
# Example usage
device = get_best_cuda_device()
print(f"Using device: {device}")

Using device: cuda:1


In [36]:
# Platform agnostic
# https://mctm.web.id/blog/2024/PyTorchGPUSelect/
def get_device(override=None):
    # Check not overriding device to use
    if not override:
        # Step 1: Check if CUDA is available
        if torch.cuda.is_available():
            # If so, return the best CUDA device
            return get_best_cuda_device()
        # Step 2: If no CUDA, check if MPS enabled machine
        elif torch.backends.mps.is_available():
            # If so, return as device
            return torch.device("mps")
        # Step 3: If no CUDA and no MPS, then we return CPU
        else:
            return torch.device("cpu")
    else:
        # Else, if we are overriding, we specifically return
        # the device given by the argument (ie. "cpu", "cuda", "cuda:0", "mps", etc)
        return torch.device(override)

In [33]:
test_device = get_device()
test_device

device(type='cuda', index=1)