Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problems in training MJ + ST training set using robustscanner algorithm #508

Closed
GaoXinJian-USTC opened this issue Sep 26, 2021 · 18 comments
Closed

Comments

@GaoXinJian-USTC
Copy link

At the beginning of training , the value of data_time is small. But a few mintues later, data_time increases gradually which leads to low GPU utilization.
image
image

@GaoXinJian-USTC
Copy link
Author

GaoXinJian-USTC commented Sep 26, 2021

I hava trained different models on mj + st datasets and tried to set the value of num_workes 0 , 2, 4 ,8 ,but got the same problems.

@GaoXinJian-USTC
Copy link
Author

environment:
sys.platform: linux
Python: 3.9.7 (default, Sep 16 2021, 13:09:58) [GCC 7.5.0]
CUDA available: True
GPU 0,1: A100-PCIE-40GB
CUDA_HOME: /usr/local/cuda-11.1
NVCC: Build cuda_11.1.TC455_06.29190527_0
GCC: gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
PyTorch: 1.9.0+cu111
PyTorch compiling details: PyTorch built with:

  • GCC 7.3
  • C++ Version: 201402
  • Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  • Intel(R) MKL-DNN v2.1.2 (Git Hash 98be7e8afa711dc9b66c8ff3504129cb82013cdb)
  • OpenMP 201511 (a.k.a. OpenMP 4.5)
  • NNPACK is enabled
  • CPU capability usage: AVX2
  • CUDA Runtime 11.1
  • NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86
  • CuDNN 8.0.5
  • Magma 2.5.2
  • Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=/opt/rh/devtoolset-7/root/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.9.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON,

TorchVision: 0.10.0+cu111
OpenCV: 4.5.3
MMCV: 1.3.14
MMCV Compiler: GCC 7.5
MMCV CUDA Compiler: 11.1
MMOCR: 0.3.0+0bc1362

@GaoXinJian-USTC
Copy link
Author

image

@gaotongxiao
Copy link
Collaborator

I suspect that it has something to do with a specific part of data. You can rerun the training process several times and check if the slowdown occurs at certain batches. Also, could you share the config that you have been using?

@GaoXinJian-USTC
Copy link
Author

I suspect that it has something to do with a specific part of data. You can rerun the training process several times and check if the slowdown occurs at certain batches. Also, could you share the config that you have been using?

Of course. Thank you very much for your help. My configuration file is below

@GaoXinJian-USTC
Copy link
Author

GaoXinJian-USTC commented Sep 26, 2021

_base_ = [
    '../../_base_/default_runtime.py',
    '../../_base_/recog_models/robust_scanner.py'
]

# optimizer
optimizer = dict(type='Adam', lr=1e-3)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='step', step=[3, 4])
total_epochs = 5

img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='ResizeOCR',
        height=48,
        min_width=48,
        max_width=160,
        keep_aspect_ratio=True,
        width_downsample_ratio=0.25),
    dict(type='ToTensorOCR'),
    dict(type='NormalizeOCR', **img_norm_cfg),
    dict(
        type='Collect',
        keys=['img'],
        meta_keys=[
            'filename', 'ori_shape', 'resize_shape', 'text', 'valid_ratio'
        ]),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiRotateAugOCR',
        rotate_degrees=[0, 90, 270],
        transforms=[
            dict(
                type='ResizeOCR',
                height=48,
                min_width=48,
                max_width=160,
                keep_aspect_ratio=True,
                width_downsample_ratio=0.25),
            dict(type='ToTensorOCR'),
            dict(type='NormalizeOCR', **img_norm_cfg),
            dict(
                type='Collect',
                keys=['img'],
                meta_keys=[
                    'filename', 'ori_shape', 'resize_shape', 'valid_ratio'
                ]),
        ])
]

dataset_type = 'OCRDataset'

train_prefix = 'data/mixture/'

train_img_prefix1 = train_prefix + \
    'SynthText/synthtext/SynthText_patch_horizontal'
train_img_prefix2 = train_prefix + 'mjsynth/mnt/ramdisk/max/90kDICT32px'

train_ann_file1 = train_prefix + 'SynthText/label.lmdb'
train_ann_file2 = train_prefix + 'mjsynth/label.lmdb'

train1 = dict(
    type=dataset_type,
    img_prefix=train_img_prefix1,
    ann_file=train_ann_file1,
    loader=dict(
        type='LmdbLoader',
        repeat=1,
        parser=dict(
            type='LineStrParser',
            keys=['filename', 'text'],
            keys_idx=[0, 1],
            separator=' ')),
    pipeline=None,
    test_mode=False)

train2 = {key: value for key, value in train1.items()}
train2['img_prefix'] = train_img_prefix2
train2['ann_file'] = train_ann_file2

test_prefix = 'data/mixture/'
test_img_prefix1 = test_prefix + 'IIIT5K/'
test_img_prefix2 = test_prefix + 'svt/'
test_img_prefix3 = test_prefix + 'ic13/'
test_img_prefix4 = test_prefix + 'ic15/'
test_img_prefix5 = test_prefix + 'svtp/'
test_img_prefix6 = test_prefix + 'CUTE80/'

test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt'
test_ann_file2 = test_prefix + 'svt/test_label.txt'
test_ann_file3 = test_prefix + 'ic13/test_label_1015.txt'
test_ann_file4 = test_prefix + 'ic15/test_label.txt'
test_ann_file5 = test_prefix + 'svtp/test_label.txt'
test_ann_file6 = test_prefix + 'CUTE80/lable.txt'

test1 = dict(
    type=dataset_type,
    img_prefix=test_img_prefix1,
    ann_file=test_ann_file1,
    loader=dict(
        type='HardDiskLoader',
        repeat=1,
        parser=dict(
            type='LineStrParser',
            keys=['filename', 'text'],
            keys_idx=[0, 1],
            separator=' ')),
    pipeline=None,
    test_mode=True)

test2 = {key: value for key, value in test1.items()}
test2['img_prefix'] = test_img_prefix2
test2['ann_file'] = test_ann_file2

test3 = {key: value for key, value in test1.items()}
test3['img_prefix'] = test_img_prefix3
test3['ann_file'] = test_ann_file3

test4 = {key: value for key, value in test1.items()}
test4['img_prefix'] = test_img_prefix4
test4['ann_file'] = test_ann_file4

test5 = {key: value for key, value in test1.items()}
test5['img_prefix'] = test_img_prefix5
test5['ann_file'] = test_ann_file5

test6 = {key: value for key, value in test1.items()}
test6['img_prefix'] = test_img_prefix6
test6['ann_file'] = test_ann_file6

data = dict(
    samples_per_gpu=128,
    workers_per_gpu=16,
    val_dataloader=dict(samples_per_gpu=1),
    test_dataloader=dict(samples_per_gpu=1),
    train=dict(
        type='UniformConcatDataset',
        datasets=[train1, train2],
        pipeline=train_pipeline),
    val=dict(
        type='UniformConcatDataset',
        datasets=[test1, test2, test3, test4, test5, test6],
        pipeline=test_pipeline),
    test=dict(
        type='UniformConcatDataset',
        datasets=[test1, test2, test3, test4, test5, test6],
        pipeline=test_pipeline))

evaluation = dict(interval=1, metric='acc')

@gaotongxiao
Copy link
Collaborator

The config looks fine. So have you tested if specific data batches are leading to the slowdown? I went over our RobustScanner log file and noticed that the data time had also been fluctuating over the entire training process. You might keep on training and let us know if there is a significant difference between your log and ours.

@GaoXinJian-USTC
Copy link
Author

GaoXinJian-USTC commented Sep 28, 2021

I changed the training set from st + mj to III5K and there was nothing wrong with the training. So I guessed that there was something wrong with my st + mj dataset, I downloaded ST + MJ dataset again and processed the dataset according to the prompts in the document,but I still got the same problem that after a few minutes of training, the power used by GPU will always be in a very low state, and occasionally the power will be full.

@GaoXinJian-USTC
Copy link
Author

I am very confused that when the bug occurs, the GPU will always be in a low power state, even if the data loading time is very short, as shown in the figure.
image

@GaoXinJian-USTC GaoXinJian-USTC changed the title Problem in traning Problems in training MJ + ST training set using robustscanner algorithm Sep 28, 2021
@gaotongxiao
Copy link
Collaborator

It's probably because some cropped images can be much larger than others, and result in a longer processing time. Let's keep monitoring it for several hours running first.

Such a large gap between data_time and time may be due to the latency introduced from transferring data from memory to GPU. You may also track the utilization graphs of your memory, cpu and disk for better and in-depth analysis.

@GaoXinJian-USTC
Copy link
Author

GaoXinJian-USTC commented Sep 29, 2021

I have run the code all night and got this result. It shows that my GPU will always be in a low power state when the bug occurs. The bug will not occur if I train other datasets ,so I think there is a bug in the code of data preprocess.
image
image

@gaotongxiao
Copy link
Collaborator

Yep there must be something wrong. I suspect this PR causes the slowdown. Could you comment out the changes in mmocr/models/textrecog/recognizer/encode_decode_recognizer.py and rerun the training process?

@GaoXinJian-USTC
Copy link
Author

Yep there must be something wrong. I suspect this PR causes the slowdown. Could you comment out the changes in mmocr/models/textrecog/recognizer/encode_decode_recognizer.py and rerun the training process?

I went back to this version and retrained and got the same error

@wushilian
Copy link

I got the same error too.

@GaoXinJian-USTC
Copy link
Author

I got the same error too.

Have you solved it,which type of gpu do you use?

@gaotongxiao
Copy link
Collaborator

I'm looking into this issue but it may take some time to find the solution. @wushilian could you also share more details to help us locate the problem? Specifically:

  1. The config that you were using;
  2. The log file(s) of your training process;
  3. Preferably, the profile results of the training script. To profile it and save the results to profile.prof:
python -m cProfile -o profile.prof tools/train.py PATH_TO_CONFIG

(@Gaoxj2020 it would be great if you can share the profile results with us :))

@GaoXinJian-USTC
Copy link
Author

I'm looking into this issue but it may take some time to find the solution. @wushilian could you also share more details to help us locate the problem? Specifically:

  1. The config that you were using;
  2. The log file(s) of your training process;
  3. Preferably, the profile results of the training script. To profile it and save the results to profile.prof:
python -m cProfile -o profile.prof tools/train.py PATH_TO_CONFIG

(@Gaoxj2020 it would be great if you can share the profile results with us :))
I am sorry to find that I meet same bug when I train IIIT5K again. Now I suspect this is caused by hardware environment problems. Thank you very much for your help!

@gaotongxiao
Copy link
Collaborator

I'm gonna close this but feel free to reopen it if you find anything worthy of reporting.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants