In [16]:
%config ZMQInteractiveShell.ast_node_interactivity = 'all'
import numpy as np

In [12]:
def softmax(x):
    orig_shape = x.shape
    if len(orig_shape) > 1:
        x -= x.max(axis=1, keepdims=True)
        x = np.exp(x)
        x = x / x.sum(axis=1, keepdims=True)
    else:
        x -= x.max()
        x = np.exp(x)
        x = x / x.sum()
    return x

In [17]:
test2 = softmax(np.array([[1001,1002],[3,4]]))
test2

array([[0.26894142, 0.73105858],
       [0.26894142, 0.73105858]])

In [20]:
def sigmoid(x):
    s = 1 / (1 + np.exp(-x))
    return s

def sigmoid_grad(s):
    ds = s*(1-s)
    return ds

In [21]:
x = np.array([[1, 2], [-1, -2]])
f = sigmoid(x)
g = sigmoid_grad(f)
f
g

array([[0.73105858, 0.88079708],
       [0.26894142, 0.11920292]])

array([[0.19661193, 0.10499359],
       [0.19661193, 0.10499359]])

In [11]:
def forward_backward_prop(X, label, params, dimensions):
    ofs = 0
    Dx, H, Dy = (dimensions[0], dimensions[1], dimensions[2])

    W1 = np.reshape(params[ofs:ofs+ Dx * H], (Dx, H))
    ofs += Dx * H
    b1 = np.reshape(params[ofs:ofs + H], (1, H))
    ofs += H
    W2 = np.reshape(params[ofs:ofs + H * Dy], (H, Dy))
    ofs += H * Dy
    b2 = np.reshape(params[ofs:ofs + Dy], (1, Dy))
    
    h = sigmoid(np.dot(X, W1)+b1)
    yhat = softmax(np.dot(h, W2)+b2)
    
    cost = np.sum(-np.log(yhat[label==1]))/X.shape[0]
    d3 = (yhat - label)/X.shape
    gradW2 = np.dot(h.T, d3)
    granb2 = np.sum(d3, 0, keepdims=True)
    
    dh = np.dot(d3, W2.T)
    grad_h = sigmoid_grad(h) * dh
    gradW1 = np.dot(X.T, grad_h)
    gradb1 = np.sum(grad_h, 0, keepdims=True)
    grad = np.concatenate((gradW1.flatten(), gradb1.flatten(), gradW2.flatten(), gradb2.flatten()))
    return grad, cost

array([[1002],
       [   4]])

In [27]:
import random
random.getstate()

(3,
 (2147483648,
  2491384097,
  3415681212,
  3896813050,
  2881236146,
  1826527951,
  1611563305,
  1487305906,
  645837433,
  2886943891,
  1020829296,
  772909907,
  3237145675,
  804726640,
  2518979481,
  3352502301,
  3947969917,
  3055960811,
  3798267424,
  374870830,
  1001738657,
  3538117892,
  3030069055,
  3354489862,
  3143389810,
  1045793392,
  1263023092,
  271718630,
  1171441766,
  2275231880,
  3825832458,
  3632964361,
  3766200192,
  2499983236,
  288949405,
  361894633,
  135948571,
  1700462513,
  712160629,
  3357018828,
  1311201973,
  2769727167,
  1030532920,
  4077770039,
  616655565,
  1007227544,
  2002910266,
  381528278,
  3209273276,
  4139856766,
  1201627805,
  2573836771,
  914775810,
  1237451994,
  1718767004,
  3223911362,
  1233808036,
  4002799551,
  2831599059,
  149821846,
  2555618664,
  864144404,
  3561512692,
  3332328944,
  2973053615,
  869336360,
  3281910058,
  92474570,
  4094505598,
  2355317547,
  281499210,
  1655198725,
  1539