In [1]:
import numpy as np

In [2]:
class Variable:
    def __init__(self, data):
        self.data = data


class Function:
    def __call__(self, input):
        x = input.data
        y = self.forward(x)
        output = Variable(y)
        return output

    def forward(self, x):
        raise NotImplementedError()


In [3]:
class Exp(Function):
    def forward(self, x):
        return np.exp(x)

In [4]:
class Square(Function):
    def forward(self, x):
        return x ** 2

In [5]:
A = Square()
B = Exp()
C = Square()

In [6]:
x = Variable(np.array(0.5))
a = A(x)
b = B(a)
y = C(b)
print(y.data)

1.648721270700128


In [7]:
## 수치 미분

def numerical_diff(f, x, eps=1e-4):
    x0 = Variable(x.data - eps)
    x1 = Variable(x.data + eps)
    y0 = f(x0)
    y1 = f(x1)
    return (y1.data - y0.data) / (2 * eps)

In [9]:
f = Square()
x = Variable(np.array(2.0))
dy = numerical_diff(f, x)
print(dy)

4.000000000004


In [10]:
## 합성함수의 미분
def f(x):
    A = Square()
    B = Exp()
    C = Square()
    return C(B(A(x)))

x = Variable(np.array(0.5))
dy = numerical_diff(f, x)
print(dy)

3.2974426293330694


In [11]:
# 하지만 수치 미분에는 오차가 포함되어있다.

In [12]:
## 역전파 이론

In [None]:
class Variable:
    def __init__(self, data):
        self.data = data

    def backward(self, retain_grad=False):
        if self.grad is None:
            self.grad = np.ones_like(self.data)

        funcs = []
        seen_set = set()

        def add_func(f):
            if f not in seen_set:
                funcs.append(f)
                seen_set.add(f)
                funcs.sort(key=lambda x : x.generation)

        add_func(self.creator)

        while funcs:
            f = funcs.pop()
            gys = [output().grad for output in f.outputs]
            gxs = f.backward(*gys)

            if not isinstance(gxs, tuple)
                gxs = (gxs,)

            for x, gx in zip(f.inputs, gxs):
                if x.grad is None:
                    x.grad = gx

                else:
                    x.grad = x.grad + gx
