In [1]:
import numpy as np
import os
import argparse
import time
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.transforms as trn
import torchvision.datasets as dset
import torch.nn.functional as F
from tqdm.notebook import tqdm
from utils.wrn import WideResNet
import utils.attacks as attacks
from utils.detector import Detector, gram_margin_loss
from tqdm.notebook import tqdm

## Args

In [2]:
parser = argparse.ArgumentParser(description='Trains a CIFAR Classifier',
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--dataset', '-d', type=str, default='cifar10', choices=['cifar10', 'cifar100'],
                    help='Choose between CIFAR-10, CIFAR-100.')
parser.add_argument('--model', '-m', type=str, default='wrn',
                    choices=['allconv', 'wrn'], help='Choose architecture.')
# Optimization options
parser.add_argument('--epochs', '-e', type=int, default=100, help='Number of epochs to train.')
parser.add_argument('--learning_rate', '-lr', type=float, default=0.1, help='The initial learning rate.')
parser.add_argument('--batch_size', '-b', type=int, default=128, help='Batch size.')
parser.add_argument('--test_bs', type=int, default=256)
parser.add_argument('--momentum', type=float, default=0.9, help='Momentum.')
parser.add_argument('--decay', type=float, default=0.0005, help='Weight decay (L2 penalty).')
# WRN Architecture
parser.add_argument('--layers', default=40, type=int, help='total number of layers')
parser.add_argument('--widen-factor', default=2, type=int, help='widen factor')
parser.add_argument('--droprate', default=0.0, type=float, help='dropout probability')
# Checkpoints
parser.add_argument('--save', '-s', type=str, default='./snapshots/baseline', help='Folder to save checkpoints.')
parser.add_argument('--load', '-l', type=str, default='', help='Checkpoint path to resume / test.')
parser.add_argument('--test', '-t', action='store_true', help='Test only flag.')
# Acceleration
parser.add_argument('--ngpu', type=int, default=1, help='0 = CPU.')
parser.add_argument('--gpu', type=int, default=1, help='0 = CPU.')
parser.add_argument('--prefetch', type=int, default=2, help='Pre-fetching threads.')
args = parser.parse_args(["--save", "checkpoints/", "--gpu", "2", "--test_bs", "128"])

## Initialization

In [3]:
state = {k: v for k, v in args._get_kwargs()}
print(state)

torch.manual_seed(1)
np.random.seed(1)

# # mean and standard deviation of channels of CIFAR-10 images
# mean = [x / 255 for x in [125.3, 123.0, 113.9]]
# std = [x / 255 for x in [63.0, 62.1, 66.7]]

train_transform = trn.Compose([trn.RandomHorizontalFlip(), trn.RandomCrop(32, padding=4),
                               trn.ToTensor()])
test_transform = trn.Compose([trn.ToTensor()])

train_data = dset.CIFAR10('~/datasets/cifarpy', train=True, transform=train_transform, download=True)
test_data = dset.CIFAR10('~/datasets/cifarpy', train=False, transform=test_transform, download=True)
num_classes = 10

train_loader = torch.utils.data.DataLoader(
    train_data, batch_size=args.batch_size, shuffle=True,
    num_workers=args.prefetch, pin_memory=True)
test_loader = torch.utils.data.DataLoader(
    test_data, batch_size=args.test_bs, shuffle=False,
    num_workers=args.prefetch, pin_memory=True)



normalize = trn.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
detector_data_transform = trn.Compose([trn.ToTensor(), normalize])

data_train = list(torch.utils.data.DataLoader(
        dset.CIFAR10('~/datasets/cifarpy', 
                     train=True, 
                     transform=detector_data_transform, 
                     download=True),
        batch_size=1, shuffle=False))

data_test = list(torch.utils.data.DataLoader(
        dset.CIFAR10('~/datasets/cifarpy', 
                     train=False, 
                     transform=detector_data_transform, 
                     download=True),
        batch_size=1, shuffle=False))


{'droprate': 0.0, 'decay': 0.0005, 'save': 'checkpoints/', 'batch_size': 128, 'dataset': 'cifar10', 'widen_factor': 2, 'gpu': 2, 'prefetch': 2, 'test_bs': 128, 'model': 'wrn', 'test': False, 'ngpu': 1, 'momentum': 0.9, 'epochs': 100, 'learning_rate': 0.1, 'layers': 40, 'load': ''}
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [4]:
# Create model
if args.model == 'allconv':
    net = AllConvNet(num_classes)
else:
    net = WideResNet(args.layers, num_classes, args.widen_factor, dropRate=args.droprate)
    net.load_state_dict(torch.load("benchmark_ckpts/cifar10_reg_training_99.pt"))
    start_epoch = 80

start_epoch = 0

# Restore model if desired
if args.load != '':
    for i in range(1000 - 1, -1, -1):
        model_name = os.path.join(args.load, args.dataset + '_' + args.model +
                                  '_baseline_epoch_' + str(i) + '.pt')
        if os.path.isfile(model_name):
            net.load_state_dict(torch.load(model_name))
            print('Model restored! Epoch:', i)
            start_epoch = i + 1
            break

    if start_epoch == 0:
        assert False, "could not resume"

if args.ngpu > 1:
    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

if args.ngpu > 0:
    torch.cuda.set_device(args.gpu)
    net.cuda()
    torch.cuda.manual_seed(1)

# cudnn.benchmark = True  # fire on all cylinders

optimizer = torch.optim.SGD(
    net.parameters(), state['learning_rate'], momentum=state['momentum'],
    weight_decay=state['decay'], nesterov=True)

def cosine_annealing(step, total_steps, lr_max, lr_min):
    return lr_min + (lr_max - lr_min) * 0.5 * (
            1 + np.cos(step / total_steps * np.pi))


scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer,
    lr_lambda=lambda step: cosine_annealing(
        step,
        args.epochs * len(train_loader),
        1,  # since lr_lambda computes multiplicative factor
        1e-6 / args.learning_rate))


adversary = attacks.PGD(epsilon=8./255, num_steps=10, step_size=2./255).cuda()

## Training

In [5]:
def train():
    net.train()  # enter train mode
    loss_avg, loss_gram_avg = 0.0, 0.0
    i = 0
    nan_counter = 0
    
#     mrgn_pwr = 1.6 if "margin_log10" not in state else state["margin_log10"]
#     if "increase_margin" in state and state["increase_margin"]:
#         mrgn_pwr += 0.1
#         print("Increased Margin Power to:", mrgn_pwr)

    mrgn_pwr = 2
    
    for bx, by in tqdm(train_loader):
        bx, by = bx.cuda(), by.cuda()

        adv_bx = adversary(net, bx, by)
                        
        # forward
        logits_reg, feats_reg = net.gram_forward(bx * 2 - 1)
        logits_adv, feats_adv = net.gram_forward(adv_bx * 2 - 1)

        # backward
        optimizer.zero_grad()

        loss_reg = F.cross_entropy(logits_reg, by)
#         loss_adv = F.cross_entropy(logits_adv, by)
        
        loss_gram = gram_margin_loss(feats_reg, feats_adv, margin=50).cuda()
        loss = 7/8 * loss_reg + 1/100 * loss_gram
                
        loss.backward()
        
        optimizer.step()
        scheduler.step()
        
        i += 1
        
        # exponential moving average
        loss_avg = loss_avg * 0.8 + float(loss) * 0.2
        loss_gram_avg = loss_gram_avg * 0.8 + float(loss_gram) * 0.2
    
    state["increase_margin"] = (loss_gram_avg < 0.1)
    state["margin_log10"] = mrgn_pwr
    state['train_loss'] = loss_avg
    state["gram_train_loss"] = loss_gram_avg
    
    print("Train Loss:", state["train_loss"])
    print("Margin Log10:", state["margin_log10"])
    print("Train Gram: ", state["gram_train_loss"])

# test function
def test(detector = None):
    net.eval()
    loss_avg, loss_reg_avg, loss_adv_avg, loss_gram_avg, auroc_avg = 0.0, 0.0, 0.0, 0.0, 0.0
    loss, loss_reg, loss_adv, loss_gram = 0.0, 0.0, 0.0, 0.0
    correct = 0
    with torch.no_grad():
        for data, target in tqdm(test_loader):
            data, target = data.cuda(), target.cuda()
            
            adv_data = adversary(net, data, target)

            # forward
            output = net(data * 2 - 1)
#             adv_output = net(adv_data * 2 - 1)
            
#             loss_reg = F.cross_entropy(output, target)
#             loss_adv = F.cross_entropy(adv_output, target)
            
#             loss = loss_reg
            
            auroc = detector.compute_ood_deviations_batch(adv_data * 2 -1)
            auroc = auroc["AUROC"]
            
#             loss_gram = gram_margin_loss(feats_reg, feats_adv, margin=10e4)
            
            # accuracy
            pred = output.data.max(1)[1]
            correct += pred.eq(target.data).sum().item()

            # test loss average
#             loss_reg_avg += float(loss_reg)
#             loss_adv_avg += float(loss_adv)
#             loss_avg += float(loss)
#             loss_gram_avg += float(loss_gram)
            auroc_avg += auroc
            
            
    state['test_loss_reg'] = loss_reg_avg / len(test_loader)
    state['test_loss_adv'] = loss_adv_avg / len(test_loader)
    state['test_loss'] = loss_avg / len(test_loader)
    state['test_accuracy'] = correct / len(test_loader.dataset)
    state['gram_loss'] = loss_gram_avg / len(test_loader)
    state["gram_auroc"] = auroc_avg / len(test_loader)

In [None]:
# Make save directory
if not os.path.exists(args.save):
    os.makedirs(args.save)
if not os.path.isdir(args.save):
    raise Exception('%s is not a dir' % args.save)

with open(os.path.join(args.save, args.dataset + '_' + args.model +
                                  '_baseline_training_results.csv'), 'w') as f:
    f.write('epoch,time(s),train_loss,test_loss,test_error(%),gram_auroc\n')

print('Beginning Training!\n')

# Main loop
for epoch in range(start_epoch, args.epochs):
    state['epoch'] = epoch

    begin_epoch = time.time()
    
    print("1. Training")
    train()
    print("2. Initializing Detector")
    
    if epoch % 3 == 0:
        net.eval()
        detector = Detector(net, data_train, data_test, args.test_bs, pbar=None)
        print("3. Testing")
        try:
            test(detector)
        except Exception as e:
            print("Failed test")
            print(e)

        # Save model
        torch.save(net.state_dict(),
                   os.path.join(args.save, args.dataset + '_' + args.model +
                                '_baseline_epoch_' + str(epoch) + '.pt'))
        # Let us not waste space and delete the previous model
        prev_path = os.path.join(args.save, args.dataset + '_' + args.model +
                                 '_baseline_epoch_' + str(epoch - 1) + '.pt')
        if os.path.exists(prev_path): os.remove(prev_path)

        # Show results

        with open(os.path.join(args.save, args.dataset + '_' + args.model +
                                          '_baseline_training_results.csv'), 'a') as f:
            f.write('%03d,%05d,%0.6f,%0.5f,%0.2f,%0.2f\n' % (
                (epoch + 1),
                time.time() - begin_epoch,
                state['train_loss'],
                state['test_loss'],
                100 - 100. * state['test_accuracy'],
                state["gram_auroc"],
            ))

        # # print state with rounded decimals
#         print({k: round(v, 4) if isinstance(v, float) else v for k, v in state.items()})

        print('Epoch {0:3d} | Time {1:5d} | Train Loss {2:.4f} | Test Loss {3:.3f} | Test Error {4:.2f} | Gram Auroc {5:.2f}'.format(
            (epoch + 1),
            int(time.time() - begin_epoch),
            state['train_loss'],
            state['test_loss'],
            100 - 100. * state['test_accuracy'],
            state["gram_auroc"])
        )


Beginning Training!

1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 1.947191717819259
Margin Log10: 2
Train Gram:  2.0428198118910355
2. Initializing Detector
3. Testing


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


Epoch   1 | Time   397 | Train Loss 1.9472 | Test Loss 0.000 | Test Error 86.02 | Gram Auroc 0.23
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 1.9973685502417031
Margin Log10: 2
Train Gram:  1.788934030370331
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 2.0404372489124696
Margin Log10: 2
Train Gram:  7.58515318851087
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 2.100500642223905
Margin Log10: 2
Train Gram:  8.171594458508729
2. Initializing Detector
3. Testing


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


Epoch   4 | Time   386 | Train Loss 2.1005 | Test Loss 0.000 | Test Error 89.45 | Gram Auroc 0.18
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 2.2222308210969115
Margin Log10: 2
Train Gram:  32.63298554672202
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 1.949803385229639
Margin Log10: 2
Train Gram:  15.717395102516914
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 3.4450450457239796
Margin Log10: 2
Train Gram:  146.51592833628516
2. Initializing Detector
3. Testing


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


Epoch   7 | Time   390 | Train Loss 3.4450 | Test Loss 0.000 | Test Error 85.34 | Gram Auroc 0.15
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 2.7135113895606136
Margin Log10: 2
Train Gram:  76.03718452571725
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 2.0516225574631903
Margin Log10: 2
Train Gram:  17.444074760447425
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 2.0309511823819775
Margin Log10: 2
Train Gram:  1.8562851441134716
2. Initializing Detector
3. Testing


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


Epoch  10 | Time   387 | Train Loss 2.0310 | Test Loss 0.000 | Test Error 87.47 | Gram Auroc 0.08
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)




Train Loss: 1.8456369592924087
Margin Log10: 2
Train Gram:  1.6877544245965037
2. Initializing Detector
3. Testing


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


Epoch  19 | Time   394 | Train Loss 1.8456 | Test Loss 0.000 | Test Error 79.02 | Gram Auroc 0.12
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 2.0387623198386082
Margin Log10: 2
Train Gram:  19.54728057455789
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 1.8428542701336617
Margin Log10: 2
Train Gram:  0.8481859021340199
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)




Train Loss: 1.4776912820143437
Margin Log10: 2
Train Gram:  0.12128716857361604
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 1.4015467320258694
Margin Log10: 2
Train Gram:  0.40851813149762645
2. Initializing Detector
3. Testing


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


Epoch  31 | Time   401 | Train Loss 1.4015 | Test Loss 0.000 | Test Error 66.65 | Gram Auroc 0.49
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 1.2455859948557624
Margin Log10: 2
Train Gram:  0.20063388129074405
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 1.1962628866995428
Margin Log10: 2
Train Gram:  0.002035389113255092
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)




Train Loss: 1.3773871281116534
Margin Log10: 2
Train Gram:  0.22694594787391723
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 1.1867556863358244
Margin Log10: 2
Train Gram:  0.019150805002701943
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 1.1991106886371252
Margin Log10: 2
Train Gram:  0.012660574284822618
2. Initializing Detector
3. Testing


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


Epoch  43 | Time   397 | Train Loss 1.1991 | Test Loss 0.000 | Test Error 53.34 | Gram Auroc 0.73
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)




Train Loss: 1.3644825775139615
Margin Log10: 2
Train Gram:  0.26819912960498804
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 1.2884313894548283
Margin Log10: 2
Train Gram:  0.002935735699939712
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 1.2165960177860538
Margin Log10: 2
Train Gram:  0.177653174338657
2. Initializing Detector
3. Testing


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


Epoch  55 | Time   398 | Train Loss 1.2166 | Test Loss 0.000 | Test Error 65.77 | Gram Auroc 0.56
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)




Train Loss: 1.3989866716841912
Margin Log10: 2
Train Gram:  0.0
2. Initializing Detector
3. Testing


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


Epoch  64 | Time   401 | Train Loss 1.3990 | Test Loss 0.000 | Test Error 62.34 | Gram Auroc 0.66
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 1.2733864616114765
Margin Log10: 2
Train Gram:  8.157407302199084e-24
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 1.292757249017559
Margin Log10: 2
Train Gram:  0.051552880328620404
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)




Train Loss: 1.1860623070750762
Margin Log10: 2
Train Gram:  8.545808303202166e-06
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 1.200173621262127
Margin Log10: 2
Train Gram:  0.029457273579347406
2. Initializing Detector
3. Testing


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


Epoch  76 | Time   406 | Train Loss 1.2002 | Test Loss 0.000 | Test Error 52.88 | Gram Auroc 0.55
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 1.1500596701509087
Margin Log10: 2
Train Gram:  6.139614616648818e-06
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)




Train Loss: 1.2360476712139585
Margin Log10: 2
Train Gram:  20.150649922344236
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 0.9693582993501133
Margin Log10: 2
Train Gram:  5.690826892137478e-05
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 0.9274245833607028
Margin Log10: 2
Train Gram:  1.583186007981256e-05
2. Initializing Detector
3. Testing


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


Epoch  88 | Time   402 | Train Loss 0.9274 | Test Loss 0.000 | Test Error 42.43 | Gram Auroc 0.43
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [15]:
test(detector)

HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))




In [16]:
state

{'batch_size': 128,
 'dataset': 'cifar10',
 'decay': 0.0005,
 'droprate': 0.0,
 'epoch': 0,
 'epochs': 100,
 'gpu': 2,
 'gram_auroc': 0.9671555478639238,
 'gram_loss': 0.0,
 'layers': 40,
 'learning_rate': 0.1,
 'load': '',
 'model': 'wrn',
 'momentum': 0.9,
 'ngpu': 1,
 'prefetch': 2,
 'save': 'checkpoints/',
 'test': False,
 'test_accuracy': 0.9145,
 'test_bs': 128,
 'test_loss': 0.0,
 'test_loss_adv': 0.0,
 'test_loss_reg': 0.0,
 'widen_factor': 2}