Skip to content

Commit

Permalink
[Enhance] Fix and enhance logger in AvoidOOM (#8091)
Browse files Browse the repository at this point in the history
  • Loading branch information
BIGWangYuDong committed May 31, 2022
1 parent 2382ee5 commit e06b0d5
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ jobs:
condition:
equal: [ "3.9.0", << parameters.python >> ]
steps:
- run: pip install protobuf && sudo apt-get update && sudo apt-get -y install libprotobuf-dev protobuf-compiler cmake
- run: pip install "protobuf <= 3.20.1" && sudo apt-get update && sudo apt-get -y install libprotobuf-dev protobuf-compiler cmake
- run:
name: Install mmdet dependencies
command: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ jobs:
- name: Install PyTorch
run: python -m pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} -f https://download.pytorch.org/whl/torch_stable.html
- name: Install dependencies for compiling onnx when python=3.9
run: python -m pip install protobuf && apt-get install libprotobuf-dev protobuf-compiler
run: python -m pip install "protobuf <= 3.20.1" && apt-get install libprotobuf-dev protobuf-compiler
if: ${{matrix.python-version == '3.9'}}
- name: Install mmdet dependencies
run: |
Expand Down Expand Up @@ -210,7 +210,7 @@ jobs:
- name: Install PyTorch
run: python -m pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} -f https://download.pytorch.org/whl/torch_stable.html
- name: Install dependencies for compiling onnx when python=3.9
run: python -m pip install protobuf && apt-get update && apt-get -y install libprotobuf-dev protobuf-compiler cmake
run: python -m pip install "protobuf <= 3.20.1" && apt-get update && apt-get -y install libprotobuf-dev protobuf-compiler cmake
if: ${{matrix.python-version == '3.9'}}
- name: Install mmdet dependencies
run: |
Expand Down
19 changes: 9 additions & 10 deletions mmdet/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ class AvoidOOM:
"""

def __init__(self, to_cpu=True, test=False):
self.logger = get_root_logger()
self.to_cpu = to_cpu
self.test = test

Expand Down Expand Up @@ -164,8 +163,9 @@ def wrapped(*args, **kwargs):
# Convert to FP16
fp16_args = cast_tensor_type(args, dst_type=torch.half)
fp16_kwargs = cast_tensor_type(kwargs, dst_type=torch.half)
self.logger.info(f'Attempting to copy inputs of {str(func)} '
f'to FP16 due to CUDA OOM')
logger = get_root_logger()
logger.warning(f'Attempting to copy inputs of {str(func)} '
'to FP16 due to CUDA OOM')

# get input tensor type, the output type will same as
# the first parameter type.
Expand All @@ -175,21 +175,20 @@ def wrapped(*args, **kwargs):
output, src_type=torch.half, dst_type=dtype)
if not self.test:
return output
self.logger.info('Using FP16 still meet CUDA OOM')
logger.warning('Using FP16 still meet CUDA OOM')

# Try on CPU. This will slow down the code significantly,
# therefore print a notice.
if self.to_cpu:
self.logger.info(f'Attempting to copy inputs of {str(func)} '
f'to CPU due to CUDA OOM')
logger.warning(f'Attempting to copy inputs of {str(func)} '
'to CPU due to CUDA OOM')
cpu_device = torch.empty(0).device
cpu_args = cast_tensor_type(args, dst_type=cpu_device)
cpu_kwargs = cast_tensor_type(kwargs, dst_type=cpu_device)

# convert outputs to GPU
with _ignore_torch_cuda_oom():
self.logger.info(f'Convert outputs to GPU '
f'(device={device})')
logger.warning(f'Convert outputs to GPU (device={device})')
output = func(*cpu_args, **cpu_kwargs)
output = cast_tensor_type(
output, src_type=cpu_device, dst_type=device)
Expand All @@ -199,8 +198,8 @@ def wrapped(*args, **kwargs):
'the output is now on CPU, which might cause '
'errors if the output need to interact with GPU '
'data in subsequent operations')
self.logger.info('Cannot convert output to GPU due to '
'CUDA OOM, the output is on CPU now.')
logger.warning('Cannot convert output to GPU due to '
'CUDA OOM, the output is on CPU now.')

return func(*cpu_args, **cpu_kwargs)
else:
Expand Down
1 change: 1 addition & 0 deletions requirements/tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ kwarray
-e git+https://github.com/open-mmlab/mmtracking#egg=mmtrack
onnx==1.7.0
onnxruntime>=1.8.0
protobuf<=3.20.1
pytest
ubelt
xdoctest>=0.10.0
Expand Down

0 comments on commit e06b0d5

Please sign in to comment.