Skip to content

Commit

Permalink
Optimised cross entropy loss (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
bclarkson-code committed May 22, 2024
1 parent cdc124c commit c2ee2ed
Show file tree
Hide file tree
Showing 13 changed files with 138 additions and 108 deletions.
10 changes: 6 additions & 4 deletions src/tricycle/attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from math import sqrt

import numpy as np

from tricycle import CUPY_ENABLED
Expand Down Expand Up @@ -83,8 +84,8 @@ def backward(self, grad: Tensor):
(self.batch_size, self.context_window, self.embedding_dim * 3)
)
self._grad[:, :, : self.embedding_dim] = query
self._grad[:, :, self.embedding_dim: self.embedding_dim * 2] = key
self._grad[:, :, self.embedding_dim * 2:] = value
self._grad[:, :, self.embedding_dim : self.embedding_dim * 2] = key
self._grad[:, :, self.embedding_dim * 2 :] = value

return to_tensor(self._grad)

Expand All @@ -96,8 +97,8 @@ def forward(self, tensor: Tensor):
# split the input into 3 peices
self._input = tensor
query = tensor[:, :, : self.embedding_dim]
key = tensor[:, :, self.embedding_dim: self.embedding_dim * 2]
value = tensor[:, :, self.embedding_dim * 2:]
key = tensor[:, :, self.embedding_dim : self.embedding_dim * 2]
value = tensor[:, :, self.embedding_dim * 2 :]

# Figure out how big everything is
self.batch_size = key._data.shape[0]
Expand Down Expand Up @@ -163,4 +164,5 @@ def to_gpu(self, device: int):
def from_gpu(self):
if CUPY_ENABLED:
import cupy as cp

self._mask = cp.asnumpy(self._mask)
4 changes: 2 additions & 2 deletions src/tricycle/configs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass, asdict
from dataclasses import asdict, dataclass


class GPTConfig:
Expand Down Expand Up @@ -67,7 +67,7 @@ class SmolGPTConfig(GPTConfig):
def dict(self) -> dict[str, int | float | str | bool]:
out = {}
for k, v in SmolGPTConfig.__dict__.items():
if k.startswith('__'):
if k.startswith("__"):
continue

if callable(v):
Expand Down
12 changes: 10 additions & 2 deletions src/tricycle/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,16 @@ def __getitem__(self, idx: int):
batch_outputs = np.vstack([self.outputs[i] for i in indices])

if self._to_tensor:
batch_inputs = to_tensor(batch_inputs, is_vector=self.is_vector)
batch_outputs = to_tensor(batch_outputs, is_vector=self.is_vector)
batch_inputs = to_tensor(
batch_inputs,
is_vector=self.is_vector,
dtype=batch_outputs.dtype,
)
batch_outputs = to_tensor(
batch_outputs,
is_vector=self.is_vector,
dtype=batch_outputs.dtype,
)
return batch_inputs, batch_outputs

def to_tensor(self):
Expand Down
107 changes: 69 additions & 38 deletions src/tricycle/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def mean_square_error(y_true: Tensor, y_pred: Tensor):
return square_error.mean()


class CrossEntropy(Op):
class CrossEntropy_(Op):
REALLY_SMALL_NUMBER = 1e-8
REALLY_BIG_NUMBER = 1e8

Expand Down Expand Up @@ -148,60 +148,91 @@ def forward(self, y_true: Tensor, y_pred: Tensor) -> Tensor:
return result


class BinaryCrossEntropyV2(Op):
class CrossEntropy(Op):
"""
Calculate cross entropy loss, given logits and target indices (as opposed
to one-hot encoded tensors)
"""

REALLY_SMALL_NUMBER = 1e-8
REALLY_BIG_NUMBER = 1e8

def backward(self, grad: Tensor) -> Tensor:
xp = grad.xp

self._y_pred[xp.arange(self._n_inputs), self._y_true] -= 1
self._y_pred /= self._n_inputs
self._grad = self._y_pred.reshape(self._original_shape) * grad._data

return to_tensor(self._grad, is_vector=self._input_vector)
def log_softmax(self, tensor: Tensor):
xp = tensor.xp
x_max = xp.max(tensor._data, axis=-1, keepdims=True)
log_sum_exp = x_max + xp.log(
xp.sum(xp.exp(tensor._data - x_max), axis=-1, keepdims=True)
)
return tensor._data - log_sum_exp

def forward(self, y_true: Tensor, y_pred: Tensor) -> Tensor:
# sourcery skip: assign-if-exp, reintroduce-else
"""
Calculate the cross entropy loss
"""
xp = y_pred.xp

# flatten to simplify multiple inputs
self._original_shape = y_pred.shape
self._input_vector = y_pred.is_vector
out_dim = y_pred.shape[-1]
self._y_true = y_true._data.reshape(-1)
y_pred_f = y_pred._data.reshape(-1, out_dim)
self._n_inputs = y_pred_f.shape[0]

# we scale values by the largest value in each vector
# for numeric stability
max_vals = xp.max(y_pred_f, axis=-1, keepdims=True)
scaled = y_pred_f - max_vals

log_probs = scaled - xp.log(
xp.sum(xp.exp(scaled), axis=-1, keepdims=True)
)
self._y_pred = xp.exp(log_probs)

corrected_log_probs = -log_probs[
xp.arange(self._n_inputs), self._y_true
]
self._out = corrected_log_probs.sum() / self._n_inputs
# Calculate log softmax
log_softmax_pred = self.log_softmax(y_pred)

# TODO: fuse normalising and calculation together
# Cache for backward pass
self._y_true = y_true._data
self._log_softmax_pred = log_softmax_pred

ndim = log_softmax_pred.ndim

if ndim == 3:
batch_indices = xp.arange(y_true.shape[0], dtype=int)
token_indices = xp.arange(y_true.shape[1], dtype=int)
loss = -log_softmax_pred[
batch_indices[:, None], token_indices, y_true._data
]
elif ndim == 2:
indices = xp.arange(y_true.shape[0], dtype=int)
loss = -log_softmax_pred[indices, y_true._data]
elif ndim == 1:
loss = -log_softmax_pred[y_true._data]
else:
raise NotImplementedError(
f"BinaryCrossEntropy with predictions with ndim: {ndim} are not yet supported"
)

# Mean loss over all elements
loss = loss.mean()

self._out = loss
result = to_tensor(self._out, is_vector=False)
result.back_fns = (self.backward,)

# y_true never requires grad so we dont calculate gradients for it
result.args = (y_pred,)
result.name = "cross_entropy"

return result

def backward(self, grad: Tensor) -> Tensor:
xp = grad.xp
ndim = self._log_softmax_pred.ndim

if ndim == 3:
batch_indices = xp.arange(self._y_true.shape[0], dtype=int)
token_indices = xp.arange(self._y_true.shape[1], dtype=int)
grad_output = xp.exp(self._log_softmax_pred)
grad_output[
batch_indices[:, None], token_indices, self._y_true
] -= 1
grad_output *= grad._data / (
self._y_true.shape[0] * self._y_true.shape[1]
)

elif ndim == 2:
indices = xp.arange(self._y_true.shape[0], dtype=int)
grad_output = xp.exp(self._log_softmax_pred)
grad_output[indices, self._y_true] -= 1
grad_output *= grad._data / self._y_true.shape[0]
elif ndim == 1:
grad_output = xp.exp(self._log_softmax_pred)
grad_output[self._y_true] -= 1
grad_output *= grad._data
else:
raise NotImplementedError(
f"BinaryCrossEntropy with predictions with ndim: {ndim} are not yet supported"
)

self._grad = grad_output
return to_tensor(self._grad, is_vector=grad.is_vector)
8 changes: 6 additions & 2 deletions src/tricycle/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def to_tensor(
requires_grad: bool = True,
is_vector: bool = False,
_id: int | None = None,
dtype: np.dtype = np.float32,
dtype: np.dtype | None = None,
**kwargs,
) -> Tensor:
"""
Expand All @@ -504,9 +504,13 @@ def to_tensor(

if isinstance(tensor_like, Tensor):
array = tensor_like._data
elif isinstance(tensor_like, cupy.ndarray):
elif isinstance(tensor_like, (np.ndarray, cupy.ndarray)):
array = tensor_like
if dtype is not None:
array = array.astype(dtype)
else:
if dtype is None:
dtype = np.float32
array = np.asarray(tensor_like, dtype=dtype, **kwargs)

elif isinstance(tensor_like, Tensor):
Expand Down
8 changes: 6 additions & 2 deletions tests/test_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,12 @@ def test_can_bmul():


def test_can_bdiv():
in_tensor_1 = to_tensor(np.arange(12).reshape(3, 4), is_vector=True)
in_tensor_2 = to_tensor(np.arange(1, 13).reshape(3, 4), is_vector=True)
in_tensor_1 = to_tensor(
np.arange(12).reshape(3, 4), is_vector=True, dtype=float
)
in_tensor_2 = to_tensor(
np.arange(1, 13).reshape(3, 4), is_vector=True, dtype=float
)

out_tensor = BinaryDivide()(in_tensor_1, in_tensor_2)

Expand Down
6 changes: 3 additions & 3 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tricycle.layers import ( # noqa: E501
Dense,
Dropout,
EmbeddingV2,
Embedding,
LayerNorm,
RMSNorm,
Sequential,
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_embedding():
dtype=int,
)

embedding_layer = EmbeddingV2(from_size=vocab_size, to_size=out_shape)
embedding_layer = Embedding(from_size=vocab_size, to_size=out_shape)
weights = np.indices((vocab_size * out_shape,)).reshape(
vocab_size, out_shape
)
Expand Down Expand Up @@ -135,7 +135,7 @@ def test_embedding_vectorised():
dtype=np.int8,
).to_vector()

embedding_layer = EmbeddingV2(from_size=vocab_size, to_size=out_shape)
embedding_layer = Embedding(from_size=vocab_size, to_size=out_shape)
weights = np.indices((vocab_size * out_shape,)).reshape(
vocab_size, out_shape
)
Expand Down
13 changes: 8 additions & 5 deletions tests/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def test_can_mean_square_error():


def test_can_CrossEntropy():
y_true = to_tensor([0, 0, 1])
y_pred = to_tensor([0, 0, 0])
y_true = to_tensor([1], dtype=int)
y_pred = to_tensor([[0, 0, 0]])

loss = CrossEntropy()(y_true, y_pred)

Expand All @@ -40,15 +40,18 @@ def test_CrossEntropy_vectorised():
n_tokens = 5
vocab_size = 7

y_true = np.random.random((batch_size, n_tokens, vocab_size))
y_true = np.random.randint(0, vocab_size, size=(batch_size, n_tokens))
y_pred = np.random.random((batch_size, n_tokens, vocab_size))

y_true = to_tensor(y_true).to_vector()
y_true = to_tensor(y_true, dtype=int).to_vector()
y_pred = to_tensor(y_pred).to_vector()

loss = CrossEntropy()(y_true, y_pred)

assert loss.shape == (batch_size, n_tokens)
assert loss.shape == ()


# TODO: write a proper backprop test for these loss functions


def test_can_single_linear_regression_step():
Expand Down
30 changes: 17 additions & 13 deletions tests/test_model_matches_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,26 +296,27 @@ def test_tricycle_softmax_matches_pytorch(in_shape, is_vector):


@given(tensor_shape(), st.booleans())
@example(in_shape=[32, 2], is_vector=False)
@example(in_shape=[2, 2, 4], is_vector=False)
def test_crossentropy_matches(in_shape, is_vector):
y_pred = build_tensor(in_shape, is_vector)
y_true = copy(y_pred)
y_true._data = y_pred.xp.zeros_like(y_true._data)
one_idx = y_pred.xp.random.choice(range(in_shape[-1]))
match len(in_shape):
case 1:
y_true[one_idx] = 1
case 2:
y_true[:, one_idx] = 1
y_true = np.random.randint(0, in_shape[-1], size=in_shape[:-1])
y_true = to_tensor(y_true, is_vector=is_vector, dtype=int)
assume(np.isfinite(y_pred._data).all())

tr_out = CrossEntropy()(y_true, y_pred).from_vector()
if len(in_shape) > 1:
tr_out = tr_out.mean()

p_y_pred = torch.tensor(y_pred._data, requires_grad=True)
p_y_true = torch.tensor(y_true._data, requires_grad=False)
p_out = torch.nn.functional.cross_entropy(
if len(in_shape) == 1:
p_y_pred = copy(y_pred._data)
if len(in_shape) == 2:
p_y_pred = copy(y_pred._data)
if len(in_shape) == 3:
p_y_pred = copy(y_pred._data).transpose(0, -1, 1)
p_y_pred = torch.tensor(p_y_pred, requires_grad=True)
p_y_true = torch.tensor(y_true._data, dtype=torch.long)

p_out = torch.nn.CrossEntropyLoss()(
input=p_y_pred,
target=p_y_true,
)
Expand All @@ -325,4 +326,7 @@ def test_crossentropy_matches(in_shape, is_vector):
tr_out.backward()
p_out.backward()

assert y_pred.grad.close_to(p_y_pred.grad.detach().numpy())
p_grad = p_y_pred.grad.detach().numpy()
if len(in_shape) == 3:
p_grad = p_grad.transpose(0, -1, 1)
assert y_pred.grad.close_to(p_grad)
Loading

0 comments on commit c2ee2ed

Please sign in to comment.