Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
[T5] Support Distributed Training (#4434)
Browse files Browse the repository at this point in the history
* t5 distributed

* skip unless gpu
  • Loading branch information
klshuster committed Mar 21, 2022
1 parent b6a51af commit f036542
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 2 deletions.
9 changes: 7 additions & 2 deletions parlai/agents/hugging_face/t5.py
Expand Up @@ -54,10 +54,14 @@ def set_device(func):
"""

def wrap(*args, **kwargs):
if torch.cuda.is_available():
self = args[0]
# self.paralleled implies whether the model has been paralleled.
# it is set to the opposite of `opt['t5_model_parallel]`
parallel = hasattr(self, 'paralleled') and not self.paralleled
if torch.cuda.is_available() and parallel:
torch.cuda.set_device('cuda:0')
ret = func(*args, **kwargs)
if torch.cuda.is_available():
if torch.cuda.is_available() and parallel:
torch.cuda.set_device('cuda:0')
return ret

Expand Down Expand Up @@ -293,6 +297,7 @@ def __init__(self, opt, dictionary):
self.t5 = build_t5(opt)
self.encoder = ParlaiT5Encoder(opt, self.t5.get_encoder(), self.pad_idx)
self.decoder = ParlaiT5Decoder(opt, self.t5.get_decoder(), self.pad_idx)
self.paralleled = not opt['t5_model_parallel']

@set_device
def _get_initial_forced_decoder_input(self, bsz: int, inputs: torch.LongTensor):
Expand Down
1 change: 1 addition & 0 deletions parlai/agents/rag/modules.py
Expand Up @@ -535,6 +535,7 @@ def __init__(self, opt, dictionary, retriever_shared=None):
super().__init__(opt, dictionary, retriever_shared)
self.embedding_size = opt['t5'].model_dim
self.t5 = opt.pop('t5', None)
self.paralleled = not opt['t5_model_parallel']

@classmethod
def build_encoder(
Expand Down
28 changes: 28 additions & 0 deletions tests/nightly/gpu/test_t5.py
Expand Up @@ -34,6 +34,8 @@
from parlai.utils.torch import padded_tensor
from parlai.utils.testing import tempdir

from tests.test_distributed import _AbstractTest


device = 'cpu' if not torch.cuda.is_available() else 'cuda:0'

Expand Down Expand Up @@ -262,5 +264,31 @@ def test_t5_model_parallel(self):
)


@testing_utils.skipUnlessGPU
class TestT5Distributed(_AbstractTest):
base_config = dict(
task='integration_tests:overfit',
model='hugging_face/t5',
optimizer='adam',
batchsize=1,
num_epochs=50,
short_final_eval=True,
validation_max_exs=12,
t5_model_arch='t5-small',
validation_metric='ppl',
skip_generation=True,
learningrate=1e-2,
validation_every_n_epochs=25,
verbose=True,
save_after_valid=False,
)

def test_t5_distributed(self):
valid, test = self._distributed_train_model()

self.assertLessEqual(valid['ppl'], 1.60)
self.assertLessEqual(test['ppl'], 1.60)


if __name__ == '__main__':
unittest.main()

0 comments on commit f036542

Please sign in to comment.