# Tiny grad

In [1]:
from functools import reduce
import operator
import math

class BaseGrad:
    def __init__(self):
        self.zero_grad()        
        
    def __add__(self, b):
        if not isinstance(b, BaseGrad):
            b = Value(b)
        return Expression('+', self, b)

    __radd__ = __add__

    def __pow__(self, b):
        if not isinstance(b, BaseGrad):
            b = Value(b)
        return Expression('^', self, b)    
    
    def __neg__(self):
        return Expression('-', self)

    def __truediv__(self, b):
        if not isinstance(b, BaseGrad):
            b = Value(b)
        return Expression('/', self, b)

    def __sub__(self, b):
        if not isinstance(b, BaseGrad):
            b = Value(b)
        return Expression('-', self, b)
            
    def __mul__(self, b):
        if not isinstance(b, BaseGrad):
            b = Value(b)        
        return Expression('*', self, b)

    __rmul__ = __mul__

    def zero_grad(self):
        self.grad = 0

class Value(BaseGrad):
    def __init__(self, v):
        super().__init__()
        self.v = v
        
    def __repr__(self):
        return f'Value({self.v})'

    @property
    def data(self):
        return self.v

    def backward(self, dydo=1)->BaseGrad:
        # dy/dx = dy/do * do/dx
        dodx = 1
        self.grad += dydo * dodx

class Expression(BaseGrad):
    def __init__(self, operator, *operands):
        super().__init__()
        self.operator = operator
        self.operands = operands

    def __repr__(self):
        'reverse polish notation repr'
        return f'{" ".join(map(str, self.operands))} {self.operator}'

    @property
    def data(self):
        if self.operator == '+':
            return reduce(operator.add, [o.data for o in self.operands], 0)
        elif self.operator == '*':
            return reduce(operator.mul, [o.data for o in self.operands], 1)
        elif self.operator == '-' and len(self.operands) == 1:
            return -self.operands[0].data
        elif self.operator == '-' and len(self.operands) > 1:
            return reduce(operator.add, 
                          [-o.data for o in self.operands[1:]], 
                          self.operands[0].data,)
        elif self.operator == '/':
            return reduce(operator.mul, 
                          [1 / o.data for o in self.operands[1:]], 
                          self.operands[0].data,)
        elif self.operator == '^':
            return self.operands[0].data**self.operands[1].data
        
        raise NotImplementedError()

    def backward(self, dydo=1)->BaseGrad:
        self.grad += 1
        if self.operator == '+':
            "d(a + b) / d a = 1"
            dodx = 1
            for o in self.operands:
                o.backward(dydo * dodx)
        elif self.operator == '*':
            "d(a * b) / d a = b"
            for i, o in enumerate(self.operands):
                dodx = self.data / o.data
                o.backward(dydo * dodx)
        elif self.operator == '-' and len(self.operands) == 1:
            "d(-a) / d a = -1"
            dodx = -1
            self.operands[0].backward(dydo * dodx)
        elif self.operator == '-' and len(self.operands) > 1:
            "d(a-b-c-d) / d a = 1; for b, c, d is -1"
            for i, o in enumerate(self.operands):
                if i == 0:
                    dodx = 1
                else:
                    dodx = -1
                o.backward(dydo * dodx)
        elif self.operator == '/':
            "d(a/b) / d a = 1 / b; d(a/b) / d b = -a b^-2"
            for i, o in enumerate(self.operands):
                if i == 0:
                    dodx = self.data / o.data
                else:
                    dodx = - self.data * o.data * o.data**(-2)
                o.backward(dydo * dodx)
        else:
            print(self.operator, self.operands)
            raise NotImplementedError()

a = Value(4.0)
b = Value(2.0)
c = a + a
# c = a + a + a + b
# c = a * 3
# c = -a * b
c = a / b
# c = 3 * (3 + (a - b) * 2)

c.backward()
a.grad, b.grad, c.grad, c, c.data

(0.5, -1.0, 1, Value(4.0) Value(2.0) /, 2.0)

# symbolic version

In [2]:
from functools import reduce
import operator
import math

class BaseGrad:
    def __init__(self):
        self.zero_grad()

    @property
    def grad(self):
        if isinstance(self._grad, float) or isinstance(self._grad, int):
            return Value(self._grad)
        return self._grad

    @grad.setter
    def grad(self, n):
        assert isinstance(n, BaseGrad)
        self._grad = n
        
    def __add__(self, b):
        if not isinstance(b, BaseGrad):
            b = Value(b)
        return Expression('+', self, b)
    
    __radd__ = __add__

    def __pow__(self, b):
        if not isinstance(b, BaseGrad):
            b = Value(b)
        return Expression('^', self, b)    
    
    def __neg__(self):
        return Expression('-', self)

    def __truediv__(self, b):
        if not isinstance(b, BaseGrad):
            b = Value(b)
        return Expression('/', self, b)

    def __sub__(self, b):
        if not isinstance(b, BaseGrad):
            b = Value(b)
        return Expression('-', self, b)
            
    def __mul__(self, b):
        if not isinstance(b, BaseGrad):
            b = Value(b)        
        return Expression('*', self, b)

    __rmul__ = __mul__

    def zero_grad(self):
        self._grad = 0

class Value(BaseGrad):
    def __init__(self, v):
        super().__init__()
        self.name = v
        self.v = v

    @property
    def has_numeric_value(self):
        return isinstance(self.v, int) or isinstance(self.v, float)
        
    def __repr__(self):
        if self.name != self.v:
            return f'Value({self.name}={self.v})'
        return f'Value({self.name})'

    def simplify(self)->BaseGrad:
        return self # nothing to simplify
        
    @property
    def data(self)->float:
        assert self.has_numeric_value, f'.data not supported for Variable {self} without numeric .v attributes'
        return self.v

    def backward(self, dydo=1)->BaseGrad:
        # dy/dx = dy/do * do/dx
        dodx = Value(1)
        self.grad += dydo * dodx

class Expression(BaseGrad):
    def __init__(self, operator, *operands):
        super().__init__()
        self.operator = operator
        self.operands = operands

    def __repr__(self):
        'reverse polish notation repr'
        return f'{" ".join(map(str, self.operands))} {self.operator}'

    def simplify(self)->BaseGrad: # TODO: simiplify the expression

        new_operands = []
        all_numeric = True
        for o in self.operands:
            simplified_operand = o.simplify()
            new_operands.append(simplified_operand)
            all_numeric = all_numeric and simplified_operand.has_numeric_value
        new_expr = Expression(self.operator, *new_operands)
        if all_numeric:
            return Value(new_expr.data)
        else:
            return new_expr
        
        # if self.operator == '+':
        #     return reduce(operator.add, [o.data for o in self.operands], 0)
        # elif self.operator == '*':
        #     return reduce(operator.mul, [o.data for o in self.operands], 1)
        # elif self.operator == '-' and len(self.operands) == 1:
        #     return -self.operands[0].data
        # elif self.operator == '-' and len(self.operands) > 1:
        #     return reduce(operator.add, 
        #                   [-o.data for o in self.operands[1:]], 
        #                   self.operands[0].data,)
        # elif self.operator == '/':
        #     return reduce(operator.mul, 
        #                   [1 / o.data for o in self.operands[1:]], 
        #                   self.operands[0].data,)
        # elif self.operator == '^':
        #     return self.operands[0].data**self.operands[1].data
        
        raise NotImplementedError()
                
    @property
    def data(self)->float:
        if self.operator == '+':
            return reduce(operator.add, [o.data for o in self.operands], 0)
        elif self.operator == '*':
            return reduce(operator.mul, [o.data for o in self.operands], 1)
        elif self.operator == '-' and len(self.operands) == 1:
            return -self.operands[0].data
        elif self.operator == '-' and len(self.operands) > 1:
            return reduce(operator.add, 
                          [-o.data for o in self.operands[1:]], 
                          self.operands[0].data,)
        elif self.operator == '/':
            return reduce(operator.mul, 
                          [1 / o.data for o in self.operands[1:]], 
                          self.operands[0].data,)
        elif self.operator == '^':
            return self.operands[0].data**self.operands[1].data
        
        raise NotImplementedError()

    def backward(self, dydo=Value(1))->BaseGrad:
        self.grad += 1
        if self.operator == '+':
            "d(a + b) / d a = 1"
            dodx = 1
            for o in self.operands:
                o.backward(dydo * dodx)
        elif self.operator == '*':
            "d(a * b) / d a = b"
            for i, o in enumerate(self.operands):
                dodx = self / o
                o.backward(dydo * dodx)
        elif self.operator == '-' and len(self.operands) == 1:
            "d(-a) / d a = -1"
            dodx = -1
            self.operands[0].backward(dydo * dodx)
        elif self.operator == '-' and len(self.operands) > 1:
            "d(a-b-c-d) / d a = 1; for b, c, d is -1"
            for i, o in enumerate(self.operands):
                if i == 0:
                    dodx = 1
                else:
                    dodx = -1
                o.backward(dydo * dodx)
        elif self.operator == '/':
            "d(a/b) / d a = 1 / b; d(a/b) / d b = -a b^-2"
            for i, o in enumerate(self.operands):
                if i == 0:
                    dodx = self / o
                else:
                    dodx = - self * o * o**(-2)
                o.backward(dydo * dodx)
        elif self.operator == '^':
            "d(a^b)/da = b a^(b-1); d(a^b)/db = a^b ln(a)"
            a, b = operands
            a.backward(dydo * b * a**(b-1))
            # TODO: implement math.log operator on Expression
            b.backward(dydo * a**b * math.log(a))
        else:
            print(self.operator, self.operands)
            raise NotImplementedError()

a = Value(4.0)
b = Value(2.0)
c = a + a
c = a + a + a + b
c = a * 3
c = -a * b
c = a / b
# c = 3 * (3 + (a - b) * 2)

c.backward()
a.grad.data, b.grad.data, c.grad.data, c, c.data

(0.5, -1.0, 1, Value(4.0) Value(2.0) /, 2.0)

In [3]:
d = Value('a')
d.v = 1
d

Value(a=1)

In [4]:
a.grad.simplify()

Value(0.5)