In [5]:
import numpy as np

In [144]:
class DenseLayer(object):
    
    def __init__(self, input_dim, output_dim, use_bias=True):
        sq = np.sqrt(1. / input_dim)
        self.use_bias = use_bias
        self.weights = np.random.uniform(-sq, sq, (output_dim, input_dim))
        if use_bias == True:
            self.bias = np.random.uniform(-sq, sq, output_dim)
        else:
            self.bias = np.zeros((output_dim))
        
    def forward(self, X_in):
        return np.tensordot(X_in, self.weights.T, axes=(-1, 0)) + self.bias
    
    def backward(self, dEdY, X_in):
        # dEdW = dEdY * dYdW = dEdY * X
        # dEdb = dEdY * dYdb = dEdY
        # dEdX = dEdY * dYdX = dEdY * W
        axis = tuple(range(len(X.shape) - 1))
        dEdW = np.tensordot(dEdY,X_in, axes=((axis), (axis)))
        dEdB = np.sum(dEdY, axis=(axis))
        dEdX = np.tensordot(dEdY, self.weights, axes=(-1, 0))
        
        return dEdX, dEdW, dEdB
    
    def refresh(self, dEdW, dEdB, learning_rate):
        self.weights = self.weights - learning_rate * dEdW
        if self.use_bias == True:
            self.bias = self.bias - learning_rate * dEdB

In [139]:
dense = DenseLayer(3, 1, True)

X = np.array([[0.2, 0.5, 0.3],
[0.2, 0.4, 0.4],
[0.3, 0.1, 0.6],
[0.2, 0.3, 0.5],
[0.5, 0.3, 0.2]])

Y = np.array([[1.], [1.], [1.], [1.], [1.]])

In [145]:
num_iter = 300

for i in range(num_iter):
    Y_c= dense.forward(X)
    dEdY = Y - Y_c
    _, dEdW, dEdB = dense.backward(dEdY, X)
    dense.refresh(dEdW, dEdB, 0.1)
    print(dense.weights)

[[-1.54068466e+133 -1.76332808e+133 -2.21614504e+133]]
[[-2.57367756e+133 -2.94559818e+133 -3.70201828e+133]]
[[-4.29926797e+133 -4.92055263e+133 -6.18413466e+133]]
[[-7.18182624e+133 -8.21966770e+133 -1.03304518e+134]]
[[-1.19970722e+134 -1.37307620e+134 -1.72567773e+134]]
[[-2.00408275e+134 -2.29369157e+134 -2.88270415e+134]]
[[-3.34777321e+134 -3.83155793e+134 -4.81548964e+134]]
[[-5.59237658e+134 -6.40052761e+134 -8.04416242e+134]]
[[-9.34193383e+134 -1.06919312e+135 -1.34375845e+135]]
[[-1.56054812e+135 -1.78606201e+135 -2.24471696e+135]]
[[-2.60685902e+135 -2.98357467e+135 -3.74974701e+135]]
[[-4.35469681e+135 -4.98399146e+135 -6.26386438e+135]]
[[-7.27441882e+135 -8.32564078e+135 -1.04636384e+136]]
[[-1.21517459e+136 -1.39077875e+136 -1.74792624e+136]]
[[-2.02992064e+136 -2.32326327e+136 -2.91986977e+136]]
[[-3.39093479e+136 -3.88095677e+136 -4.87757393e+136]]
[[-5.66447698e+136 -6.48304720e+136 -8.14787278e+136]]
[[-9.46237587e+136 -1.08297782e+137 -1.36108303e+137]]
[[-1.58066