Skip to content

Commit

Permalink
fix attn_mask (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
stas00 committed Aug 5, 2021
1 parent bdef48f commit 42fe3b3
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions pretrain_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,16 @@ def model_provider(pre_process=True, post_process=True):
attention_mask = torch.tril(torch.ones(
(1, args.seq_length, args.seq_length), device=torch.cuda.current_device())).view(
1, 1, args.seq_length, args.seq_length)

# Convert attention mask to binary:
attention_mask = (attention_mask < 0.5)
if args.fp16:
attention_mask = attention_mask.half()
elif args.bf16:
attention_mask = attention_mask.bfloat16()

args.attn_mask = attention_mask

# must be bool or the training crashes expecting bool, but getting Half
args.attn_mask = attention_mask.to(torch.bool)

else:
model = GPTModel(
Expand Down

0 comments on commit 42fe3b3

Please sign in to comment.