In [1]:
# the code is partially based on https://github.com/Zinoex/bound_propagation/blob/main/examples/fashion_mnist.py 

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import time
import matplotlib.pyplot as plt
from tqdm import tqdm

from torchvision import datasets, transforms
# from tensorboardX import SummaryWriter

use_cuda = True
device = torch.device("mps" if use_cuda else "cpu")
batch_size = 64

np.random.seed(42)
torch.manual_seed(42)

transform = transforms.Compose([
        transforms.PILToTensor(),
        transforms.ConvertImageDtype(torch.float),
        transforms.Lambda(torch.flatten)
    ])


## Dataloaders
train_dataset = datasets.MNIST('mnist_data/', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('mnist_data/', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [25]:
from bound_propagation import BoundModelFactory, HyperRectangle

# set up the network
class Network(nn.Sequential):
    def __init__(self):
        in_size = 28*28
        classes = 10

        super().__init__(
            nn.Linear(in_size, 50),
            nn.ReLU(),
            nn.Linear(50, 50),
            nn.ReLU(),
            nn.Linear(50, 50),
            nn.ReLU(),
            nn.Linear(50, classes)
        )

net = Network().to(device)

factory = BoundModelFactory()
net = factory.build(net)

In [3]:
# train the network on clean data
def clean_train(model, num_epochs):
    learning_rate = 0.0001

    opt = optim.Adam(params=net.parameters(), lr=learning_rate)

    ce_loss = torch.nn.CrossEntropyLoss()

    tot_steps = 0

    for epoch in range(1,num_epochs+1):
        t1 = time.time()

        for _, (x_batch, y_batch) in enumerate(train_loader):
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            tot_steps += 1
            opt.zero_grad()
            out = model(x_batch)
            batch_loss = ce_loss(out, y_batch)
            batch_loss.backward()
            opt.step()
    
        tot_test, tot_acc = 0.0, 0.0
        for _, (x_batch, y_batch) in enumerate(test_loader):
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            out = model(x_batch)
            pred = torch.max(out, dim=1)[1]
            acc = pred.eq(y_batch).sum().item()
            tot_acc += acc
            tot_test += x_batch.size()[0]
        t2 = time.time()

        print('Epoch %d: Accuracy %.5lf [%.2lf seconds]' % (epoch, tot_acc/tot_test, t2-t1))

In [60]:
clean_train(net, 20)
torch.save(net.state_dict(), 'weights_clean.pt')

Epoch 1: Accuracy 0.86100 [8.93 seconds]
Epoch 2: Accuracy 0.89270 [6.33 seconds]
Epoch 3: Accuracy 0.90480 [6.19 seconds]
Epoch 4: Accuracy 0.91400 [6.14 seconds]
Epoch 5: Accuracy 0.92110 [6.11 seconds]
Epoch 6: Accuracy 0.92520 [6.11 seconds]
Epoch 7: Accuracy 0.92970 [6.11 seconds]
Epoch 8: Accuracy 0.93560 [6.09 seconds]
Epoch 9: Accuracy 0.93610 [6.21 seconds]
Epoch 10: Accuracy 0.94140 [6.30 seconds]
Epoch 11: Accuracy 0.94390 [6.04 seconds]
Epoch 12: Accuracy 0.94560 [6.15 seconds]
Epoch 13: Accuracy 0.94680 [6.11 seconds]
Epoch 14: Accuracy 0.94980 [6.22 seconds]
Epoch 15: Accuracy 0.95190 [6.31 seconds]
Epoch 16: Accuracy 0.95280 [6.27 seconds]
Epoch 17: Accuracy 0.95320 [6.11 seconds]
Epoch 18: Accuracy 0.95410 [6.19 seconds]
Epoch 19: Accuracy 0.95530 [6.13 seconds]
Epoch 20: Accuracy 0.95810 [6.12 seconds]


### Interval Analysis

In [14]:
# inputs: images, model, perturbation sizes
# return: the robustness of the inputs
def ibp_analysis(train_loader, net, eps):
    with tqdm(train_loader, unit="batch") as tepoch:
        running_correct = 0
        num_samples = 0
        for i, (x, target) in enumerate(tepoch):
            x, target = x.to(device), target.to(device)
            input_bounds = HyperRectangle.from_eps(x, eps)
            ibp_bounds = net.ibp(input_bounds)
            lower, upper = ibp_bounds.lower, ibp_bounds.upper

            # expand the classes with the same size of bounds [batch_size, 10]
            classes = torch.arange(10, device=device).unsqueeze(0).expand(len(target), -1)
            # get mask with 1 of the correct class and 0 rest of them
            mask = (classes == target.unsqueeze(-1)).to(dtype=lower.dtype)

            upper_sel = (upper * (1-mask)).max(dim=-1)[0] # get the higest upper bounds for wrong classes
            lower_sel = lower[classes == target.unsqueeze(-1)] 

            running_correct += (lower_sel > upper_sel).cpu().float().sum().item()
            num_samples += len(x)
    
    print(f"Robustness: {running_correct / num_samples} with eps {eps}")

In [31]:
eps =np.linspace(0.01, 0.1, 10)
net.load_state_dict(torch.load('weights_clean.pt'))
test_clean(net, test_loader)
for ep in eps:
    ibp_analysis(test_loader, net, ep)


[TEST]
Accuracy: 0.958


100%|██████████| 157/157 [00:02<00:00, 73.40batch/s]


Robustness: 0.0 with eps 0.01


100%|██████████| 157/157 [00:02<00:00, 75.24batch/s]


Robustness: 0.0 with eps 0.020000000000000004


100%|██████████| 157/157 [00:02<00:00, 76.81batch/s]


Robustness: 0.0 with eps 0.030000000000000006


100%|██████████| 157/157 [00:02<00:00, 68.17batch/s]


Robustness: 0.0 with eps 0.04000000000000001


100%|██████████| 157/157 [00:02<00:00, 70.58batch/s]


Robustness: 0.0 with eps 0.05000000000000001


100%|██████████| 157/157 [00:02<00:00, 69.95batch/s]


Robustness: 0.0 with eps 0.06000000000000001


100%|██████████| 157/157 [00:01<00:00, 78.94batch/s]


Robustness: 0.0 with eps 0.07


100%|██████████| 157/157 [00:02<00:00, 74.23batch/s]


Robustness: 0.0 with eps 0.08


100%|██████████| 157/157 [00:01<00:00, 78.72batch/s]


Robustness: 0.0 with eps 0.09000000000000001


100%|██████████| 157/157 [00:01<00:00, 78.83batch/s]

Robustness: 0.0 with eps 0.1





### Robust Training

In [43]:
def adversarial_logit(y_hat, y):
    batch_size = y.size(0)
    classes = torch.arange(10, device=y.device).unsqueeze(0).expand(batch_size, -1)
    mask = (classes == y.unsqueeze(-1)).to(dtype=y_hat.lower.dtype)

    # Take upper bound for logit of all but the correct class where you take the lower bound
    adversarial_logit = (1 - mask) * y_hat.upper + mask * y_hat.lower

    return adversarial_logit

def robust_training(train_loader, net, eps, num_epoch):
    criterion = torch.nn.CrossEntropyLoss().to(device)
    k = 1.0
    for _ in range(num_epoch):
        running_loss = 0.0
        running_cross_entropy = 0.0
        with tqdm(train_loader, unit="batch") as tepoch:
            running_correct = 0
            num_samples = 0
            learning_rate = 5e-4

            opt = optim.Adam(params=net.parameters(), lr=learning_rate)

            # ce_loss = torch.nn.CrossEntropyLoss()
            for i, (x, target) in enumerate(tepoch):
                x, target = x.to(device), target.to(device)

                opt.zero_grad()

                y_hat = net(x)
                cross_entropy = criterion(y_hat, target)


                input_bounds = HyperRectangle.from_eps(x, eps)
                ibp_bounds = net.ibp(input_bounds)

                logit = adversarial_logit(ibp_bounds, target)

                loss = k * cross_entropy + (1 - k) * criterion(logit, target)

                # lower, upper = ibp_bounds.lower, ibp_bounds.upper

                # # expand the classes with the same size of bounds [batch_size, 10]
                # classes = torch.arange(10, device=device).unsqueeze(0).expand(len(target), -1)
                # # get mask with 1 of the correct class and 0 rest of them
                # mask = (classes == target.unsqueeze(-1)).to(dtype=lower.dtype)

                # upper_sel = (upper * (1-mask)).max(dim=-1)[0] # get the higest upper bounds for wrong classes
                # lower_sel = lower[classes == target.unsqueeze(-1)] 

                # loss = (upper_sel - lower_sel).mean()

                loss.backward()
                opt.step()

                # running_correct += (lower_sel > upper_sel).cpu().float().sum().item()
                # num_samples += len(x)
                running_loss += loss.item()
                running_cross_entropy += cross_entropy.item()

                tepoch.set_postfix({'loss_all':running_loss / (i+1), 'loss_clean': running_cross_entropy / (i+1)})
        
        k = max(k - 0.1, 0.5)
        # ibp_analysis(train_loader, net, eps)


    # print(f"Robustness: {running_correct / num_samples} with eps {eps}")

In [44]:
def test_clean(net, test_loader):
    # print('')
    # print('[TEST]')

    correct = 0
    num_samples = 0
    for i, (X, y) in enumerate(test_loader):
        X, y = X.to(device), y.to(device)

        y_hat = net(X)

        predicted = torch.argmax(y_hat, 1)
        correct += (predicted == y).sum().item()

        num_samples += len(y)

    print(f'Accuracy: {correct / num_samples :.3f}')

In [45]:
eps =np.linspace(0.01, 0.1, 10)
# net.load_state_dict(torch.load('weights_clean.pt'))
for ep in eps:
    net = Network().to(device)

    factory = BoundModelFactory()
    net = factory.build(net)

    robust_training(train_loader, net, ep, 20)

    # testing on robustness
    print('robustness:')
    ibp_analysis(test_loader, net, ep)

    # testing on accuracy
    print('accuracy:')
    test_clean(net, test_loader)

100%|██████████| 938/938 [00:21<00:00, 43.22batch/s, loss_all=0.539, loss_clean=0.539]
100%|██████████| 938/938 [00:21<00:00, 42.83batch/s, loss_all=0.555, loss_clean=0.316]
100%|██████████| 938/938 [00:22<00:00, 42.42batch/s, loss_all=0.338, loss_clean=0.225]
100%|██████████| 938/938 [00:22<00:00, 42.48batch/s, loss_all=0.295, loss_clean=0.186]
100%|██████████| 938/938 [00:22<00:00, 42.53batch/s, loss_all=0.269, loss_clean=0.161]
100%|██████████| 938/938 [00:22<00:00, 42.42batch/s, loss_all=0.253, loss_clean=0.146]
100%|██████████| 938/938 [00:22<00:00, 42.33batch/s, loss_all=0.227, loss_clean=0.134]
100%|██████████| 938/938 [00:22<00:00, 42.40batch/s, loss_all=0.208, loss_clean=0.123]
100%|██████████| 938/938 [00:22<00:00, 42.38batch/s, loss_all=0.194, loss_clean=0.116]
100%|██████████| 938/938 [00:22<00:00, 42.17batch/s, loss_all=0.183, loss_clean=0.109]
100%|██████████| 938/938 [00:22<00:00, 42.21batch/s, loss_all=0.174, loss_clean=0.104] 
100%|██████████| 938/938 [00:22<00:00, 42.

robustness:


100%|██████████| 157/157 [00:02<00:00, 78.31batch/s]


Robustness: 0.9223 with eps 0.01
accuracy:
Accuracy: 0.972


100%|██████████| 938/938 [00:20<00:00, 45.27batch/s, loss_all=0.543, loss_clean=0.543]
100%|██████████| 938/938 [00:20<00:00, 45.08batch/s, loss_all=0.831, loss_clean=0.432]
100%|██████████| 938/938 [00:20<00:00, 45.05batch/s, loss_all=0.442, loss_clean=0.276]
100%|██████████| 938/938 [00:20<00:00, 45.02batch/s, loss_all=0.375, loss_clean=0.218]
100%|██████████| 938/938 [00:20<00:00, 44.97batch/s, loss_all=0.347, loss_clean=0.189]
100%|██████████| 938/938 [00:20<00:00, 45.02batch/s, loss_all=0.333, loss_clean=0.171]
100%|██████████| 938/938 [00:20<00:00, 44.99batch/s, loss_all=0.298, loss_clean=0.155]
100%|██████████| 938/938 [00:21<00:00, 44.26batch/s, loss_all=0.275, loss_clean=0.144]
100%|██████████| 938/938 [00:21<00:00, 43.95batch/s, loss_all=0.259, loss_clean=0.136]
100%|██████████| 938/938 [00:21<00:00, 44.25batch/s, loss_all=0.245, loss_clean=0.129]
100%|██████████| 938/938 [00:20<00:00, 44.79batch/s, loss_all=0.234, loss_clean=0.123]
100%|██████████| 938/938 [00:20<00:00, 44.7

robustness:


100%|██████████| 157/157 [00:01<00:00, 79.97batch/s]


Robustness: 0.9097 with eps 0.020000000000000004
accuracy:
Accuracy: 0.970


100%|██████████| 938/938 [00:20<00:00, 45.08batch/s, loss_all=0.556, loss_clean=0.556]
100%|██████████| 938/938 [00:20<00:00, 44.95batch/s, loss_all=1.04, loss_clean=0.498]
100%|██████████| 938/938 [00:20<00:00, 44.94batch/s, loss_all=0.504, loss_clean=0.299]
100%|██████████| 938/938 [00:20<00:00, 45.04batch/s, loss_all=0.442, loss_clean=0.244]
100%|██████████| 938/938 [00:20<00:00, 44.88batch/s, loss_all=0.423, loss_clean=0.218]
100%|██████████| 938/938 [00:20<00:00, 44.96batch/s, loss_all=0.417, loss_clean=0.202]
100%|██████████| 938/938 [00:20<00:00, 44.73batch/s, loss_all=0.38, loss_clean=0.188] 
100%|██████████| 938/938 [00:20<00:00, 44.98batch/s, loss_all=0.353, loss_clean=0.175]
100%|██████████| 938/938 [00:20<00:00, 44.80batch/s, loss_all=0.333, loss_clean=0.167]
100%|██████████| 938/938 [00:20<00:00, 44.87batch/s, loss_all=0.317, loss_clean=0.159]
100%|██████████| 938/938 [00:20<00:00, 44.99batch/s, loss_all=0.305, loss_clean=0.153]
100%|██████████| 938/938 [00:20<00:00, 44.72

robustness:


100%|██████████| 157/157 [00:02<00:00, 78.22batch/s]


Robustness: 0.8732 with eps 0.030000000000000006
accuracy:
Accuracy: 0.964


100%|██████████| 938/938 [00:20<00:00, 45.05batch/s, loss_all=0.545, loss_clean=0.545]
100%|██████████| 938/938 [00:21<00:00, 43.36batch/s, loss_all=1.26, loss_clean=0.597]
100%|██████████| 938/938 [00:22<00:00, 42.48batch/s, loss_all=0.604, loss_clean=0.347]
100%|██████████| 938/938 [00:22<00:00, 42.60batch/s, loss_all=0.54, loss_clean=0.284] 
100%|██████████| 938/938 [21:17<00:00,  1.36s/batch, loss_all=0.528, loss_clean=0.259]  
100%|██████████| 938/938 [16:36<00:00,  1.06s/batch, loss_all=0.527, loss_clean=0.245]  
100%|██████████| 938/938 [00:41<00:00, 22.72batch/s, loss_all=0.485, loss_clean=0.228]
100%|██████████| 938/938 [00:49<00:00, 19.02batch/s, loss_all=0.455, loss_clean=0.215]
100%|██████████| 938/938 [00:22<00:00, 42.02batch/s, loss_all=0.431, loss_clean=0.204]
100%|██████████| 938/938 [00:22<00:00, 42.01batch/s, loss_all=0.411, loss_clean=0.195]
100%|██████████| 938/938 [00:22<00:00, 42.43batch/s, loss_all=0.396, loss_clean=0.189]
100%|██████████| 938/938 [00:22<00:00, 4

robustness:


100%|██████████| 157/157 [00:02<00:00, 77.89batch/s]


Robustness: 0.8617 with eps 0.04000000000000001
accuracy:
Accuracy: 0.956


100%|██████████| 938/938 [00:21<00:00, 44.48batch/s, loss_all=0.537, loss_clean=0.537]
100%|██████████| 938/938 [00:21<00:00, 44.46batch/s, loss_all=1.4, loss_clean=0.65]  
100%|██████████| 938/938 [00:21<00:00, 44.30batch/s, loss_all=0.693, loss_clean=0.386]
100%|██████████| 938/938 [00:21<00:00, 44.35batch/s, loss_all=0.629, loss_clean=0.324]
100%|██████████| 938/938 [00:21<00:00, 44.40batch/s, loss_all=0.62, loss_clean=0.295] 
100%|██████████| 938/938 [00:21<00:00, 44.52batch/s, loss_all=0.631, loss_clean=0.28] 
100%|██████████| 938/938 [00:21<00:00, 44.61batch/s, loss_all=0.591, loss_clean=0.267]
100%|██████████| 938/938 [00:21<00:00, 44.38batch/s, loss_all=0.564, loss_clean=0.256]
100%|██████████| 938/938 [00:22<00:00, 41.55batch/s, loss_all=0.545, loss_clean=0.248]
100%|██████████| 938/938 [00:21<00:00, 44.31batch/s, loss_all=0.529, loss_clean=0.242]
100%|██████████| 938/938 [00:21<00:00, 44.47batch/s, loss_all=0.515, loss_clean=0.237]
100%|██████████| 938/938 [00:21<00:00, 43.35

robustness:


100%|██████████| 157/157 [00:02<00:00, 64.72batch/s]


Robustness: 0.8135 with eps 0.05000000000000001
accuracy:
Accuracy: 0.942


100%|██████████| 938/938 [00:22<00:00, 42.06batch/s, loss_all=0.558, loss_clean=0.558]
100%|██████████| 938/938 [00:22<00:00, 42.11batch/s, loss_all=1.62, loss_clean=0.729]
100%|██████████| 938/938 [00:22<00:00, 42.32batch/s, loss_all=0.808, loss_clean=0.458]
100%|██████████| 938/938 [00:21<00:00, 43.84batch/s, loss_all=0.74, loss_clean=0.384] 
100%|██████████| 938/938 [00:21<00:00, 44.15batch/s, loss_all=0.733, loss_clean=0.352]
100%|██████████| 938/938 [00:21<00:00, 43.91batch/s, loss_all=0.739, loss_clean=0.334]
100%|██████████| 938/938 [00:21<00:00, 43.88batch/s, loss_all=0.682, loss_clean=0.314]
100%|██████████| 938/938 [00:21<00:00, 44.02batch/s, loss_all=0.646, loss_clean=0.301]
100%|██████████| 938/938 [00:21<00:00, 43.26batch/s, loss_all=0.619, loss_clean=0.29] 
100%|██████████| 938/938 [00:21<00:00, 44.17batch/s, loss_all=0.598, loss_clean=0.283]
100%|██████████| 938/938 [00:21<00:00, 44.28batch/s, loss_all=0.581, loss_clean=0.276]
100%|██████████| 938/938 [00:21<00:00, 44.25

robustness:


100%|██████████| 157/157 [00:02<00:00, 77.74batch/s]


Robustness: 0.7981 with eps 0.06000000000000001
accuracy:
Accuracy: 0.935


100%|██████████| 938/938 [00:21<00:00, 44.42batch/s, loss_all=0.567, loss_clean=0.567]
100%|██████████| 938/938 [00:21<00:00, 44.07batch/s, loss_all=1.78, loss_clean=0.847]
100%|██████████| 938/938 [00:21<00:00, 44.35batch/s, loss_all=0.901, loss_clean=0.515]
100%|██████████| 938/938 [00:21<00:00, 44.02batch/s, loss_all=0.819, loss_clean=0.422]
100%|██████████| 938/938 [00:21<00:00, 44.12batch/s, loss_all=0.817, loss_clean=0.393]
100%|██████████| 938/938 [00:21<00:00, 44.30batch/s, loss_all=0.835, loss_clean=0.382]
100%|██████████| 938/938 [00:21<00:00, 43.16batch/s, loss_all=0.78, loss_clean=0.365] 
100%|██████████| 938/938 [00:21<00:00, 44.25batch/s, loss_all=0.743, loss_clean=0.353]
100%|██████████| 938/938 [00:22<00:00, 41.30batch/s, loss_all=0.717, loss_clean=0.344]
100%|██████████| 938/938 [00:21<00:00, 43.75batch/s, loss_all=0.697, loss_clean=0.336]
100%|██████████| 938/938 [00:21<00:00, 43.60batch/s, loss_all=0.681, loss_clean=0.331]
100%|██████████| 938/938 [00:21<00:00, 43.72

robustness:


100%|██████████| 157/157 [00:02<00:00, 77.42batch/s]


Robustness: 0.7297 with eps 0.07
accuracy:
Accuracy: 0.914


100%|██████████| 938/938 [00:21<00:00, 44.41batch/s, loss_all=0.549, loss_clean=0.549]
100%|██████████| 938/938 [00:21<00:00, 44.28batch/s, loss_all=1.98, loss_clean=0.98] 
100%|██████████| 938/938 [00:21<00:00, 44.11batch/s, loss_all=1.04, loss_clean=0.633]
100%|██████████| 938/938 [00:21<00:00, 43.57batch/s, loss_all=0.941, loss_clean=0.522]
100%|██████████| 938/938 [00:21<00:00, 44.18batch/s, loss_all=0.933, loss_clean=0.479]
100%|██████████| 938/938 [00:21<00:00, 42.69batch/s, loss_all=0.953, loss_clean=0.461]
100%|██████████| 938/938 [00:21<00:00, 43.40batch/s, loss_all=0.899, loss_clean=0.439]
100%|██████████| 938/938 [00:22<00:00, 41.86batch/s, loss_all=0.861, loss_clean=0.42] 
100%|██████████| 938/938 [00:21<00:00, 44.17batch/s, loss_all=0.831, loss_clean=0.404]
100%|██████████| 938/938 [00:22<00:00, 41.79batch/s, loss_all=0.805, loss_clean=0.392]
100%|██████████| 938/938 [00:22<00:00, 41.70batch/s, loss_all=0.782, loss_clean=0.382]
100%|██████████| 938/938 [00:21<00:00, 42.96b

robustness:


100%|██████████| 157/157 [00:02<00:00, 78.37batch/s]


Robustness: 0.7323 with eps 0.08
accuracy:
Accuracy: 0.912


100%|██████████| 938/938 [00:21<00:00, 44.26batch/s, loss_all=0.55, loss_clean=0.55]  
100%|██████████| 938/938 [00:21<00:00, 44.22batch/s, loss_all=2.11, loss_clean=1.01] 
100%|██████████| 938/938 [00:21<00:00, 43.96batch/s, loss_all=1.1, loss_clean=0.69]  
100%|██████████| 938/938 [00:21<00:00, 44.18batch/s, loss_all=1.05, loss_clean=0.607]
100%|██████████| 938/938 [00:21<00:00, 44.15batch/s, loss_all=1.07, loss_clean=0.568]
100%|██████████| 938/938 [00:21<00:00, 43.05batch/s, loss_all=1.12, loss_clean=0.564]
100%|██████████| 938/938 [00:21<00:00, 43.20batch/s, loss_all=1.06, loss_clean=0.546]
100%|██████████| 938/938 [00:21<00:00, 44.08batch/s, loss_all=1.01, loss_clean=0.527]
100%|██████████| 938/938 [00:21<00:00, 44.02batch/s, loss_all=0.967, loss_clean=0.505]
100%|██████████| 938/938 [00:21<00:00, 44.09batch/s, loss_all=0.926, loss_clean=0.478]
100%|██████████| 938/938 [00:21<00:00, 43.97batch/s, loss_all=0.891, loss_clean=0.459]
100%|██████████| 938/938 [00:21<00:00, 44.14batch/

robustness:


100%|██████████| 157/157 [00:02<00:00, 77.69batch/s]


Robustness: 0.6855 with eps 0.09000000000000001
accuracy:
Accuracy: 0.886


100%|██████████| 938/938 [00:21<00:00, 42.71batch/s, loss_all=0.557, loss_clean=0.557]
100%|██████████| 938/938 [00:21<00:00, 42.80batch/s, loss_all=2.27, loss_clean=1.04] 
100%|██████████| 938/938 [00:21<00:00, 43.99batch/s, loss_all=1.15, loss_clean=0.689]
100%|██████████| 938/938 [00:21<00:00, 44.21batch/s, loss_all=1.08, loss_clean=0.581]
100%|██████████| 938/938 [00:21<00:00, 44.11batch/s, loss_all=1.1, loss_clean=0.545] 
100%|██████████| 938/938 [00:21<00:00, 44.15batch/s, loss_all=1.15, loss_clean=0.547]
100%|██████████| 938/938 [00:21<00:00, 43.85batch/s, loss_all=1.1, loss_clean=0.532] 
100%|██████████| 938/938 [00:21<00:00, 44.07batch/s, loss_all=1.06, loss_clean=0.521]
100%|██████████| 938/938 [00:21<00:00, 44.19batch/s, loss_all=1.03, loss_clean=0.509]
100%|██████████| 938/938 [00:21<00:00, 44.21batch/s, loss_all=1.01, loss_clean=0.499]
100%|██████████| 938/938 [00:21<00:00, 44.11batch/s, loss_all=0.983, loss_clean=0.491]
100%|██████████| 938/938 [00:21<00:00, 44.11batch/s,

robustness:


100%|██████████| 157/157 [00:02<00:00, 77.69batch/s]


Robustness: 0.6775 with eps 0.1
accuracy:
Accuracy: 0.874
