In [1]:
import os
import urllib.request
import tarfile
import pickle
import numpy as np
from common.ResNet20 import ResNet20
from common.functions import softmax

def download_cifar100(save_path='cifar-100-python'):
    if os.path.exists(save_path):
        print("CIFAR-100 이미 존재")
        return
    url = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
    filename = 'cifar-100-python.tar.gz'
    print("CIFAR-100 다운로드 중...")
    urllib.request.urlretrieve(url, filename)
    with tarfile.open(filename, 'r:gz') as tar:
        tar.extractall()
    os.remove(filename)
    print("다운로드 완료")

def load_batch(filepath):
    with open(filepath, 'rb') as f:
        data_dict = pickle.load(f, encoding='bytes')
    data = data_dict[b'data']
    fine_labels = np.array(data_dict[b'fine_labels'])
    coarse_labels = np.array(data_dict[b'coarse_labels'])
    data = data.reshape(-1, 3, 32, 32).astype(np.float32) / 255.0
    return data, fine_labels, coarse_labels

In [2]:
def load_cifar100_full(data_dir='./cifar-100-python', valid_ratio=0.1):
    x_train_all, y_train_fine_all, y_train_coarse_all = load_batch(os.path.join(data_dir, 'train'))
    x_test, y_test_fine, y_test_coarse = load_batch(os.path.join(data_dir, 'test'))
    num_total = x_train_all.shape[0]
    num_valid = int(num_total * valid_ratio)
    x_valid = x_train_all[:num_valid]
    t_valid_fine = y_train_fine_all[:num_valid]
    t_valid_coarse = y_train_coarse_all[:num_valid]
    x_train = x_train_all[num_valid:]
    t_train_fine = y_train_fine_all[num_valid:]
    t_train_coarse = y_train_coarse_all[num_valid:]
    return (x_train, x_valid, x_test,
            t_train_coarse, t_valid_coarse, y_test_coarse,
            t_train_fine, t_valid_fine, y_test_fine)

In [3]:
def restore_model_parameters(model, model_state):
    model.conv1.W = model_state['conv1_W']
    model.conv1.b = model_state['conv1_b']
    model.fc.W = model_state['fc_W']
    model.fc.b = model_state['fc_b']
    idx = 0
    for block in model.layer1 + model.layer2 + model.layer3:
        for attr in ['conv1', 'conv2', 'shortcut']:
            if hasattr(block, attr):
                conv = getattr(block, attr)
                conv.W = model_state[f'{idx}_W']
                conv.b = model_state[f'{idx}_b']
                idx += 1

def restore_bn_params(model, state):
    bn_count = 0
    for block in model.layer1 + model.layer2 + model.layer3:
        for attr in ['bn1', 'bn2']:
            bn = getattr(block, attr)
            bn.gamma = state[f'{bn_count}_gamma']
            bn.beta = state[f'{bn_count}_beta']
            bn.running_mean = state[f'{bn_count}_running_mean']
            bn.running_var = state[f'{bn_count}_running_var']
            bn_count += 1
        if hasattr(block, 'bn_shortcut'):
            bn = block.bn_shortcut
            bn.gamma = state[f'{bn_count}_gamma']
            bn.beta = state[f'{bn_count}_beta']
            bn.running_mean = state[f'{bn_count}_running_mean']
            bn.running_var = state[f'{bn_count}_running_var']
            bn_count += 1
    bn = model.bn1
    bn.gamma = state[f'{bn_count}_gamma']
    bn.beta = state[f'{bn_count}_beta']
    bn.running_mean = state[f'{bn_count}_running_mean']
    bn.running_var = state[f'{bn_count}_running_var']

In [4]:
def evaluate_model(model, x, y_true):
    batch_size = 100
    preds = []
    for i in range(0, x.shape[0], batch_size):
        x_batch = x[i:i+batch_size]
        logits = model.predict(x_batch)
        probs = softmax(logits)
        y_pred = np.argmax(probs, axis=1)
        preds.append(y_pred)
    preds = np.concatenate(preds)
    acc = np.sum(preds == y_true) / len(y_true)
    return preds, acc

In [5]:
# CIFAR-100 다운로드 및 데이터 로드
download_cifar100()
(x_train, x_valid, x_test,
 t_train_coarse, t_valid_coarse, t_test_coarse,
 t_train_fine, t_valid_fine, t_test_fine) = load_cifar100_full()

CIFAR-100 이미 존재


In [6]:
# 모델 평가
model_files = {
    "crop": "crop_epoch_10.pkl",
    "crop+flip": "crop+flip_epoch_10.pkl",
    "crop+flip+cutout": "crop+flip+cutout_epoch_10.pkl"
}

for name, file in model_files.items():
    model = ResNet20()
    with open(file, "rb") as f:
        checkpoint = pickle.load(f)
        model_state = checkpoint["model"]
    restore_bn_params(model, model_state)
    restore_model_parameters(model, model_state)
    _, acc = evaluate_model(model, x_test, t_test_fine)
    print(f"{name} model - [Fine Label] Valid Accuracy: {acc:.4f}")

crop [Fine Label] Valid Accuracy: 0.4826
crop+flip [Fine Label] Valid Accuracy: 0.5061
crop+flip+cutout [Fine Label] Valid Accuracy: 0.4944


In [8]:
def predict_softmax(model, x):
    batch_size = 100
    probs = []
    for i in range(0, x.shape[0], batch_size):
        x_batch = x[i:i+batch_size]
        logits = model.predict(x_batch)
        prob = softmax(logits)
        probs.append(prob)
    return np.vstack(probs)
    
# 앙상블 softmax 예측
probs_list = []

for name, file in model_files.items():
    model = ResNet20()
    with open(file, "rb") as f:
        checkpoint = pickle.load(f)
        model_state = checkpoint["model"]
    restore_bn_params(model, model_state)
    restore_model_parameters(model, model_state)
    probs = predict_softmax(model, x_test)
    probs_list.append(probs)

# softmax 평균 앙상블
ensemble_probs = np.mean(probs_list, axis=0)
ensemble_preds = np.argmax(ensemble_probs, axis=1)
ensemble_acc = np.sum(ensemble_preds == t_test_fine) / len(t_test_fine)

print(f"[Fine Label] Ensemble Valid Accuracy: {ensemble_acc:.4f}")

[Fine Label] Ensemble Valid Accuracy: 0.5750


In [9]:
# 1. 가중 평균 앙상블 (Weighted Average Ensemble)

weights = {
    "crop": 0.2,
    "crop+flip": 0.5,
    "crop+flip+cutout": 0.3
}

probs_list = []
for name, file in model_files.items():
    model = ResNet20()
    with open(file, "rb") as f:
        checkpoint = pickle.load(f)
        model_state = checkpoint["model"]
    restore_bn_params(model, model_state)
    restore_model_parameters(model, model_state)
    probs = predict_softmax(model, x_test)
    probs_list.append(weights[name] * probs)

ensemble_probs = np.sum(probs_list, axis=0)
ensemble_preds = np.argmax(ensemble_probs, axis=1)
ensemble_acc = np.sum(ensemble_preds == t_test_fine) / len(t_test_fine)
print(f"[Fine Label] Weighted Ensemble Accuracy: {ensemble_acc:.4f}")

[Fine Label] Weighted Ensemble Accuracy: 0.5688


In [10]:
# 2. 로그 소프트맥스 평균 (Logit Averaging Ensemble)

def predict_logits(model, x):
    batch_size = 100
    logits_all = []
    for i in range(0, x.shape[0], batch_size):
        x_batch = x[i:i+batch_size]
        logits = model.predict(x_batch)
        logits_all.append(logits)
    return np.vstack(logits_all)

logits_list = []
for name, file in model_files.items():
    model = ResNet20()
    with open(file, "rb") as f:
        checkpoint = pickle.load(f)
        model_state = checkpoint["model"]
    restore_bn_params(model, model_state)
    restore_model_parameters(model, model_state)
    logits = predict_logits(model, x_test)
    logits_list.append(logits)

ensemble_logits = np.mean(logits_list, axis=0)
ensemble_probs = softmax(ensemble_logits)
ensemble_preds = np.argmax(ensemble_probs, axis=1)
ensemble_acc = np.sum(ensemble_preds == t_test_fine) / len(t_test_fine)
print(f"[Fine Label] Logit Averaging Ensemble Accuracy: {ensemble_acc:.4f}")

[Fine Label] Logit Averaging Ensemble Accuracy: 0.5796
