<a href="https://colab.research.google.com/github/heewonLEE2/Data-Ai-Colab/blob/main/Mnist_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# A-1. 기본 임포트
import os, random, json, time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

In [None]:
# A-2. 재현성(Seed) 고정
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(42)

In [None]:
# A-3. 디바이스
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
# B-1. 변환(정규화 포함) & 약한 증강(회전)
transform_train = transforms.Compose([
    transforms.RandomRotation(10),    # 약한 데이터 증강 (과적합 억제)
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

transform_eval = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

In [None]:
# B-2. MNIST 다운로드
root = './data'
train_full = datasets.MNIST(root=root, train=True, download=True, transform=transform_train)
test_ds    = datasets.MNIST(root=root, train=False, download=True, transform=transform_eval)

In [None]:
# B-3. Train/Val 분리 (90:10)
val_ratio = 0.1
val_len = int(len(train_full) * val_ratio)
train_len = len(train_full) - val_len
train_ds, val_ds = random_split(train_full, [train_len, val_len], generator=torch.Generator().manual_seed(42))
# 검증셋에는 증강 제거
val_ds.dataset.transform = transform_eval

In [None]:
# B-4. 로더
batch_size = 128
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

In [None]:
len(train_ds), len(val_ds), len(test_ds)

In [None]:
import torch
import torch.nn as nn

class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),  # 28->14

            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),  # 14->7
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64*7*7, 128), nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# 재생성/검증
model = SimpleCNN().to(device)
print(hasattr(model, "forward"))  # True가 떠야 정상
print(model)  # 구조 출력

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.7)

In [None]:
def run_epoch(loader, train=True):
    model.train(train)
    total_loss, correct, total = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        if train:
            optimizer.zero_grad()
        with torch.set_grad_enabled(train):
            logits = model(x)
            loss = criterion(logits, y)
            if train:
                loss.backward()
                optimizer.step()
        total_loss += loss.item() * x.size(0)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += x.size(0)
    return total_loss/total, correct/total

def evaluate(loader):
    return run_epoch(loader, train=False)

In [None]:
# EarlyStopping 설정
best_val_acc = 0.0
patience, patience_cnt = 3, 0
num_epochs = 12
history = {'train_loss':[], 'train_acc':[], 'val_loss':[], 'val_acc':[]}

os.makedirs('results', exist_ok=True)

In [None]:
for epoch in range(1, num_epochs+1):
    t0 = time.time()
    tr_loss, tr_acc = run_epoch(train_loader, train=True)
    val_loss, val_acc = evaluate(val_loader)
    scheduler.step()

    history['train_loss'].append(tr_loss)
    history['train_acc'].append(tr_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)

    print(f"[{epoch:02d}] train loss {tr_loss:.4f} acc {tr_acc:.4f} | val loss {val_loss:.4f} acc {val_acc:.4f} | {time.time()-t0:.1f}s")

    # 베스트 모델 저장
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_cnt = 0
        torch.save(model.state_dict(), 'results/best_model.pt')
    else:
        patience_cnt += 1
        if patience_cnt >= patience:
            print("Early stopping triggered.")
            break

In [None]:
# 베스트 가중치 로드
model.load_state_dict(torch.load('results/best_model.pt', map_location=device))
test_loss, test_acc = evaluate(test_loader)
print(f"Test loss {test_loss:.4f}, Test acc {test_acc:.4f}")

# 혼동행렬
import itertools
from sklearn.metrics import confusion_matrix, classification_report

y_true, y_pred = [], []
model.eval()
with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        pred = logits.argmax(dim=1)
        y_true.extend(y.cpu().numpy().tolist())
        y_pred.extend(pred.cpu().numpy().tolist())

cm = confusion_matrix(y_true, y_pred)
report = classification_report(y_true, y_pred, digits=4)
print(report)

# 혼동행렬 플롯 저장
plt.figure(figsize=(6,6))
plt.imshow(cm, interpolation='nearest')
plt.title('Confusion Matrix (MNIST)')
plt.colorbar()
tick_marks = np.arange(10)
plt.xticks(tick_marks, tick_marks)
plt.yticks(tick_marks, tick_marks)
plt.xlabel('Predicted'); plt.ylabel('True')
plt.tight_layout()
plt.savefig('results/confusion_matrix.png', dpi=150)
plt.close()

# 학습 곡선 저장
plt.figure()
plt.plot(history['train_loss'], label='train_loss')
plt.plot(history['val_loss'], label='val_loss')
plt.legend(); plt.title('Loss'); plt.xlabel('epoch')
plt.savefig('results/loss_curve.png', dpi=150); plt.close()

plt.figure()
plt.plot(history['train_acc'], label='train_acc')
plt.plot(history['val_acc'], label='val_acc')
plt.legend(); plt.title('Accuracy'); plt.xlabel('epoch')
plt.savefig('results/acc_curve.png', dpi=150); plt.close()

# 오분류 샘플 저장
mis_imgs, mis_true, mis_pred = [], [], []
model.eval()
with torch.no_grad():
    for x, y in test_loader:
        x = x.to(device); y = y.to(device)
        p = model(x).argmax(1)
        mask = (p != y)
        if mask.any():
            mis_imgs.append(x[mask][:16].cpu())
            mis_true += y[mask][:16].cpu().tolist()
            mis_pred += p[mask][:16].cpu().tolist()
        if len(mis_true) >= 16:
            break

if mis_imgs:
    grid = torch.cat(mis_imgs, dim=0)[:16]
    grid = (grid * 0.3081 + 0.1307).clamp(0,1)  # 역정규화
    fig, axes = plt.subplots(4,4, figsize=(6,6))
    for i, ax in enumerate(axes.flat):
        ax.imshow(grid[i][0], cmap='gray')
        ax.set_title(f"T:{mis_true[i]} / P:{mis_pred[i]}")
        ax.axis('off')
    plt.tight_layout()
    plt.savefig('results/misclassified.png', dpi=150)
    plt.close()

# 메트릭 JSON 저장
metrics = {
    'best_val_acc': float(best_val_acc),
    'test_acc': float(test_acc),
    'test_loss': float(test_loss),
    'params': sum(p.numel() for p in model.parameters() if p.requires_grad),
    'batch_size': batch_size,
    'epochs': len(history['train_loss']),
    'optimizer': 'Adam',
    'lr': 1e-3,
    'weight_decay': 1e-4,
    'scheduler': 'StepLR(5, gamma=0.7)'
}
with open('results/metrics.json', 'w') as f:
    json.dump(metrics, f, indent=2)
metrics

In [None]:
import csv
exp_row = {
    'timestamp': int(time.time()),
    'seed': 42,
    'epochs': metrics['epochs'],
    'batch_size': batch_size,
    'lr': 1e-3,
    'weight_decay': 1e-4,
    'best_val_acc': metrics['best_val_acc'],
    'test_acc': metrics['test_acc'],
}
csv_path = 'results/experiments.csv'
file_exists = os.path.exists(csv_path)
with open(csv_path, 'a', newline='') as f:
    writer = csv.DictWriter(f, fieldnames=list(exp_row.keys()))
    if not file_exists:
        writer.writeheader()
    writer.writerow(exp_row)
print(f"Appended to {csv_path}")