Skip to content

Commit

Permalink
fix rotate_every_two
Browse files Browse the repository at this point in the history
  • Loading branch information
patil-suraj committed Mar 21, 2022
1 parent c1af180 commit 052fa2f
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/transformers/models/gptj/modeling_flax_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def create_sinusoidal_positions(num_pos, dim):


def rotate_every_two(tensor):
rotate_half_tensor = jnp.stack((tensor[:, :, :, 1::2], tensor[:, :, :, ::2]), axis=-1)
rotate_half_tensor = jnp.stack((-tensor[:, :, :, 1::2], tensor[:, :, :, ::2]), axis=-1)
rotate_half_tensor = rotate_half_tensor.reshape(rotate_half_tensor.shape[:-2] + (-1,))
return rotate_half_tensor

Expand Down

0 comments on commit 052fa2f

Please sign in to comment.