Skip to content

Commit

Permalink
parallel residual fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
fattorib committed May 29, 2023
1 parent c0285c0 commit 16081f4
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions src/models/GPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def __call__(
train: bool = False,
) -> jnp.array:

x = f_psum(x)
x_ln = nn.LayerNorm(dtype=self.dtype, use_bias=False)(x)

if self.tp_comms:
x = f_psum(x)


x_ln = nn.LayerNorm(dtype=self.dtype, use_bias=False)(x)

attn_out = CausalAttention(
self.embedding_dim,
self.num_head,
Expand All @@ -48,15 +48,20 @@ def __call__(
self.dtype,
tp_comms=self.tp_comms
)(x_ln, train)
# x = x + attn_out
mlp_out = MLPBlock(
self.embedding_dim,
dropout=self.residual_dropout,
N=self.N,
dtype=self.dtype,
tp_comms=self.tp_comms
)(x_ln, train)
return x_ln + g_psum(attn_out + mlp_out)

if self.tp_comms:
out = x_ln + g_psum(attn_out + mlp_out)
else:
out = x_ln + attn_out + mlp_out

return out


class Transformer(nn.Module):
Expand Down

0 comments on commit 16081f4

Please sign in to comment.