In [None]:
import numpy as np
from numpy import ndarray
from typing import Dict, List, Tuple

'''assert是程序员用于保证程序的正确性，
不是用于检查使用者输入参数是否合法，
正式的代码中应该使用if-raise来检查用户输入是否正确'''


def assert_same_shape(arr: ndarray, arr_grad: ndarray):
    assert arr.shape == arr_grad.shape, \
        '''
        Two ndarrays must have the same shape; 
        instead, the first ndarray has shape {0} 
        and the second ndarray has shape {1}.
        '''.format(tuple(arr.shape), tuple(arr_grad.shape))
    return None


class Operation:
    '''
    Base class for an operation in a neural network.
    '''

    def __init__(self):
        pass

    def forward(self, input_: ndarray):
        self.input_ = input_
        self.output = self._output()
        return self.output

    def _output(self) -> ndarray:
        '''must be defined for each Operation.'''
        raise NotImplementedError()

    def _input_grad(self, output_grad: ndarray) -> ndarray:
        '''must be defined for each Operation.'''
        raise NotImplementedError()

    def backward(self, output_grad: ndarray) -> ndarray:
        assert_same_shape(output_grad, self.output)
        self.input_grad = self._input_grad(output_grad)
        assert_same_shape(self.input_, self.input_grad)
        return self.input_grad


class ParamOperation(Operation):
    def __init__(self, params: ndarray) -> ndarray:
        super().__init__()
        self.params = params

    def _param_grad(self, output_grad: ndarray) -> ndarray:
        raise NotImplementedError()

    def backward(self, output_grad: ndarray) -> ndarray:
        assert_same_shape(self.output, output_grad)
        self.input_grad = self._input_grad(output_grad)
        self.param_grad = self._param_grad(output_grad)
        assert_same_shape(self.input_grad, self.input_)
        assert_same_shape(self.param_grad, self.params)
        return self.input_grad
    
class WeightMultiply(ParamOperation)