In [None]:
import importlib
import os
import random
import time

import torch

An Intel GPU card may be made up of multiple stacks, also known as tiles.
Stack visibility is controlled by the environment variable `ZE_FLAT_DEVICE_HIERARCHY`.
If it is set to `"FLAT"` (default), each GPU stack is made visible as a separate device.
If it is set to `"COMPOSITE"`, the GPU card is seen as a single device.

In `torch`, the visibility of a device needs to be set before the device is initialised,
and can't subsequently be changed.  If on a system with Intel GPUs, try running this notebook for the differnt visibility modes.

**WARNING**: When running on a GPU, the function `torch_matrix_multiplication` terminates when it attempts multiplication of matrices that are sufficiently large to generate an out-of-memory error.  This gives an idea of the GPU memory available.  Recovery from the error can take a little while (usually seconds rather than minutes).  If you want to avoid this, you can set a limit on matrix size for all device types - not just for `"cpu"`.

In [None]:
modes = ["FLAT", "COMPOSITE"]
mode = modes[0]

In [None]:
def pytorch_check_devices(mode="FLAT"):
    """
    Check devices available on current system, for specified visibility mode.
    """
    
    os.environ["ZE_FLAT_DEVICE_HIERARCHY"] = mode

    # Define device types to be considered.
    device_types = ["cpu", "cuda", "mps", "xpu", "fictional_device"]
    
    # Print information about available device types.
    print(f"{mode} mode - devices seen by torch:")
    for device_type in sorted(device_types):
        # Determine number of devices of each type.
        try:
            device_module = importlib.import_module(f"torch.{device_type}")
        except ModuleNotFoundError:
            device_module = None
        n_device = getattr(device_module, "device_count", lambda: 0)()
        devices = [f"{device_type}:{idx}" for idx in range(n_device)]
        print(f"    {device_type}: {devices}")
    print()

pytorch_check_devices(mode=mode)

In [None]:
def torch_matrix_multiplication(mode="FLAT"):
    """
    Check time for multiplication of square matrices of different ranks,
    using different device types, and specified visibility mode.
    """
    os.environ["ZE_FLAT_DEVICE_HIERARCHY"] = mode
    # Define device types to be considered.
    device_types = ["cpu", "cuda", "mps", "xpu", "fictional_device"]
    # Number of times to attempt matrix multiplication.
    n_attempt = 3
    # Print information about available device types.
    for device_type in device_types:
        # Determine number of devices of each type.
        try:
            device_module = importlib.import_module(f"torch.{device_type}")
        except ModuleNotFoundError:
            device_module = None
        if hasattr(device_module, "is_available") and device_module.is_available():
            n_device = device_module.device_count()
        else:
            n_device = 0
        print(f"\n{mode} mode - device type: {device_type}")
        print(f"Number of devices: {n_device}")
        # Test matrix-multiplication time for all devices of current type,
        # considering devices in random order.
        indices = list(range(n_device))
        random.shuffle(indices)
        i_dim = 0
        while n_device:
            dim = 2**i_dim
            i_dim += 1
            i_attempt = 0
            print()
            while i_attempt < n_attempt:
                i_attempt += 1
                for i_device in indices:
                    device_name = f"{device_type}:{i_device}"
                    if dim > 1024 and "cpu" == device_type:
                        n_device = 0
                        i_attempt = n_attempt + 1
                        break
                    t0 = time.time()
                    try:
                        x=torch.randn((dim, dim), device=torch.device(device_name))
                        y=torch.randn((dim, dim), device=torch.device(device_name))
                        z=torch.matmul(x,y)
                    except RuntimeError:
                        n_device =0
                    t1 = time.time()
                    if n_device:
                        print(f"{device_name}: order = {dim}; "
                                f"attempt ={i_attempt : 3d}; "
                                f"time ={(t1 - t0) * 1.e6 : 8.1f} microseconds")
                    else:
                        print(f"{device_type}: order = {dim}; out of memory")
                        i_attempt = n_attempt + 1
                        break

torch_matrix_multiplication(mode=mode)