In [3]:
import numpy as np

In [23]:
class Scalar:

    def __init__(self, data, prev_scalars=(), op=''):
        self.data = data
        self._prev_scalars = set(prev_scalars)
        
        self._op = op
        self.grad = 0
        self._backward = lambda: None
        
    def backward(self):
        
        def build_topo_order(scalar, topo_order):
            for prev_scalar in scalar._prev_scalars:
                if prev_scalar not in topo_order:
                    build_topo_order(prev_scalar, topo_order)
            topo_order.append(scalar)
            
        topo_order = []
        build_topo_order(self, topo_order)
        
        self.grad = 1
        for scalar in reversed(topo_order):
            scalar._backward()
        
    def __repr__(self):
        return f'Scalar[{self.data}, {self._op}]'
        
    def __add__(self, other):
        other = other if isinstance(other, Scalar) else Scalar(other)
        out = Scalar(self.data + other.data, (self, other), '<Add>')
        
        def backward():
            self.grad += out.grad
            other.grad += out.grad
        out._backward = backward

        return out
    
    def __mul__(self, other):
        other = other if isinstance(other, Scalar) else Scalar(other)
        out = Scalar(self.data * other.data, (self, other), '<Mul>')
        
        def backward():
            self.grad += out.grad * other.data
            other.grad += out.grad * self.data
        out._backward = backward

        return out
    
    def __pow__(self, other):
        assert isinstance(other, (int, float)), "only supporting int/float powers for now"
        out = Scalar(self.data ** other, (self,), f'<Pow{other}>')
        
        def backward():
            self.grad += out.grad * other * self.data**(other-1)
            
        out._backward = backward
        
        return out
    
    def __neg__(self):
        return self * -1
    
    def __radd__(self, other):
        return self + other
    
    def __sub__(self, other):
        return self + (-other)
    
    def __rsub__(self, other):
        return other + (-self)
    
    def __rmul__(self, other):
        return self * other
    
    def __truediv__(self, other):
        return self * other**-1
    
    def __rtruediv__(self, other):
        return other * self**-1
        

In [25]:
a = Scalar(2)
b = Scalar(3)

x = 2*a + 3*b
y = 5*(a**2) + 3*(b**3)

z = 2*x + 3*y
print(z)

z.backward()

print(a.grad, b.grad)

Scalar[329, <Add>]
64 249
