Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Rotary embeddings respecting input types #326

Merged
merged 2 commits into from
Jun 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## TBD
### Fixed
- Removed dupliacated biases in the FusedMLP layers [#317]
- Rotary embeddings respecting input types [#326]

### Added
- Four blocksparsity layouts from DeepSpeed [#320]
Expand Down
13 changes: 11 additions & 2 deletions tests/test_rotary_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,25 @@ def test_helper_methods():


@pytest.mark.parametrize("device", DEVICES)
def test_rotary_embeddings(device):
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_rotary_embeddings(device, dtype):
rotary = RotaryEmbedding(EMB).to(device)

# Generate dummy inputs
q = torch.ones((BATCH, HEADS, SEQ, EMB), device=device) # uniform on purpose
q = torch.ones(
(BATCH, HEADS, SEQ, EMB), device=device, dtype=dtype
) # uniform on purpose
k = q.clone()

q_rot, k_rot = rotary(q, k)

assert q_rot.dtype == q.dtype
assert k_rot.dtype == k.dtype

# Check that the sequences now encode relative position information
q, k = q.float(), k.float()
q_rot, k_rot = q_rot.float(), k_rot.float()

att = torch.einsum("bhne,bhme->bhnm", q, k)
att_rot = torch.einsum("bhne,bhme->bhnm", q_rot, k_rot)

Expand Down
16 changes: 10 additions & 6 deletions xformers/components/positional_embedding/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,20 @@ def _update_cos_sin_tables(self, x, seq_dimension=1):

# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
if (
seq_len != self._seq_len_cached
or self._cos_cached.device != x.device
or self._cos_cached.dtype != x.dtype
):
self._seq_len_cached = seq_len
t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(
self.inv_freq
t = torch.arange(
x.shape[seq_dimension], device=x.device, dtype=torch.float32
)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)

self._cos_cached = emb.cos()[None, None, :, :]
self._sin_cached = emb.sin()[None, None, :, :]
self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)

return self._cos_cached, self._sin_cached

Expand Down