In [10]:
import numpy as np

## Variable Like a Tag

> 变量像个标签

In [11]:
class Variable:
    def __init__(self, data):
        if data is not None:
            if not isinstance(data, np.ndarray):
                raise TypeError('{} is not supported'.format(type(data)))

        self.data = data
        self.grad = None
        self.creator = None

    def set_creator(self, func):
        self.creator = func

    def backward(self):
        if self.grad is None:
            self.grad = np.ones_like(self.data) # TODO: ques

        funcs = [self.creator]
        while funcs:
            f = funcs.pop()
            x, y = f.input, f.output
            x.grad = f.backward(y.grad)

            if x.creator is not None:
                funcs.append(x.creator)

In [12]:
import unittest

class TestVariable(unittest.TestCase):
    """
    x = Variable(np.array(1.0))  # OK
    x = Variable(None)  # OK
    x = Variable(1.0)  # NG
    """
    def test_init_with_ndarray(self):
        data = np.array(1.0)
        var = Variable(data)
        self.assertIsInstance(var.data, np.ndarray)
        self.assertEqual(var.data, data)

    def test_init_with_none(self):
        var = Variable(None)
        self.assertIsNone(var.data)

    def test_init_with_unsupported_type(self):
        with self.assertRaises(TypeError) as context:
            Variable(1.0)
        self.assertTrue('is not supported' in str(context.exception))

    def test_set_creator(self):
        var = Variable(np.array(1.0))
        var.set_creator('creator_func')
        self.assertEqual(var.creator, 'creator_func')

    # Additional tests for backward can be added here
    # depending on the implementation of the function and its context
    
unittest.main(argv=[''], verbosity=2, exit=False)

test_init_with_ndarray (__main__.TestVariable) ... ok
test_init_with_none (__main__.TestVariable) ... ok
test_init_with_unsupported_type (__main__.TestVariable) ... ok
test_set_creator (__main__.TestVariable) ... ok

----------------------------------------------------------------------
Ran 4 tests in 0.001s

OK


<unittest.main.TestProgram at 0x1a7ffb724c8>

In [13]:
def as_array(x):
    if np.isscalar(x):
        return np.array(x)
    return x

In [14]:
class Function:
    def __call__(self, input_variable: Variable):
        input_data = input_variable.data
        output_data = self.forward(input_data)
        output_variable = Variable(as_array(output_data))
        output_variable.set_creator(self)
        self.input: Variable = input_variable
        self.output: Variable = output_variable
        return output_variable
    
    def forward(self, input_data):
        raise NotImplementedError()
    
    def backward(self, output_data):
        """
        
        :param output_data: gradient of the output with respect to input data.
        :return: 
        """
        raise NotImplementedError()

In [15]:
class Square(Function):
    def forward(self, input_data):
        """
        
        :param input_data: any data, or `x`
        :return: output, or `y`
        """
        output = input_data ** 2
        return output
    
    def backward(self, output_data):
        """
        
        :param output_data: `gy`
        :return: grad_input, or `gx`
        """
        input_data = self.input.data
        grad_input = 2 * input_data * output_data
        return grad_input

In [16]:
class Exp(Function):
    def forward(self, input_data):
        return np.exp(input_data)
    
    def backward(self, output_data):
        input_data = self.input.data
        grad_input = np.exp(input_data) * output_data
        return grad_input

In [17]:
def square(x):
    return Square()(x)


def exp(x):
    return Exp()(x)

In [18]:
x = Variable(np.array(0.5))
y = square(exp(square(x)))
y.backward()
print(x.grad)

3.297442541400256
