Skip to content

Commit

Permalink
Fix other PyTorch models
Browse files Browse the repository at this point in the history
  • Loading branch information
julien-c committed Nov 6, 2019
1 parent d531979 commit 2f3a421
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
6 changes: 4 additions & 2 deletions templates/adding_a_new_model/modeling_xxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,10 +309,12 @@ def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, posi
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

device = input_ids.device if input_ids is not None else inputs_embeds.device

if attention_mask is None:
attention_mask = torch.ones(input_shape)
attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long)
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
Expand Down
4 changes: 3 additions & 1 deletion transformers/modeling_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,8 +450,10 @@ def forward(self,
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

device = input_ids.device if input_ids is not None else inputs_embeds.device

if attention_mask is None:
attention_mask = torch.ones(input_shape) # (bs, seq_length)
attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length)

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
Expand Down

0 comments on commit 2f3a421

Please sign in to comment.