Skip to content

Commit

Permalink
optimised embedding (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
bclarkson-code committed May 21, 2024
1 parent 544da6e commit cdc124c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
18 changes: 9 additions & 9 deletions src/tricycle/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,12 @@ def forward(self, tensor: Tensor):
return tensor
xp = tensor.xp
coef = 1 / (1 - self.probability)
random_mask = (xp.random.rand(*tensor.shape) > self.probability).astype(
tensor.dtype
) * coef
random_mask = (
xp.random.rand(*tensor.shape) > self.probability
).astype(tensor.dtype) * coef
random_mask = to_tensor(
random_mask, is_vector=True, requires_grad=False)
random_mask, is_vector=True, requires_grad=False
)
return BinaryMultiply()(tensor, random_mask)


Expand Down Expand Up @@ -404,8 +405,7 @@ def back_fn(self, grad: Tensor):
case 1:
xp.add.at(out, self.input._data, grad._data)
case 2:
for batch in range(grad._data.shape[0]):
xp.add.at(out, self.input._data, grad._data[batch])
xp.add.at(out, self.input._data, grad._data.sum(axis=0))
case _:
raise NotImplementedError(
f"{grad.ndim=}, {self.input.ndim=} are not supported"
Expand All @@ -420,11 +420,11 @@ def forward(self, tensor: Tensor):

self.input = tensor
if tensor.is_vector:
self._out = tensor.xp.stack(
[self.weights._data[idx] for idx in tensor._data]
self._out = self.weights._data[tensor._data.flatten()].reshape(
tensor._data.shape + (-1,)
)
else:
self._out = self.weights[tensor._data]
self._out = self.weights._data[tensor._data]
result = to_tensor(self._out, is_vector=tensor.is_vector)

result.args = (tensor, self.weights)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tricycle.layers import ( # noqa: E501
Dense,
Dropout,
Embedding,
EmbeddingV2,
LayerNorm,
RMSNorm,
Sequential,
Expand Down Expand Up @@ -66,7 +66,7 @@ def test_dropout(): # sourcery skip: square-identity
assert in_tensor.grad is not None
assert in_tensor.grad.shape == in_tensor.shape

coef = 1 / (1-dropout_prob)
coef = 1 / (1 - dropout_prob)
correct_grad = np.full(in_tensor.shape, coef)
correct_grad[zero_x_idx, zero_y_idx] = 0

Expand Down Expand Up @@ -102,7 +102,7 @@ def test_embedding():
dtype=int,
)

embedding_layer = Embedding(from_size=vocab_size, to_size=out_shape)
embedding_layer = EmbeddingV2(from_size=vocab_size, to_size=out_shape)
weights = np.indices((vocab_size * out_shape,)).reshape(
vocab_size, out_shape
)
Expand Down Expand Up @@ -135,7 +135,7 @@ def test_embedding_vectorised():
dtype=np.int8,
).to_vector()

embedding_layer = Embedding(from_size=vocab_size, to_size=out_shape)
embedding_layer = EmbeddingV2(from_size=vocab_size, to_size=out_shape)
weights = np.indices((vocab_size * out_shape,)).reshape(
vocab_size, out_shape
)
Expand Down

0 comments on commit cdc124c

Please sign in to comment.