Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable training on XPU devices in OTX2.0 #3094

Merged
merged 67 commits into from
Mar 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
0abdf10
add raising an error when metric is None
kprokofi Jan 14, 2024
8756968
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi Jan 17, 2024
e171e9d
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi Jan 18, 2024
4e6e21e
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi Jan 22, 2024
c0abe24
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi Jan 23, 2024
1868961
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi Jan 24, 2024
d33e66e
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi Jan 25, 2024
35e925f
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi Jan 31, 2024
3de253f
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi Feb 1, 2024
3f0ce95
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi Feb 2, 2024
bddffa6
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi Feb 12, 2024
8b69f62
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi Feb 14, 2024
7efa031
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi Feb 18, 2024
1f8b9ce
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi Mar 4, 2024
6e0f0b6
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi Mar 8, 2024
4619997
Merge branch 'v2' of https://github.com/openvinotoolkit/training_exte…
kprokofi Mar 12, 2024
96104fc
added accelerators
kprokofi Mar 12, 2024
d09392c
fix packages
kprokofi Mar 12, 2024
8c59e04
fix assigning model
chuneuny-emily Mar 13, 2024
57b0f33
debug on MAX
kprokofi Mar 13, 2024
1471fbb
change precision
kprokofi Mar 13, 2024
0a4a69a
update MixedPrecisionXPUPlugin
kprokofi Mar 13, 2024
d77b84b
debug
kprokofi Mar 13, 2024
37601e2
merge
kprokofi Mar 13, 2024
272a534
Merge branch 'kp/xpu_otx2.0' of https://github.com/openvinotoolkit/tr…
kprokofi Mar 13, 2024
79ec108
added monkey patching
kprokofi Mar 13, 2024
1558295
minor
kprokofi Mar 13, 2024
d71492a
minor
kprokofi Mar 13, 2024
1e7f005
Merge branch 'kp/xpu_otx2.0' of https://github.com/openvinotoolkit/tr…
kprokofi Mar 13, 2024
d117eb5
Merge branch 'kp/xpu_otx2.0' of https://github.com/openvinotoolkit/tr…
kprokofi Mar 14, 2024
bac7e72
added patch for mmengine
kprokofi Mar 14, 2024
52a4e78
Merge branch 'kp/xpu_otx2.0' of https://github.com/openvinotoolkit/tr…
kprokofi Mar 14, 2024
0df569a
fix OD and IS
kprokofi Mar 15, 2024
1292d1c
benchmark debug
kprokofi Mar 15, 2024
9e64563
change device
kprokofi Mar 16, 2024
defbf9e
quick fix for instance seg
kprokofi Mar 16, 2024
d585a05
merge
kprokofi Mar 16, 2024
df7d89e
fix pre-commit
kprokofi Mar 18, 2024
bbabd6e
fix pre-commit
kprokofi Mar 18, 2024
509a226
clean the code
kprokofi Mar 19, 2024
5640921
merge develop
kprokofi Mar 19, 2024
9a9aeb4
merge
kprokofi Mar 18, 2024
ea1ea19
added additional flag for mmcv
kprokofi Mar 19, 2024
98d9742
Merge branch 'develop' into kp/xpu_otx2.0
kprokofi Mar 19, 2024
aaf5568
added unit tests
kprokofi Mar 19, 2024
a3573c1
fixed unit test
kprokofi Mar 19, 2024
2184b7d
fix linter
kprokofi Mar 19, 2024
7e303a1
added unit tests and replied comments
kprokofi Mar 20, 2024
4dd325e
fix pre-commit
kprokofi Mar 20, 2024
64a4100
minor fix
kprokofi Mar 20, 2024
967b2db
added documentation
kprokofi Mar 20, 2024
5411d04
fix unit test
kprokofi Mar 20, 2024
2f1e411
add workaround for semantic segmentation
kprokofi Mar 20, 2024
ef4b93d
remove RoiAlignTest due to unstability
kprokofi Mar 20, 2024
121be65
minor
kprokofi Mar 20, 2024
047b0b7
remove strategy back
kprokofi Mar 20, 2024
028190a
try to patch SingleDeviceStrategy
kprokofi Mar 21, 2024
c1deb52
added auto xpu configuration
kprokofi Mar 21, 2024
52e529a
patch strategy
kprokofi Mar 21, 2024
8892171
small fix
kprokofi Mar 21, 2024
2489581
reply to comments
kprokofi Mar 21, 2024
8fb52df
move patching xpu packages to accelerator
kprokofi Mar 21, 2024
d3b9a6d
fix test_xpu test
kprokofi Mar 21, 2024
a47e884
Merge branch 'releases/2.0.0' into kp/xpu_otx2.0
kprokofi Mar 21, 2024
f99f022
remove do-not-install-mmcv
kprokofi Mar 21, 2024
a368e9e
fix pre-commit
kprokofi Mar 21, 2024
98e0b69
remove torch.xpu.optimize for segmentation
kprokofi Mar 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 65 additions & 3 deletions docs/source/guide/get_started/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ The current version of OpenVINO™ Training Extensions was tested in the followi
- Python >= 3.10


***********************************************
Install OpenVINO™ Training Extensions for users
***********************************************
**********************************************************
Install OpenVINO™ Training Extensions for users (CUDA/CPU)
**********************************************************

1. Install OpenVINO™ Training Extensions package:

Expand Down Expand Up @@ -57,6 +57,68 @@ Install OpenVINO™ Training Extensions for users
3. Once the package is installed in the virtual environment, you can use full
OpenVINO™ Training Extensions command line functionality.

*************************************************************
Install OpenVINO™ Training Extensions for users (XPU devices)
*************************************************************

1. Follow the first two steps from above instructions
on cloning the repository and creating a virtual environment.

2. Install Intel Extensions For Pytorch (IPEX).
Follow the `official documentation <https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu&version=v2.1.10%2Bxpu>`_ to install prerequisites such as OneAPI and proper drivers.

.. code-block:: shell

python -m pip install torch==2.1.0a0 torchvision==0.16.0a0 torchaudio==2.1.0a0 intel-extension-for-pytorch==2.1.10+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/

3. Install MMCV.
It is required to install mmcv from source to properly build it with IPEX.

.. code-block:: shell

git clone https://github.com/open-mmlab/mmcv
cd mmcv
git checkout v2.1.0
MMCV_WITH_OPS=1 pip install -e .

4. Install OpenVINO™ Training Extensions
package from either:

* A local source in development mode

.. code-block:: shell

pip install -e .

* PyPI

.. code-block:: shell

pip install otx

5. Install requirements for training
excluding Pytorch.

.. code-block:: shell

otx install -v --do-not-install-torch

6. Activate OneAPI environment
and export required IPEX system variables

.. code-block:: shell

source /path/to/intel/oneapi/setvars.sh
export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libstdc++.so.6.0.30
export IPEX_FP32_MATH_MODE=TF32

7. Once the package is installed in the virtual environment, you can use full
OpenVINO™ Training Extensions command line functionality.

.. code-block:: shell

otx --help

****************************************************
Install OpenVINO™ Training Extensions for developers
****************************************************
Expand Down
22 changes: 20 additions & 2 deletions src/otx/algo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,24 @@
#
"""Module for OTX custom algorithms, e.g., model, losses, hook, etc..."""

from . import action_classification, classification, detection, segmentation, visual_prompting
from . import (
accelerators,
action_classification,
classification,
detection,
plugins,
segmentation,
strategies,
visual_prompting,
)

__all__ = ["action_classification", "classification", "detection", "segmentation", "visual_prompting"]
__all__ = [
"action_classification",
"classification",
"detection",
"segmentation",
"visual_prompting",
"strategies",
"accelerators",
"plugins",
]
8 changes: 8 additions & 0 deletions src/otx/algo/accelerators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Lightning accelerator for XPU device."""

from .xpu import XPUAccelerator

__all__ = ["XPUAccelerator"]
88 changes: 88 additions & 0 deletions src/otx/algo/accelerators/xpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Lightning accelerator for XPU device."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
from __future__ import annotations

from typing import Any, Union

import numpy as np
import torch
from lightning.pytorch.accelerators import AcceleratorRegistry
from lightning.pytorch.accelerators.accelerator import Accelerator
from mmcv.ops.nms import NMSop
from mmcv.ops.roi_align import RoIAlign
from mmengine.structures import instance_data

from otx.algo.detection.utils import monkey_patched_nms, monkey_patched_roi_align
from otx.utils.utils import is_xpu_available


class XPUAccelerator(Accelerator):
"""Support for a XPU, optimized for large-scale machine learning."""

accelerator_name = "xpu"

def setup_device(self, device: torch.device) -> None:
"""Sets up the specified device."""
if device.type != "xpu":
msg = f"Device should be xpu, got {device} instead"
raise RuntimeError(msg)

Check warning on line 30 in src/otx/algo/accelerators/xpu.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/accelerators/xpu.py#L29-L30

Added lines #L29 - L30 were not covered by tests

torch.xpu.set_device(device)
self.patch_packages_xpu()

@staticmethod
def parse_devices(devices: str | list | torch.device) -> list:
"""Parses devices for multi-GPU training."""
if isinstance(devices, list):
return devices
return [devices]

Check warning on line 40 in src/otx/algo/accelerators/xpu.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/accelerators/xpu.py#L40

Added line #L40 was not covered by tests

@staticmethod
def get_parallel_devices(devices: list) -> list[torch.device]:
"""Generates a list of parrallel devices."""
return [torch.device("xpu", idx) for idx in devices]

@staticmethod
def auto_device_count() -> int:
"""Returns number of XPU devices available."""
return torch.xpu.device_count()

@staticmethod
def is_available() -> bool:
"""Checks if XPU available."""
return is_xpu_available()

def get_device_stats(self, device: str | torch.device) -> dict[str, Any]:
"""Returns XPU devices stats."""
return {}

def teardown(self) -> None:
"""Cleans-up XPU-related resources."""
self.revert_packages_xpu()

Check warning on line 63 in src/otx/algo/accelerators/xpu.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/accelerators/xpu.py#L63

Added line #L63 was not covered by tests

def patch_packages_xpu(self) -> None:
"""Patch packages when xpu is available."""
# patch instance_data from mmengie
long_type_tensor = Union[torch.LongTensor, torch.xpu.LongTensor]
bool_type_tensor = Union[torch.BoolTensor, torch.xpu.BoolTensor]
instance_data.IndexType = Union[str, slice, int, list, long_type_tensor, bool_type_tensor, np.ndarray]

Check warning on line 70 in src/otx/algo/accelerators/xpu.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/accelerators/xpu.py#L68-L70

Added lines #L68 - L70 were not covered by tests

# patch nms and roi_align
self._nms_op_forward = NMSop.forward
self._roi_align_forward = RoIAlign.forward
NMSop.forward = monkey_patched_nms
RoIAlign.forward = monkey_patched_roi_align

Check warning on line 76 in src/otx/algo/accelerators/xpu.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/accelerators/xpu.py#L73-L76

Added lines #L73 - L76 were not covered by tests

def revert_packages_xpu(self) -> None:
"""Revert packages when xpu is available."""
NMSop.forward = self._nms_op_forward
RoIAlign.forward = self._roi_align_forward

Check warning on line 81 in src/otx/algo/accelerators/xpu.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/accelerators/xpu.py#L80-L81

Added lines #L80 - L81 were not covered by tests


AcceleratorRegistry.register(
XPUAccelerator.accelerator_name,
XPUAccelerator,
description="Accelerator supports XPU devices",
)
8 changes: 8 additions & 0 deletions src/otx/algo/detection/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""utils for detection task."""

from .mmcv_patched_ops import monkey_patched_nms, monkey_patched_roi_align

__all__ = ["monkey_patched_nms", "monkey_patched_roi_align"]
73 changes: 73 additions & 0 deletions src/otx/algo/detection/utils/mmcv_patched_ops.py
kprokofi marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""utils for detection task."""

from __future__ import annotations

from typing import TYPE_CHECKING

import torch
from mmcv.utils import ext_loader
from torchvision.ops import nms as tv_nms
from torchvision.ops import roi_align as tv_roi_align

if TYPE_CHECKING:
from mmcv.ops.nms import NMSop
from mmcv.ops.roi_align import RoIAlign

ext_module = ext_loader.load_ext("_ext", ["nms", "softnms", "nms_match", "nms_rotated", "nms_quadri"])


def monkey_patched_nms(
ctx: NMSop,
bboxes: torch.Tensor,
scores: torch.Tensor,
iou_threshold: float,
offset: float,
score_threshold: float,
max_num: int,
) -> torch.Tensor:
"""Runs MMCVs NMS with torchvision.nms, or forces NMS from MMCV to run on CPU."""
kprokofi marked this conversation as resolved.
Show resolved Hide resolved
_ = ctx
is_filtering_by_score = score_threshold > 0
if is_filtering_by_score:
valid_mask = scores > score_threshold
bboxes, scores = bboxes[valid_mask], scores[valid_mask]
valid_inds = torch.nonzero(valid_mask, as_tuple=False).squeeze(dim=1)

if bboxes.dtype == torch.bfloat16:
bboxes = bboxes.to(torch.float32)
if scores.dtype == torch.bfloat16:
scores = scores.to(torch.float32)

if offset == 0:
inds = tv_nms(bboxes, scores, float(iou_threshold))
else:
device = bboxes.device
bboxes = bboxes.to("cpu")
scores = scores.to("cpu")
inds = ext_module.nms(bboxes, scores, iou_threshold=float(iou_threshold), offset=offset)
bboxes = bboxes.to(device)
scores = scores.to(device)

if max_num > 0:
inds = inds[:max_num]
if is_filtering_by_score:
inds = valid_inds[inds]
return inds


def monkey_patched_roi_align(self: RoIAlign, _input: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
"""Replaces MMCVs roi align with the one from torchvision.

Args:
self: patched instance
_input: NCHW images
rois: Bx5 boxes. First column is the index into N. The other 4 columns are xyxy.
"""
if "aligned" in tv_roi_align.__code__.co_varnames:
return tv_roi_align(_input, rois, self.output_size, self.spatial_scale, self.sampling_ratio, self.aligned)
if self.aligned:
rois -= rois.new_tensor([0.0] + [0.5 / self.spatial_scale] * 4)
return tv_roi_align(_input, rois, self.output_size, self.spatial_scale, self.sampling_ratio)

Check warning on line 73 in src/otx/algo/detection/utils/mmcv_patched_ops.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/detection/utils/mmcv_patched_ops.py#L69-L73

Added lines #L69 - L73 were not covered by tests
8 changes: 8 additions & 0 deletions src/otx/algo/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Plugin for mixed-precision training on XPU."""

from .xpu_precision import MixedPrecisionXPUPlugin

__all__ = ["MixedPrecisionXPUPlugin"]
Loading
Loading