In [1]:
! python3 -m pip install wandb -q

In [2]:
import os
import argparse
from argparse import Namespace
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from decimal import Decimal
import wandb
import sys

sys.path.append('../../')

from sopa.src.solvers.utils import create_solver
from sopa.src.models.utils import fix_seeds, RunningAverageMeter
from sopa.src.models.odenet_mnist.layers import MetaNODE
from sopa.src.models.odenet_mnist.utils import makedirs, learning_rate_with_decay
from sopa.src.models.odenet_mnist.data import get_mnist_loaders, inf_generator
from MegaAdversarial.src.attacks import (
    Clean,
    PGD,
    FGSM
)

In [3]:
parser = argparse.ArgumentParser()
parser.add_argument('--network', type=str, choices=['resnet', 'odenet'], default='odenet')
parser.add_argument('--downsampling-method', type=str, default='conv', choices=['conv', 'res'])
parser.add_argument('--activation', type=str, choices=['tanh', 'softplus', 'softsign', 'relu'], default='relu')
parser.add_argument('--in_channels', type=int, default=1)

parser.add_argument('--solvers',
                    type=lambda s: [tuple(map(lambda iparam: str(iparam[1]) if iparam[0] <= 1 else (
                        int(iparam[1]) if iparam[0] == 2 else (
                            float(iparam[1]) if iparam[0] == 3 else Decimal(iparam[1]))),
                                              enumerate(item.split(',')))) for item in s.strip().split(';')],
                    default=None,
                    help='Each solver is represented with (method,parameterization,n_steps,step_size,u0,v0) \n' +
                         'If the solver has only one parameter u0, set v0 to -1; \n' +
                         'n_steps and step_size are exclusive parameters, only one of them can be != -1, \n'
                         'If n_steps = step_size = -1, automatic time grid_constructor is used \n;'
                         'For example, --solvers rk4,uv,2,-1,0.3,0.6;rk3,uv,-1,0.1,0.4,0.6;rk2,u,4,-1,0.3,-1')

parser.add_argument('--solver_mode', type=str, choices=['switch', 'ensemble', 'standalone'], default='standalone')
parser.add_argument('--val_solver_modes',
                    type=lambda s: s.strip().split(','),
                    default=['standalone'],
                    help='Solver modes to use for validation step')

parser.add_argument('--switch_probs', type=lambda s: [float(item) for item in s.split(',')], default=None,
                    help="--switch_probs 0.8,0.1,0.1")
parser.add_argument('--ensemble_weights', type=lambda s: [float(item) for item in s.split(',')], default=None,
                    help="ensemble_weights 0.6,0.2,0.2")
parser.add_argument('--ensemble_prob', type=float, default=1.)

parser.add_argument('--noise_type', type=str, choices=['cauchy', 'normal'], default=None)
parser.add_argument('--noise_sigma', type=float, default=0.001)
parser.add_argument('--noise_prob', type=float, default=0.)
parser.add_argument('--minimize_rk2_error', type=eval, default=False, choices=[True, False])

parser.add_argument('--nepochs_nn', type=int, default=50)
parser.add_argument('--nepochs_solver', type=int, default=0)
parser.add_argument('--nstages', type=int, default=1)

parser.add_argument('--data_aug', type=eval, default=True, choices=[True, False])
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--weight_decay', type=float, default=0.0005)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--test_batch_size', type=int, default=1000)
parser.add_argument('--base_lr', type=float, default=1e-5, help='base_lr for CyclicLR scheduler')
parser.add_argument('--max_lr', type=float, default=1e-3, help='max_lr for CyclicLR scheduler')
parser.add_argument('--step_size_up', type=int, default=2000, help='step_size_up for CyclicLR scheduler')
parser.add_argument('--cyclic_lr_mode', type=str, default='triangular2', help='mode for CyclicLR scheduler')
parser.add_argument('--lr_uv', type=float, default=1e-3)
parser.add_argument('--torch_dtype', type=str, default='float32')
parser.add_argument('--wandb_name', type=str, default='find_best_solver')

parser.add_argument('--data_root', type=str, default='./')
parser.add_argument('--save_dir', type=str, default='./')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--gpu', type=int, default=0)

parser.add_argument('--seed', type=int, default=502)
# Noise and adversarial attacks parameters:
parser.add_argument('--data_noise_std', type=float, default=0.,
                    help='Applies Norm(0, std) gaussian noise to each training batch')
parser.add_argument('--eps_adv_training', type=float, default=0.3,
                    help='Epsilon for adversarial training')
parser.add_argument(
    "--adv_training_mode",
    default="clean",
    choices=["clean", "fgsm", "at"],  # , "at_ls", "av", "fs", "nce", "nce_moco", "moco", "av_extra", "meta"],
    help='''Adverarial training method/mode, by default there is no adversarial training (clean).
        For further details see MegaAdversarial/train in this repository.
        '''
)
parser.add_argument('--use_wandb', type=eval, default=True, choices=[True, False])
parser.add_argument('--ss_loss', type=eval, default=False, choices=[True, False])
parser.add_argument('--ss_loss_reg', type=float, default=0.1)
parser.add_argument('--timestamp', type=int, default=int(1e6 * time.time()))

parser.add_argument('--eps_adv_testing', type=float, default=0.3,
                    help='Epsilon for adversarial testing')
parser.add_argument('--adv_testing_mode',
                    default="clean",
                    choices=["clean", "fgsm", "at"],
                    help='''Adversarsarial testing mode''')

args = parser.parse_args(['--solvers', 'rk4,u3,4,-1,0.3,-1', '--seed', '902', '--adv_testing_mode', 'at', 
                          '--max_lr', '0.001', '--base_lr', '1e-05'])

In [4]:
makedirs(args.save_dir)
if args.use_wandb:
    wandb.init(project=args.wandb_name, anonymous="allow", entity="sopa_node")
    wandb.config.update(args)
    wandb.config.update({'u': float(args.solvers[0][-2])})
    makedirs(wandb.config.save_dir)
    makedirs(os.path.join(wandb.config.save_dir,  wandb.run.path))

[34m[1mwandb[0m: Currently logged in as: [33mtalgat[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.19 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


### Load the model

In [5]:
# Load a checkpoint
checkpoint_name = './checkpoints/checkpoint_15444.pth'
device = f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu'
model = torch.load(checkpoint_name, map_location=device)

In [6]:
model

MetaNODE(
  (downsampling_layers): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): GroupNorm(32, 64, eps=1e-05, affine=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): GroupNorm(32, 64, eps=1e-05, affine=True)
    (5): ReLU(inplace=True)
    (6): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  )
  (fc_layers): Sequential(
    (0): GroupNorm(32, 64, eps=1e-05, affine=True)
    (1): ReLU(inplace=True)
    (2): AdaptiveAvgPool2d(output_size=(1, 1))
    (3): Flatten()
    (4): Linear(in_features=64, out_features=10, bias=True)
  )
  (blocks): ModuleList(
    (0): MetaODEBlock(
      (rhs_func): ODEfunc(
        (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
        (relu): ReLU(inplace=True)
        (conv1): ConcatConv2d(
          (_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (norm2): GroupNorm(32, 64, eps=1e-05, affine

### Build a data loader

In [7]:
train_loader, test_loader, train_eval_loader = get_mnist_loaders(args.data_aug,
                                                                     args.batch_size,
                                                                     args.test_batch_size,
                                                                     data_root=args.data_root)
data_gen = inf_generator(train_loader)
batches_per_epoch = len(train_loader)

### Evaluate the model

In [8]:
if args.torch_dtype == 'float64':
    dtype = torch.float64
elif args.torch_dtype == 'float32':
    dtype = torch.float32
    
solvers = [create_solver(*solver_params, dtype=dtype, device=device) for solver_params in args.solvers]
for solver in solvers:
    solver.freeze_params()
    
solver_options = Namespace(**{key: vars(args)[key] for key in ['solver_mode', 'switch_probs',
                                                                         'ensemble_prob', 'ensemble_weights']})

In [None]:
def one_hot(x, K):
    return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)

def accuracy(model, dataset_loader, device, solvers=None, solver_options=None):
    model.eval()
    total_correct = 0
    for x, y in dataset_loader:
        x = x.to(device)
        y = one_hot(np.array(y.numpy()), 10)
        target_class = np.argmax(y, axis=1)
        with torch.no_grad():
            if solver is not None:
                out = model(x, solvers, solver_options).cpu().detach().numpy()
            else:
                out = model(x).cpu().detach().numpy()
        predicted_class = np.argmax(out, axis=1)
        total_correct += np.sum(predicted_class == target_class)
    return total_correct / len(dataset_loader.dataset)


def adversarial_accuracy(model, dataset_loader, device, solvers=None, solver_options=None, args=None):
    model.eval()
    total_correct = 0
    if args.adv_testing_mode == "clean":
        test_attack = Clean(model)
    elif args.adv_testing_mode == "fgsm":
        test_attack = FGSM(model, mean=[0.], std=[1.], **CONFIG_FGSM_TEST)
    elif args.adv_testing_mode == "at":
        test_attack = PGD(model, mean=[0.], std=[1.], **CONFIG_PGD_TEST)
    else:
        raise ValueError("Attack type not understood.")
    for x, y in dataset_loader:
        x, y = x.to(device), y.to(device)
        x, y = test_attack(x, y, {"solvers": solvers, "solver_options": solver_options})
        y = one_hot(np.array(y.cpu().numpy()), 10)
        target_class = np.argmax(y, axis=1)
        with torch.no_grad():
            if solver is not None:
                out = model(x, solvers, solver_options).cpu().detach().numpy()
            else:
                out = model(x).cpu().detach().numpy()
        predicted_class = np.argmax(out, axis=1)
        total_correct += np.sum(predicted_class == target_class)
    return total_correct / len(dataset_loader.dataset)

In [None]:
accuracy_test = accuracy(model, test_loader, device=device,
                         solvers=solvers, solver_options=solver_options)
accuracy_test

In [None]:
CONFIG_PGD_TEST = {"eps": 0.3, "lr": 2.0 / 255, "n_iter": 7}
adv_accuracy_test = adversarial_accuracy(model, test_loader, device,
                                         solvers=solvers, solver_options=solver_options, args=args
                                        )
adv_accuracy_test