In [1]:
import numpy as np
from mynn import op as nn

batchsize = 16

# test 1 # 
# input = np.random.randn(batchsize, 10)
# l1 = nn.Linear(10, 20)
# l1(input).shape
# grad = np.random.randn(batchsize, 20)
# l1.backward(grad).shape

# test 2 #
def test2(verbose=False):
    iter_time = 10000
    N = 16*iter_time
    X_data = np.random.randn(N, 10)
    W_gt = np.random.randn(10, 20)
    b_gt = np.random.randn(20)
    Y_data = X_data@W_gt + b_gt
    l2 = nn.Linear(10, 20)
    for i in range(iter_time):
        X = X_data[i*16:(i+1)*16, :]
        Y_gt = Y_data[i*16:(i+1)*16, :]
        Y_pred = l2(X)
        loss = np.linalg.norm(Y_gt-Y_pred)
        if verbose:
            print(loss)
        grad = -(Y_gt-Y_pred)
        l2.backward(grad)
        for key in l2.grads.keys():
            l2.params[key] -= 1*l2.grads[key]
            if np.any(np.isnan(l2.params[key])):
                print(l2.grads[key])
                print(grad)
                raise ValueError("Need to break!!!") 
        # print([np.linalg.norm(l2.params[key]) for key in l2.grads.keys()])
    print(f"the residual norm is {np.linalg.norm(l2.W-W_gt), np.linalg.norm(l2.b-b_gt)}")

def test3(verbose=False):
    iter_time = 200000
    N = 16*iter_time
    X_data = np.random.randn(N, 10)
    W_gt = np.random.randn(10, 20)
    b_gt = np.random.randn(20)
    Y_data = X_data@W_gt + b_gt
    l1 = nn.Linear(10, 10)
    relu = nn.ReLU()
    l2 = nn.Linear(10, 20)
    for i in range(iter_time):
        X = X_data[i*16:(i+1)*16, :]
        Y_gt = Y_data[i*16:(i+1)*16, :]
        Y_pred = l2(relu(l1(X)))
        loss = np.linalg.norm(Y_gt-Y_pred)
        if verbose:
            print(loss)
        grad = -(Y_gt-Y_pred)

        # passing the grad to l2!
        passing_grad = l2.backward(grad)
        print(f"norm of the grad is {np.linalg.norm(grad)}")
        for key in l2.grads.keys():
            l2.params[key] -= 0.01*l2.grads[key]
            if np.any(np.isnan(l2.params[key])):
                print(l2.grads[key])
                print(f"grad is {grad}")
                raise ValueError("l2 Need to break!!!") 
            
        # passing the grad to relu!
        passing_grad = relu.backward(passing_grad)
        # no params to optimize for relu!

        # passing the grad to l1!
        l1.backward(passing_grad)
        print(np.linalg.norm(passing_grad))
        for key in l1.grads.keys():
            l1.params[key] -= 0.01*l1.grads[key]
            if np.any(np.isnan(l1.params[key])):
                print(l1.grads[key])
                print(grad)
                raise ValueError("l1 Need to break!!!")        

def test4():
    con1 = nn.Conv2D(in_channels=3, out_channels=6, kernel_size=4)
    con2 = nn.Conv2D(in_channels=6, out_channels=12, kernel_size=5, stride=2)
    X = np.random.rand(16, 3, 32, 32)
    print(con2(con1(X)).shape)
    grad = np.zeros((16, 12, 13, 13))
    con2.backward(grad)

def test5():
    predict = np.random.rand(5, 10)
    lable = np.array([2, 1, 4, 3, 6])
    loss = nn.MultiCrossEntropyLoss()
    print(loss(predicts=predict, labels=lable))
    loss.backward()
    print(loss.grads)

if __name__ == "__main__":
    test5()

2.4704513749915966
[[ 0.08010519  0.09514217 -0.27446077  0.09705294  0.07291801  0.1262416
   0.12853563  0.1381744   0.10687867  0.06588128]
 [ 0.14426688 -0.13856374  0.07343536  0.07732372  0.08723299  0.06565907
   0.10860828  0.08513169  0.10661724  0.17211681]
 [ 0.11519213  0.05843734  0.05659112  0.07515481 -0.48571856  0.13614891
   0.09966839  0.11968198  0.13456721  0.11096082]
 [ 0.11087483  0.104292    0.09989931 -0.43177009  0.12091445  0.05666207
   0.12156603  0.13936521  0.09501132  0.06301911]
 [ 0.14572514  0.07599145  0.10150195  0.08129988  0.06561162  0.16228628
  -0.1209137   0.12373146  0.07167487  0.09853964]]


In [6]:
a = np.random.rand(10, 10)
np.exp(a)/np.exp(a).mean(axis=1,)

array([[1.58902126, 0.84982927, 0.69148555, 0.79715837, 0.86140474,
        0.8477196 , 1.16029094, 1.37419002, 0.75400499, 0.89468823],
       [1.52677724, 0.70761547, 0.79329229, 0.95238444, 1.56455597,
        1.71611036, 0.66965285, 1.67495572, 1.56626208, 1.02444736],
       [1.47392092, 0.57078069, 1.0675779 , 1.10062624, 1.34603732,
        1.80266524, 0.72181436, 1.3427381 , 1.29061461, 0.7701871 ],
       [0.82875284, 1.02353493, 1.14056256, 1.06572518, 1.71726292,
        1.44589808, 1.58676943, 0.94848372, 0.7140915 , 0.86007213],
       [0.73105419, 0.65518691, 0.59307837, 1.15775071, 1.47799324,
        0.9675821 , 0.68992062, 0.72303462, 1.084457  , 0.54069814],
       [0.71945183, 0.84364523, 1.00120862, 0.64481229, 0.78763621,
        1.10130791, 1.61132444, 0.96329192, 0.66052967, 0.53621996],
       [0.82698881, 1.19407872, 0.68600773, 1.05863821, 0.70824527,
        0.92217503, 0.79053727, 1.61909287, 0.74721811, 0.67159098],
       [1.48766046, 0.58227073, 0.6166551