In [1]:
# Use_API2.ipynb

import math
import random
import time
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms

from AutoBatch import auto_batch

In [2]:
# 모델 정의
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(28 * 28, 10)

    def forward(self, x):
        return self.fc(x.view(x.size(0), -1))


In [3]:
# 평가 함수
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for xb, yb in loader:
            out = model(xb)
            pred = torch.argmax(out, dim=1)
            correct += (pred == yb).sum().item()
            total += yb.size(0)
    return correct / total

In [4]:
def run_experiment(N_values, fixed_batch_sizes):
    transform = transforms.ToTensor()
    full = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    test = datasets.MNIST(root="./data", train=False, transform=transform)

    results = []

    for N in N_values:
        indices = random.sample(range(len(full)), min(N, len(full)))
        subset = Subset(full, indices)

        B_auto_default = auto_batch(N)
        B_auto_mode = auto_batch(N, 9)

        batches = list(set(fixed_batch_sizes + [B_auto_default, B_auto_mode]))

        for B in batches:
            model = SimpleNet()
            opt = optim.Adam(model.parameters(), lr=0.001)
            loss_fn = nn.CrossEntropyLoss()
            train_loader = DataLoader(subset, batch_size=B, shuffle=True)
            test_loader = DataLoader(test, batch_size=256)

            start = time.time()
            model.train()
            for xb, yb in train_loader:
                out = model(xb)
                loss = loss_fn(out, yb)
                opt.zero_grad()
                loss.backward()
                opt.step()
            elapsed = time.time() - start

            acc = evaluate(model, test_loader)
            results.append({
                "N": N,
                "B": B,
                "is_auto_default": B == B_auto_default,
                "is_auto_mode": B == B_auto_mode,
                "accuracy": acc,
                "train_time": elapsed,
                "efficiency": acc / elapsed
            })

    return pd.DataFrame(results)

In [None]:
# 실험 실행
N_values = [60000, 120000, 240000]
batch_sizes = [64, 128, 256, 512, 1024]
df = run_experiment(N_values, batch_sizes)

In [None]:
# 시각화
sns.set(style="whitegrid")
for N in N_values:
    sub = df[df.N == N].copy()
    def label_mode(row):
        label = []
        if row["is_auto_default"]: label.append("auto(0.5)")
        if row["is_auto_mode"]: label.append("auto(9)")
        return "/".join(label) if label else "manual"

    sub["mode"] = sub.apply(label_mode, axis=1)

    plt.figure(figsize=(8, 4))
    sns.barplot(data=sub, x="B", y="efficiency", hue="mode")
    plt.title(f"Efficiency by Batch Size (N={N})")
    plt.ylabel("Accuracy / Train Time")
    plt.xlabel("Batch Size")
    plt.ylim(0, sub["efficiency"].max() * 1.1)
    plt.legend(title="Batch Mode")
    plt.tight_layout()
    plt.show()

display(df.sort_values(by=["N", "efficiency"], ascending=[True, False]))

<1차 결과_a = 0.9>

N	B	is_auto_default	is_auto_speed	accuracy	train_time	efficiency
1	10000	256	False	False	0.7962	1.607250	0.495380
2	10000	64	False	False	0.8642	1.773134	0.487386
3	10000	128	True	True	0.8340	1.775550	0.469714
0	10000	32	False	False	0.8785	1.957632	0.448756
7	30000	128	True	True	0.8850	3.392746	0.260851
5	30000	256	False	False	0.8622	3.335811	0.258468
6	30000	64	False	False	0.9001	3.538790	0.254352
4	30000	32	False	False	0.9068	4.866295	0.186343
11	60000	128	False	False	0.9020	7.293218	0.123677
9	60000	256	True	True	0.8910	7.737632	0.115152
8	60000	32	False	False	0.9166	8.281725	0.110677
10	60000	64	False	False	0.9096	8.822620	0.103099

<2차 결과 - a = 0.9>

	N	B	is_auto_default	is_auto_mode	accuracy	train_time	efficiency
3	10000	128	True	True	0.8379	1.238122	0.676751
2	10000	64	False	False	0.8630	1.277381	0.675601
1	10000	256	False	False	0.8043	1.241014	0.648099
0	10000	32	False	False	0.8808	1.442739	0.610505
5	30000	256	False	False	0.8635	3.449372	0.250335
7	30000	128	True	True	0.8826	3.611814	0.244365
6	30000	64	False	False	0.8989	3.923373	0.229114
4	30000	32	False	False	0.9066	4.250427	0.213296
11	60000	128	False	False	0.9043	6.823071	0.132536
9	60000	256	True	True	0.8910	6.765995	0.131688
10	60000	64	False	False	0.9107	7.243638	0.125724
8	60000	32	False	False	0.9147	7.856381	0.116428

<