From c2ee2ed2d0936d087d3fe6c1ce13394990cd0b87 Mon Sep 17 00:00:00 2001 From: bclarkson-code <57139598+bclarkson-code@users.noreply.github.com> Date: Wed, 22 May 2024 22:09:50 +0100 Subject: [PATCH] Optimised cross entropy loss (#60) --- src/tricycle/attention.py | 10 +-- src/tricycle/configs.py | 4 +- src/tricycle/dataset.py | 12 +++- src/tricycle/loss.py | 107 ++++++++++++++++++---------- src/tricycle/tensor.py | 8 ++- tests/test_binary.py | 8 ++- tests/test_layers.py | 6 +- tests/test_loss.py | 13 ++-- tests/test_model_matches_pytorch.py | 30 ++++---- tests/test_optimisers.py | 12 ++-- tests/test_simple_neural_network.py | 4 +- tests/test_vectorise.py | 26 ------- train_smol_gpt.py | 6 +- 13 files changed, 138 insertions(+), 108 deletions(-) diff --git a/src/tricycle/attention.py b/src/tricycle/attention.py index 7933ead..bbd8076 100644 --- a/src/tricycle/attention.py +++ b/src/tricycle/attention.py @@ -1,4 +1,5 @@ from math import sqrt + import numpy as np from tricycle import CUPY_ENABLED @@ -83,8 +84,8 @@ def backward(self, grad: Tensor): (self.batch_size, self.context_window, self.embedding_dim * 3) ) self._grad[:, :, : self.embedding_dim] = query - self._grad[:, :, self.embedding_dim: self.embedding_dim * 2] = key - self._grad[:, :, self.embedding_dim * 2:] = value + self._grad[:, :, self.embedding_dim : self.embedding_dim * 2] = key + self._grad[:, :, self.embedding_dim * 2 :] = value return to_tensor(self._grad) @@ -96,8 +97,8 @@ def forward(self, tensor: Tensor): # split the input into 3 peices self._input = tensor query = tensor[:, :, : self.embedding_dim] - key = tensor[:, :, self.embedding_dim: self.embedding_dim * 2] - value = tensor[:, :, self.embedding_dim * 2:] + key = tensor[:, :, self.embedding_dim : self.embedding_dim * 2] + value = tensor[:, :, self.embedding_dim * 2 :] # Figure out how big everything is self.batch_size = key._data.shape[0] @@ -163,4 +164,5 @@ def to_gpu(self, device: int): def from_gpu(self): if CUPY_ENABLED: import cupy as cp + self._mask = cp.asnumpy(self._mask) diff --git a/src/tricycle/configs.py b/src/tricycle/configs.py index eae3fe5..5438629 100644 --- a/src/tricycle/configs.py +++ b/src/tricycle/configs.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass, asdict +from dataclasses import asdict, dataclass class GPTConfig: @@ -67,7 +67,7 @@ class SmolGPTConfig(GPTConfig): def dict(self) -> dict[str, int | float | str | bool]: out = {} for k, v in SmolGPTConfig.__dict__.items(): - if k.startswith('__'): + if k.startswith("__"): continue if callable(v): diff --git a/src/tricycle/dataset.py b/src/tricycle/dataset.py index ea6b11a..827ee95 100644 --- a/src/tricycle/dataset.py +++ b/src/tricycle/dataset.py @@ -81,8 +81,16 @@ def __getitem__(self, idx: int): batch_outputs = np.vstack([self.outputs[i] for i in indices]) if self._to_tensor: - batch_inputs = to_tensor(batch_inputs, is_vector=self.is_vector) - batch_outputs = to_tensor(batch_outputs, is_vector=self.is_vector) + batch_inputs = to_tensor( + batch_inputs, + is_vector=self.is_vector, + dtype=batch_outputs.dtype, + ) + batch_outputs = to_tensor( + batch_outputs, + is_vector=self.is_vector, + dtype=batch_outputs.dtype, + ) return batch_inputs, batch_outputs def to_tensor(self): diff --git a/src/tricycle/loss.py b/src/tricycle/loss.py index 7ec36f2..8835911 100644 --- a/src/tricycle/loss.py +++ b/src/tricycle/loss.py @@ -19,7 +19,7 @@ def mean_square_error(y_true: Tensor, y_pred: Tensor): return square_error.mean() -class CrossEntropy(Op): +class CrossEntropy_(Op): REALLY_SMALL_NUMBER = 1e-8 REALLY_BIG_NUMBER = 1e8 @@ -148,60 +148,91 @@ def forward(self, y_true: Tensor, y_pred: Tensor) -> Tensor: return result -class BinaryCrossEntropyV2(Op): +class CrossEntropy(Op): """ Calculate cross entropy loss, given logits and target indices (as opposed to one-hot encoded tensors) """ - REALLY_SMALL_NUMBER = 1e-8 - REALLY_BIG_NUMBER = 1e8 - - def backward(self, grad: Tensor) -> Tensor: - xp = grad.xp - - self._y_pred[xp.arange(self._n_inputs), self._y_true] -= 1 - self._y_pred /= self._n_inputs - self._grad = self._y_pred.reshape(self._original_shape) * grad._data - - return to_tensor(self._grad, is_vector=self._input_vector) + def log_softmax(self, tensor: Tensor): + xp = tensor.xp + x_max = xp.max(tensor._data, axis=-1, keepdims=True) + log_sum_exp = x_max + xp.log( + xp.sum(xp.exp(tensor._data - x_max), axis=-1, keepdims=True) + ) + return tensor._data - log_sum_exp def forward(self, y_true: Tensor, y_pred: Tensor) -> Tensor: - # sourcery skip: assign-if-exp, reintroduce-else """ Calculate the cross entropy loss """ xp = y_pred.xp - # flatten to simplify multiple inputs - self._original_shape = y_pred.shape - self._input_vector = y_pred.is_vector - out_dim = y_pred.shape[-1] - self._y_true = y_true._data.reshape(-1) - y_pred_f = y_pred._data.reshape(-1, out_dim) - self._n_inputs = y_pred_f.shape[0] - - # we scale values by the largest value in each vector - # for numeric stability - max_vals = xp.max(y_pred_f, axis=-1, keepdims=True) - scaled = y_pred_f - max_vals - - log_probs = scaled - xp.log( - xp.sum(xp.exp(scaled), axis=-1, keepdims=True) - ) - self._y_pred = xp.exp(log_probs) - - corrected_log_probs = -log_probs[ - xp.arange(self._n_inputs), self._y_true - ] - self._out = corrected_log_probs.sum() / self._n_inputs + # Calculate log softmax + log_softmax_pred = self.log_softmax(y_pred) - # TODO: fuse normalising and calculation together + # Cache for backward pass + self._y_true = y_true._data + self._log_softmax_pred = log_softmax_pred + + ndim = log_softmax_pred.ndim + + if ndim == 3: + batch_indices = xp.arange(y_true.shape[0], dtype=int) + token_indices = xp.arange(y_true.shape[1], dtype=int) + loss = -log_softmax_pred[ + batch_indices[:, None], token_indices, y_true._data + ] + elif ndim == 2: + indices = xp.arange(y_true.shape[0], dtype=int) + loss = -log_softmax_pred[indices, y_true._data] + elif ndim == 1: + loss = -log_softmax_pred[y_true._data] + else: + raise NotImplementedError( + f"BinaryCrossEntropy with predictions with ndim: {ndim} are not yet supported" + ) + + # Mean loss over all elements + loss = loss.mean() + + self._out = loss result = to_tensor(self._out, is_vector=False) result.back_fns = (self.backward,) - # y_true never requires grad so we dont calculate gradients for it result.args = (y_pred,) result.name = "cross_entropy" return result + + def backward(self, grad: Tensor) -> Tensor: + xp = grad.xp + ndim = self._log_softmax_pred.ndim + + if ndim == 3: + batch_indices = xp.arange(self._y_true.shape[0], dtype=int) + token_indices = xp.arange(self._y_true.shape[1], dtype=int) + grad_output = xp.exp(self._log_softmax_pred) + grad_output[ + batch_indices[:, None], token_indices, self._y_true + ] -= 1 + grad_output *= grad._data / ( + self._y_true.shape[0] * self._y_true.shape[1] + ) + + elif ndim == 2: + indices = xp.arange(self._y_true.shape[0], dtype=int) + grad_output = xp.exp(self._log_softmax_pred) + grad_output[indices, self._y_true] -= 1 + grad_output *= grad._data / self._y_true.shape[0] + elif ndim == 1: + grad_output = xp.exp(self._log_softmax_pred) + grad_output[self._y_true] -= 1 + grad_output *= grad._data + else: + raise NotImplementedError( + f"BinaryCrossEntropy with predictions with ndim: {ndim} are not yet supported" + ) + + self._grad = grad_output + return to_tensor(self._grad, is_vector=grad.is_vector) diff --git a/src/tricycle/tensor.py b/src/tricycle/tensor.py index 251d25f..4206709 100644 --- a/src/tricycle/tensor.py +++ b/src/tricycle/tensor.py @@ -492,7 +492,7 @@ def to_tensor( requires_grad: bool = True, is_vector: bool = False, _id: int | None = None, - dtype: np.dtype = np.float32, + dtype: np.dtype | None = None, **kwargs, ) -> Tensor: """ @@ -504,9 +504,13 @@ def to_tensor( if isinstance(tensor_like, Tensor): array = tensor_like._data - elif isinstance(tensor_like, cupy.ndarray): + elif isinstance(tensor_like, (np.ndarray, cupy.ndarray)): array = tensor_like + if dtype is not None: + array = array.astype(dtype) else: + if dtype is None: + dtype = np.float32 array = np.asarray(tensor_like, dtype=dtype, **kwargs) elif isinstance(tensor_like, Tensor): diff --git a/tests/test_binary.py b/tests/test_binary.py index d3ae21c..2298025 100644 --- a/tests/test_binary.py +++ b/tests/test_binary.py @@ -86,8 +86,12 @@ def test_can_bmul(): def test_can_bdiv(): - in_tensor_1 = to_tensor(np.arange(12).reshape(3, 4), is_vector=True) - in_tensor_2 = to_tensor(np.arange(1, 13).reshape(3, 4), is_vector=True) + in_tensor_1 = to_tensor( + np.arange(12).reshape(3, 4), is_vector=True, dtype=float + ) + in_tensor_2 = to_tensor( + np.arange(1, 13).reshape(3, 4), is_vector=True, dtype=float + ) out_tensor = BinaryDivide()(in_tensor_1, in_tensor_2) diff --git a/tests/test_layers.py b/tests/test_layers.py index 3f9fcaa..3d5b54b 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -7,7 +7,7 @@ from tricycle.layers import ( # noqa: E501 Dense, Dropout, - EmbeddingV2, + Embedding, LayerNorm, RMSNorm, Sequential, @@ -102,7 +102,7 @@ def test_embedding(): dtype=int, ) - embedding_layer = EmbeddingV2(from_size=vocab_size, to_size=out_shape) + embedding_layer = Embedding(from_size=vocab_size, to_size=out_shape) weights = np.indices((vocab_size * out_shape,)).reshape( vocab_size, out_shape ) @@ -135,7 +135,7 @@ def test_embedding_vectorised(): dtype=np.int8, ).to_vector() - embedding_layer = EmbeddingV2(from_size=vocab_size, to_size=out_shape) + embedding_layer = Embedding(from_size=vocab_size, to_size=out_shape) weights = np.indices((vocab_size * out_shape,)).reshape( vocab_size, out_shape ) diff --git a/tests/test_loss.py b/tests/test_loss.py index edfdfbb..0dbacc6 100644 --- a/tests/test_loss.py +++ b/tests/test_loss.py @@ -27,8 +27,8 @@ def test_can_mean_square_error(): def test_can_CrossEntropy(): - y_true = to_tensor([0, 0, 1]) - y_pred = to_tensor([0, 0, 0]) + y_true = to_tensor([1], dtype=int) + y_pred = to_tensor([[0, 0, 0]]) loss = CrossEntropy()(y_true, y_pred) @@ -40,15 +40,18 @@ def test_CrossEntropy_vectorised(): n_tokens = 5 vocab_size = 7 - y_true = np.random.random((batch_size, n_tokens, vocab_size)) + y_true = np.random.randint(0, vocab_size, size=(batch_size, n_tokens)) y_pred = np.random.random((batch_size, n_tokens, vocab_size)) - y_true = to_tensor(y_true).to_vector() + y_true = to_tensor(y_true, dtype=int).to_vector() y_pred = to_tensor(y_pred).to_vector() loss = CrossEntropy()(y_true, y_pred) - assert loss.shape == (batch_size, n_tokens) + assert loss.shape == () + + +# TODO: write a proper backprop test for these loss functions def test_can_single_linear_regression_step(): diff --git a/tests/test_model_matches_pytorch.py b/tests/test_model_matches_pytorch.py index b70f306..d62f5a6 100644 --- a/tests/test_model_matches_pytorch.py +++ b/tests/test_model_matches_pytorch.py @@ -296,26 +296,27 @@ def test_tricycle_softmax_matches_pytorch(in_shape, is_vector): @given(tensor_shape(), st.booleans()) -@example(in_shape=[32, 2], is_vector=False) +@example(in_shape=[2, 2, 4], is_vector=False) def test_crossentropy_matches(in_shape, is_vector): y_pred = build_tensor(in_shape, is_vector) - y_true = copy(y_pred) - y_true._data = y_pred.xp.zeros_like(y_true._data) - one_idx = y_pred.xp.random.choice(range(in_shape[-1])) - match len(in_shape): - case 1: - y_true[one_idx] = 1 - case 2: - y_true[:, one_idx] = 1 + y_true = np.random.randint(0, in_shape[-1], size=in_shape[:-1]) + y_true = to_tensor(y_true, is_vector=is_vector, dtype=int) assume(np.isfinite(y_pred._data).all()) tr_out = CrossEntropy()(y_true, y_pred).from_vector() if len(in_shape) > 1: tr_out = tr_out.mean() - p_y_pred = torch.tensor(y_pred._data, requires_grad=True) - p_y_true = torch.tensor(y_true._data, requires_grad=False) - p_out = torch.nn.functional.cross_entropy( + if len(in_shape) == 1: + p_y_pred = copy(y_pred._data) + if len(in_shape) == 2: + p_y_pred = copy(y_pred._data) + if len(in_shape) == 3: + p_y_pred = copy(y_pred._data).transpose(0, -1, 1) + p_y_pred = torch.tensor(p_y_pred, requires_grad=True) + p_y_true = torch.tensor(y_true._data, dtype=torch.long) + + p_out = torch.nn.CrossEntropyLoss()( input=p_y_pred, target=p_y_true, ) @@ -325,4 +326,7 @@ def test_crossentropy_matches(in_shape, is_vector): tr_out.backward() p_out.backward() - assert y_pred.grad.close_to(p_y_pred.grad.detach().numpy()) + p_grad = p_y_pred.grad.detach().numpy() + if len(in_shape) == 3: + p_grad = p_grad.transpose(0, -1, 1) + assert y_pred.grad.close_to(p_grad) diff --git a/tests/test_optimisers.py b/tests/test_optimisers.py index 748f968..8ed2568 100644 --- a/tests/test_optimisers.py +++ b/tests/test_optimisers.py @@ -18,7 +18,7 @@ def test_can_train_simple_neural_network_no_wd(): np.random.seed(42) X, y = load_iris(return_X_y=True) # one hot encode y - y = np.eye(3)[y.astype(int)] + y = y.astype(int) # create a dataset ds = InfiniteBatchDataset(X, y, batch_size=BATCH_SIZE) @@ -40,7 +40,7 @@ def test_can_train_simple_neural_network_no_wd(): break y_pred = model(x) - loss = loss_fn(y, y_pred).from_vector().e("a->") / BATCH_SIZE + loss = loss_fn(y, y_pred) loss.backward() losses.append(loss) @@ -60,7 +60,7 @@ def test_can_train_simple_neural_network_wd(): np.random.seed(42) X, y = load_iris(return_X_y=True) # one hot encode y - y = np.eye(3)[y.astype(int)] + y = y.astype(int) # create a dataset ds = InfiniteBatchDataset(X, y, batch_size=BATCH_SIZE) @@ -82,7 +82,7 @@ def test_can_train_simple_neural_network_wd(): break y_pred = model(x) - loss = loss_fn(y, y_pred).from_vector().e("a->") / BATCH_SIZE + loss = loss_fn(y, y_pred) loss.backward() losses.append(loss) @@ -102,7 +102,7 @@ def test_can_train_simple_neural_network_momentum(): np.random.seed(42) X, y = load_iris(return_X_y=True) # one hot encode y - y = np.eye(3)[y.astype(int)] + y = y.astype(int) # create a dataset ds = InfiniteBatchDataset(X, y, batch_size=BATCH_SIZE) @@ -124,7 +124,7 @@ def test_can_train_simple_neural_network_momentum(): break y_pred = model(x) - loss = loss_fn(y, y_pred).from_vector().e("a->") / BATCH_SIZE + loss = loss_fn(y, y_pred) loss.backward() losses.append(loss) diff --git a/tests/test_simple_neural_network.py b/tests/test_simple_neural_network.py index c39ec71..dd0538d 100644 --- a/tests/test_simple_neural_network.py +++ b/tests/test_simple_neural_network.py @@ -31,7 +31,7 @@ def test_can_train_simple_neural_network(): X, y = load_iris(return_X_y=True) # one hot encode y - y = np.eye(3)[y.astype(int)] + y = y.astype(int) # create a dataset ds = InfiniteBatchDataset(X, y, batch_size=BATCH_SIZE) @@ -55,7 +55,7 @@ def test_can_train_simple_neural_network(): break y_pred = model(x_in) - loss = loss_fn(y_out, y_pred).from_vector().e("a->") / BATCH_SIZE + loss = loss_fn(y_out, y_pred) loss.backward() losses.append(loss) diff --git a/tests/test_vectorise.py b/tests/test_vectorise.py index c62caa5..50fb582 100644 --- a/tests/test_vectorise.py +++ b/tests/test_vectorise.py @@ -87,32 +87,6 @@ def test_can_vectorise_mse(): assert output_vector.close_to(correct_output) -def test_can_vectorise_cross_entropy(): - y_true = to_tensor([0, 0, 1, 0]) - input_1 = to_tensor(np.arange(1, 5)) - input_2 = to_tensor(np.arange(2, 6)) - input_3 = to_tensor(np.arange(3, 7)) - - output_1 = CrossEntropy()(y_true, input_1) - output_2 = CrossEntropy()(y_true, input_2) - output_3 = CrossEntropy()(y_true, input_3) - - input_y_true = to_tensor(np.array([y_true._data] * 3)) - input_vector = to_tensor( - np.array([input_1._data, input_2._data, input_3._data]) - ) - correct_output = to_tensor( - np.array([output_1._data, output_2._data, output_3._data]) - ) - - input_y_true = vectorise(input_y_true) - input_vector = vectorise(input_vector) - output_vector = CrossEntropy()(input_y_true, input_vector) - output_vector = unvectorise(output_vector) - - assert output_vector.close_to(correct_output) - - def test_can_vectorise_softmax(): input_1 = to_tensor(np.arange(1, 5)) input_2 = to_tensor(np.arange(2, 6)) diff --git a/train_smol_gpt.py b/train_smol_gpt.py index 533f32b..b560b17 100644 --- a/train_smol_gpt.py +++ b/train_smol_gpt.py @@ -24,7 +24,7 @@ from inference import generate from tricycle.configs import SmolGPTConfig from tricycle.dataset import CausalLMDataset -from tricycle.loss import BinaryCrossEntropy +from tricycle.loss import CrossEntropy from tricycle.models import GPT from tricycle.optimisers import AdamW from tricycle.scheduler import lr_schedule @@ -50,7 +50,7 @@ .to_vector() .shuffle() ) -loss_fn = BinaryCrossEntropy() +loss_fn = CrossEntropy() optimiser = AdamW( learning_rate=lr_schedule( 0, @@ -187,7 +187,7 @@ def get_sample(sample_text: str | None = None, n_samples: int = 50) -> str: mlflow.log_text(predicted, f"generated/{step}.txt") # checkpoint - avg_loss = xp.mean(losses[step-config.eval_interval:step]) + avg_loss = xp.mean(losses[step - config.eval_interval : step]) if avg_loss < best_loss: Path("models").mkdir(exist_ok=True) with open(f"models/model_{unique_id}.pkl", "wb") as f: