# PyTorch rocm install into the specific notebook kernel

this notebook is split into two parts:

- **Part A (install andiagnostics)**: choose ROCm channel, optionally do a clean install, install via the kernels Python.
- **Part B (after restart if u want to or use another ipynb)**: verify PyTorch/rocm + run a tiny GPU test.

installs PyTorch ROCm wheels into the same environment as the running kernel.

doesnt install the system ROCm driver/runtime stack (kernel modules, `/dev/kfd`, etc.).

## Important
- if you change/install PyTorch, **restart the kernel** before Part B.
- large downloads (multigb) are expected and it may take a while. It does not use cache to eep a fresh install just in case.


## Part A — Install +and diagnstics

### A1) kkernel diagnostics 
Run this first.


In [None]:
import sys, platform, site, shutil, subprocess

print('=== KERNEL PYTHON ===')
print('sys.executable:', sys.executable)
print('sys.version:', sys.version)
print('platform:', platform.platform())
print()
print('=== PIP ON PATH vs KERNEL PIP(!!!! important) ===')
print('which pip:', shutil.which('pip'))
try:
    out = subprocess.check_output(['pip','-V'], text=True)
    print('pip -V:', out.strip())
except Exception as e:
    print('pip -V failed:', type(e).__name__, e)

try:
    out = subprocess.check_output([sys.executable,'-m','pip','-V'], text=True)
    print('kernel pip -V:', out.strip())
except Exception as e:
    print('kernel pip -V failed:', type(e).__name__, e)

print()
print('=== SITE-PACKAGES TARGETS ===')
try:
    print('site.getsitepackages():')
    for p in site.getsitepackages():
        print('  -', p)
except Exception:
    print('site.getsitepackages() not available in this environment')
print('site.getusersitepackages():', site.getusersitepackages())


### A2) Choose rcomm wheel channel andinstall behavior

Set `CHOICE` to:
- `'stable_rocm6.4'`  (stable wheels)
- `'nightly_rocm7.0'` (nightly wheels; uses `--pre`)

Set `CLEAN_INSTALL = True` if you wanna uninstall core torch packages from this kernel first.

this notebook always uses `--no-cache-dir` so pip wont use its wheel download cache.


In [None]:
# --- ATTENTION HERE!!! USER OPTIONS FOR U TO CONFIG ---
CHOICE = 'stable_rocm6.4'   #'stable_rocm6.4' or 'nightly_rocm7.0'
CLEAN_INSTALL = True       #uninstall core torch pkgs first
NO_PIP_CACHE = True        #keep True (adds --no-cache-dir)
FORCE_REINSTALL = False    #set True to force reinstall even if already satisfied

OPTIONS = {
    'stable_rocm6.4': {
        'label': 'Stable ROCm 6.4',
        'index_url': 'https://download.pytorch.org/whl/rocm6.4',
        'pip_extra': [],
    },
    'nightly_rocm7.0': {
        'label': 'Nightly ROCm 7.0',
        'index_url': 'https://download.pytorch.org/whl/nightly/rocm7.0',
        'pip_extra': ['--pre'],
    },
}

if CHOICE not in OPTIONS:
    raise ValueError(f"Invalid CHOICE={CHOICE!r}. Pick one of: {list(OPTIONS)}")

SEL = OPTIONS[CHOICE]
INDEX_URL = SEL['index_url']
PIP_EXTRA = list(SEL['pip_extra'])

print('Selected channel:', SEL['label'])
print('Wheel index URL:', INDEX_URL)
print('Extra pip args:', PIP_EXTRA)
print('CLEAN_INSTALL:', CLEAN_INSTALL)
print('NO_PIP_CACHE:', NO_PIP_CACHE)
print('FORCE_REINSTALL:', FORCE_REINSTALL)


### A3) (Optional) Host rocm sanity check
if these fail, the host likely doesn't have ROCm installed/available, might be good for debug.


In [None]:
print('Checking /dev/kfd and /dev/dri...')
!ls -l /dev/kfd /dev/dri 2>/dev/null || true

print('\nrocminfo (first lines):')
!rocminfo 2>/dev/null | head -n 30 || echo 'rocminfo not found or failed'

print('\nrocm-smi:')
!rocm-smi 2>/dev/null || echo 'rocm-smi not found or failed'


### A4) Install (into the kernel's Python)

This cell:
- prints the exact Python adn pip being used
- optionally uninstalls core torch packages (i mean, clean install)
- upgrades pip tooling
- installs `torch torchvision torchaudio` from the selected rcomm channel


In [None]:
import os, sys, subprocess, shutil

def run(cmd):
    print('\n$ ' + ' '.join(cmd))
    subprocess.check_call(cmd)

print('=== INSTALL TARGET (KERNEL) ===')
print('sys.executable:', sys.executable)
print('kernel pip -V:', subprocess.check_output([sys.executable,'-m','pip','-V'], text=True).strip())
print('which pip on PATH:', shutil.which('pip'))

#force pip to not use cache so we can control everything and make sure the runtime is right
if NO_PIP_CACHE:
    os.environ['PIP_NO_CACHE_DIR'] = '1'

pip_base = [sys.executable, '-m', 'pip']

if CLEAN_INSTALL:
    # Core torch packages only (does NOT(!!!) remove torchmetrics/pytorch-lightning/etc.)
    to_remove = ['torch', 'torchvision', 'torchaudio', 'torchtext', 'torchdata']
    run(pip_base + ['uninstall', '-y'] + to_remove)

tool_args = ['install', '-U', 'pip', 'wheel', 'setuptools']
if NO_PIP_CACHE:
    tool_args.insert(1, '--no-cache-dir')
run(pip_base + tool_args)

install_args = ['install']
if NO_PIP_CACHE:
    install_args.append('--no-cache-dir')
install_args += PIP_EXTRA
if FORCE_REINSTALL:
    install_args += ['--force-reinstall']
install_args += ['torch', 'torchvision', 'torchaudio', '--index-url', INDEX_URL]
run(pip_base + install_args)

print('\n=== DONE ===')
print('Now: Kernel → Restart Kernel, then run Part B.')


---

## Part B — Run AFTER restarting the kernel

**Restart the kernel now**, then continue below.


### B1) Verify installed torch + ROCm


In [None]:
import sys
print('sys.executable:', sys.executable)

import torch
print('torch.__version__:', torch.__version__)
print('torch file:', torch.__file__)
print('torch.version.hip:', getattr(torch.version, 'hip', None))
print('torch.cuda.is_available():', torch.cuda.is_available())
print('torch.cuda.device_count():', torch.cuda.device_count())
if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        print(f'Device {i}:', torch.cuda.get_device_name(i))


### B2) Tiny matmul test on GPU 0


In [None]:
import time, torch
assert torch.cuda.is_available(), 'No torch-visible GPU.'
device = torch.device('cuda:0')
a = torch.randn((2048, 2048), device=device, dtype=torch.float16)
b = torch.randn((2048, 2048), device=device, dtype=torch.float16)
for _ in range(5):
    c = a @ b
torch.cuda.synchronize(device)
t0 = time.time()
for _ in range(20):
    c = a @ b
torch.cuda.synchronize(device)
t1 = time.time()
print('OK | elapsed:', round(t1-t0, 6), 's | mean:', float(c.mean().item()))
