-
Notifications
You must be signed in to change notification settings - Fork 5.4k
Python objects without back references (suspecting memory leak) #905
Description
I am writing this issue regarding to a suspect of memory leak. I am writing some code to run inference of Faster R-CNN with 3 set of model parameters for every image. When I repeatedly execute the inference function within the same process, the number of object grows steadily, while expecting the unused object will be cleared after function execution.
Expected results
-
workspace.ResetWorkspacewill remove all objects created, andgc.collect()will free the memory. -
Back references could be found for all objects created.
Actual results
Some objects are not cleared after the execution of inference function ("GenerateProposalsOp", "CollectAndDistributeFpnRpnProposalsOp"), particularly created by the model creation in 'test_engine.py'
model = model_builder.create(cfg.MODEL.TYPE, train=False, gpu_id=gpu_id)
No back references of "CollectAndDistributeFpnRpnProposalsOp" found by objgraph after inference with 1, 2 and 3 models, with the following 2 lines of code:
obj = objgraph.by_type("CollectAndDistributeFpnRpnProposalsOp")
objgraph.show_backrefs(obj, max_depth=10, filename="collect_fpn_obj" + str(i) + ".png")
Similar result for "GenerateProposalsOp":

And here's the growth of objects:
===== model: 0 =====
function 16231 +16231
list 8262 +8262
dict 6896 +6896
tuple 5837 +5837
weakref 3271 +3271
method_descriptor 2540 +2540
wrapper_descriptor 2510 +2510
builtin_function_or_method 2408 +2408
getset_descriptor 2102 +2102
type 1924 +1924
cell 1899 +1899
set 1045 +1045
property 744 +744
module 690 +690
ModuleSpec 672 +672
SourceFileLoader 580 +580
member_descriptor 486 +486
Parameter 426 +426
staticmethod 372 +372
FontEntry 325 +325
WeakSet 299 +299
MovedAttribute 254 +254
classmethod 246 +246
And 213 +213
itemgetter 189 +189
OrderedDict 189 +189
Literal 151 +151
instancemethod 146 +146
MovedModule 130 +130
LinearSegmentedColormap 126 +126
===== model: 1 =====
list 9145 +883
method 115 +6
GenerateProposalsOp 10 +5
dict 6899 +3
Random 2 +1
CollectAndDistributeFpnRpnProposalsOp 2 +1
_RandomNameSequence 1 +1
===== model: 2 =====
list 10027 +882
method 121 +6
GenerateProposalsOp 15 +5
dict 6901 +2
CollectAndDistributeFpnRpnProposalsOp 3 +1
Detailed steps to reproduce
ResNeXt-152-32x8d model downloaded from:
https://dl.fbaipublicfiles.com/detectron/ImageNetPretrained/25093814/X-152-32x8d-IN5k.pkl
Configuration file:
MODEL:
TYPE: generalized_rcnn
CONV_BODY: FPN.add_fpn_ResNet152_conv5_body
NUM_CLASSES: 2
FASTER_RCNN: True
NUM_GPUS: 1
SOLVER:
WEIGHT_DECAY: 0.0025
LR_POLICY: steps_with_decay
BASE_LR: 0.01
GAMMA: 0.1
MAX_ITER: 360000
STEPS: [0, 240000, 320000]
FPN:
FPN_ON: True
MULTILEVEL_ROIS: True
MULTILEVEL_RPN: True
RESNETS:
STRIDE_1X1: False # default True for MSRA; False for C2 or Torch models
TRANS_FUNC: bottleneck_transformation
NUM_GROUPS: 32
WIDTH_PER_GROUP: 8
FAST_RCNN:
ROI_BOX_HEAD: fast_rcnn_heads.add_roi_2mlp_head
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 7
ROI_XFORM_SAMPLING_RATIO: 2
TRAIN:
WEIGHTS: https://dl.fbaipublicfiles.com/detectron/ImageNetPretrained/25093814/X-152-32x8d-IN5k.pkl
DATASETS: ('coco_2014_train',)
SCALES: (800,)
MAX_SIZE: 1333
BATCH_SIZE_PER_IM: 512
TEST:
DATASETS: ('coco_2014_minival',)
SCALE: 800
MAX_SIZE: 1333
NMS: 0.5
OUTPUT_DIR: .
Minimal reproducible code:
DIRECTORY_TO_WEIGHTS contains: 'model0.pkl', 'model1.pkl', 'model2.pkl'
DIRECTORY_TO_CONFIGS contains: 'config_model0.yaml', 'config_model1.yaml', 'config_model2.yaml'
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import cv2 # NOQA (Must import before importing caffe2 due to bug in cv2)
import gc
import argparse
import objgraph
import detectron.utils.c2 as c2_utils
from caffe2.python import workspace
from detectron.core.config import assert_and_infer_cfg
from detectron.core.config import cfg
from detectron.core.config import merge_cfg_from_file
from detectron.utils.io import cache_url
from detectron.utils.setup_logging import setup_logging
import detectron.core.test_engine as infer_engine
c2_utils.import_detectron_ops()
# OpenCL may be enabled by default in OpenCV3; disable it because it's not
# thread safe and causes unwanted GPU memory allocations.
cv2.ocl.setUseOpenCL(False)
def parse_args():
parser = argparse.ArgumentParser(description='End-to-end inference')
# ADD SOME ARGUMENT HERE, DOES NOT AFFECT THE GROWTH OF OBJECTS #####
return parser.parse_args()
def inference(args):
workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
merge_cfg_from_file(args.cfg)
cfg.NUM_GPUS = 1
args.weights = cache_url(args.weights, cfg.DOWNLOAD_CACHE)
assert_and_infer_cfg(cache_urls=False, make_immutable=False)
assert not cfg.MODEL.RPN_ONLY, \
'RPN models are not supported'
assert not cfg.TEST.PRECOMPUTED_PROPOSALS, \
'Models that require precomputed proposals are not supported'
model = infer_engine.initialize_model_from_cfg(args.weights)
# MODEL INFERENCE HERE, DOES NOT AFFECT THE GROWTH OF OBJECTS #####
workspace.ResetWorkspace('.')
gc.collect()
def main(args):
model_names = ["model0", "model1", 'model2']
config_dir = "DIRECTORY_TO_CONFIGS"
weight_dir = "DIRECTORY_TO_WEIGHTS"
f = open("mem_leak_suspect.txt", "w")
objgraph.show_most_common_types(limit=30, file=f)
f.writelines("\n")
for i in range(len(model_names)):
args.cfg = os.path.join(config_dir, 'config_' + model_names[i] + '.yaml')
args.weights = os.path.join(weight_dir, model_names[i] + '.pkl')
inference(args)
f.writelines("\n===== model: {} =====\n".format(str(i)))
objgraph.show_growth(limit=30, file=f)
obj = objgraph.by_type("GenerateProposalsOp")
objgraph.show_backrefs(obj, max_depth=10, filename="proposalop_obj" + str(i) + ".png")
obj = objgraph.by_type("CollectAndDistributeFpnRpnProposalsOp")
objgraph.show_backrefs(obj, max_depth=10, filename="collect_fpn_obj" + str(i) + ".png")
f.flush()
f.close()
# POST-PROCESS HERE, BUT DOES NOT AFFECT THE GROWTH OF OBJECTS #####
if __name__ == '__main__':
setup_logging(__name__)
args = parse_args()
main(args)
System information
- Operating system: Ubuntu 18.04.2 LTS
- Compiler version: gcc version 7.4.0
- CUDA version: 10.0
- cuDNN version: 7.5.1
- NVIDIA driver version: 410.104
- GPU models (for all devices if they are not all the same): GeForce GTX 1080 Ti
python --versionoutput: Python 3.6.8 :: Anaconda, Inc.


