Skip to content

Commit

Permalink
Added test for mean_squared_error
Browse files Browse the repository at this point in the history
  • Loading branch information
bclarkson-code committed Jan 14, 2024
1 parent f7e8b1f commit 939372d
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions tests/test_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import numpy as np

from tricycle_v2.loss import mean_squared_error
from tricycle_v2.tensor import to_tensor
from tricycle_v2.ops import repeat


def test_can_mean_square_error():
y_true = to_tensor([[0, 0, 1], [0, 1, 0], [1 / 3, 1 / 3, 1 / 3]])
y_pred = to_tensor([[0, 0, 1], [0, 0, 1], [0, 0, 1]])

mse = mean_squared_error(y_true, y_pred)

assert mse.shape == (3,)
assert np.allclose(mse, np.array([0, 2 / 3, 2 / 9]))


def test_can_linear_regression():
np.random.seed(42)

x = np.linspace(-10, 10, 201)
y = x * 2 + 1 + np.random.normal(loc=0, scale=0.01, size=201)

x = to_tensor(x.reshape(-1, 1))
y = to_tensor(y)

slope = to_tensor([0.01])
intercept = to_tensor(0.01)

for _ in range(100):
repeated_slope = repeat("i->ji", slope, (x.shape[0],))
repeated_intercept = repeat("i->ji", intercept, (x.shape[0],))

y_pred = x * repeated_slope + repeated_intercept
loss = mean_squared_error(y, y_pred)

loss.backward()
breakpoint()

slope -= slope.grad
intercept -= intercept.grad

0 comments on commit 939372d

Please sign in to comment.