In [16]:
from __future__ import annotations
from typing import Optional

class Tensor:
  
    def __init__(self, value: float, name: str = None):
        self.value = float(value)
        
        #store args references when they become available, using built-in Python tuple, to build up the computation graph
        self.args: tuple[Tensor] = None # old version was ()  

        #store partial derivatives of the inputs to use in gradient calculations on the backprob path
        self.local_derivatives: tuple[Tensor] = None #old version was ()

        #placeholder to store the gradient of the final output wrt this tensor, to be computed during backpropagation
        self.derivative: Optional[Tensor] = None
        #self.derivative = Tensor(0.0)  # possible swap to try later later to try to eliminate Optional

        #Optional name for this Tensor
        self.name = name


    # build topological order
    #
    # requires as input:
    # order: list["Tensor"] = [] # will contain the Tensors in topological order
    # visited: set["Tensor"] = set() # used to track seen Tensors while building the topological list
 
    def build_topo(self, order:list[Tensor], visited:set[Tensor]):
        if self in visited:
            return
        visited.add(self)

        if self.args is not None:
            for arg in self.args:
                arg.build_topo(order, visited)

        order.append(self)  
 

    def backward(self):
        self.derivative = 1.0

        order: list[Tensor] = []
        visited: set[Tensor] = set()
        self.build_topo(order,visited)

        for node in reversed(order):
            if node.args is not None:
                for arg, local_derivative in zip(node.args, node.local_derivatives):
                    contrib = node.derivative * local_derivative.value
                    if arg.derivative is None:
                        arg.derivative = contrib
                    else:
                        arg.derivative += contrib

    def clear_derivatives(self):
        order: list[Tensor] = []
        visited: set[Tensor] = set()
        self.build_topo(order, visited)
    
        for node in order:
            node.derivative = None


    def __repr__(self) -> str:
        return f"Tensor(name={self.name}, value={self.value})"

    def __add__(self, x: Tensor |int|float):
        if isinstance(x, Tensor):
            return _add(self,x)
        elif isinstance(x, (int, float)):
            return _add(self,Tensor(x))
        else:
            return NotImplemented

    def __radd__(self, x:Tensor|int|float):
        if isinstance(x, Tensor):
            return _add(x,self)
        elif isinstance(x, (int, float)):
            return _add(Tensor(x), self)
        else:
            return NotImplemented

    def __sub__(self, x:Tensor|int|float):
        if isinstance(x, Tensor):
            return _sub(self,x)
        elif isinstance(x, (int, float)):
            return _sub(self,Tensor(x))
        else:
            return NotImplemented

    def __rsub__(self, x:Tensor|int|float):
        if isinstance(x, Tensor):
            return _sub(x,self)
        elif isinstance(x, (int, float)):
            return _sub(Tensor(x), self)
        else:
            return NotImplemented

    def __mul__(self, x:Tensor|int|float):
        if isinstance(x, Tensor):
            return _mul(self,x)
        elif isinstance(x, (int, float)):
            return _mul(self,Tensor(x))
        else:
            return NotImplemented

    def __rmul__(self, x:Tensor|int|float):
        if isinstance(x, Tensor):
            return _mul(x,self)
        elif isinstance(x, (int, float)):
            return _mul(Tensor(x), self)
        else:
            return NotImplemented


def _add(a:Tensor, b:Tensor):
    result = Tensor(a.value + b.value)
    result.local_derivatives = (Tensor(1), Tensor(1))
    result.args = (a,b)   
    return result

def _sub(a:Tensor, b:Tensor):
    result = Tensor(a.value - b.value)
    result.local_derivatives = (Tensor(1), Tensor(-1))
    result.args = (a,b)                           
    return result

def _mul(a:Tensor, b:Tensor):
    result = Tensor(a.value * b.value)
    result.local_derivatives = (b,a)
    result.args = (a,b)
    return result

def test(want: any, got: any):
    indicator = "✅" if want == got else "❌" 
    print(f"{indicator}: want {want}, got {got}")

    

In [2]:
x = Tensor(5)
y = Tensor(6)
test(_add(x,y).value,11)
test(_sub(x,y).value,-1)
test(_mul(x,y).value,30)

✅: want 11.0, got 11
✅: want -1.0, got -1
✅: want 30.0, got 30


In [3]:
# The following three cells are the full set of tests that should pass once the implementation is completed
# this is for testing the correct buildup of the computation graph
# they validate that:
# the computation graph and local derivatives are built correctly after the forward pass
# the global gradients are calculated correctly after the backward pass

a = Tensor(3)
b = Tensor(4)
output = _mul(a,b)

In [4]:
# these two tests should pass after the forward pass
test(got=output.args,              want=(a,b))       # this should work after the forward pass
test(got=output.local_derivatives, want=(b,a))       # this should work after the forward pass

✅: want (Tensor(name=None, value=3.0), Tensor(name=None, value=4.0)), got (Tensor(name=None, value=3.0), Tensor(name=None, value=4.0))
✅: want (Tensor(name=None, value=4.0), Tensor(name=None, value=3.0)), got (Tensor(name=None, value=4.0), Tensor(name=None, value=3.0))


In [5]:
# setup backward pass, TBD
# output.derivative = Tensor(1); output.backward()

# these two tests should pass after the backward pass

output.clear_derivatives()
output.backward()
test(want=b.value, got=a.derivative)                 # this should work only after the backward pass
test(want=a.value, got=b.derivative)                 # this should work only after the backward pass

# these two tests are also for after the backward pass
# leave them commented out for now as they will not run until the backwards pass
# test(got=a.derivative.value, want=4)
# test(got=b.derivative.value, want=3)

✅: want 4.0, got 4.0
✅: want 3.0, got 3.0


In [6]:
a = Tensor(3)
b = Tensor(4)
output = _mul(a,b)

output.clear_derivatives()
output.backward()

test(got=output.args, want=(a,b))
test(got=output.local_derivatives, want=(b,a))
test(got=a.derivative, want=b.value)
test(got=b.derivative, want=a.value)

✅: want (Tensor(name=None, value=3.0), Tensor(name=None, value=4.0)), got (Tensor(name=None, value=3.0), Tensor(name=None, value=4.0))
✅: want (Tensor(name=None, value=4.0), Tensor(name=None, value=3.0)), got (Tensor(name=None, value=4.0), Tensor(name=None, value=3.0))
✅: want 4.0, got 4.0
✅: want 3.0, got 3.0


In [7]:
repr(output.derivative)

'1.0'

In [8]:
#Testing topo sort
X = Tensor(2, name="X")
C = Tensor(3, name="C")
Z = _mul((_add(X,C)),X)
order: list["Tensor"] = []
visited: set["Tensor"] = set()
Z.build_topo(order,visited)
print(order)

[Tensor(name=X, value=2.0), Tensor(name=C, value=3.0), Tensor(name=None, value=5.0), Tensor(name=None, value=10.0)]


In [9]:
X = Tensor(2, name="X")
C = Tensor(3, name="C")
ZZ = C + X + C + X
print(ZZ)

Tensor(name=None, value=10.0)


In [10]:
X = Tensor(5, name="X")
C = Tensor(3, name="C")

Z1 = X + C
Z2 = C + X
Z3 = X + 2
Z4 = 2 + X

print(Z1, Z2, Z3, Z4)

Tensor(name=None, value=8.0) Tensor(name=None, value=8.0) Tensor(name=None, value=7.0) Tensor(name=None, value=7.0)


In [11]:
X = Tensor(5, name="X")
C = Tensor(3, name="C")

Z1 = X - C
Z2 = C - X
Z3 = X - 2
Z4 = 2 - X

print(Z1, Z2, Z3, Z4)

Tensor(name=None, value=2.0) Tensor(name=None, value=-2.0) Tensor(name=None, value=3.0) Tensor(name=None, value=-3.0)


In [12]:
X = Tensor(5, name="X")
C = Tensor(3, name="C")

Z1 = X * C
Z2 = C * X
Z3 = X * 2
Z4 = 2 * X

print(Z1, Z2, Z3, Z4)

Tensor(name=None, value=15.0) Tensor(name=None, value=15.0) Tensor(name=None, value=10.0) Tensor(name=None, value=10.0)


In [13]:
X = Tensor(5, name="X")
C = Tensor(3, name="C")

Z = 5 + X * (X + C)  # f(x) = x^2 + Cx + 5

Z.backward()

print("Z value")
test(want=45.0, got=Z.value)         # should be 45.0
print("dZ/dX")
test(want = 13.0, got=X.derivative)       # should be 2*5 + 3 = 13.0
print("dZ/dC")
test(want = 5.0, got=C.derivative)       # should be X = 5.0


Z value
✅: want 45.0, got 45.0
dZ/dX
✅: want 13.0, got 13.0
dZ/dC
✅: want 5.0, got 5.0


In [14]:
# 1) y = x
x = Tensor(5, name="x")
y = x * Tensor(1)  # or just construct a trivial op
y.clear_derivatives()
y.backward()
print("dy/dx")
test(want=1.0, got=x.derivative)

# 2) y = x * x
x = Tensor(3, name="x")
y = x * x         # y = x^2
y.clear_derivatives()
y.backward()
print("dy/dx")
test(want = 6.0, got=x.derivative)  # expect 2 * 3 = 6.0

# 3) y = (x + c) * x
x = Tensor(5, name="x")
c = Tensor(3, name="c")
y = (x + c) * x               # y = x^2 + cx
y.clear_derivatives()
y.backward()
print("dy/dx")
test(want = 13.0, got=x.derivative)  # expect 2*5 + 3 = 13
print("dy/dc")
test (want = 5.0, got=c.derivative)  # expect x = 5


dy/dx
✅: want 1.0, got 1.0
dy/dx
✅: want 6.0, got 6.0
dy/dx
✅: want 13.0, got 13.0
dy/dc
✅: want 5.0, got 5.0


In [15]:
x = Tensor(5, name="X")
y.clear_derivatives()
x.backward()
test(want=1.0, got=x.derivative)


✅: want 1.0, got 1.0
