Skip to content

Commit

Permalink
fixing arange not knowing half on cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Jun 4, 2022
1 parent 0ef428a commit 2050923
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
4 changes: 2 additions & 2 deletions tests/test_rotary_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def test_rotary_embeddings(device, dtype):
assert k_rot.dtype == k.dtype

# Check that the sequences now encode relative position information
att = torch.einsum("bhne,bhme->bhnm", q, k)
att_rot = torch.einsum("bhne,bhme->bhnm", q_rot, k_rot)
att = torch.einsum("bhne,bhme->bhnm", q, k).float()
att_rot = torch.einsum("bhne,bhme->bhnm", q_rot, k_rot).float()

# - the attention for the same positions is not meaningfully changed
assert torch.allclose(
Expand Down
8 changes: 5 additions & 3 deletions xformers/components/positional_embedding/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,14 @@ def _update_cos_sin_tables(self, x, seq_dimension=1):
or self._cos_cached.dtype != x.dtype
):
self._seq_len_cached = seq_len
t = torch.arange(x.shape[seq_dimension], device=x.device, dtype=x.dtype)
t = torch.arange(
x.shape[seq_dimension], device=x.device, dtype=torch.float32
)
freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)

self._cos_cached = emb.cos()[None, None, :, :]
self._sin_cached = emb.sin()[None, None, :, :]
self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)

return self._cos_cached, self._sin_cached

Expand Down

0 comments on commit 2050923

Please sign in to comment.