Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: CUDA out of memory. Tried to allocate 258.96 GiB (GPU 0; 15.75 GiB total capacity; 2.26 GiB already allocated; 11.11 GiB free; 172.34 MiB cached) #6908

Closed
sanmulab opened this issue Dec 28, 2021 · 16 comments
Assignees
Labels
enhancement New feature or request

Comments

@sanmulab
Copy link

I use the SOLO framework to train the LVIS datasets. After the model training is completed, the evaluation is performed, but the memory is greatly exceeded. And I tested MaskRcnn in the same environment without any problems.
error:RuntimeError: CUDA out of memory. Tried to allocate 258.96 GiB (GPU 0; 15.75 GiB total capacity; 2.26 GiB already allocated; 11.11 GiB free; 172.34 MiB cached)

Env: mmdet=2.18.1 mmcv-full=1.3.17 pytorch=1.3.1

config:

dataset_type = 'LVISV05Dataset'
data_root = 'data/lvis_v05/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(
type='Resize',
img_scale=[(1333, 640), (1333, 672), (1333, 704), (1333, 736),
(1333, 768), (1333, 800)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type='LVISV05Dataset',
ann_file='data/lvis_v05/annotations/lvis_v0.5_train.json',
img_prefix='data/lvis_v05/train2017/',
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(
type='Resize',
img_scale=[(1333, 640), (1333, 672), (1333, 704), (1333, 736),
(1333, 768), (1333, 800)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(
type='Collect',
keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks'])
]),
val=dict(
type='LVISV05Dataset',
ann_file='data/lvis_v05/annotations/lvis_v0.5_val.json',
img_prefix='data/lvis_v05/val2017/',
pipeline=[
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]),
test=dict(
type='LVISV05Dataset',
ann_file='data/lvis_v05/annotations/lvis_v0.5_val.json',
img_prefix='data/lvis_v05/val2017/',
pipeline=[
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]))
evaluation = dict(metric=['bbox', 'segm'], interval=12)
optimizer = dict(type='SGD', lr=0.005, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[8, 11])
runner = dict(type='EpochBasedRunner', max_epochs=12)
checkpoint_config = dict(interval=2)
log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')])
custom_hooks = [dict(type='NumClassCheckHook')]
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
model = dict(
type='SOLO',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'),
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=0,
num_outs=5),
mask_head=dict(
type='SOLOHead',
num_classes=1230,
in_channels=256,
stacked_convs=7,
feat_channels=256,
strides=[8, 8, 16, 32, 32],
scale_ranges=((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048)),
pos_scale=0.2,
num_grids=[40, 36, 24, 16, 12],
cls_down_index=0,
loss_mask=dict(type='DiceLoss', use_sigmoid=True, loss_weight=3.0),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)),
test_cfg=dict(
nms_pre=500,
score_thr=0.0001, #lvis setting
mask_thr=0.5,
filter_thr=0.05,
kernel='gaussian',
sigma=2.0,
max_per_img=300)) #lvis setting
work_dir = None
gpu_ids = range(0, 1)

@BIGWangYuDong
Copy link
Collaborator

SOLO inference time needs more memory because it uses matrix_nms. The probably suggested method is to convert gpu to cpu.
https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/post_processing/matrix_nms.py

@sanmulab
Copy link
Author

I don't know why, usually the test should not need so much GPU memory, but mainly uses the system memory. I don't understand why matrix_nms needs hundreds of gigabytes of video memory. Is it my environmental or model problem?

@BIGWangYuDong
Copy link
Collaborator

Due to this line

inter_matrix = torch.mm(flatten_masks, flatten_masks.transpose(1, 0))

Seems like a model problem,

@sanmulab
Copy link
Author

So how to solve it, I just changed the data set, and adjusted the score threshold of the test configuration and the maximum number of detections per image. I try to retrain the model.

@sanmulab
Copy link
Author

Due to this line

inter_matrix = torch.mm(flatten_masks, flatten_masks.transpose(1, 0))

Seems like a model problem,

Hello, I want to ask how should I write test-cfg if I don’t use matrix_nms

@ZwwWayne ZwwWayne added the enhancement New feature or request label Mar 3, 2022
@ZwwWayne
Copy link
Collaborator

ZwwWayne commented Mar 3, 2022

This part can be optimized by torch.einsum, which should save some memories. You could try torch.einsum() to rewrite this line. @BIGWangYuDong will create a PR to fix this in the next release.

@BIGWangYuDong
Copy link
Collaborator

@ZwwWayne I have checked torch.einsum and torch.mm. Seems cannot save memories.

I initialize a torch.Tensor with size (456, 60800), and process 2,000 iter (warmup 500 iter). Then I check the following code:

inter_matrix = torch.mm(flatten_masks, flatten_masks.transpose(1, 0))
inter_matrix = torch.einsum('ik, kj -> ij', flatten_masks, flatten_masks.transpose(1, 0))

Here is the result:

torch.mm torch.einsum
process time (s) 4.69856 4.5713
GPU memory (MiB) 807 807

@sanmulab
Copy link
Author

sanmulab commented Mar 8, 2022

So I should :
inter_matrix = torch.mm(flatten_masks, flatten_masks.transpose(1, 0)) ->
inter_matrix = torch.einsum(flatten_masks, flatten_masks.transpose(1, 0))
?

@ZwwWayne I have checked torch.einsum and torch.mm. Seems cannot save memories.

I initialize a torch.Tensor with size (456, 60800), and process 2,000 iter (warmup 500 iter). Then I check the following code:

inter_matrix = torch.mm(flatten_masks, flatten_masks.transpose(1, 0))
inter_matrix = torch.einsum('ik, kj -> ij', flatten_masks, flatten_masks.transpose(1, 0))

Here is the result:

torch.mm torch.einsum
process time (s) 4.69856 4.5713
GPU memory (MiB) 807 807

@BIGWangYuDong
Copy link
Collaborator

Seems that you cannot save the GPU memory by using torch.einsum. There is a suggestion to use .cpu() to avoid this error. One more thing is that @sanmulab which line did you meet OOM?

@sanmulab
Copy link
Author

sanmulab commented Mar 9, 2022

Sorry, for some reason, I haven't tried torch. einsum () method to test. Previously, I just used the SOLO framework to train the LVIS datasets and then encountered a GPU memory when evaluating the model. But when I use the Mask RCNN framework, I can train and evaluate normally.

@BIGWangYuDong
Copy link
Collaborator

Yep, I have got this error when I used SOLO to test COCO instance segmentation too ( but only meet one or two times). And did not get OOM error during training.

I have discussed this on PyTorch forum.
My error does not just show Out Of Memory but some other error: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling cublasCreate(handle).

SO, if you still meet OOM error, it is suggested to give us more error details, that we can locate the problem.

@BIGWangYuDong
Copy link
Collaborator

Because I cannot recurrent this error right now, it is suggested to have a try:

change this line

inter_matrix = torch.mm(flatten_masks, flatten_masks.transpose(1, 0))

to

inter_matrix = torch.einsum('ik, kj -> ij', flatten_masks, flatten_masks.transpose(1, 0))

@sanmulab
Copy link
Author

sanmulab commented Mar 9, 2022

OK,thank you!

@BIGWangYuDong
Copy link
Collaborator

@sanmulab #7338 you can have a try to use AvoidOOM to change the tensor to fp16 or CPU. This PR is not ready, but you can have a try first.

@sanmulab
Copy link
Author

Thanks, but I seem to have found the cause of the memory explosion. When I change the score_thr in test_cfg from 0.0001 to 0.05, this evaluates fine. Sadly, AP=0 (I only use epoch_6.pth for testing lvis dataset)

@BIGWangYuDong
Copy link
Collaborator

Maybe because you have changed the test_cfg. may there are no results that the score larger than 0.05. I'll close this issue, please feel free to create a new issue if you still meet problems

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants