In [10]:
import os
import sys
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

In [11]:
ROOT = os.path.abspath('../')
if ROOT not in sys.path:
    sys.path.append(ROOT)

In [12]:
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

In [13]:
from models.backbone import ResNetBackbone
from models.head import LinearClassifier

In [14]:
from datasets.corrupted import CIFAR10C, CORRUPTIONS
from datasets.transforms import TestAugment

In [15]:
encoder = ResNetBackbone('resnet50', data='cifar10')
classifier = LinearClassifier(encoder.out_channels, num_classes=10)

In [22]:
CKPT = "../checkpoints/cifar10/linear_clapp/resnet50/2020-11-10_13:46:37/ckpt.100.pth.tar"
CKPT = "../checkpoints/cifar10/linear_moco/resnet50/2020-11-11_11:35:22/ckpt.best.pth.tar"
assert os.path.isfile(CKPT)

In [23]:
encoder.load_weights_from_checkpoint(CKPT, key='backbone')
classifier.load_weights_from_checkpoint(CKPT, key='classifier')

In [24]:
net = nn.Sequential(encoder, classifier)
net.eval();

In [25]:
%%time

for c in CORRUPTIONS:
    test_trans = TestAugment(size=32, data='cifar10')
    test_set = CIFAR10C('../data/cifar10-c', corruption=c, transform=test_trans)
    test_loader = DataLoader(test_set, batch_size=256, shuffle=False, num_workers=4, pin_memory=False)
    
    num_correct = 0
    
    with torch.no_grad():
        for _, batch in enumerate(test_loader):
            
            net.cuda()
            x, y = batch['x'].cuda(), batch['y'].cuda()
            logits = net(x)
            _, pred = logits.data.max(dim=1)
            num_correct += pred.eq(y.data).sum().item()
    
    acc = num_correct / len(test_set)
    print(f"CIFAR10-C ({c}): {acc*100:.2f}%", end='\n')

CIFAR10-C (gaussian_noise): 64.10%
CIFAR10-C (shot_noise): 71.41%
CIFAR10-C (impulse_noise): 64.74%
CIFAR10-C (defocus_blur): 86.95%
CIFAR10-C (glass_blur): 65.98%
CIFAR10-C (motion_blur): 81.13%
CIFAR10-C (zoom_blur): 87.51%
CIFAR10-C (snow): 82.37%
CIFAR10-C (frost): 82.97%
CIFAR10-C (fog): 80.06%
CIFAR10-C (brightness): 91.01%
CIFAR10-C (contrast): 86.20%
CIFAR10-C (elastic_transform): 84.50%
CIFAR10-C (pixelate): 85.39%
CIFAR10-C (jpeg_compression): 84.64%
CPU times: user 4min 22s, sys: 4.01 s, total: 4min 26s
Wall time: 4min 27s


In [21]:
%%time

for c in CORRUPTIONS:
    test_trans = TestAugment(size=32, data='cifar10')
    test_set = CIFAR10C('../data/cifar10-c', corruption=c, transform=test_trans)
    test_loader = DataLoader(test_set, batch_size=256, shuffle=False, num_workers=4, pin_memory=False)
    
    num_correct = 0
    
    with torch.no_grad():
        for _, batch in enumerate(test_loader):
            
            net.cuda()
            x, y = batch['x'].cuda(), batch['y'].cuda()
            logits = net(x)
            _, pred = logits.data.max(dim=1)
            num_correct += pred.eq(y.data).sum().item()
    
    acc = num_correct / len(test_set)
    print(f"CIFAR10-C ({c}): {acc*100:.2f}%", end='\n')

CIFAR10-C (gaussian_noise): 65.88%
CIFAR10-C (shot_noise): 72.28%
CIFAR10-C (impulse_noise): 65.56%
CIFAR10-C (defocus_blur): 87.56%
CIFAR10-C (glass_blur): 63.85%
CIFAR10-C (motion_blur): 79.77%
CIFAR10-C (zoom_blur): 87.70%
CIFAR10-C (snow): 82.35%
CIFAR10-C (frost): 83.82%
CIFAR10-C (fog): 80.16%
CIFAR10-C (brightness): 91.37%
CIFAR10-C (contrast): 88.26%
CIFAR10-C (elastic_transform): 85.05%
CIFAR10-C (pixelate): 84.23%
CIFAR10-C (jpeg_compression): 85.16%
CPU times: user 4min 23s, sys: 4.17 s, total: 4min 27s
Wall time: 4min 28s
