In [66]:
from typing import Union, List
import numpy as np

a = 3
a.__add__(4)

7

In [67]:
a = np.array([2, 3, 1, 0])
a.__add__(4)

array([6, 7, 5, 4])

In [68]:
a + 4

array([6, 7, 5, 4])

In [69]:
# Gradient Accumulation
Numberable = Union[float, int]

def ensure_number(num: Numberable):
    if isinstance(num, NumberWithGrad):
        return num
    return NumberWithGrad(num)

class NumberWithGrad():
    def __init__(self, num: Numberable, depends_on: List[Numberable] = None, creation_op: str = ""):
        self.num = num
        self.grad = None
        self.depends_on = depends_on or []
        self.creation_op = creation_op

    def __add__(self, other: Numberable):
        return NumberWithGrad(self.num + ensure_number(other).num, depends_on=[self, ensure_number(other)], creation_op='add')

    def __mul__(self, other: Numberable = None):
        return NumberWithGrad(self.num * ensure_number(other).num, depends_on=[self, ensure_number(other)], creation_op='mul')

    def backward(self, backward_grad: Numberable = None):
        if backward_grad is None:
            self.grad = 1
        else:
            if self.grad is None:
                self.grad = backward_grad
            else:
                self.grad += backward_grad
        if self.creation_op == "add":
            self.depends_on[0].backward(self.grad)
            self.depends_on[1].backward(self.grad)
        if self.creation_op == "mul":
            new = self.depends_on[1] * self.grad
            self.depends_on[0].backward(new.num)
            new = self.depends_on[0] * self.grad
            self.depends_on[1].backward(new.num)

In [70]:
a = NumberWithGrad(3)
b = a * 4
c = b + 3
a.num, b.num, c.num

(3, 12, 15)

In [71]:
c.backward()
a.grad, b.grad, c.grad
# a.grad = dc/da = 4

(4, 1, 1)

In [78]:
a = NumberWithGrad(3)
b = a * 4
c = b + 3
d = a + 2
e = c * d
# e = (a * 4 + 3) * (a + 2) = 4a**2 + 11a + 6
a.num, b.num, c.num, d.num, e.num


(3, 12, 15, 5, 75)

In [79]:
e.backward()
a.grad, b.grad, c.grad, d.grad, e.grad
# d.grad = de/dd = c = 15
# c.grad = de/dc = d = 5
# b.grad = de/db = de/dc * dc/db = 5 * 1 = 5
# 1. a.grad = de/da = d(c*d)/da = dc/da*d + dd/da*c = dc/db*db/da*d + dd/da*c = 1*4*5 + 1*15 = 20 + 15 = 35
# 2. a.grad = de/da = 8a + 11 = 35
# 3. a.grad = de/dc + de/dd = 

(35, 5, 5, 15, 1)