In [1]:
import ivy
import os
import gc
import abc
import math
import time
import queue
import psutil
import inspect
import logging
import nvidia_smi
from typing import Optional

from typing import Optional, Union, Tuple, Sequence

# local
from ivy.backend_handler import current_backend
from ivy.func_wrapper import (
    infer_device,
    infer_dtype,
    outputs_to_ivy_arrays,
    handle_out_argument,
    to_native_arrays_and_back,
    handle_nestable,
)

def used_mem_on_dev(
    device: Union[ivy.Device, ivy.NativeDevice], process_specific=False
):
    """
    Get the used memory (in GB) for a given device string. In case of CPU, the used
    RAM is returned.

    Parameters
    ----------
    device
        The device string to convert to native device handle.
    process_specific
        Whether the check the memory used by this python process alone. Default is
        False.
        
    Return type
        Float

    Returns 
    -------
    ret
        The used memory on the device in GB.
    
    Examples
    --------
    A "cpu" as device string:
    >>> x = ivy.as_native_dev("cpu") as device
    >>> ivy.used_mem_on_dev(x)
    
    A "gpu" as device string:
    >>> y = ivy.as_native_dev("gpu:idx") as device
    >>> ivy.used_mem_on_dev(y)
    
    >>> ivy.set_backend("torch")
    >>> z = ivy.as_native_dev("cpu")
    >>> ivy.used_mem_on_dev(z)
    
    >>> import torch
    >>> ivy.set_backend("torch")
    >>> device = torch.device("cpu")
    >>> ivy.default_device(as_native=True)
    >>> ivy.used_mem_on_dev(device)
    """
    ivy.clear_mem_on_dev(device)
    if "gpu" in device:
        if process_specific:
            raise Exception("process-specific GPU queries are currently not supported")
            handle = _get_nvml_gpu_handle(device)
            info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
            return info.used / 1e9
    elif device == "cpu":
        if process_specific:
            return psutil.Process(os.getpid()).memory_info().rss
            vm = psutil.virtual_memory()
            return (vm.total - vm.available) / 1e9
    else:
        raise Exception(
            'Invalid device string input, must be on the form "gpu:idx" or "cpu", '
            "but found {}".format(device)
        )