In [8]:
import torch
from cnnf.model_cifar import WideResNet
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from utils import *
from tqdm import tqdm
from train import train_adv, test, test_pgd

In [9]:
#args
class Arg:
    pass
#opt params
args = Arg()
args.batch_size = 256
args.test_batch_size = 128
args.epochs = 500
args.lr = 0.05
args.power = 0.9
args.momentum = 0.9
args.wd = 5e-4
args.grad_clip = True
args.dataset = 'cifar10'
args.schedule = 'poly'
args.no_cuda = False
args.seed = 0
args.log_interval = 400

#adver training params
args.eps = 0.063
args.eps_iter = 0.02
args.nb_iter = 7
args.clean = 'supclean'

#hyper params
args.mse_parameter = 0.1
args.clean_parameter = 1.0
args.res_parameter = 0.1

#model params
args.layers = 40
args.widen_factor = 2
args.droprate = 0.0
args.ind = 5
args.max_cycles = 2
args.save_model = 'CNNF_superes_cifar'
args.model_dir = 'models'

In [10]:
#params cuda
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
seed_torch(args.seed)

In [11]:
#transforms
train_transform_cifar = transforms.Compose(
  [transforms.RandomHorizontalFlip(),
   transforms.RandomCrop(32, padding=4),
   transforms.ToTensor(),
   transforms.Normalize([0.5] * 3, [0.5] * 3)])

test_transform_cifar = transforms.Compose(
  [transforms.ToTensor(),
   transforms.Normalize([0.5] * 3, [0.5] * 3)])

In [12]:
#model and dataset
train_data = datasets.CIFAR10(
    'data', train=True, transform=train_transform_cifar, download=True)
test_data = datasets.CIFAR10(
    'data', train=False, transform=test_transform_cifar, download=True)
train_loader = torch.utils.data.DataLoader(
  train_data, batch_size=args.batch_size,
  shuffle=True, num_workers=4, pin_memory=True)
test_loader = torch.utils.data.DataLoader(
  test_data, batch_size=args.test_batch_size,
  shuffle=True, num_workers=4, pin_memory=True)
num_classes = 10
model = WideResNet(args.layers, 10, args.widen_factor, args.droprate, args.ind, args.max_cycles, args.res_parameter).to(device)

Files already downloaded and verified
Files already downloaded and verified


In [13]:
#optimizer and scheduler
optimizer = torch.optim.SGD(
      model.parameters(),
      args.lr,
      momentum=args.momentum,
      weight_decay=args.wd)

if(args.schedule == 'cos'):
    scheduler = torch.optim.lr_scheduler.LambdaLR(
      optimizer, lr_lambda=lambda step: get_lr(step, args.epochs * len(train_loader), 1.0, 1e-5))
else:
    scheduler = torch.optim.lr_scheduler.LambdaLR(
      optimizer, lr_lambda=lambda step: lr_poly(1.0, step, args.epochs * len(train_loader), args.power))

In [None]:
# Begin training
best_acc = 0

for epoch in tqdm(range(args.epochs)):
    train_loss, train_acc = train_adv(args, model, device, train_loader, optimizer, scheduler, epoch,
      cycles=args.max_cycles, mse_parameter=args.mse_parameter, clean_parameter=args.clean_parameter, clean=args.clean)

    test_loss, test_acc = test(args, model, device, test_loader, cycles=args.max_cycles, epoch=epoch)

    # print(f"Epoch {epoch}:")
    # print('loss', 'train:', train_loss)
    # print('acc', 'train:', train_acc)
    # print('loss', 'test:', test_loss)
    # print('acc', 'test:', test_acc)

    # Save the model with the best accuracy
    if test_acc > best_acc and args.save_model is not None:
        best_acc = test_acc
        experiment_fn = args.save_model
        torch.save(model.state_dict(),
                   args.model_dir + "/{}-best.pt".format(experiment_fn))

    if ((epoch+1)%50)==0 and args.save_model is not None:
        experiment_fn = args.save_model
        torch.save(model.state_dict(),
                   args.model_dir + "/{}-epoch{}.pt".format(experiment_fn,epoch))
        pgd_acc = test_pgd(args, model, device, test_loader, epsilon=args.eps)

        print('pgd_acc', 'test:', pgd_acc)

# Save final model
if args.save_model is not None:
    experiment_fn = args.save_model
    torch.save(model.state_dict(),
               args.model_dir + "/{}.pt".format(experiment_fn))

  0%|          | 0/500 [00:00<?, ?it/s]



  0%|          | 1/500 [03:30<29:14:38, 210.98s/it]


Test set: Average loss: 2.3038, Accuracy: 1000/10000 (10%)



  0%|          | 2/500 [07:01<29:08:50, 210.70s/it]


Test set: Average loss: 2.3028, Accuracy: 1000/10000 (10%)



  1%|          | 3/500 [10:32<29:05:14, 210.69s/it]


Test set: Average loss: 2.3027, Accuracy: 1000/10000 (10%)



  1%|          | 4/500 [14:01<28:58:50, 210.34s/it]


Test set: Average loss: 2.3027, Accuracy: 1000/10000 (10%)



  1%|          | 5/500 [17:31<28:54:07, 210.20s/it]


Test set: Average loss: 2.3027, Accuracy: 1000/10000 (10%)



  1%|          | 6/500 [21:01<28:49:44, 210.09s/it]


Test set: Average loss: 2.3027, Accuracy: 1000/10000 (10%)



  1%|▏         | 7/500 [24:31<28:44:49, 209.92s/it]


Test set: Average loss: 2.3027, Accuracy: 1000/10000 (10%)



  2%|▏         | 8/500 [28:00<28:40:15, 209.79s/it]


Test set: Average loss: 2.3027, Accuracy: 1000/10000 (10%)



  2%|▏         | 9/500 [31:30<28:36:37, 209.77s/it]


Test set: Average loss: 2.3027, Accuracy: 1000/10000 (10%)



  2%|▏         | 10/500 [35:00<28:32:48, 209.73s/it]


Test set: Average loss: 2.3027, Accuracy: 1000/10000 (10%)



  2%|▏         | 11/500 [38:30<28:29:22, 209.74s/it]


Test set: Average loss: 2.3027, Accuracy: 1000/10000 (10%)



  2%|▏         | 12/500 [41:59<28:26:06, 209.77s/it]


Test set: Average loss: 2.3027, Accuracy: 1000/10000 (10%)



  3%|▎         | 13/500 [45:29<28:21:49, 209.67s/it]


Test set: Average loss: 2.3027, Accuracy: 1000/10000 (10%)



  3%|▎         | 14/500 [48:58<28:18:02, 209.63s/it]


Test set: Average loss: 2.3027, Accuracy: 1000/10000 (10%)



  3%|▎         | 15/500 [52:28<28:14:33, 209.64s/it]


Test set: Average loss: 2.3027, Accuracy: 1000/10000 (10%)



  3%|▎         | 16/500 [55:58<28:10:52, 209.61s/it]


Test set: Average loss: 2.3027, Accuracy: 1000/10000 (10%)



  3%|▎         | 17/500 [59:27<28:07:14, 209.60s/it]


Test set: Average loss: 2.3028, Accuracy: 1000/10000 (10%)



  4%|▎         | 18/500 [1:02:57<28:04:08, 209.64s/it]


Test set: Average loss: 2.3029, Accuracy: 1000/10000 (10%)



  4%|▍         | 19/500 [1:06:26<28:00:25, 209.62s/it]


Test set: Average loss: 2.3030, Accuracy: 1000/10000 (10%)



  4%|▍         | 20/500 [1:09:56<27:57:30, 209.69s/it]


Test set: Average loss: 2.3032, Accuracy: 1000/10000 (10%)

