# Data & Helper Functions

Initializing some helper functions that would be used later during training

In [3]:
'''Some helper functions for PyTorch, including:
    - get_mean_and_std: calculate the mean and std value of dataset.
    - msr_init: net parameter initialization.
    - progress_bar: progress bar mimic xlua.progress.
'''
import os
import sys
import time
import math
import shutil

import torch
import torch.nn as nn
import torch.nn.init as init
import torchvision

import numpy as np
import random

# fixing random seeds
random_seed = 6
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)

def get_mean_and_std(dataset):
    '''Compute the mean and std value of dataset.'''
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    print('==> Computing mean and std..')
    for inputs, targets in dataloader:
        for i in range(3):
            mean[i] += inputs[:,i,:,:].mean()
            std[i] += inputs[:,i,:,:].std()
    mean.div_(len(dataset))
    std.div_(len(dataset))
    return mean, std

def init_params(net):
    '''Init layer parameters.'''
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            init.kaiming_normal(m.weight, mode='fan_out')
            if m.bias:
                init.constant(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            init.constant(m.weight, 1)
            init.constant(m.bias, 0)
        elif isinstance(m, nn.Linear):
            init.normal(m.weight, std=1e-3)
            if m.bias:
                init.constant(m.bias, 0)

_, term_width = shutil.get_terminal_size()
term_width = int(term_width)

TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
    global last_time, begin_time
    if current == 0:
        begin_time = time.time()  # Reset for new bar.

    cur_len = int(TOTAL_BAR_LENGTH*current/total)
    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

    sys.stdout.write(' [')
    for i in range(cur_len):
        sys.stdout.write('=')
    sys.stdout.write('>')
    for i in range(rest_len):
        sys.stdout.write('.')
    sys.stdout.write(']')

    cur_time = time.time()
    step_time = cur_time - last_time
    last_time = cur_time
    tot_time = cur_time - begin_time

    L = []
    L.append('  Step: %s' % format_time(step_time))
    L.append(' | Tot: %s' % format_time(tot_time))
    if msg:
        L.append(' | ' + msg)

    msg = ''.join(L)
    sys.stdout.write(msg)
    for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
        sys.stdout.write(' ')

    # Go back to the center of the bar.
    for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
        sys.stdout.write('\b')
    sys.stdout.write(' %d/%d ' % (current+1, total))

    if current < total-1:
        sys.stdout.write('\r')
    else:
        sys.stdout.write('\n')
    sys.stdout.flush()

def format_time(seconds):
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    secondsf = int(seconds)
    seconds = seconds - secondsf
    millis = int(seconds*1000)

    f = ''
    i = 1
    if days > 0:
        f += str(days) + 'D'
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
    if millis > 0 and i <= 2:
        f += str(millis) + 'ms'
        i += 1
    if f == '':
        f = '0ms'
    return f


# Model

Here, we define our model - DLA class. In our experiments, we use the DLA (https://arxiv.org/abs/1707.06484) architecture with ZigZag-aligned modifications, specifically with a modified first layer.

In [4]:
'''Simplified version of DLA in PyTorch.
Note this implementation is not identical to the original paper version.
But it seems works fine.
See dla.py for the original paper version.
Reference:
    Deep Layer Aggregation. https://arxiv.org/abs/1707.06484
'''
import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Root(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1):
        super(Root, self).__init__()
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size,
            stride=1, padding=(kernel_size - 1) // 2, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, xs):
        x = torch.cat(xs, 1)
        out = F.relu(self.bn(self.conv(x)))
        return out


class Tree(nn.Module):
    def __init__(self, block, in_channels, out_channels, level=1, stride=1):
        super(Tree, self).__init__()
        self.root = Root(2*out_channels, out_channels)
        if level == 1:
            self.left_tree = block(in_channels, out_channels, stride=stride)
            self.right_tree = block(out_channels, out_channels, stride=1)
        else:
            self.left_tree = Tree(block, in_channels,
                                  out_channels, level=level-1, stride=stride)
            self.right_tree = Tree(block, out_channels,
                                   out_channels, level=level-1, stride=1)

    def forward(self, x):
        out1 = self.left_tree(x)
        out2 = self.right_tree(out1)
        out = self.root([out1, out2])
        return out

class F1(nn.Module):
    def __init__(self, block=BasicBlock, num_classes=10):
        super(F1, self).__init__()
        self.base = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(True)
        )

        self.layer1 = nn.Sequential(
            nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(True)
        )

        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(True)
        )

        self.layer3 = Tree(block,  32,  64, level=1, stride=1)
        self.layer4 = Tree(block,  64, 128, level=2, stride=2)

    def forward(self, x):
        out = self.base(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        return out

class F2(nn.Module):
    def __init__(self, block=BasicBlock, num_classes=10):
        super(F2, self).__init__()

        self.y_hat_fc = nn.Sequential(
            nn.Linear(10, 128),
            nn.LeakyReLU()
        )
        self.layer5 = Tree(block, 128, 256, level=2, stride=2)
        self.layer6 = Tree(block, 256, 512, level=1, stride=2)
        self.linear = nn.Linear(512, num_classes + 128)

    def forward(self, x, y):
        out = x + self.y_hat_fc(y)[..., None, None]
        out = self.layer5(out)
        out = self.layer6(out)
        z = F.avg_pool2d(out, 4)
        z = z.view(z.size(0), -1)
        z = self.linear(z)
        return z[:, :10], z[:, 10:]

class DLA(nn.Module):
    def __init__(self, block=BasicBlock, num_classes=10):
        super(DLA, self).__init__()
        self.f1 = F1(block, num_classes)
        self.f2 = F2(block, num_classes)

    def forward(self, x, y):
        z = self.f1(x)
        y_pred, z_pred  = self.f2(z, y)
        return y_pred, z_pred

# Initialize Training

Here, we define the optimization parameters and the functions for **train** and **test**. As described in the paper, the first inference of our method is performed with a "blank" additional input, while the second inference incorporates the class labels.

In [5]:
'''Train CIFAR10 with PyTorch.'''
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse
import tqdm

class Args:

  def __init__(self):
    self.lr = 0.001
    self.resume = False
    self.checkpoint = "cifar_iter"
    self.batch_size = 56

args = Args()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Model
device = "cuda"
print('==> Building model..')
net = DLA().to(device)

if device == 'cuda':
    net = torch.nn.DataParallel(net)

batch_size = 64
num_workers = 64

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=args.batch_size, shuffle=True, num_workers=num_workers, drop_last=True)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=args.batch_size, shuffle=False, num_workers=num_workers)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load(f'./checkpoint/{args.checkpoint}.pth')
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

# Training
def train(epoch):

    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0

    t = tqdm.trange(len(trainset) // args.batch_size + 1, desc='Current Loss = ', leave=True)

    for _, (batch_idx, (inputs, targets)) in zip(t, enumerate(trainloader)):

        inputs, targets = inputs.to(device), targets.to(device)
        batch_size, _, H, W = inputs.shape

        x1 = torch.cat([inputs], dim=1)

        inputs = torch.cat([x1], dim=0)
        targets = torch.cat([targets], dim=0)

        optimizer.zero_grad()

        mask = torch.rand(1) > 0.8

        # creating "blank" inputs for the first inferense
        # in this case, it's uniform distribution over 10 classes
        y_0 = F.one_hot(targets, num_classes=10).float() if mask else torch.ones(batch_size, 10).to(device) * 0.1
        z = net.module.f1(inputs)

        # first inference with "blank" input
        y_1, z1 = net.module.f2(z, y_0)

        # creating additional input for the second inference
        y_2, z2 = net.module.f2(z + z1[..., None, None], y_1.softmax(-1))

        loss_supervised_1 = criterion(y_1, targets)
        loss_supervised_2 = criterion(y_2, targets)

        loss = loss_supervised_1 + loss_supervised_2
        loss.backward()
        optimizer.step()

        inputs, targets = inputs.to(device), targets.to(device)

        train_loss += loss.item()
        _, predicted = y_2.max(1)
        if not mask:
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

        if total != 0:
            t.set_description(f"Epoch {epoch} Current Loss = {round(100.*correct/total, 3)}", refresh=True)

def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            batch_size, _, H, W = inputs.shape
            inputs = torch.cat([inputs], dim=1)

            # creating "blank" inputs for the first inferense
            # in this case, it's uniform distribution over 10 classes
            y_0 = torch.ones(batch_size, 10).to(device) * 0.1

            z = net.module.f1(inputs)

             # first inference with "blank" input
            y_1, z1 = net.module.f2(z, y_0)

            # creating additional input for the second inference
            y_2, z2 = net.module.f2(z + z1[..., None, None], y_1.softmax(-1))

            loss = criterion(y_2, targets)

            test_loss += loss.item()
            _, predicted = y_1.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    # Save checkpoint.
    acc = 100.*correct/total

    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, f'./checkpoint/{args.checkpoint}.pth')
        best_acc = acc

    print("BEST ACCURACY: ", best_acc, "Current Accuracy: ", acc)

==> Preparing data..
==> Building model..
Files already downloaded and verified




Files already downloaded and verified


# Training

The next cell trains the model based on ZigZag. For convenience, a pretrained checkpoint is provided in the following section.

In [None]:
for epoch in range(start_epoch, start_epoch + 70):
    train(epoch)
    test(epoch)
    scheduler.step()

for param_group in optimizer.param_groups:
    param_group['lr'] *= 0.1

for epoch in range(start_epoch, start_epoch + 10):
    train(epoch)
    test(epoch)
    scheduler.step()

# Evaluation

In [6]:
# Creating the model
device = "cuda"
print('==> Building model..')
net = DLA().to(device)

if device == 'cuda':
    net = torch.nn.DataParallel(net)

==> Building model..


### Download pretrained weights

In [7]:
!gdown 1VeGGbAdfBvVyOybZwbQ7bSC24Bbli4X9

Downloading...
From (original): https://drive.google.com/uc?id=1VeGGbAdfBvVyOybZwbQ7bSC24Bbli4X9
From (redirected): https://drive.google.com/uc?id=1VeGGbAdfBvVyOybZwbQ7bSC24Bbli4X9&confirm=t&uuid=952e5a9f-1ed3-4608-8aa6-d5551487ab8d
To: /content/cifar_iter.pth
100% 61.0M/61.0M [00:00<00:00, 147MB/s]


In [8]:
checkpoint = "./cifar_iter.pth"

# Load checkpoint.
print('==> Resuming from checkpoint..')
checkpoint = torch.load(checkpoint)
net.load_state_dict(checkpoint['net'])

==> Resuming from checkpoint..


  checkpoint = torch.load(checkpoint)


<All keys matched successfully>

# Noisy Inputs Evaluation

In this subsection, we apply a substantial amount of Gaussian noise to the input images. We observe a significant drop in performance for the original (non-optimized) model, whereas the ITTT-optimized model performs considerably better.

In [9]:
class AddGaussianNoise(object):
    def __init__(self, std, mean=0.):
        self.std = std
        self.mean = mean

    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean

    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    AddGaussianNoise(0.5)
])

ood_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
ood_loader = torch.utils.data.DataLoader(ood_dataset, batch_size=64, shuffle=False, num_workers=num_workers, drop_last=True)

Files already downloaded and verified


### Vanilla Model

The vanilla model achieves only around 30% accuracy on noisy inputs, compared to over 90% on clean data.

In [10]:
import tqdm

def test_vanilla(net, dataloader):
    global best_acc
    criterion = nn.CrossEntropyLoss()
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in tqdm.tqdm(enumerate(dataloader), total=len(dataloader)):
            inputs, targets = inputs.to(device), targets.to(device)
            batch_size, _, H, W = inputs.shape

            inputs = torch.cat([inputs], dim=1)

            y_0 = torch.ones(batch_size, 10).to(device) * 0.1

            z = net.module.f1(inputs)
            y_1, z1 = net.module.f2(z, y_0)
            y_2, z2 = net.module.f2(z + z1[..., None, None], y_1.softmax(-1))

            loss = criterion(y_2, targets)

            test_loss += loss.item()
            _, predicted = y_1.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    acc = round(100.*correct/total, 2)
    print(f"\nAccuracy: {acc}%")

In [11]:
test_vanilla(net, ood_loader)

100%|██████████| 156/156 [00:26<00:00,  5.85it/s]


Accuracy: 31.47%





### ITTT Optimized Model

The model optimized with Idempotent Test-Time Training achieves approximately 40% higher accuracy on noisy data.

In [14]:
def js_div(p, q):
    """Function that computes distance between two predictions"""
    m = 0.5 * (p + q)
    return 0.5 * (F.kl_div(torch.log(p), m, reduction='batchmean') +
                  F.kl_div(torch.log(q), m, reduction='batchmean'))

def ttt_one_instance(x, f_ttt, f, optimizer, n_steps, y, n_classes=10):
  """Function that runs test-time training on one batch 'x'"""

  f_ttt.load_state_dict(f.state_dict())  # reset f_ttt to f
  f_ttt.train()
  for step in range(n_steps):
    y_0 = torch.ones(batch_size, 10).to(device) * 0.1

    z = f_ttt.f1(x)
    y_1, z1 = f_ttt.f2(z, y_0)
    y_2, z2 = f_ttt.f2(z + z1[..., None, None], y_1.softmax(-1))

    loss_unsupervised_y = js_div(y_1.softmax(-1), y_2.softmax(-1))
    loss_unsupervised_z = (z1 - z2).pow(2).mean()
    loss = loss_unsupervised_y

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  return y_1, y_2


def ttt(f, test_loader, n_steps, lr):
  """Running test-time training over the whole test dataloader"""

  f_ttt = deepcopy(f)
  f.eval()
  optimizer = optim.Adam(f_ttt.parameters(), lr=lr)
  test_loss_1, correct_1 = 0, 0
  test_loss_2, correct_2 = 0, 0

  for ind, (data, target) in tqdm.tqdm(enumerate(test_loader), total=len(test_loader)):
    x, y = data.to(device), target.to(device)
    y_hat_1, y_hat_2 = ttt_one_instance(x, f_ttt, f, optimizer, n_steps, y)

    test_loss_1 += F.nll_loss(y_hat_1.log(), y, size_average=False).item()
    test_loss_2 += F.nll_loss(y_hat_2.log(), y, size_average=False).item()

    pred_1 = y_hat_1.data.max(1, keepdim=True)[1]
    pred_2 = y_hat_2.data.max(1, keepdim=True)[1]

    correct_1 += pred_1.eq(y.data.view_as(pred_1)).sum()
    correct_2 += pred_2.eq(y.data.view_as(pred_2)).sum()

  acc = round(100. * int(correct_2) / len(test_loader.dataset), 2)
  print(f"\nAccuracy: {acc}%")

In [13]:
from copy import deepcopy
ttt(net.module, ood_loader, n_steps=1, lr=1e-3)

100%|██████████| 156/156 [01:12<00:00,  2.16it/s]


Accuracy: 68.96%



