Skip to content

Commit

Permalink
corrected some bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
bclarkson-code committed Jan 14, 2024
1 parent 939372d commit d73bc0c
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 36 deletions.
3 changes: 1 addition & 2 deletions src/tricycle_v2/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ def bsub(tensor_1: Tensor, tensor_2: Tensor) -> Tensor:
"""
assert tensor_1.shape == tensor_2.shape

tensor_2_neg = umul(tensor_2, -1)
return badd(tensor_1, tensor_2_neg)
return badd(tensor_1, umul(tensor_2, -1))


def bmul(tensor_1: Tensor, tensor_2: Tensor) -> Tensor:
Expand Down
7 changes: 4 additions & 3 deletions src/tricycle_v2/loss.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from string import ascii_lowercase

from tricycle_v2.tensor import Tensor
from tricycle_v2.binary import bsub
from tricycle_v2.reduce import radd
from tricycle_v2.tensor import Tensor


def mean_squared_error(y_true: Tensor, y_pred: Tensor) -> Tensor:
"""
Calcuate the mean square error along the final index of a tensor
"""
square_error = (y_true - y_pred) ** 2
square_error = (y_true - y_pred)**2
indices = ascii_lowercase[: len(square_error.shape)]
subscript = f"{indices}->{indices[:-1]}"
total_error = radd(square_error, subscript)
return total_error / y_true.shape[-1]
return total_error / square_error.shape[-1]
6 changes: 3 additions & 3 deletions src/tricycle_v2/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ def repeat(subscripts, tensor, out_shape):

one_indices = ""
one_shape = []
for i, out_idx in enumerate(output):
for size, out_idx in zip(out_shape, output):
if out_idx not in index:
one_indices += out_idx
one_shape.append(out_shape[i])
one_shape.append(size)

ones = to_tensor(np.ones(one_shape))
ones = to_tensor(np.ones(one_shape), requires_grad=False)
new_subscript = f"{one_indices},{index}->{output}"
return einsum(new_subscript, ones, tensor)

Expand Down
6 changes: 4 additions & 2 deletions src/tricycle_v2/reduce.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from tricycle_v2.ops import _parse_subscripts, einsum, to_tensor
from tricycle_v2.tensor import Tensor
from tricycle_v2.ops import _parse_subscripts, einsum
from tricycle_v2.tensor import Tensor, to_tensor


def radd(tensor: Tensor, subscript: str):
Expand Down Expand Up @@ -53,6 +53,7 @@ def rmax(tensor: Tensor, subscript: str):
indicator = (
tensor == np.max(tensor, axis=tuple(reduce_along_axes), keepdims=True)
).astype(int)
indicator = to_tensor(indicator, requires_grad=False)

new_subscript = f"{idx},{idx}->{output}"

Expand All @@ -79,6 +80,7 @@ def rmin(tensor: Tensor, subscript: str):
indicator = (
tensor == np.min(tensor, axis=tuple(reduce_along_axes), keepdims=True)
).astype(int)
indicator = to_tensor(indicator, requires_grad=False)

new_subscript = f"{idx},{idx}->{output}"

Expand Down
29 changes: 18 additions & 11 deletions src/tricycle_v2/tensor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import logging
from collections import defaultdict
from typing import Callable, Dict, List, Optional, Tuple

import numpy as np

logger = logging.getLogger(__name__)

Op = Callable[..., "Tensor"]


Expand All @@ -22,6 +26,7 @@ class Tensor(np.ndarray):
def backward(self):
stack: List[Tuple[Tensor, List[Op]]] = [(self, [])]
leaves: Dict[int, Tensor] = {}
adjecency_matrix = defaultdict(list)

# Find every route to a differentiable parameter
while stack:
Expand All @@ -38,11 +43,13 @@ def backward(self):

else:
for arg, op in zip(current_node.args, current_node.back_fn):
logger.info(f"{hash(current_node)=} {hash(arg)=} {op=}")
if not arg.requires_grad:
continue

new_gradient = current_gradient + [op]
stack.append((arg, new_gradient))
adjecency_matrix[(hash(current_node))].append(hash(arg))

# calculate the gradient for each parameter
for leaf in leaves.values():
Expand All @@ -62,7 +69,7 @@ def __hash__(self) -> int:
return id(self)

def __add__(self, other):
if isinstance(other, np.ndarray):
if isinstance(other, np.ndarray) and not isinstance(other, Tensor):
other = to_tensor(other)
if np.isscalar(other):
from tricycle_v2.unary import uadd
Expand All @@ -76,11 +83,10 @@ def __add__(self, other):
raise NotImplementedError(f"Cannot add {type(self)} and {type(other)}")

def __iadd__(self, other):
self = self + other
return self
return self + other

def __sub__(self, other):
if isinstance(other, np.ndarray):
if isinstance(other, np.ndarray) and not isinstance(other, Tensor):
other = to_tensor(other)
if np.isscalar(other):
from tricycle_v2.unary import usub
Expand All @@ -95,11 +101,10 @@ def __sub__(self, other):
raise NotImplementedError(f"Cannot sub {type(self)} and {type(other)}")

def __isub__(self, other):
self = self - other
return self
return self - other

def __mul__(self, other):
if isinstance(other, np.ndarray):
if isinstance(other, np.ndarray) and not isinstance(other, Tensor):
other = to_tensor(other)
if np.isscalar(other):
from tricycle_v2.unary import umul
Expand All @@ -115,11 +120,10 @@ def __mul__(self, other):
raise NotImplementedError(f"Cannot mul {type(self)} and {type(other)}")

def __imul__(self, other):
self = self * other
return self
return self * other

def __truediv__(self, other):
if isinstance(other, np.ndarray):
if isinstance(other, np.ndarray) and not isinstance(other, Tensor):
other = to_tensor(other)
if np.isscalar(other):
from tricycle_v2.unary import udiv
Expand All @@ -133,14 +137,17 @@ def __truediv__(self, other):
else:
raise NotImplementedError(f"Cannot divide {type(self)} and {type(other)}")

def __itruediv__(self, other):
return self / other

def __floordiv__(self, _):
raise NotImplementedError("Cannot floor divide")

def __mod__(self, _):
raise NotImplementedError("Cannot mod")

def __pow__(self, other):
if isinstance(other, np.ndarray):
if isinstance(other, np.ndarray) and not isinstance(other, Tensor):
other = to_tensor(other)
if np.isscalar(other):
from tricycle_v2.unary import upow
Expand Down
4 changes: 2 additions & 2 deletions src/tricycle_v2/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def upow(tensor: Tensor, constant: float) -> Tensor:
result = to_tensor(np.power(tensor, constant))
result.args = (tensor,)

coeff = to_tensor(np.power(tensor, constant - 1))
coeff = to_tensor(np.power(tensor, constant - 1), requires_grad=False)
coeff = umul(coeff, constant)

assert coeff.shape == tensor.shape
Expand Down Expand Up @@ -101,7 +101,7 @@ def umax(tensor: Tensor, constant: float) -> Tensor:

result = to_tensor(np.maximum(tensor, constant))

indicator = to_tensor((tensor > constant).astype(float))
indicator = to_tensor((tensor > constant).astype(float), requires_grad=False)
indices = ascii_letters[: len(tensor.shape)]
subscripts = f"{indices},{indices}->{indices}"

Expand Down
40 changes: 27 additions & 13 deletions tests/test_loss.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np
from matplotlib import pyplot as plt

from tricycle_v2.loss import mean_squared_error
from tricycle_v2.ops import repeat
from tricycle_v2.reduce import radd
from tricycle_v2.tensor import to_tensor
from tricycle_v2.ops import repeat


def test_can_mean_square_error():
Expand All @@ -18,24 +20,36 @@ def test_can_mean_square_error():
def test_can_linear_regression():
np.random.seed(42)

x = np.linspace(-10, 10, 201)
y = x * 2 + 1 + np.random.normal(loc=0, scale=0.01, size=201)
n = 10
learning_rate = 1e-2
x = np.linspace(-10, 10, n)
y = x * 2 + 1 + np.random.normal(loc=0, scale=0.01, size=n)

x = to_tensor(x.reshape(-1, 1))
y = to_tensor(y)
x = to_tensor(x.reshape(-1, 1), requires_grad=False, name="x")
y = to_tensor(y.reshape(-1, 1), requires_grad=False, name="y")

slope = to_tensor([0.01])
intercept = to_tensor(0.01)
slope = to_tensor([0.01], name="slope")
intercept = to_tensor([0.01], name="intercept")

losses = []
for _ in range(100):
repeated_slope = repeat("i->ji", slope, (x.shape[0],))
repeated_intercept = repeat("i->ji", intercept, (x.shape[0],))
repeated_slope = repeat("j->ij", slope, x.shape)
repeated_intercept = repeat("j->ij", intercept, x.shape)

y_pred = x * repeated_slope + repeated_intercept
loss = mean_squared_error(y, y_pred)
mse = mean_squared_error(y, y_pred)
loss = radd(mse, "i->") / y.shape[0]

losses.append(loss)

loss.backward()
breakpoint()

slope -= slope.grad
intercept -= intercept.grad
slope = to_tensor(slope - slope.grad * learning_rate, name="slope")
intercept = to_tensor(
intercept - intercept.grad * learning_rate, name="intercept"
)

_, ax = plt.subplots()
ax.plot(losses)
ax.set_yscale("log")
plt.show()

0 comments on commit d73bc0c

Please sign in to comment.