Skip to content

Commit

Permalink
[Reformer] Make random seed generator available on random seed and no…
Browse files Browse the repository at this point in the history
…t on model device (#6244)

* improve if else statement random seeds

* Apply suggestions from code review

* Update src/transformers/modeling_reformer.py
  • Loading branch information
patrickvonplaten committed Aug 4, 2020
1 parent d5b0a0e commit 6c9ba1d
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions src/transformers/modeling_reformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1399,15 +1399,16 @@ def _init_attention_seed(self):
"""

# randomize seeds
if next(self.parameters()).device.type == "cuda":
# use cuda generator if available
if len(torch.cuda.default_generators) > 0:

This comment has been minimized.

Copy link
@sgugger

sgugger Aug 5, 2020

Collaborator

Ploum ploum, this function is only available in pytorch 1.6 so we should find a way to not use it in the earlier versions (pytorch 1.5.1 cpu does not have it for instance).

# GPU
device_idx = torch.cuda.current_device()
self.attention_seed = torch.cuda.default_generators[device_idx].seed()
torch.cuda.manual_seed(self.attention_seed)
else:
# CPU
self.attention_seed = int(torch.seed() % sys.maxsize)
torch.manual_seed(self.attention_seed)

torch.manual_seed(self.attention_seed)

def _init_feed_forward_seed(self):
"""
Expand All @@ -1417,17 +1418,17 @@ def _init_feed_forward_seed(self):
call and 1 forward call in backward
to recalculate activations.
"""

# randomize seeds
if next(self.parameters()).device.type == "cuda":
# use cuda generator if available
if len(torch.cuda.default_generators) > 0:
# GPU
device_idx = torch.cuda.current_device()
self.feed_forward_seed = torch.cuda.default_generators[device_idx].seed()
torch.cuda.manual_seed(self.feed_forward_seed)
else:
# CPU
self.feed_forward_seed = int(torch.seed() % sys.maxsize)
torch.manual_seed(self.feed_forward_seed)

torch.manual_seed(self.feed_forward_seed)

def forward(
self,
Expand Down

0 comments on commit 6c9ba1d

Please sign in to comment.