In [None]:
# coding: utf-8
import os
import sys

# Jupyter Notebook では __file__ が使えないのでカレントディレクトリからの相対指定を使う
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(project_root)
import numpy as np
import pickle
from dataset.mnist import load_mnist  # ← これで解決する


# シード値を固定（再現性確保）
np.random.seed(42)


def sigmoid(x):
    return 1 / (1 + np.exp(-x))


# softmax 関数（ローカル定義）
def softmax(x):
    if x.ndim == 2:
        x = x - np.max(x, axis=1, keepdims=True)
        x = np.exp(x)
        x /= np.sum(x, axis=1, keepdims=True)
    else:
        x = x - np.max(x)
        x = np.exp(x) / np.sum(np.exp(x))
    return x


def get_data():
    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True)
    return x_test, t_test


def init_network():
    with open("sample_weight.pkl", "rb") as f:
        network = pickle.load(f)
    return network


def predict(network, x):
    W1, W2, W3 = network["W1"], network["W2"], network["W3"]
    b1, b2, b3 = network["b1"], network["b2"], network["b3"]

    a1 = np.dot(x, W1) + b1
    z1 = sigmoid(a1)
    a2 = np.dot(z1, W2) + b2
    z2 = sigmoid(a2)
    a3 = np.dot(z2, W3) + b3
    y = softmax(a3)
    return y


# バッチサイズの候補
batch_sizes = [1, 10, 100, 1000, len(range(10000))]

x, t = get_data()
network = init_network()

print(f"{'BatchSize':>10} | {'Accuracy':>8} | {'Time (s)':>8}")
print("-" * 34)

import matplotlib.pyplot as plt

# バッチサイズ・精度・時間のリストを初期化
batch_list = []
accuracy_list = []
time_list = []

for batch_size in batch_sizes:
    start_time = time.time()
    accuracy_cnt = 0

    for i in range(0, len(x), batch_size):
        x_batch = x[i : i + batch_size]
        y_batch = predict(network, x_batch)
        p = np.argmax(y_batch, axis=1)
        accuracy_cnt += np.sum(p == t[i : i + batch_size])

    elapsed = time.time() - start_time
    accuracy = accuracy_cnt / len(x)

    print(f"{batch_size:>10} | {accuracy:.4f}  | {elapsed:8.4f}")
    batch_list.append(batch_size)
    accuracy_list.append(accuracy)
    time_list.append(elapsed)

# グラフ表示
fig, ax1 = plt.subplots()

ax1.set_xlabel("Batch Size")
ax1.set_ylabel("Accuracy", color="tab:red")
ax1.plot(batch_list, accuracy_list, marker="o", color="tab:red", label="Accuracy")
ax1.tick_params(axis="y", labelcolor="tab:red")
ax1.set_xscale("log")
ax1.set_xticks(batch_list)
ax1.get_xaxis().set_major_formatter(plt.ScalarFormatter())

ax2 = ax1.twinx()
ax2.set_ylabel("Time (s)", color="tab:blue")
ax2.plot(
    batch_list,
    time_list,
    marker="s",
    linestyle="--",
    color="tab:blue",
    label="Time (s)",
)
ax2.tick_params(axis="y", labelcolor="tab:blue")

fig.tight_layout()
plt.title("Batch Size vs Accuracy and Time")
plt.savefig("batch_vs_time_and_accuracy.png")
plt.show()

NameError: name '__file__' is not defined