In [None]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import models.cifar as models

import numpy as np
import scipy
from statsmodels.stats.proportion import proportion_confint

import math
import time

In [None]:
def sample_under_noise(model, x, n, sigma, batch=100):
    noisy_images = x + sigma * torch.randn((n, *x.shape[1:]))
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))(noisy_images)
    preds = []
    for i in range(math.ceil(n/batch)):
        batched = noisy_images[i * batch: (i + 1) * batch]
        logits = model(batched)
        pred = torch.argmax(logits, dim=1)
        preds.append(pred)
    return torch.cat(preds).cpu().numpy()

In [None]:
def batched_sample_under_noise(model, x, n, sigma):
    noisy_images = x + sigma * torch.randn(x.size())
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))(noisy_images)
    preds = []
    for i in range(n):
        logits = model(noisy_images)
        pred = torch.argmax(logits, dim=1)
        preds.append(pred.unsqueeze(-1))
    return torch.cat(preds, dim=1).cpu().numpy()

In [None]:
def predict(model, sigma, x, n=1000, alpha=0.001):
    preds = sample_under_noise(model, x, n, sigma).astype(int)
    counts = np.bincount(preds)
    cA = np.argmax(counts)
    nA, nB = counts[np.argpartition(counts, 2)[:2]]
    if scipy.stats.binom_test(max(nA, nB), nA + nB, 0.5) <= alpha:
        return cA
    return None

In [None]:
def batched_predict(model, sigma, x, n=1000, alpha=0.001):
    preds = batched_sample_under_noise(model, x, n, sigma).astype(int)
    print(preds.shape)
    results = []
    for i in range(preds.shape[0]):
        counts = np.bincount(preds[i])
        cA = np.argmax(counts)
        if len(counts) < 2:
            results.append(cA)
        else:
            nA, nB = counts[np.argpartition(counts, 2)[:2]]
            if scipy.stats.binom_test(max(nA, nB), nA + nB, 0.5) <= alpha:
                results.append(cA)
            else:
                results.append(-1)
    return results

In [None]:
def certify(model, sigma, x, n0=100, n=100000, alpha=0.001):
    preds0 = sample_under_noise(model, x, n0, sigma).astype(int)
    counts0 = np.bincount(preds0)
    cA = np.argmax(counts0)
    preds = sample_under_noise(model, x, n, sigma).astype(int)
    counts = np.bincount(preds)
    pA = proportion_confint(counts[cA], n, alpha=2*alpha, method='beta')[0]
    if pA > 0.5:
        return cA, sigma * scipy.stats.norm.ppf(pA)
    return -1, 0

In [None]:
def batched_certify(model, sigma, x, n0=100, n=100000, alpha=0.001):
    preds0 = batched_sample_under_noise(model, x, n0, sigma).astype(int)
    counts0 = np.bincount(preds0)
    cA = np.argmax(counts0)
    preds = sample_under_noise(model, x, n, sigma).astype(int)
    results = []
    for i in range(x.shape[0]):
        counts = np.bincount(preds[i])
        pA = proportion_confint(counts[cA], n, alpha=2*alpha, method='beta')[0]
        if pA > 0.5:
            results.append([cA, sigma * scipy.stats.norm.ppf(pA)])
        else:
            results.append([-1, 0])
    return results

In [None]:
model = models.__dict__['resnet'](
    num_classes=10,
    depth=110,
    block_name='BasicBlock',
)

In [None]:
saved_model = '/pasteur/results/jeff-results/pretrained-models/models/cifar10/resnet110/noise_0.12/checkpoint.pth.tar'
if 'pretrained-models' in saved_model:
    pretrained = True

In [None]:
checkpoint = torch.load(saved_model)

In [None]:
state_dict = checkpoint['state_dict']

In [None]:
if pretrained:
    for k in list(state_dict.keys()):
        # retain only encoder_q up to before the embedding layer
        if k.startswith('1.'):
            # remove prefix
            state_dict[k[len("1."):]] = state_dict[k]
        # delete renamed or unused k
        del state_dict[k]
model.load_state_dict(state_dict)

In [None]:
# model = nn.DataParallel(model)

In [None]:
transform_test = transforms.Compose([
    transforms.ToTensor(),
])

In [None]:
testset = datasets.CIFAR10(root='/pasteur/data', train=False, download=False, transform=transform_test)
testloader = data.DataLoader(testset, batch_size=1, shuffle=False)

In [None]:
# test predict
pred_acc = 0
abstain = 0
t0 = time.time()
for batch_idx, (inputs, targets) in enumerate(testloader):
    pred = predict(model, sigma=0.12, x=inputs, n=1000, alpha=0.001)
    if pred == targets.item():
        pred_acc += 1
    if pred == -1:
        abstain += 1
    if batch_idx == 1:
        break
t1 = time.time()    
print('Total time:', t1 - t0)
print('Average time:', (t1 - t0)/(batch_idx + 1))
print('Abstain precent:', abstain/(batch_idx + 1))
print('Predicted accuracy:', pred_acc/(batch_idx + 1 - abstain))

In [None]:
# test certify
cert_acc = 0
abstain = 0
results = []
all_targets = []
t0 = time.time()
for batch_idx, (inputs, targets) in enumerate(testloader):
    pred, radius = certify(model, sigma=0.12, x=inputs, n=1000, alpha=0.001)
    results.append([pred, radius])
    all_targets = all_targets + list(targets.cpu().numpy())
    if batch_idx == 1:
        break
t1 = time.time()    
print('Total time:', t1 - t0)
print('Average time:', (t1 - t0)/(batch_idx + 1))

In [None]:
cert_acc = 0
abstain = 0
for i in range(len(results)):
    pred, radius = results[i]
    if pred == all_targets[i]:
        cert_acc += 1
    if pred == -1:
        abstain += 1
print('Abstain precent:', abstain/len(results))
print('Predicted accuracy:', pred_acc/(len(results) - abstain))

In [None]:
testset = datasets.CIFAR10(root='/pasteur/data', train=False, download=False, transform=transform_test)
testloader = data.DataLoader(testset, batch_size=8, shuffle=False)

In [None]:
# test batched predict
results = []
t0 = time.time()
all_targets = []
for batch_idx, (inputs, targets) in enumerate(testloader):
    result = batched_predict(model, sigma=0.12, x=inputs, n=1000, alpha=0.001)
    results = results + result
    all_targets = all_targets + list(targets.cpu().numpy())
    if batch_idx == 0:
        break
t1 = time.time()    
print('Total time:', t1 - t0)
print('Average time:', (t1 - t0)/len(results))

In [None]:
pred_acc = 0
abstain = 0
for i in range(len(results)):
    pred = results[i]
    if pred == all_targets[i]:
        pred_acc += 1
    if pred == -1:
        abstain += 1
print('Abstain precent:', abstain/len(results))
print('Predicted accuracy:', pred_acc/(len(results) - abstain))

In [None]:
# test batched certify
results = []
all_targets = []
t0 = time.time()
for batch_idx, (inputs, targets) in enumerate(testloader):
    result = predict(model, sigma=0.12, x=inputs, n=1000, alpha=0.001)
    results = results + result
    all_targets = all_targets + list(targets.cpu().numpy())
    if batch_idx == 0:
        break
t1 = time.time()    
print('Total time:', t1 - t0)
print('Average time:', (t1 - t0)/len(results))

In [None]:
cert_acc = 0
abstain = 0
for i in range(len(results)):
    pred, radius = results[i]
    if pred == all_targets[i]:
        cert_acc += 1
    if pred == -1:
        abstain += 1
print('Abstain precent:', abstain/len(results))
print('Predicted accuracy:', cert_acc/(len(results) - abstain))