Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Eliminate dummy batches and init_cuda_buffer. #3732

Merged
merged 1 commit into from Jun 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 0 additions & 9 deletions parlai/agents/hred/hred.py
Expand Up @@ -172,12 +172,3 @@ def _set_text_vec(self, obs, history, truncate):
truncated_vec = self._check_truncate(obs["text_vec"], truncate, True)
obs.force_set("text_vec", torch.LongTensor(truncated_vec))
return obs

def _dummy_batch(self, batchsize, maxlen):
"""
Overridden to add dummy context vec and hist lens.
"""
batch = super()._dummy_batch(batchsize, maxlen)
batch["context_vec"] = batch["text_vec"]
batch["hist_lens"] = torch.ones(batchsize, dtype=torch.long)
return batch
17 changes: 0 additions & 17 deletions parlai/agents/image_seq2seq/image_seq2seq.py
Expand Up @@ -102,23 +102,6 @@ def _set_text_vec(self, *args, **kwargs) -> dict:
)
return obs

def _dummy_batch(self, batchsize: int, maxlen: int) -> Batch:
"""
Override to include image feats.
"""
b = super()._dummy_batch(batchsize, maxlen)
image = torch.ones(batchsize, self.image_features_dim).cuda()
if self.n_image_channels > 1:
image = image.unsqueeze(1).repeat(1, self.n_image_channels, 1)
if self.fp16:
image = image.half()
return Batch(
text_vec=b.text_vec,
label_vec=b.label_vec,
image=image,
personalities=torch.ones(batchsize, self.opt['embedding_size']).cuda(),
)

def batchify_image_features(self, batch: Batch) -> Batch:
"""
Format and return the batched image features.
Expand Down
21 changes: 3 additions & 18 deletions parlai/agents/rag/rag.py
Expand Up @@ -328,19 +328,6 @@ def _encoder_input(
return self._model_input(batch)

##### 2. Standard TGA Function Overrides #####
def _dummy_batch(self, batchsize: int, maxlen: int) -> Batch:
"""
Add query/input turn vecs.
"""
batch = self._generation_agent._dummy_batch(self, batchsize, maxlen)
batch.query_vec = batch.text_vec.clone()
batch.input_turn_cnt_vec = (
None
if self.rag_model_type != 'turn'
else torch.ones(batch.query_vec.size(0)).to(batch.query_vec)
)
return batch

def build_model(self) -> RagModel:
"""
Build and return RagModel.
Expand Down Expand Up @@ -407,11 +394,9 @@ def _should_override_dpr_model_weights(self, opt: Opt):
"""
Determine if we need to override the DPR Model weights.

Under certain circumstances, one may wish to specify a different
`--dpr-model-file` for a pre-trained, RAG model. Thus, we additionally
check to make sure that the loaded DPR model weights are not overwritten
by the state loading.

Under certain circumstances, one may wish to specify a different `--dpr-model-
file` for a pre-trained, RAG model. Thus, we additionally check to make sure
that the loaded DPR model weights are not overwritten by the state loading.
"""
override_dpr = False
overrides = opt.get('override', {})
Expand Down
84 changes: 30 additions & 54 deletions parlai/core/torch_generator_agent.py
Expand Up @@ -578,59 +578,36 @@ def set_interactive_mode(self, mode, shared=False):
else:
self.skip_generation = self.opt.get('skip_generation', False)

def _dummy_batch(self, batchsize, maxlen):
def _cache_dummy_batch(self, batch: Batch):
"""
Create a dummy batch.

This is used to preinitialize the cuda buffer, or otherwise force a
null backward pass after an OOM.
Cache a batch to be used as a dummy during _fake_forward_pass.
"""
if not hasattr(self, '_dummy_batch'):
self._dummy_batch = batch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this unnecessarily use up memory? what if the first batch is huge for example?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be a worse problem for images. For text, imagine it's a 2x2048x1024 LongTensor (2x for content and label; bs 2048; string length 1024). A long is 8 bytes. So we'll be wasting 32MiB of memory.


If your model uses additional inputs beyond text_vec and label_vec,
you will need to override it to add additional fields.
def _fake_forward_backward_pass(self):
"""
text_vec = (
torch.arange(1, maxlen + 1) # need it as long as specified
.clamp(max=3) # cap at 3 for testing with tiny dictionaries
.unsqueeze(0)
.expand(batchsize, maxlen)
.cuda()
)
# label vec has two tokens to make it interesting, but we we can't use the
# start token, it's reserved.
label_vec = (
torch.LongTensor([self.END_IDX, self.NULL_IDX])
.unsqueeze(0)
.expand(batchsize, 2)
.cuda()
)
return Batch(
text_vec=text_vec, label_vec=label_vec, text_lengths=[maxlen] * batchsize
)
Force a worker to syncronize with others in case of distributed mode.

def _init_cuda_buffer(self, batchsize, maxlen, force=False):
"""
Pre-initialize CUDA buffer by doing fake forward pass.

This is also used in distributed mode to force a worker to sync with others.
"""
if self.use_cuda and (force or not hasattr(self, 'buffer_initialized')):
try:
self._control_local_metrics(disabled=True)
loss = 0 * self.compute_loss(self._dummy_batch(batchsize, maxlen))
self._control_local_metrics(enabled=True)
self._temporarily_disable_local_metrics = False
self.backward(loss)
self.buffer_initialized = True
except RuntimeError as e:
if 'out of memory' in str(e):
m = (
'CUDA OOM: Lower batch size (-bs) from {} or lower '
' max sequence length (-tr) from {}'
''.format(batchsize, maxlen)
)
raise RuntimeError(m)
else:
raise e
Necessary during recovery of OOMs to prevent hangs during the all-reduce of
gradients.
"""
try:
self._control_local_metrics(disabled=True)
loss = 0 * self.compute_loss(self._dummy_batch)
self._control_local_metrics(enabled=True)
self.backward(loss)
self.buffer_initialized = True
except RuntimeError as e:
if 'out of memory' in str(e):
m = (
'CUDA OOM: Lower batch size (-bs) from {} or lower '
' max sequence length (-tr) from {}'
''.format(self.opt['batchsize'], self.opt['truncate'])
)
raise RuntimeError(m)
else:
raise e

def reset_metrics(self):
"""
Expand Down Expand Up @@ -741,10 +718,9 @@ def train_step(self, batch):
"""
Train on a single batch of examples.
"""
# helps with memory usage
# note we want to use the opt's batchsize instead of the observed batch size
# in case dynamic batching is in use
self._init_cuda_buffer(self.opt['batchsize'], self.label_truncate or 256)
# cache a dummy batch in case we OOM and need to catch up
self._cache_dummy_batch(batch)

self.model.train()
self.zero_grad()

Expand Down Expand Up @@ -774,7 +750,7 @@ def train_step(self, batch):

# gradients are synced on backward, now this model is going to be
# out of sync! catch up with the other workers
self._init_cuda_buffer(8, 8, True)
self._fake_forward_pass()

def _construct_token_losses(self, labels, model_output):
# Get non-aggregated losses
Expand Down
5 changes: 0 additions & 5 deletions projects/dialogue_unlikelihood/agents.py
Expand Up @@ -66,11 +66,6 @@ def add_cmdline_args(
grp.add_argument('--alpha', default=1.0, type=float)
return parser

def _dummy_batch(self, batchsize, maxlen):
batch = super()._dummy_batch(batchsize, maxlen)
batch['rewards'] = torch.ones(batchsize, dtype=torch.long).cuda()
return batch

def compute_loss(self, batch, return_output=False):
if batch.label_vec is None:
raise ValueError('Cannot compute loss without a label.')
Expand Down
10 changes: 0 additions & 10 deletions projects/wizard_of_wikipedia/generator/agents.py
Expand Up @@ -114,16 +114,6 @@ def __init__(self, opt, shared=None):
self.max_knowledge = opt.get('max_knowledge')
self.knowledge_alpha = opt['knowledge_alpha']

def _dummy_batch(self, bsz, maxlen):
batch = super()._dummy_batch(bsz, maxlen)
batch['know_vec'] = th.zeros(bsz, 2, 2).long().cuda()
# bool/uint8 backwards for pytorch 1.0/1.2 compatibility
ck_mask = (th.ones(bsz, 2, dtype=th.uint8) != 0).cuda()
batch['ck_mask'] = ck_mask
batch['cs_ids'] = th.zeros(bsz).long().cuda()
batch['use_cs_ids'] = True
return batch

def compute_loss(self, batch, return_output=False):
# first compute our regular forced decoding loss
token_loss, model_output = super().compute_loss(batch, return_output=True)
Expand Down