In [2]:
import numpy as np
from typing import Callable, List, Tuple

# A Function takes in an ndarray as an argument and produces an ndarray
ArrayFunc = Callable[[np.ndarray], np.ndarray]
# A Chain is a list of functions
Chain = List[ArrayFunc]


def derivative(func: ArrayFunc,
               x: np.ndarray,
               dx: float = 1e-3) -> np.ndarray:
    """中心差分"""
    return (func(x + dx) - func(x - dx)) / (2 * dx)


def chain_length_2(chain: Chain, x: np.ndarray) -> np.ndarray:
    assert len(chain) == 2, "Length of input 'chain' should be 2"
    f1 = chain[0]
    f1 = chain[0]
    f2 = chain[1]
    return f2(f1(x))


input_x = np.array([1, 2, 3, 4, 5])
chain_length_2([np.sqrt, np.square], input_x)

array([1., 2., 3., 4., 5.])

In [2]:
def sigmoid(x: np.ndarray) -> np.ndarray:
    return 1 / (1 + np.exp(-x))


def leaky_relu(x: np.ndarray) -> np.ndarray:
    return np.maximum(0.2 * x, x)


def chain_derivative_2(chain: Chain, x: np.ndarray) -> np.ndarray:
    '''
    compute the derivative of two nested functions with respect to x:
    (f2(f1(x))' = f2'(f1(x)) * f1'(x)
    '''
    assert len(chain) == 2, "This function requires 'Chain' objects of length 2"
    f1 = chain[0]
    f2 = chain[1]

    # f1(x)
    f1_of_x = f1(x)

    # df1/dx
    df1_dx = derivative(f1, x)
    # df2(f1(x)) * df1(x) = df2/du(f1(x)) * df1/du(x) 
    df2_du = derivative(f2, f1_of_x)
    return df2_du * df1_dx


def chain_derivative_3(chain: Chain, x: np.ndarray) -> np.ndarray:
    '''
    compute the derivative of three nested functions with respect to x:
    (f3(f2(f1(x))))' = f3'(f2(f1(x))) * f2'(f1(x)) * f1'(x)
    '''
    assert len(chain) == 3, "This function requires 'Chain' objects of length 3"
    f1 = chain[0]
    f2 = chain[1]
    f3 = chain[2]

    # f1(x)
    f1_of_x = f1(x)

    # f2(f1(x))
    f2_of_x = f2(f1_of_x)

    # df1/dx
    df1_dx = derivative(f1, x)

    # df2/du(f1(x))
    df2_du = derivative(f2, f1_of_x)

    # df3/du(f2(f1(x)))
    df3_du = derivative(f3, f2_of_x)
    return df3_du * df2_du * df1_dx


input_x = np.arange(-3, 3, 0.8)
print(input_x)
c2 = chain_derivative_2([sigmoid, np.square], input_x)
print(c2)
c3 = chain_derivative_3([leaky_relu, sigmoid, np.square], input_x)
print(c3)

[-3.  -2.2 -1.4 -0.6  0.2  1.   1.8  2.6]
[0.00428509 0.01791525 0.06278086 0.1621365  0.27218603 0.28746967
 0.20892382 0.11981735]
[0.0324273  0.03733761 0.04221259 0.04683478 0.27218603 0.28746967
 0.20892382 0.11981735]


In [5]:
def multiple_inputs_addition(x: np.ndarray,
                             y: np.ndarray,
                             sigma: ArrayFunc) -> float:
    assert x.shape == y.shape
    a = x + y
    return sigma(a)


def multiple_inputs_add_backward(x: np.ndarray, y: np.ndarray, sigma: ArrayFunc) -> float:
    # compute forward pass
    a = x + y
    ds_da = derivative(sigma, a)
    da_dx, da_dy = 1, 1
    return ds_da * da_dx, ds_da * da_dy


x = np.array([1, 2, 3, 4, 5])
y = np.array([3, 4, 5, 6, 7])
multiple_inputs_addition(x, y, np.sum)

40

In [None]:
def matmul_forward(X: np.ndarray, W: np.ndarray) -> np.ndarray:
    assert X.shape[1] == W.shape[0], \
    '''
    For matrix multiplication, the number of columns in the first array should match the number of rows in the second; instead the number of columns in the    first array is {X.shape[1]} and the number of rows in the second array is {W.shape[0]}.
    '''
    N = np.dot(X, W)
    return N

def matmul_backward_first(X: np.ndarray, W: np.ndarray) -> np.ndarray:
    dN_dX = np.transpose(W, (1, 0))
    return dN_dX

def matrix_forward_extra(X: np.ndarray, W: np.ndarray, sigma:ArrayFunc) -> np.ndarray:
    assert X.shape[1] == W.shape[0]
    N = np.dot(X, W)
    S = sigma(N)
    return S

def matrix_function_backward_1(X: np.ndarray, W: np.ndarray, sigma:ArrayFunc) -> np.ndarray:
    assert X.shape[1] == W.shape[0]
    N = np.dot(X, W)
    S = sigma(N)
    dS_dN = derivative(sigma, N)
    dN_dX = np.transpose(W, (1, 0))
    return np.dot(dS_dN, dN_dX)

In [5]:
rd  = np.random.rand(3, 6)
print(rd)
print(np.swapaxes(rd, 1, 0))
print(np.swapaxes(rd, 0, 1))

[[0.36040529 0.74915431 0.91171237 0.57924666 0.82155248 0.25016635]
 [0.85657514 0.72022414 0.78123028 0.64038329 0.93329273 0.05486043]
 [0.80934961 0.54389149 0.9552091  0.35741833 0.47891503 0.29096131]]
[[0.36040529 0.85657514 0.80934961]
 [0.74915431 0.72022414 0.54389149]
 [0.91171237 0.78123028 0.9552091 ]
 [0.57924666 0.64038329 0.35741833]
 [0.82155248 0.93329273 0.47891503]
 [0.25016635 0.05486043 0.29096131]]
[[0.36040529 0.85657514 0.80934961]
 [0.74915431 0.72022414 0.54389149]
 [0.91171237 0.78123028 0.9552091 ]
 [0.57924666 0.64038329 0.35741833]
 [0.82155248 0.93329273 0.47891503]
 [0.25016635 0.05486043 0.29096131]]
