Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix problem of dataparallel #155

Merged
merged 8 commits into from May 7, 2018
Merged

fix problem of dataparallel #155

merged 8 commits into from May 7, 2018

Conversation

bobchennan
Copy link
Contributor

@bobchennan bobchennan commented May 4, 2018

Pytorch's default dataparallel try to split each element in input list to support multi-input.
In our case each element is a json object corresponding to one sample.
We want to split the list instead.

@sw005320
Copy link
Contributor

sw005320 commented May 4, 2018

It seems to be good.
My only concern is where we should put class DataParallel(torch.nn.DataParallel):.
Actually you put it in asr_pytorch.py.
See class ChainerMultiProcessParallelUpdaterKaldi(training.updaters.MultiprocessParallelUpdater): in asr_chainer.py.
The multiple GPU related classes are in asr_chainer.py not in e2e_asr_attctc.py

@bobchennan
Copy link
Contributor Author

OK, I changed it.

@sw005320 sw005320 merged commit f96d2a6 into espnet:master May 7, 2018
@sw005320
Copy link
Contributor

sw005320 commented May 8, 2018

I have the following error

Traceback (most recent call last):
  File "/export/a08/shinji/201707e2e/espnet_dev2/tools/venv/local/lib/python2.7/site-packages/chainer/training/trainer.py", line 306, in run
    update()
  File "/export/a08/shinji/201707e2e/espnet_dev2/tools/venv/local/lib/python2.7/site-packages/chainer/training/updaters/standard_updater.py", line 149, in update
    self.update_core()
  File "/export/a08/shinji/201707e2e/espnet_dev2/src/asr/asr_pytorch.py", line 113, in update_core
    loss = 1. / self.num_gpu * self.model(x)
  File "/export/a08/shinji/201707e2e/espnet_dev2/tools/venv/local/lib/python2.7/site-packages/torch/nn/modules/module.py", line 357, in __call__
    result = self.forward(*input, **kwargs)
  File "/export/a08/shinji/201707e2e/espnet_dev2/src/asr/asr_pytorch.py", line 155, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/export/a08/shinji/201707e2e/espnet_dev2/tools/venv/local/lib/python2.7/site-packages/torch/nn/parallel/data_parallel.py", line 83, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/export/a08/shinji/201707e2e/espnet_dev2/tools/venv/local/lib/python2.7/site-packages/torch/nn/parallel/parallel_apply.py", line 24, in parallel_apply
    assert len(modules) == len(inputs)
Will finalize trainer extensions and updater before reraising the exception.
Traceback (most recent call last):
  File "/export/a08/shinji/201707e2e/espnet_dev2/egs/librispeech/asr1/../../../src/bin/asr_train.py", line 196, in <module>
    main()
  File "/export/a08/shinji/201707e2e/espnet_dev2/egs/librispeech/asr1/../../../src/bin/asr_train.py", line 190, in main
    train(args)
  File "/export/a08/shinji/201707e2e/espnet_dev2/src/asr/asr_pytorch.py", line 332, in train
    trainer.run()
  File "/export/a08/shinji/201707e2e/espnet_dev2/tools/venv/local/lib/python2.7/site-packages/chainer/training/trainer.py", line 320, in run
    six.reraise(*sys.exc_info())
  File "/export/a08/shinji/201707e2e/espnet_dev2/tools/venv/local/lib/python2.7/site-packages/chainer/training/trainer.py", line 306, in run
    update()
  File "/export/a08/shinji/201707e2e/espnet_dev2/tools/venv/local/lib/python2.7/site-packages/chainer/training/updaters/standard_updater.py", line 149, in update
    self.update_core()
  File "/export/a08/shinji/201707e2e/espnet_dev2/src/asr/asr_pytorch.py", line 113, in update_core
    loss = 1. / self.num_gpu * self.model(x)
  File "/export/a08/shinji/201707e2e/espnet_dev2/tools/venv/local/lib/python2.7/site-packages/torch/nn/modules/module.py", line 357, in __call__
    result = self.forward(*input, **kwargs)
  File "/export/a08/shinji/201707e2e/espnet_dev2/src/asr/asr_pytorch.py", line 155, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/export/a08/shinji/201707e2e/espnet_dev2/tools/venv/local/lib/python2.7/site-packages/torch/nn/parallel/data_parallel.py", line 83, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/export/a08/shinji/201707e2e/espnet_dev2/tools/venv/local/lib/python2.7/site-packages/torch/nn/parallel/parallel_apply.py", line 24, in parallel_apply
    assert len(modules) == len(inputs)
AssertionError

@bobchennan bobchennan mentioned this pull request May 10, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants