In [1]:
!pip install git+https://github.com/RobustBench/robustbench.git

Collecting git+https://github.com/RobustBench/robustbench.git
  Cloning https://github.com/RobustBench/robustbench.git to /tmp/pip-req-build-vkft6ou5
  Running command git clone --filter=blob:none --quiet https://github.com/RobustBench/robustbench.git /tmp/pip-req-build-vkft6ou5
  Resolved https://github.com/RobustBench/robustbench.git to commit 46a91f44524133b2cd8f721ec7e73ecb63f17fc8
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting autoattack@ git+https://github.com/fra31/auto-attack.git@a39220048b3c9f2cca9a4d3a54604793c68eca7e#egg=autoattack (from robustbench==1.1)
  Cloning https://github.com/fra31/auto-attack.git (to revision a39220048b3c9f2cca9a4d3a54604793c68eca7e) to /tmp/pip-install-g7gphgae/autoattack_ce7ac0e7a73a43eb921f816db5059706
  Running command git clone --filter=blob:none --quiet https://github.com/fra31/auto-attack.git /tmp/pip-install-g7gphgae/autoattack_ce7ac0e7a73a43eb921f816db5059706
  Running command git rev-parse -q --verify 'sha^a39220048b3c9f2c

In [2]:
from copy import deepcopy

import torch
import torch.nn as nn
import torch.jit


class Tent(nn.Module):
    """Tent adapts a model by entropy minimization during testing.

    Once tented, a model adapts itself by updating on every forward.
    """
    def __init__(self, model, optimizer, steps=1, episodic=False):
        super().__init__()
        self.model = model
        self.optimizer = optimizer
        self.steps = steps
        assert steps > 0, "tent requires >= 1 step(s) to forward and update"
        self.episodic = episodic

        # note: if the model is never reset, like for continual adaptation,
        # then skipping the state copy would save memory
        self.model_state, self.optimizer_state = \
            copy_model_and_optimizer(self.model, self.optimizer)

    def forward(self, x):
        if self.episodic:
            self.reset()

        for _ in range(self.steps):
            outputs = forward_and_adapt(x, self.model, self.optimizer)

        return outputs

    def reset(self):
        if self.model_state is None or self.optimizer_state is None:
            raise Exception("cannot reset without saved model/optimizer state")
        load_model_and_optimizer(self.model, self.optimizer,
                                 self.model_state, self.optimizer_state)


@torch.jit.script
def softmax_entropy(x: torch.Tensor) -> torch.Tensor:
    """Entropy of softmax distribution from logits."""
    return -(x.softmax(1) * x.log_softmax(1)).sum(1)


@torch.enable_grad()  # ensure grads in possible no grad context for testing
def forward_and_adapt(x, model, optimizer):
    """Forward and adapt model on batch of data.

    Measure entropy of the model prediction, take gradients, and update params.
    """
    # forward
    outputs = model(x)
    # adapt
    loss = softmax_entropy(outputs).mean(0)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    return outputs


def collect_params(model):
    """Collect the affine scale + shift parameters from batch norms.

    Walk the model's modules and collect all batch normalization parameters.
    Return the parameters and their names.

    Note: other choices of parameterization are possible!
    """
    params = []
    names = []
    for nm, m in model.named_modules():
        if isinstance(m, nn.BatchNorm2d):
            for np, p in m.named_parameters():
                if np in ['weight', 'bias']:  # weight is scale, bias is shift
                    params.append(p)
                    names.append(f"{nm}.{np}")
    return params, names


def copy_model_and_optimizer(model, optimizer):
    """Copy the model and optimizer states for resetting after adaptation."""
    model_state = deepcopy(model.state_dict())
    optimizer_state = deepcopy(optimizer.state_dict())
    return model_state, optimizer_state


def load_model_and_optimizer(model, optimizer, model_state, optimizer_state):
    """Restore the model and optimizer states from copies."""
    model.load_state_dict(model_state, strict=True)
    optimizer.load_state_dict(optimizer_state)


def configure_model(model):
    """Configure model for use with tent."""
    # train mode, because tent optimizes the model to minimize entropy
    model.train()
    # disable grad, to (re-)enable only what tent updates
    model.requires_grad_(False)
    # configure norm for tent updates: enable grad + force batch statisics
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.requires_grad_(True)
            # force use of batch stats in train and eval modes
            m.track_running_stats = False
            m.running_mean = None
            m.running_var = None
    return model


def check_model(model):
    """Check model for compatability with tent."""
    is_training = model.training
    assert is_training, "tent needs train mode: call model.train()"
    param_grads = [p.requires_grad for p in model.parameters()]
    has_any_params = any(param_grads)
    has_all_params = all(param_grads)
    assert has_any_params, "tent needs params to update: " \
                           "check which require grad"
    assert not has_all_params, "tent should not update all params: " \
                               "check which require grad"
    has_bn = any([isinstance(m, nn.BatchNorm2d) for m in model.modules()])
    assert has_bn, "tent needs normalization for its optimization"

In [3]:
cfg_MODEL_ARCH = 'Standard'

# Choice of (source, norm, tent)
# - source: baseline without adaptation
# - norm: test-time normalization
# - tent: test-time entropy minimization (ours)
cfg_MODEL_ADAPTATION = 'source'

# By default tent is online, with updates persisting across batches.
# To make adaptation episodic, and reset the model for each batch, choose True.
cfg_MODEL_EPISODIC = False

# ----------------------------- Corruption options -------------------------- #
#cfg_CORRUPTION = CfgNode()

# Dataset for evaluation
cfg_CORRUPTION_DATASET = 'cifar10'

# Check https://github.com/hendrycks/robustness for corruption details
cfg_CORRUPTION_TYPE = ['gaussian_noise', 'shot_noise', 'impulse_noise',
                      'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur',
                      'snow', 'frost', 'fog', 'brightness', 'contrast',
                      'elastic_transform', 'pixelate', 'jpeg_compression']
cfg_CORRUPTION_SEVERITY = [5, 4, 3, 2, 1]

# Number of examples to evaluate (10000 for all samples in CIFAR-10)
cfg_CORRUPTION_NUM_EX = 10000

# ------------------------------- Batch norm options ------------------------ #
#cfg_BN = CfgNode()

# BN epsilon
cfg_BN_EPS = 1e-5

# BN momentum (BN momentum in PyTorch = 1 - BN momentum in Caffe2)
cfg_BN_MOM = 0.1

# ------------------------------- Optimizer options ------------------------- #
#cfg_OPTIM = CfgNode()

# Number of updates per batch
cfg_OPTIM_STEPS = 1

# Learning rate
cfg_OPTIM_LR = 1e-3

# Choices: Adam, SGD
cfg_OPTIM_METHOD = 'Adam'

# Beta
cfg_OPTIM_BETA = 0.9

# Momentum
cfg_OPTIM_MOMENTUM = 0.9

# Momentum dampening
cfg_OPTIM_DAMPENING = 0.0

# Nesterov momentum
cfg_OPTIM_NESTEROV = True

# L2 regularization
cfg_OPTIM_WD = 0.0

# ------------------------------- Testing options --------------------------- #
#cfg_TEST = CfgNode()

# Batch size for evaluation (and updates for norm + tent)
cfg_TEST_BATCH_SIZE = 128

# --------------------------------- CUDNN options --------------------------- #
#cfg_CUDNN = CfgNode()

# Benchmark to select fastest CUDNN algorithms (best for fixed input sizes)
cfg_CUDNN_BENCHMARK = True

# ---------------------------------- Misc options --------------------------- #

# Optional description of a config
cfg_DESC = ""

# Note that non-determinism is still present due to non-deterministic GPU ops
cfg_RNG_SEED = 1

# Output directory
cfg_SAVE_DIR = "./output"

# Data directory
cfg_DATA_DIR = "./data"

# Weight directory
cfg_CKPT_DIR = "./ckpt"

# Log destination (in SAVE_DIR)
cfg_LOG_DEST = "log.txt"

# Log datetime
cfg_LOG_TIME = ''

In [4]:
import logging

import torch
import torch.optim as optim

from robustbench.data import load_cifar10c
from robustbench.model_zoo.enums import ThreatModel
from robustbench.utils import load_model
from robustbench.utils import clean_accuracy as accuracy


logger = logging.getLogger(__name__)


In [5]:

def setup_source(model):
    """Set up the baseline source model without adaptation."""
    model.eval()
    logger.info(f"model for evaluation: %s", model)
    return model

In [6]:
def setup_tent(model):
    """Set up tent adaptation.

    Configure the model for training + feature modulation by batch statistics,
    collect the parameters for feature modulation by gradient optimization,
    set up the optimizer, and then tent the model.
    """
    model = configure_model(model)
    params, param_names = collect_params(model)
    optimizer = setup_optimizer(params)
    tent_model = Tent(model, optimizer,
                           steps=cfg_OPTIM_STEPS,
                           episodic=cfg_MODEL_EPISODIC)
    print(f"model for adaptation: %s", model)
    print(f"params for adaptation: %s", param_names)
    print(f"optimizer for adaptation: %s", optimizer)
    return tent_model


def setup_optimizer(params):
    """Set up optimizer for tent adaptation.

    Tent needs an optimizer for test-time entropy minimization.
    In principle, tent could make use of any gradient optimizer.
    In practice, we advise choosing Adam or SGD+momentum.
    For optimization settings, we advise to use the settings from the end of
    trainig, if known, or start with a low learning rate (like 0.001) if not.

    For best results, try tuning the learning rate and batch size.
    """
    if cfg_OPTIM_METHOD == 'Adam':
        return optim.Adam(params,
                    lr=cfg_OPTIM_LR,
                    betas=(cfg_OPTIM_BETA, 0.999),
                    weight_decay=cfg_OPTIM_WD)
    elif cfg_OPTIM_METHOD == 'SGD':
        return optim.SGD(params,
                   lr=cfg_OPTIM_LR,
                   momentum=cfg_OPTIM_MOMENTUM,
                   dampening=cfg_OPTIM_DAMPENING,
                   weight_decay=cfg_OPTIM_WD,
                   nesterov=cfg_OPTIM_NESTEROV)
    else:
        raise NotImplementedError

In [7]:
def evaluate(description):
    #load_cfg_fom_args(description)
    # configure model
    base_model = load_model(cfg_MODEL_ARCH, cfg_CKPT_DIR,
                       cfg_CORRUPTION_DATASET, ThreatModel.corruptions).cuda()
    if cfg_MODEL_ADAPTATION == "source":
        logger.info("test-time adaptation: NONE")
        model = setup_source(base_model)
    if cfg_MODEL_ADAPTATION == "norm":
        logger.info("test-time adaptation: NORM")
        model = setup_norm(base_model)
    if cfg_MODEL_ADAPTATION == "tent":
        logger.info("test-time adaptation: TENT")
        model = setup_tent(base_model)
    # evaluate on each severity and type of corruption in turn
    print("done")
    for severity in cfg_CORRUPTION_SEVERITY:
        for corruption_type in cfg_CORRUPTION_TYPE:
            # reset adaptation for each combination of corruption x severity
            # note: for evaluation protocol, but not necessarily needed
            try:
                model.reset()
                print("resetting model")
            except:
                logger.warning("not resetting model")
            x_test, y_test = load_cifar10c(cfg_CORRUPTION_NUM_EX,
                                           severity, cfg_DATA_DIR, False,
                                           [corruption_type])
            x_test, y_test = x_test.cuda(), y_test.cuda()
            #print(x_test.shape, y_test.shape)
            acc = accuracy(model, x_test, y_test, cfg_TEST_BATCH_SIZE)
            err = 1. - acc
            print(corruption_type)
            print(severity)
            print(err)
            #print("error % [{corruption_type}{severity}]: {err:.2%}")



In [8]:
if __name__ == '__main__':
    evaluate('"CIFAR-10-C evaluation.')

Downloading ckpt/cifar10/corruptions/Standard.pt (gdrive_id=1t98aEuzeTL8P7Kpd5DIrCoCL21BNZUhC).


Downloading...
From (original): https://drive.google.com/uc?id=1t98aEuzeTL8P7Kpd5DIrCoCL21BNZUhC
From (redirected): https://drive.google.com/uc?id=1t98aEuzeTL8P7Kpd5DIrCoCL21BNZUhC&confirm=t&uuid=3db333d5-3a48-406e-a16a-79c3f6362b5d
To: /kaggle/working/ckpt/cifar10/corruptions/Standard.pt
100%|██████████| 292M/292M [00:03<00:00, 93.6MB/s] 


done
Starting download from https://zenodo.org/api/records/2535967/files/CIFAR-10-C.tar/content


44533it [03:40, 201.63it/s]                           


Download finished, extracting...
Downloaded and extracted.
gaussian_noise
5
0.7233
shot_noise
5
0.6571
impulse_noise
5
0.7292000000000001
defocus_blur
5
0.46940000000000004
glass_blur
5
0.5432
motion_blur
5
0.34750000000000003
zoom_blur
5
0.4202
snow
5
0.25070000000000003
frost
5
0.41300000000000003
fog
5
0.2601
brightness
5
0.09299999999999997
contrast
5
0.4669
elastic_transform
5
0.2659
pixelate
5
0.5845
jpeg_compression
5
0.30300000000000005
gaussian_noise
4
0.6738999999999999
shot_noise
4
0.5465
impulse_noise
4
0.5989
defocus_blur
4
0.22560000000000002
glass_blur
4
0.5681
motion_blur
4
0.25170000000000003
zoom_blur
4
0.29710000000000003
snow
4
0.19489999999999996
frost
4
0.2914
fog
4
0.10409999999999997
brightness
4
0.07189999999999996
contrast
4
0.16410000000000002
elastic_transform
4
0.2106
pixelate
4
0.39649999999999996
jpeg_compression
4
0.25849999999999995
gaussian_noise
3
0.6081
shot_noise
3
0.46509999999999996
impulse_noise
3
0.4263
defocus_blur
3
0.11019999999999996
glass_b

In [9]:
cfg_MODEL_ADAPTATION = 'tent'
cfg_TEST_BATCH_SIZE = 200
cfg_OPTIM_METHOD = 'Adam'
cfg_OPTIM_STEPS = 1
cfg_OPTIM_BETA = 0.9
cfg_OPTIM_LR = 1e-3
cfg_OPTIM_WD = 0.


In [None]:
if __name__ == '__main__':
    evaluate('"CIFAR-10-C evaluation.')

model for adaptation: %s WideResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (block1): NetworkBlock(
    (layer): Sequential(
      (0): BasicBlock(
        (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(16, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (convShortcut): Conv2d(16, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (1): BasicBlock(
        (bn1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=Fa