# MultigPU Torch sanity check (ROCm)

This notebook does the following for us 9if everything is indeed fine):
- prints PyTorch build info
- prints how many GPUs torch sees
- runs a small matmul on **each** GPU and prints `OK` per device :)


In [None]:
import sys, platform
print('Python:', sys.version)
print('Platform:', platform.platform())
print('Executable:', sys.executable)


In [None]:
import torch
print('torch.__version__:', torch.__version__)
print('torch.version.hip:', getattr(torch.version, 'hip', None))
print('torch.cuda.is_available():', torch.cuda.is_available())
print('torch.cuda.device_count():', torch.cuda.device_count())

if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        print(f'Device {i}:', torch.cuda.get_device_name(i))
else:
    print('No GPU visible to torch.')


In [None]:
import time
import torch

def matmul_on_device(i: int, n: int = 2048, iters: int = 20, warmup: int = 5):
    device = torch.device(f'cuda:{i}')
    # just for the sake of it, use fp16 like typical ROCm workloads; reduce n if you want it faster.
    a = torch.randn((n, n), device=device, dtype=torch.float16)
    b = torch.randn((n, n), device=device, dtype=torch.float16)
    # warmup just to see the usage u know
    for _ in range(warmup):
        c = a @ b
    torch.cuda.synchronize(device)
    t0 = time.time()
    for _ in range(iters):
        c = a @ b
    torch.cuda.synchronize(device)
    t1 = time.time()
    # touch result so it cantt be optimized away; mean triggers a reduction so an opps.
    m = c.mean().item()
    return (t1 - t0), m

if not torch.cuda.is_available():
    raise SystemExit('No torch-visible GPU; nothing to test ;((.')

count = torch.cuda.device_count()
print(f'Running matmul test on {count} GPU(s)...')

all_ok = True
results = []
for i in range(count):
    try:
        elapsed, meanv = matmul_on_device(i)
        print(f'GPU {i}: OK | elapsed={elapsed:.4f}s | mean={meanv:.6f}')
        results.append((i, True, elapsed, meanv))
    except Exception as e:
        all_ok = False
        print(f'GPU {i}: FAIL | {type(e).__name__}: {e}')
        results.append((i, False, None, None))

print('\nSUMMARY:', 'OK' if all_ok else 'FAIL')
