# 3.6.3 バッチ処理

In [45]:
import sys,os

In [46]:
sys.path.append(os.pardir)  # 親ディレクトリのファイルをインポートするための設定

In [47]:
from dataset.mnist import load_mnist
import numpy as np

In [48]:
def sigmoid(x):
    return 1 / (1 + np.exp(-x))    


In [49]:
def softmax(a):
    c = np.max(a)
    exp_a = np.exp(a - c)
    sum_exp_a = np.sum(exp_a)
    y = exp_a / sum_exp_a
    return y

In [50]:
def get_data():
    (x_train, t_train), (x_test, t_test) = \
        load_mnist(normalize=True, flatten=True, one_hot_label=False)
    return x_test, t_test

In [51]:
import pickle

In [52]:
def init_network():
    with open("sample_weight.pkl", 'rb') as f:
        network = pickle.load(f)
    return network

In [53]:
x, _ = get_data()
network = init_network()
W1, W2, W3 = network['W1'], network['W2'], network['W3']



In [54]:
x.shape

(10000, 784)

In [55]:
x[0].shape

(784,)

In [56]:
W1.shape

(784, 50)

In [57]:
W2.shape

(50, 100)

In [58]:
W3.shape

(100, 10)

|X|W1|W2|W3| -> Y|
|-|--|--|--|-----|
|784 |784x50 |50x100 |100x10 |10|

yの結果は前章参照

In [59]:
def predict(network, x):
    W1, W2, W3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']
    
    a1 = np.dot(x, W1) + b1
    z1 = sigmoid(a1)
    a2 = np.dot(z1, W2) + b2
    z2 = sigmoid(a2)
    a3 = np.dot(z2, W3) + b3
    y = softmax(a3)
    
    return y

In [60]:
x, t = get_data()
network = init_network()

In [64]:
batch_size = 100  # バッチの数
accuracy_cnt = 0 # accuracy : 正確、的確、精密
for i in range(0, len(x), batch_size):
    x_batch = x[i:i+batch_size]
    y_batch = predict(network, x_batch)
    p = np.argmax(y_batch, axis=1) # 最も確率の高い要素のインデックスを取得
    accuracy_cnt += np.sum(p == t[i:i+batch_size])

In [65]:
print("Accuracy:" + str(float(accuracy_cnt) / len(x)))

Accuracy:0.9352


In [68]:
x = np.array([[0.1, 0.8, 0.1], [0.3, 0.1, 0.6], 
             [0.2, 0.5, 0.3], [0.8,0.1,0.1]])
y = np.argmax(x, axis=1)
print(y)

[1 2 1 0]


In [69]:
print(np.argmax(x,axis=0))

[3 0 1]


In [70]:
x

array([[ 0.1,  0.8,  0.1],
       [ 0.3,  0.1,  0.6],
       [ 0.2,  0.5,  0.3],
       [ 0.8,  0.1,  0.1]])

In [71]:
x.shape

(4, 3)

In [42]:
y = np.array([1, 2, 1, 0])
t = np.array([1, 2, 0, 0])
print(y==t)

[ True  True False  True]


In [43]:
np.sum(y==t)

3