In [1]:
from matplotlib import pyplot as plt

import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

In [2]:
# 디바이스 설정
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
])

In [4]:
test_data = datasets.CIFAR100(root='./data', train=False, download=True, transform=test_transform)
test_loader = DataLoader(test_data, batch_size=128, shuffle=False, num_workers=2)

Files already downloaded and verified


In [7]:
# 이 부분만 수정!!
from models import resnet
net = resnet.resnet18()
net.to(device)
net.load_state_dict(torch.load("./runs/resnet_18_base/savepoints/Sunday_06_October_2024_13h_21m_09s/ResNet_18_base-238-best.pth"))

  net.load_state_dict(torch.load("./runs/resnet_18_base/savepoints/Sunday_06_October_2024_13h_21m_09s/ResNet_18_base-238-best.pth"))


<All keys matched successfully>

In [8]:
net.eval()

ResNet(
  (conv1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (conv2_x): Sequential(
    (0): BasicBlock(
      (residual_function): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (residual_function): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_

##### **fine_to_superclass**

In [9]:
fine_to_superclass = {
    # aquatic mammals
    4: 0, 30: 0, 55: 0, 72: 0, 95: 0,
    
    # fish
    1: 1, 32: 1, 67: 1, 73: 1, 91: 1,
    
    # flowers
    54: 2, 62: 2, 70: 2, 82: 2, 92: 2,
    
    # food containers
    9: 3, 10: 3, 16: 3, 28: 3, 61: 3,
    
    # fruit and vegetables
    0: 4, 51: 4, 53: 4, 57: 4, 83: 4,
    
    # household electrical devices
    22: 5, 39: 5, 40: 5, 86: 5, 87: 5,
    
    # household furniture
    5: 6, 20: 6, 25: 6, 84: 6, 94: 6,
    
    # insects
    6: 7, 7: 7, 14: 7, 18: 7, 24: 7,
    
    # large carnivores
    3: 8, 42: 8, 43: 8, 88: 8, 97: 8,
    
    # large man-made outdoor things
    12: 9, 17: 9, 37: 9, 68: 9, 76: 9,
    
    # large natural outdoor scenes
    23: 10, 33: 10, 49: 10, 60: 10, 71: 10,
    
    # large omnivores and herbivores
    15: 11, 19: 11, 21: 11, 31: 11, 38: 11,
    
    # medium-sized mammals
    34: 12, 63: 12, 64: 12, 66: 12, 75: 12,
    
    # non-insect invertebrates
    26: 13, 45: 13, 77: 13, 79: 13, 99: 13,
    
    # people
    2: 14, 11: 14, 35: 14, 46: 14, 98: 14,
    
    # reptiles
    27: 15, 29: 15, 44: 15, 78: 15, 93: 15,
    
    # small mammals
    36: 16, 50: 16, 65: 16, 74: 16, 80: 16,
    
    # trees
    47: 17, 52: 17, 56: 17, 59: 17, 96: 17,
    
    # vehicles 1
    8: 18, 13: 18, 48: 18, 58: 18, 90: 18,
    
    # vehicles 2
    41: 19, 69: 19, 81: 19, 85: 19, 89: 19
}

##### **Get Accuracy**

In [10]:
def all_accuracy(net, test_loader, device):
    correct_1_fine = 0.0  # 세부 클래스 top-1 정확도
    correct_5_fine = 0.0  # 세부 클래스 top-5 정확도
    correct_1_super = 0.0  # 슈퍼 클래스 top-1 정확도
    total = 0

    with torch.no_grad():
        for n_iter, (image, label) in enumerate(test_loader):
            print("iteration: {}\ttotal {} iterations".format(n_iter + 1, len(test_loader)))
    
            image = image.cuda()
            label = label.cuda()
    
            output = net(image)
            _, pred = output.topk(5, 1, largest=True, sorted=True)
    
            label = label.view(label.size(0), -1).expand_as(pred)
            correct = pred.eq(label).float()
    
            # 세부 클래스 top-5 정확도
            correct_5_fine += correct[:, :5].sum()
            # 세부 클래스 top-1 정확도
            correct_1_fine += correct[:, :1].sum()
    
            # 슈퍼 클래스 변환 (contiguous()로 메모리 연속성 확보 후 view() 사용)
            target_super = torch.tensor([fine_to_superclass[t.item()] for t in label.contiguous().view(-1)], device=device)
            pred_super = torch.tensor([fine_to_superclass[p.item()] for p in pred.contiguous().view(-1)], device=device).view_as(pred)
    
            # target_super를 pred_super의 크기로 확장
            target_super = target_super.view(label.size(0), 5).expand_as(pred_super)
    
            # 슈퍼 클래스 top-1 정확도만 계산
            correct_super = pred_super.eq(target_super).float()
            correct_1_super += correct_super[:, :1].sum()
    
            total += label.size(0)
    
    # 세부 클래스 및 슈퍼 클래스 정확도 계산
    top1_acc_fine = correct_1_fine / total
    top5_acc_fine = correct_5_fine / total
    top1_acc_super = correct_1_super / total

    return top1_acc_fine, top5_acc_fine, top1_acc_super

In [11]:
acc = all_accuracy(net, test_loader, device)

iteration: 1	total 79 iterations
iteration: 2	total 79 iterations
iteration: 3	total 79 iterations
iteration: 4	total 79 iterations
iteration: 5	total 79 iterations
iteration: 6	total 79 iterations
iteration: 7	total 79 iterations
iteration: 8	total 79 iterations
iteration: 9	total 79 iterations
iteration: 10	total 79 iterations
iteration: 11	total 79 iterations
iteration: 12	total 79 iterations
iteration: 13	total 79 iterations
iteration: 14	total 79 iterations
iteration: 15	total 79 iterations
iteration: 16	total 79 iterations
iteration: 17	total 79 iterations
iteration: 18	total 79 iterations
iteration: 19	total 79 iterations
iteration: 20	total 79 iterations
iteration: 21	total 79 iterations
iteration: 22	total 79 iterations
iteration: 23	total 79 iterations
iteration: 24	total 79 iterations
iteration: 25	total 79 iterations
iteration: 26	total 79 iterations
iteration: 27	total 79 iterations
iteration: 28	total 79 iterations
iteration: 29	total 79 iterations
iteration: 30	total 79 

In [13]:
print("Top 1 Fine Class accuracy: {:.4f}".format(acc[0]))
print("Top 5 Fine Class accuracy: {:.4f}".format(acc[1]))
print("Top 1 Super Class accuracy: {:.4f}".format(acc[2]))

Top 1 Fine Class accuracy: 0.7581
Top 5 Fine Class accuracy: 0.9345
Top 1 Super Class accuracy: 0.8479
