In [1]:
import torch
import torchvision
import timm
import random
import numpy as np
import os
import albumentations as A
import ttach as tta
from albumentations.pytorch import ToTensorV2

class Cifar10SearchDataset(torchvision.datasets.CIFAR10):
    def __init__(self, root="./data/cifar10", train=True, download=True, transform=None):
        super().__init__(root=root, train=train, download=download, transform=transform)
    def __getitem__(self, index):
        image, label = self.data[index], self.targets[index]

        if self.transform is not None:
            transformed = self.transform(image=image)
            image = transformed["image"]

        return image, label
    
def seed_all(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [2]:
seed_all(42)

transform_test = A.Compose([
        A.Resize(224,224),
        ToTensorV2()
])

testset = Cifar10SearchDataset(root='./data', train=False, download=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=16)

# 사용한 모델
model_names = ['resnet18']

models = []    # 모델 리스트
model_path = './models/'
threshold = 9600

# 저장된 모델 가중치 불러오기
for model_name in model_names:
    for model_parameter in sorted(os.listdir(os.path.join(model_path, model_name))):
        # 특정 test accuracy 이상인 model들만 불러오기
        if int(model_parameter[-8:-4]) < threshold:
            continue
        
        model = timm.create_model(model_name, num_classes=10)
        model.load_state_dict(torch.load(os.path.join(model_path, model_name, model_parameter)))
        model.eval()
        model = model.to(device)
        models.append((model, int(model_parameter[-8:-4])))
print(len(models))

18


In [None]:
# tta (test time augmentation) 모듈 사용
tta_transforms = tta.Compose(
    [
        tta.HorizontalFlip(),
        # tta.Add([1,2]),
        # tta.Multiply([1, 1.05, 1.1])
    ]
)

correct = 0
total = 0
seed_all(42)
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device).float(), labels.to(device)
        outputs = torch.zeros(100, 10).to(device)
        for model in models:
            tta_model = tta.ClassificationTTAWrapper(model[0], tta_transforms)
            model_output = tta_model(images)
            outputs += model_output    # 출력 모두 더하기
        _, predicted = torch.max(outputs.data, 1)   # 최종 예측값 뽑아내기
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the ensemble on the 10000 test images: {(100 * correct / total)}')