Skip to content

Commit

Permalink
Set fp16 as default behavior if user has supported Nvidia GPU (or arm…
Browse files Browse the repository at this point in the history
… mac) (#2203)

* Set FP16 as default behavior if user has supported nvidia GPU

* oops

* Make some adjustments

* Update backend/src/gpu.py

Co-authored-by: Michael Schmidt <mitchi5000.ms@googlemail.com>

* cleanup

* Apply suggestions from code review

---------

Co-authored-by: Michael Schmidt <mitchi5000.ms@googlemail.com>
  • Loading branch information
joeyballentine and RunDevelopment committed Sep 13, 2023
1 parent 37d1f27 commit 926ad07
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 4 deletions.
32 changes: 31 additions & 1 deletion backend/src/gpu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Tuple
from typing import List, Tuple, Union

import pynvml as nv
from sanic.log import logger
Expand All @@ -22,6 +22,29 @@ class _GPU:
uuid: str
index: int
handle: int
arch: int


FP16_ARCH_ABILITY_MAP = {
nv.NVML_DEVICE_ARCH_KEPLER: False,
nv.NVML_DEVICE_ARCH_MAXWELL: False,
nv.NVML_DEVICE_ARCH_PASCAL: False,
nv.NVML_DEVICE_ARCH_VOLTA: True,
nv.NVML_DEVICE_ARCH_TURING: True,
nv.NVML_DEVICE_ARCH_AMPERE: True,
nv.NVML_DEVICE_ARCH_ADA: True,
nv.NVML_DEVICE_ARCH_HOPPER: True,
nv.NVML_DEVICE_ARCH_UNKNOWN: False,
}


def supports_fp16(gpu: _GPU):
# This generation also contains the GTX 1600 cards, which do not support FP16.
if gpu.arch == nv.NVML_DEVICE_ARCH_TURING:
# There may be a more robust way to check this, but for now I think this will do.
return "RTX" in gpu.name
# Future proofing. We can be reasonably sure that future architectures will support FP16.
return FP16_ARCH_ABILITY_MAP.get(gpu.arch, gpu.arch > nv.NVML_DEVICE_ARCH_HOPPER)


class NvidiaHelper:
Expand All @@ -39,6 +62,7 @@ def __init__(self):
uuid=nv.nvmlDeviceGetUUID(handle),
index=i,
handle=handle,
arch=nv.nvmlDeviceGetArchitecture(handle),
)
)

Expand All @@ -57,6 +81,12 @@ def get_current_vram_usage(self, gpu_index=0) -> Tuple[int, int, int]:

return info.total, info.used, info.free

def supports_fp16(self, gpu_index: Union[int, None] = None) -> bool:
if gpu_index is None:
return all(supports_fp16(gpu) for gpu in self.__gpus)
gpu = self.__gpus[gpu_index]
return supports_fp16(gpu)


_cachedNvidiaHelper = None

Expand Down
9 changes: 7 additions & 2 deletions backend/src/packages/chaiNNer_onnx/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@

from . import package

nv = get_nvidia_helper()

if not is_arm_mac:
nv = get_nvidia_helper()
gpu_list = nv.list_gpus() if nv is not None else []

package.add_setting(
Expand Down Expand Up @@ -66,12 +67,16 @@ def get_providers():
)
)

should_fp16 = False
if nv is not None:
should_fp16 = nv.supports_fp16()

package.add_setting(
ToggleSetting(
label="Use TensorRT FP16 Mode",
key="tensorrt_fp16_mode",
description="Runs TensorRT in half-precision (FP16) mode for less VRAM usage. RTX GPUs also get a speedup.",
default=False,
default=should_fp16,
disabled="TensorrtExecutionProvider" not in execution_providers,
)
)
Expand Down
11 changes: 10 additions & 1 deletion backend/src/packages/chaiNNer_pytorch/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import torch

from api import DropdownSetting, ToggleSetting
from gpu import get_nvidia_helper
from system import is_arm_mac

from . import package

nv = get_nvidia_helper()

if not is_arm_mac:
gpu_list = []
for i in range(torch.cuda.device_count()):
Expand All @@ -32,6 +35,12 @@
),
)

should_fp16 = False
if nv is not None:
should_fp16 = nv.supports_fp16()
else:
should_fp16 = is_arm_mac

package.add_setting(
ToggleSetting(
label="Use FP16 Mode",
Expand All @@ -41,7 +50,7 @@
if is_arm_mac
else "Runs PyTorch in half-precision (FP16) mode for less VRAM usage. RTX GPUs also get a speedup."
),
default=False,
default=should_fp16,
),
)

Expand Down

0 comments on commit 926ad07

Please sign in to comment.