---

In [524]:
import numpy as np
import uuid

In [794]:
class V:
    def __init__(self, data, ctx=None):
        self.data = data
        self.ctx = ctx
        uid = uuid.uuid4()
        if ctx is not None:
            self.id = f'{ctx.parent.__name__}({uid})'
        else:
            self.id = f'var({uid})'
            
    def assign(self, new_v):
        self.data = new_v.data

    def backward(self, d=1):
        if self.ctx is None:
            return Backward({self.id: d})

        inputs_ids = [i.id for i in self.ctx.inputs]
        raw_d = self.ctx.parent.backward(self.ctx, d)
        if type(raw_d) is not tuple and type(raw_d) is not list:
            raw_d = [raw_d]

        out = Backward(dict(zip(inputs_ids, raw_d), **{self.id: d}))
        for inp, inp_d in zip(self.ctx.inputs, raw_d):
            internal_backward = inp.backward(inp_d)
            del internal_backward.backward_dict[inp.id]
            out.extend_with(internal_backward)

        return out
        
    def __repr__(self):
        def recur(depth, v):
            indent = ' ' * depth
            if v.ctx is not None:
                parent_name = v.ctx.parent.__name__
                inputs = tuple(i.data for i in v.ctx.inputs)
                out = f'{indent}{v.data} = {parent_name}{inputs}\n'
                for c in v.ctx.inputs:
                    out += recur(depth + 2, c)
            else:
                out = f'{indent}var({v.data})\n'
            return out
        return recur(0, self)

In [795]:
class Op:
    @classmethod
    def apply(cls, *inputs):
        raw_inputs = [i.data for i in inputs]
        raw_outputs = cls.forward(*raw_inputs)
        return V(raw_outputs, ctx=Ctx(parent=cls, inputs=inputs))

In [813]:
class Backward:
    def __init__(self, backward_dict={}):
        self.backward_dict = backward_dict
        
    def __repr__(self):
        return f'Backward({self.backward_dict})'

    def extend_with(self, other_backward: Backward):
        for idx, value in other_backward.backward_dict.items():
            if idx in self.backward_dict:
                self.backward_dict[idx] += value
            else:
                self.backward_dict[idx] = value

    def wrt(self, *vs: V):
        result = []
        for v in vs:
            if v.id in self.backward_dict:
                result.append(V(self.backward_dict[v.id]))
            else:
                raise ValueError(f'var({v.id}) is not part of the computation')
        return result

In [814]:
class Ctx:
    def __init__(self, parent, inputs):
        self.parent = parent
        self.inputs = inputs
        self.raw_inputs = [i.data for i in inputs]
        self.raw_input = self.raw_inputs[0]

In [815]:
V(1).backward()

Backward({'var(b0135068-c919-4d5a-8c08-7cf07019c490)': 1})

In [816]:
class Add(Op):
    @staticmethod
    def forward(a, b):
        return a + b

    @staticmethod
    def backward(ctx, d):
        return d, d

class Sub(Op):
    @staticmethod
    def forward(a, b):
        return a - b

    @staticmethod
    def backward(ctx, d):
        return d, -d

class ReduceMean(Op):
    @staticmethod
    def forward(x):
        return np.mean(x)

    @staticmethod
    def backward(ctx, d):
        return d / ctx.raw_input.size * np.ones(ctx.raw_input.shape)

class Mul(Op):
    @staticmethod
    def forward(a, b):
        return a * b
    
    @staticmethod
    def backward(ctx, d):
        a, b = ctx.raw_inputs
        return b * d, a * d

class MM(Op):
    @staticmethod
    def forward(a, b):
        return a.dot(b)
    
    @staticmethod
    def backward(ctx, d):
        a, b = ctx.raw_inputs
        return d.dot(b.T), a.T.dot(d)

class ReLU(Op):
    @staticmethod
    def forward(x):
        return np.maximum(0, x)
    
    @staticmethod
    def backward(ctx, d):
        mask = (ctx.raw_input > 0).astype(np.uint8)
        return d * mask
    
class Sq(Op):
    @staticmethod
    def forward(x):
        return x ** 2

    @staticmethod
    def backward(ctx, d):
        return 2 * ctx.raw_input * d

In [817]:
w, b = V(1), V(3)

In [818]:
x = V(2)
y = V(4)

In [819]:
out = Sq.apply(Add.apply(Mul.apply(x, w), b))
out

25 = Sq(5,)
  5 = Add(2, 3)
    2 = Mul(2, 1)
      var(2)
      var(1)
    var(3)

In [898]:
b = out.backward()
b.wrt(x)

[var(10)]

In [899]:
W1 = V(np.random.normal(size=(1, 10)))
b1 = V(np.zeros(10))

W2 = V(np.random.normal(size=(10, 1)))
b2 = V(np.zeros(1))

In [900]:
X = V(np.linspace(0, 1, 14)[..., np.newaxis])
Y = V(np.linspace(0, 1, 14)[..., np.newaxis] * 4.5 + 5.5)

In [901]:
X.data.shape, Y.data.shape

((14, 1), (14, 1))

In [902]:
def forward():
    A0 = X
    Z1 = Add.apply(MM.apply(A0, W1), b1)
    A1 = ReLU.apply(Z1)

    Z2 = Add.apply(MM.apply(A1, W2), b2)
    A2 = Z2
    return A2

In [903]:
def loss():
    return ReduceMean.apply(Sq.apply(Sub.apply(forward(), Y)))

In [904]:
for i in range(1000):
    l = loss()
    if i % 100 == 0:
        print(l.data)

    variables = W1, b1, W2, b2
    grads = l.backward().wrt(*variables)

    lr = V(0.01)
    for v, g in zip(variables, grads):
        new_v = Sub.apply(v, Mul.apply(lr, g))
        v.assign(new_v)

94.14267907481225
1.8615497031238397
1.2219706118787703
0.9164043148250246
0.6885032785575397
0.517288800158692
0.38865138619027195
0.29200303627367663
0.21938883077325794
0.1648320499755209
