# **신경망의 추론 처리**

MNIST 데이터 셋을 가지고 추론을 수행하는 신경망 구현

In [1]:
import sys, os
sys.path.append(os.pardir)
from dataset.mnist import load_mnist
import pickle
import numpy as np

ModuleNotFoundError: ignored

In [None]:
def sigmoid(x):
    return 1 / (1+np.exp(-x)

In [None]:
def softmax(a):
    exp_a = np.exp(a)
    sum_exp_a = np.sum(exp_a)
    y = exp_a / sum_exp_a
    return y

In [None]:
def get_data():
    (x_train, t_train), (x_test, t_test) = \
        load_mnist(normalize=True, flatten=True, one_hot_label=False)
    return x_test, t_test

In [None]:
def init_network():
    with open("dataset/sample_weight.pkl", 'rb') as f:
        network = pickle.load(f)
        
    return network

In [None]:
def predict(network, x):
    W1, W2, W3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']
    
    a1 = np.dot(x, W1) + b1
    z1 = sigmoid(a1)
    a2 = np.dot(z1, W2) + b2
    z2 = sigmoid(a2)
    a3 = np.dot(z2, W3) + b3
    y = softmax(a3)
    
    return y

In [None]:
x, t = get_data()

In [None]:
x.shape, t.shape

정확도 평가

In [None]:
network = init_network()

In [None]:
#network['W1'].shape, network['b1'].shape

In [None]:
y = predict(network, x[0])

In [None]:
y.shape

In [None]:
y

In [None]:
np.max(y)

In [None]:
np.argmax(y)

In [None]:
title = 'Ans #{}, Pred.#{}, {:.4f}'.format(t[0], np.argmax(y), np.max(y))
print(title)

In [None]:
accuracy_cnt = 0
for i in range(len(x)):
    y = predict(network, x[i])
    
    p = np.argmax(y)
    
    if p == t[i]:
        accuracy_cnt += 1

In [None]:
print("Accuracy:" + str(float(accuracy_cnt) / len(x)))

In [None]:
len(x)

**랜덤으로 몇개만 뽑아서 테스트 해보기**

In [None]:
idx_test = np.random.choice(len(x), 1000)

In [None]:
idx_test

In [None]:
x_select = x[idx_test]
y_select = t[idx_test]

In [None]:
x_select.shape

In [None]:
accuracy_cnt = 0
for i in range(len(x_select)):
    y = predict(network, x_select[i])
    
    p = np.argmax(y) 
    
    if p == y_select[i]:
        accuracy_cnt += 1

In [None]:
accuracy_cnt / len(x_select)