Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training speed of mxnet-ssd slows down? #135

Closed
nopattern opened this issue May 20, 2018 · 16 comments
Closed

Training speed of mxnet-ssd slows down? #135

nopattern opened this issue May 20, 2018 · 16 comments

Comments

@nopattern
Copy link

nopattern commented May 20, 2018

 I have use record file(voc07+12) to train old-style ssd at a speed of 40 images/s ,The speed is about 25 images/s when  I try the new  train_ssd.py in gluoncv.
 I use rec dataset and  transform to replace origin file datasets in new ssd code. But when I set **num-workers=4** the gdata.DetectionDataLoader  failed ,while **num-workers=1** , It works but the speed is almost as  slow as original data reading method.
The  error infomation is as following:
Process Process-3:
Traceback (most recent call last):
  File "/usr/lib/python2.7/multiprocessing/process.py", line 258, in _bootstrap
  File "/usr/lib/python2.7/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/usr/lib/python2.7/multiprocessing/process.py", line 114, in run
    self._target(*self._args, **self._kwargs)
  File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/gluon/data/dataloader.py", line 134, in worker_loop
    batch = batchify_fn([dataset[i] for i in samples])
  File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/gluon/data/dataset.py", line 126, in __getitem__
    self.run()
    item = self._data[idx]
  File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/gluon/data/vision/datasets.py", line 257, in __getitem__
  File "/usr/lib/python2.7/multiprocessing/process.py", line 114, in run
    record = super(ImageRecordDataset, self).__getitem__(idx)
  File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/gluon/data/dataset.py", line 180, in __getitem__
    return self._record.read_idx(self._record.keys[idx])
  File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/recordio.py", line 265, in read_idx
    self._target(*self._args, **self._kwargs)
  File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/gluon/data/dataloader.py", line 134, in worker_loop
    return self.read()
  File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/recordio.py", line 163, in read
    ctypes.byref(size)))
  File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/base.py", line 149, in check_call
    batch = batchify_fn([dataset[i] for i in samples])
    raise MXNetError(py_str(_LIB.MXGetLastError()))
MXNetError: [16:12:48] src/recordio.cc:65: Check failed: header[0] == RecordIOWriter::kMagic Invalid RecordIO File

It seems a multi-process problem with old rec file dataset?

@zhreshold
Copy link
Member

Can you elaborate how do you use record file with gluon data loader?
Did you check out this one which is added recently: https://github.com/dmlc/gluon-cv/blob/master/gluoncv/data/recordio/detection.py

@nopattern
Copy link
Author

  I added a dataset transform function
 def old_rec_2_new(rec_old):
   """ rec_old  , [2,6, cls_id, xmin, ymin, xmax, ymax,  difficult.....]
       rec_new  , [[xmin, ymin, xmax, ymax, cls_id, difficult],]
   """
   label_width = int(rec_old[1])
   obj_num = (rec_old.shape[0]-2)/label_width
   rec_new = np.zeros((obj_num,label_width))
   offset = 2
   for i  in range(obj_num):
       rec_new[i][0:4] = rec_old[offset+1:offset+5]
       rec_new[i][4] = rec_old[offset]
       rec_new[i][5] = rec_old[offset+5]

       offset += label_width

   return rec_new

and read recordiofile ,

 def get_dataset_rec(dataset):
    data_dir='./data/'
    
    train_dataset = ImageRecordDataset(data_dir+'train.rec',transform=rec_old2_new_transform)
    val_dataset = ImageRecordDataset(data_dir+'val.rec',transform=rec_old2_new_transform)
    
    return train_dataset,val_dataset 

I pull the new code , change the function to use RecordFileDetection,

def get_dataset_rec(dataset):
    data_dir='./data/'
    
    train_dataset = RecordFileDetection(data_dir+'train.rec')
    val_dataset = RecordFileDetection(data_dir+'val.rec')
    
    return train_dataset,val_dataset

The error info is ,

[09:09:14] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:107: Running performance tests to find the best convolution algorithm, this can take a while... (setting env variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
Traceback (most recent call last):
File "train_ssd.py", line 238, in
train(net, train_data, val_data, classes, args)
File "train_ssd.py", line 182, in train
autograd.backward(sum_loss)
File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/autograd.py", line 267, in backward
ctypes.c_void_p(0)))
File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/base.py", line 149, in check_call
raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [09:09:22] src/imperative/imperative.cc:373: Check failed: !AGInfo::IsNone(*i) Cannot differentiate node because it is not in a computational graph. You need to set is_recording to true or use autograd.record() to save computational graphs for backward. If you want to differentiate the same graph twice, you need to pass retain_graph=True to backward.

Stack trace returned 10 entries:
[bt] (0) /home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/../../lib/libmxnet.so(dmlc::StackTraceabi:cxx11+0x5b) [0x7fd6a6077b6b]
[bt] (1) /home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/../../lib/libmxnet.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x28) [0x7fd6a60786d8]
[bt] (2) /home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/../../lib/libmxnet.so(mxnet::Imperative::Backward(std::vector<mxnet::NDArray*, std::allocatormxnet::NDArray* > const&, std::vector<mxnet::NDArray*, std::allocatormxnet::NDArray* > const&, std::vector<mxnet::NDArray*, std::allocatormxnet::NDArray* > const&, bool, bool, bool)+0x291) [0x7fd6a858ded1]
[bt] (3) /home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/../../lib/libmxnet.so(MXAutogradBackwardEx+0x9a8) [0x7fd6a893aae8]
[bt] (4) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call_unix64+0x4c) [0x7fd6b6514e40]
[bt] (5) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call+0x2eb) [0x7fd6b65148ab]
[bt] (6) /usr/lib/python2.7/lib-dynload/_ctypes.x86_64-linux-gnu.so(_ctypes_callproc+0x48f) [0x7fd6b67243df]
[bt] (7) /usr/lib/python2.7/lib-dynload/_ctypes.x86_64-linux-gnu.so(+0x11d82) [0x7fd6b6728d82]
[bt] (8) python(PyObject_Call+0x43) [0x4b0c93]
[bt] (9) python(PyEval_EvalFrameEx+0x602f) [0x4c9f9f]

@nopattern nopattern reopened this May 21, 2018
@zhreshold
Copy link
Member

The latest error seems to be unrelated to data loader, did you changed the network training part?

@nopattern
Copy link
Author

nopattern commented May 21, 2018

sorry ,The error before may be caused by gluoncv not updated correctly. I updated gluoncv again and the error is as following,

Corrupt JPEG data: premature end of data segment
Process Process-4:
Process Process-2:
Traceback (most recent call last):
Traceback (most recent call last):
File "/usr/lib/python2.7/multiprocessing/process.py", line 258, in _bootstrap
File "/usr/lib/python2.7/multiprocessing/process.py", line 258, in _bootstrap
self.run()
self.run()
File "/usr/lib/python2.7/multiprocessing/process.py", line 114, in run
File "/usr/lib/python2.7/multiprocessing/process.py", line 114, in run
self._target(*self._args, **self._kwargs)
self._target(*self._args, **self._kwargs)
File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/gluon/data/dataloader.py", line 134, in worker_loop
File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/gluon/data/dataloader.py", line 134, in worker_loop
batch = batchify_fn([dataset[i] for i in samples])
batch = batchify_fn([dataset[i] for i in samples])
File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/gluon/data/dataset.py", line 126, in getitem
File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/gluon/data/dataset.py", line 126, in getitem
item = self._data[idx]
item = self._data[idx]
File "build/bdist.linux-x86_64/egg/gluoncv/data/recordio/detection.py", line 37, in getitem
File "build/bdist.linux-x86_64/egg/gluoncv/data/recordio/detection.py", line 37, in getitem
img, label = super(RecordFileDetection, self).getitem(idx)
img, label = super(RecordFileDetection, self).getitem(idx)
File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/gluon/data/vision/datasets.py", line 257, in getitem
File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/gluon/data/vision/datasets.py", line 257, in getitem
record = super(ImageRecordDataset, self).getitem(idx)
record = super(ImageRecordDataset, self).getitem(idx)
File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/gluon/data/dataset.py", line 180, in getitem
File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/gluon/data/dataset.py", line 180, in getitem
return self._record.read_idx(self._record.keys[idx])
return self._record.read_idx(self._record.keys[idx])
File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/recordio.py", line 265, in read_idx
File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/recordio.py", line 265, in read_idx
return self.read()
return self.read()
File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/recordio.py", line 163, in read
File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/recordio.py", line 163, in read
ctypes.byref(size)))
ctypes.byref(size)))
File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/base.py", line 149, in check_call
File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/base.py", line 149, in check_call
raise MXNetError(py_str(_LIB.MXGetLastError()))
raise MXNetError(py_str(_LIB.MXGetLastError()))
MXNetError: [12:09:26] src/recordio.cc:65: Check failed: header[0] == RecordIOWriter::kMagic Invalid RecordIO File

Stack trace returned 10 entries:
[bt] (0) /home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/../../lib/libmxnet.so(dmlc::StackTraceabi:cxx11+0x5b) [0x7fdc39641b6b]
[bt] (1) /home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/../../lib/libmxnet.so(dmlc::RecordIOReader::NextRecord(std::__cxx11::basic_string<char, std::char_traits, std::allocator >*)+0x6ec) [0x7fdc3bfe941c]
[bt] (2) /home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/../../lib/libmxnet.so(MXRecordIOReaderReadRecord+0x26) [0x7fdc3bf2e236]
[bt] (3) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call_unix64+0x4c) [0x7fdc49adee40]
[bt] (4) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call+0x2eb) [0x7fdc49ade8ab]
[bt] (5) /usr/lib/python2.7/lib-dynload/_ctypes.x86_64-linux-gnu.so(_ctypes_callproc+0x48f) [0x7fdc49cee3df]
[bt] (6) /usr/lib/python2.7/lib-dynload/_ctypes.x86_64-linux-gnu.so(+0x11d82) [0x7fdc49cf2d82]
[bt] (7) python(PyObject_Call+0x43) [0x4b0c93]
[bt] (8) python(PyEval_EvalFrameEx+0x602f) [0x4c9f9f]
[bt] (9) python(PyEval_EvalCodeEx+0x255) [0x4c2705]

@nopattern
Copy link
Author

Try again:
The before error happened when I set num-workers = 1. The "Invalid RecordIO File "error happened when num-workers = 4.

@zhreshold
Copy link
Member

Can you use num-workers=0 to disable multiprocessing and make sure the record file is good?

@zhreshold
Copy link
Member

I suspect it might relate to multi-worker but cannot confirm.

@nopattern
Copy link
Author

nopattern commented May 21, 2018

The record file is good because I have tested with my own transform with num-workers = 1 or num-workers= 0.
The output is as following when setting num-workers = 0 and using RecordFileDetection,

Traceback (most recent call last):
File "train_ssd.py", line 271, in
train(net, train_data, val_data, classes, args)
File "train_ssd.py", line 214, in train
autograd.backward(sum_loss)
File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/autograd.py", line 267, in backward
ctypes.c_void_p(0)))
File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/base.py", line 149, in check_call
raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [07:37:51] src/imperative/imperative.cc:373: Check failed: !AGInfo::IsNone(*i) Cannot differentiate node because it is not in a computational graph. You need to set is_recording to true or use autograd.record() to save computational graphs for backward. If you want to differentiate the same graph twice, you need to pass retain_graph=True to backward.

Stack trace returned 10 entries:
[bt] (0) /home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/../../lib/libmxnet.so(dmlc::StackTraceabi:cxx11+0x5b) [0x7f8bfe309b6b]
[bt] (1) /home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/../../lib/libmxnet.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x28) [0x7f8bfe30a6d8]
[bt] (2) /home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/../../lib/libmxnet.so(mxnet::Imperative::Backward(std::vector<mxnet::NDArray*, std::allocatormxnet::NDArray* > const&, std::vector<mxnet::NDArray*, std::allocatormxnet::NDArray* > const&, std::vector<mxnet::NDArray*, std::allocatormxnet::NDArray* > const&, bool, bool, bool)+0x291) [0x7f8c0081fed1]
[bt] (3) /home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/../../lib/libmxnet.so(MXAutogradBackwardEx+0x9a8) [0x7f8c00bccae8]
[bt] (4) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call_unix64+0x4c) [0x7f8c0e7a6e40]
[bt] (5) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call+0x2eb) [0x7f8c0e7a68ab]
[bt] (6) /usr/lib/python2.7/lib-dynload/_ctypes.x86_64-linux-gnu.so(_ctypes_callproc+0x48f) [0x7f8c0e9b63df]
[bt] (7) /usr/lib/python2.7/lib-dynload/_ctypes.x86_64-linux-gnu.so(+0x11d82) [0x7f8c0e9bad82]
[bt] (8) python(PyObject_Call+0x43) [0x4b0c93]
[bt] (9) python(PyEval_EvalFrameEx+0x602f) [0x4c9f9f]

@zhreshold
Copy link
Member

I am confused. Can you post your training script?

@nopattern
Copy link
Author

OK, but I haven't change training script:

def train(net, train_data, val_data, classes, args):
   """Training pipeline"""
   net.collect_params().reset_ctx(ctx)
   trainer = gluon.Trainer(
       net.collect_params(), 'sgd',
       {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum})

   # lr decay policy
   lr_decay = float(args.lr_decay)
   lr_steps = sorted([float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()])

   mbox_loss = gcv.loss.SSDMultiBoxLoss()
   ce_metric = mx.metric.Loss('CrossEntropy')
   smoothl1_metric = mx.metric.Loss('SmoothL1')

   # set up logger
   logging.basicConfig()
   logger = logging.getLogger()
   logger.setLevel(logging.INFO)
   log_file_path = args.save_prefix + '_train.log'
   log_dir = os.path.dirname(log_file_path)
   if log_dir and not os.path.exists(log_dir):
       os.makedirs(log_dir)
   fh = logging.FileHandler(log_file_path)
   logger.addHandler(fh)
   logger.info(args)
   logger.info('Start training from [Epoch %d]' % args.start_epoch)
   best_map = [0]
   for epoch in range(args.start_epoch, args.epochs):
       while lr_steps and epoch >= lr_steps[0]:
           new_lr = trainer.learning_rate * lr_decay
           lr_steps.pop(0)
           trainer.set_learning_rate(new_lr)
           logger.info("[Epoch {}] Set learning rate to {}".format(epoch, new_lr))
       ce_metric.reset()
       smoothl1_metric.reset()
       tic = time.time()
       btic = time.time()
       net.hybridize()
       for i, batch in enumerate(train_data):
           batch_size = batch[0].shape[0]
           data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
           cls_targets = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
           box_targets = gluon.utils.split_and_load(batch[2], ctx_list=ctx, batch_axis=0)
           with autograd.record():
               cls_preds = []
               box_preds = []
               for x in data:
                   cls_pred, box_pred, _ = net(x)
                   cls_preds.append(cls_pred)
                   box_preds.append(box_pred)
               sum_loss, cls_loss, box_loss = mbox_loss(
                   cls_preds, box_preds, cls_targets, box_targets)
               autograd.backward(sum_loss)
           # since we have already normalized the loss, we don't want to normalize
           # by batch-size anymore
           trainer.step(1)
           ce_metric.update(0, [l * batch_size for l in cls_loss])
           smoothl1_metric.update(0, [l * batch_size for l in box_loss])
           if args.log_interval and not (i + 1) % args.log_interval:
               name1, loss1 = ce_metric.get()
               name2, loss2 = smoothl1_metric.get()
               logger.info('[Epoch %d][Batch %d], Speed: %f samples/sec, %s=%f, %s=%f'%(
                   epoch, i, batch_size/(time.time()-btic), name1, loss1, name2, loss2))
           btic = time.time()

       name1, loss1 = ce_metric.get()
       name2, loss2 = smoothl1_metric.get()
       logger.info('[Epoch %d] Training cost: %f, %s=%f, %s=%f'%(
           epoch, (time.time()-tic), name1, loss1, name2, loss2))
       map_name, mean_ap = validate(net, val_data, ctx, classes)
       val_msg = '\n'.join(['%s=%f'%(k, v) for k, v in zip(map_name, mean_ap)])
       logger.info('[Epoch %d] Validation: \n%s'%(epoch, val_msg))
       save_params(net, best_map, mean_ap[-1], epoch, args.save_interval, args.save_prefix)

I writed the classes directly in main function :

if __name__ == '__main__':
    
    CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
               'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
               'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor')
               
    args = parse_args()
    # fix seed for mxnet, numpy and python builtin random generator.
    gutils.random.seed(args.seed)

    # training contexts
    ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()]
    ctx = ctx if ctx else [mx.cpu()]

    # network
    net_name = '_'.join(('ssd', str(args.data_shape), args.network, args.dataset))
    net = get_model(net_name, pretrained_base=True)
    if args.resume.strip():
        net.load_params(args.resume.strip())
    else:
        for param in net.collect_params().values():
            if param._data is not None:
                continue
            param.initialize()

    # training data
    #train_dataset, val_dataset = get_dataset(args.dataset)
    #train_dataset, val_dataset = get_dataset_rec(args.dataset)
    train_dataset, val_dataset = get_dataset_rec_me(args.dataset)
    train_data, val_data = get_dataloader(
        net, train_dataset, val_dataset, args.data_shape, args.batch_size, args.num_workers)
    classes = CLASSES#train_dataset.classes  # class names

    # training
    args.save_prefix += net_name
    train(net, train_data, val_data, classes, args)

@nopattern
Copy link
Author

@zhreshold The previous error occured because I have change position format from float to int(original ,not normalized) in record file for my transformer. sorry .
Now I restored record file in the original format .num-workers = 1 works. But the speed drops to 12image/s. num-workers = 4 dones't work .The output info,

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 40 extraneous bytes before marker 0xdb
Corrupt JPEG data: premature end of data segment
Process Process-2:
Traceback (most recent call last):
File "/usr/lib/python2.7/multiprocessing/process.py", line 258, in _bootstrap
self.run()
File "/usr/lib/python2.7/multiprocessing/process.py", line 114, in run
self._target(*self._args, **self._kwargs)
File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/gluon/data/dataloader.py", line 134, in worker_loop
batch = batchify_fn([dataset[i] for i in samples])
File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/gluon/data/dataset.py", line 126, in getitem
item = self._data[idx]
File "build/bdist.linux-x86_64/egg/gluoncv/data/recordio/detection.py", line 37, in getitem
img, label = super(RecordFileDetection, self).getitem(idx)
File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/gluon/data/vision/datasets.py", line 257, in getitem
record = super(ImageRecordDataset, self).getitem(idx)
File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/gluon/data/dataset.py", line 180, in getitem
return self._record.read_idx(self._record.keys[idx])
File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/recordio.py", line 265, in read_idx
return self.read()
File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/recordio.py", line 163, in read
ctypes.byref(size)))
File "/home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/base.py", line 149, in check_call
raise MXNetError(py_str(_LIB.MXGetLastError()))
MXNetError: [09:24:49] src/recordio.cc:65: Check failed: header[0] == RecordIOWriter::kMagic Invalid RecordIO File

Stack trace returned 10 entries:
[bt] (0) /home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/../../lib/libmxnet.so(dmlc::StackTraceabi:cxx11+0x5b) [0x7fe3e9440b6b]
[bt] (1) /home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/../../lib/libmxnet.so(dmlc::RecordIOReader::NextRecord(std::__cxx11::basic_string<char, std::char_traits, std::allocator >*)+0x6ec) [0x7fe3ebde841c]
[bt] (2) /home/deep/workssd/mxnet/incubator-mxnet/python/mxnet/../../lib/libmxnet.so(MXRecordIOReaderReadRecord+0x26) [0x7fe3ebd2d236]
[bt] (3) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call_unix64+0x4c) [0x7fe3f9a1ee40]
[bt] (4) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call+0x2eb) [0x7fe3f9a1e8ab]
[bt] (5) /usr/lib/python2.7/lib-dynload/_ctypes.x86_64-linux-gnu.so(_ctypes_callproc+0x48f) [0x7fe3f9c2e3df]
[bt] (6) /usr/lib/python2.7/lib-dynload/_ctypes.x86_64-linux-gnu.so(+0x11d82) [0x7fe3f9c32d82]
[bt] (7) python(PyObject_Call+0x43) [0x4b0c93]
[bt] (8) python(PyEval_EvalFrameEx+0x602f) [0x4c9f9f]
[bt] (9) python(PyEval_EvalCodeEx+0x255) [0x4c2705]

@zhreshold
Copy link
Member

Tracked to apache/mxnet#9974

@nopattern
Copy link
Author

nopattern commented May 23, 2018

@zhreshold
The latest version of gluoncv have installed and slows down the ssd training speed from about 25images/s to about 15images/s (unstable) , But your orginal mxnet-ssd is about 40 images/s !
It 's better to add expand image augment function in old ImageDetIter.

@WalterMa
Copy link
Contributor

@zhreshold
According to your referenced issue, it isn't able to use RecordFileDetection in multi processing mode.
Am I right?

@zhreshold
Copy link
Member

@WalterMa Yes, this bug should be easy to fix, but need to be careful not to change current api, so we are still discussing.

@zhreshold
Copy link
Member

An temporary solution is added to RecordFileDetection so multi worker can be enabled.
I am closing this due to lack of activity. Feel free to ping me to reopen.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants