diff --git a/src/tricycle/layers.py b/src/tricycle/layers.py index c482fa0..429e165 100644 --- a/src/tricycle/layers.py +++ b/src/tricycle/layers.py @@ -160,11 +160,12 @@ def forward(self, tensor: Tensor): return tensor xp = tensor.xp coef = 1 / (1 - self.probability) - random_mask = (xp.random.rand(*tensor.shape) > self.probability).astype( - tensor.dtype - ) * coef + random_mask = ( + xp.random.rand(*tensor.shape) > self.probability + ).astype(tensor.dtype) * coef random_mask = to_tensor( - random_mask, is_vector=True, requires_grad=False) + random_mask, is_vector=True, requires_grad=False + ) return BinaryMultiply()(tensor, random_mask) @@ -404,8 +405,7 @@ def back_fn(self, grad: Tensor): case 1: xp.add.at(out, self.input._data, grad._data) case 2: - for batch in range(grad._data.shape[0]): - xp.add.at(out, self.input._data, grad._data[batch]) + xp.add.at(out, self.input._data, grad._data.sum(axis=0)) case _: raise NotImplementedError( f"{grad.ndim=}, {self.input.ndim=} are not supported" @@ -420,11 +420,11 @@ def forward(self, tensor: Tensor): self.input = tensor if tensor.is_vector: - self._out = tensor.xp.stack( - [self.weights._data[idx] for idx in tensor._data] + self._out = self.weights._data[tensor._data.flatten()].reshape( + tensor._data.shape + (-1,) ) else: - self._out = self.weights[tensor._data] + self._out = self.weights._data[tensor._data] result = to_tensor(self._out, is_vector=tensor.is_vector) result.args = (tensor, self.weights) diff --git a/tests/test_layers.py b/tests/test_layers.py index e4d9e49..3f9fcaa 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -7,7 +7,7 @@ from tricycle.layers import ( # noqa: E501 Dense, Dropout, - Embedding, + EmbeddingV2, LayerNorm, RMSNorm, Sequential, @@ -66,7 +66,7 @@ def test_dropout(): # sourcery skip: square-identity assert in_tensor.grad is not None assert in_tensor.grad.shape == in_tensor.shape - coef = 1 / (1-dropout_prob) + coef = 1 / (1 - dropout_prob) correct_grad = np.full(in_tensor.shape, coef) correct_grad[zero_x_idx, zero_y_idx] = 0 @@ -102,7 +102,7 @@ def test_embedding(): dtype=int, ) - embedding_layer = Embedding(from_size=vocab_size, to_size=out_shape) + embedding_layer = EmbeddingV2(from_size=vocab_size, to_size=out_shape) weights = np.indices((vocab_size * out_shape,)).reshape( vocab_size, out_shape ) @@ -135,7 +135,7 @@ def test_embedding_vectorised(): dtype=np.int8, ).to_vector() - embedding_layer = Embedding(from_size=vocab_size, to_size=out_shape) + embedding_layer = EmbeddingV2(from_size=vocab_size, to_size=out_shape) weights = np.indices((vocab_size * out_shape,)).reshape( vocab_size, out_shape )