# Tiny grad

In [96]:
from functools import reduce
import operator

class BaseGrad:
    def __init__(self):
        self.grad = 0
        
    def __add__(self, b):
        assert isinstance(b, BaseGrad)
        return Expression('+', self, b)

    def __mul__(self, b):
        condition = True
        for t in [BaseGrad, int, float]:
            condition = condition or isinstance(b, t)
        assert condition
        if not isinstance(b, BaseGrad):
            return Expression('*', self, Value(b))
        return Expression('*', self, b)

    def __neg__(self):
        return Expression('-', self)
        
    def __rmul__(self, b):
        # print('__rmul__')
        condition = True
        for t in [int, float]:
            condition = condition or isinstance(b, t)
        assert condition
        
        if not isinstance(b, BaseGrad):
            return Expression('*', self, Value(b))

    def zero_grad(self):
        self.grad = 0

class Value(BaseGrad):
    def __init__(self, v):
        super().__init__()
        self.v = v
        
    def __repr__(self):
        return f'Value({self.v})'

    @property
    def data(self):
        return self.v

    def backward(self, dydo=1):
        # dy/dx = dy/do * do/dx
        dodx = 1
        self.grad += dydo * dodx

class Expression(BaseGrad):
    def __init__(self, operator, *operands):
        super().__init__()
        self.operator = operator
        self.operands = operands

    def __repr__(self):
        'reverse polish notation repr'
        return f'{" ".join(map(str, self.operands))} {self.operator}'

    @property
    def data(self):
        if self.operator == '+':
            return reduce(operator.add, [o.data for o in self.operands], 0)
        elif self.operator == '*':
            return reduce(operator.mul, [o.data for o in self.operands], 1)
        elif self.operator == '-' and len(self.operands) == 1:
            return -self.operands[0].data
        raise NotImplementedError()

    def backward(self, dydo=1):
        self.grad += 1
        if self.operator == '+':
            "d(a + b) / d a = 1"
            dodx = 1
            for o in self.operands:
                o.backward(dydo * dodx)
        elif self.operator == '*':
            "d(a * b) / d a = b"
            for i, o in enumerate(self.operands):
                dodx = self.data / o.data
                o.backward(dydo * dodx)
        elif self.operator == '-' and len(self.operands) == 1:
            "d(-a) / d a = -1"
            dodx = -1
            return self.operands[0].backward(dydo * dodx)
        else:
            print(self.operator, self.operands)
            raise NotImplementedError()

a = Value(4.0)
b = Value(2.0)
# c = a + a + a + b
# c = a * 3
c = -a * b
# TODO: support __minus__ and __div__
c.backward()
a.grad, b.grad, c.grad, c

(-2.0, -4.0, 1, Value(4.0) - Value(2.0) *)