**Load Module**

In [1]:
# Utils
import numpy as np
import random

# Torch
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler, Subset

from torchvision import transforms, datasets

**Seed & Device Setting**

In [2]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # GPU 여러 개 사용할 경우 사용
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False # 완벽한 재현성을 위해선 False로 하는 게 맞지만 성능의 저하를 일으킬 수 있음
    np.random.seed(seed)
    random.seed(seed)

seed = 42
set_seed(seed)

# 디바이스 설정
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

**Set Hyperparameters**

In [3]:
config = {
    "batch_size" : 128,
    "pin_memory" : True
}

In [4]:
fine_to_superclass = {
    4: 0, 30: 0, 55: 0, 72: 0, 95: 0,       # aquatic mammals
    1: 1, 32: 1, 67: 1, 73: 1, 91: 1,       # fish
    54: 2, 62: 2, 70: 2, 82: 2, 92: 2,      # flowers
    9: 3, 10: 3, 16: 3, 28: 3, 61: 3,       # food containers
    0: 4, 51: 4, 53: 4, 57: 4, 83: 4,       # fruit and vegetables
    22: 5, 39: 5, 40: 5, 86: 5, 87: 5,      # household electrical devices
    5: 6, 20: 6, 25: 6, 84: 6, 94: 6,       # household furniture
    6: 7, 7: 7, 14: 7, 18: 7, 24: 7,        # insects
    3: 8, 42: 8, 43: 8, 88: 8, 97: 8,       # large carnivores
    12: 9, 17: 9, 37: 9, 68: 9, 76: 9,      # large man-made outdoor things
    23: 10, 33: 10, 49: 10, 60: 10, 71: 10, # large natural outdoor scenes
    15: 11, 19: 11, 21: 11, 31: 11, 38: 11, # large omnivores and herbivores
    34: 12, 63: 12, 64: 12, 66: 12, 75: 12, # medium-sized mammals
    26: 13, 45: 13, 77: 13, 79: 13, 99: 13, # non-insect invertebrates
    2: 14, 11: 14, 35: 14, 46: 14, 98: 14,  # people
    27: 15, 29: 15, 44: 15, 78: 15, 93: 15, # reptiles
    36: 16, 50: 16, 65: 16, 74: 16, 80: 16, # small mammals
    47: 17, 52: 17, 56: 17, 59: 17, 96: 17, # trees
    8: 18, 13: 18, 48: 18, 58: 18, 90: 18,  # vehicles 1
    41: 19, 69: 19, 81: 19, 85: 19, 89: 19  # vehicles 2
}

# Data Preprocessing

**Data Augementation (수정 구간)**

In [5]:
CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
])

**Define Dataloader**

In [8]:
test_data = datasets.CIFAR100(root='./data', train=False, download=True, transform=test_transform)
test_loader = DataLoader(test_data, batch_size=config["batch_size"], shuffle=False, num_workers=4, pin_memory=config["pin_memory"])

**Ensemble Model Test**

In [23]:
def ensemble_accuracy(net1, net2, test_loader, device):
    net2.eval()
    top1_correct = 0
    top5_correct = 0
    total = 0

    superclass_correct = 0
    superclass_total = 0

    with torch.no_grad():
        for _, (images, labels) in enumerate(test_loader):

            images, labels = images.to(device), labels.to(device)
            outputs1 = net1(images)
            outputs2 = net2(images)

            outputs = (outputs1*0.3 + outputs2*0.7)

            _, preds = torch.max(outputs, 1)
            top1_correct += torch.sum(preds == labels).item()

            _, top5_preds = outputs.topk(5, 1, True, True)
            top5_correct += torch.sum(top5_preds.eq(labels.view(-1, 1).expand_as(top5_preds))).item()

            total += labels.size(0)

            # Superclass accuracy
            super_preds = torch.tensor([fine_to_superclass[p.item()] for p in preds], dtype=torch.long)
            super_labels = torch.tensor([fine_to_superclass[t.item()] for t in labels], dtype=torch.long)
            superclass_correct += torch.sum(super_preds == super_labels).item()
            superclass_total += super_labels.size(0)

    # 세부 클래스 및 슈퍼 클래스 정확도 계산
    top1_acc = top1_correct / total
    top5_acc = top5_correct / total
    superclass_acc = superclass_correct / superclass_total
    accuracy = [top1_acc, top5_acc, superclass_acc]

    return accuracy

In [24]:
from models import resnet, shake_pyramidnet

# path 수정 필요
net1_path = "./runs/ensemble/resnet18_cutmix_50000-245-best.pth"
net2_path = "./runs/ensemble/shake_pyramidnet110_270-210-best.pth"
net1 = resnet.resnet18().to(device)
net2 = shake_pyramidnet.ShakePyramidNet(110,270,100).to(device)
net1.load_state_dict(torch.load(net1_path))
net2.load_state_dict(torch.load(net2_path))
net1.eval()
net2.eval()

acc = ensemble_accuracy(net1, net2, test_loader, device)
print("Top-1 accuracy: {:.4f}".format(acc[0]))
print("Top-5 accuracy: {:.4f}".format(acc[1]))
print("Super-Class accuracy: {:.4f}".format(acc[2]))
print("Total accuracy: {:.4f}".format(sum(acc)))

  net1.load_state_dict(torch.load(net1_path))
  net2.load_state_dict(torch.load(net2_path))


Top-1 accuracy: 0.8531
Top-5 accuracy: 0.9747
Super-Class accuracy: 0.9194
Total accuracy: 2.7472
