In this notebook we show how to perform the forward pass through the Neural ODE using different types of Meta Solver regimes, namely
- Standalone
- Solver switching/smoothing
- Solver ensembling
- Model ensembling

In more details, usage of different regimes means
- **Standalone**
    - Use one solver during  inference
    - Applied during training/testing.
     
    
    
- **Solver switching / smoothing**
    - For each batch one solver is chosen from a group of solvers with finite (in switching regime) or infinite (in smoothing regime) number of members.
    - Applied during training.
    
    
- **Solver ensembling**
    - Use several solvers durung inference.
    - Outputs of ODE Block (obtained with different solvers) are averaged before propagating through the next layer.
    - Applied during training/testing
    
    
- **Model ensembling**
    - Use several solvers durung inference.
    - Model probabilites obtained via propagation with different solvers are averaged to get the final result.
    - Applied during training/testing
    

In [1]:
import os
os.environ['CUDA_DEVICE_ORDER']="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICE']="4"

In [2]:
from argparse import Namespace
import torch
import numpy as np
import itertools
import wandb

import sys

sys.path.append('../../')
import sopa.src.models.odenet_cifar10.layers as cifar10_models
from sopa.src.models.odenet_cifar10.utils import *
from sopa.src.models.odenet_cifar10.data import get_cifar10_test_loader
from sopa.src.models.utils import fix_seeds
from sopa.src.solvers.utils import create_solver, noise_params, create_solver_ensemble_by_noising_params

from MegaAdversarial.src.attacks import (
    Clean,
    PGD,
    FGSM,
    Clean2Ensemble,
    FGSM2Ensemble,)

# Build a model

In [3]:
# Load a checkpoint

# checkpoint_name = './checkpoints/fgsm_random_8_255.pth'
checkpoint_name = './checkpoints/fgsm_random_8_255_smoothing_00125.pth'

checkpoint=torch.load(checkpoint_name)
config = Namespace(**checkpoint['wandb_config'])

In [4]:
# Solvers used during model training
config.solvers

[['rk2', 'u', 8, -1.0, 0.5, -1]]

In [5]:
# Initialize Neural ODE model
norm_layers = (get_normalization(config.normalization_resblock),
               get_normalization(config.normalization_odeblock),
               get_normalization(config.normalization_bn1))
param_norm_layers = (get_param_normalization(config.param_normalization_resblock),
                     get_param_normalization(config.param_normalization_odeblock),
                     get_param_normalization(config.param_normalization_bn1))
act_layers = (get_activation(config.activation_resblock),
              get_activation(config.activation_odeblock),
              get_activation(config.activation_bn1))

model = getattr(cifar10_models, config.network)(norm_layers, param_norm_layers, act_layers,
                                                config.in_planes, is_odenet=config.is_odenet)
model.load_state_dict(checkpoint['model'])

checkpoint=None
torch.cuda.empty_cache()

# Build a data loader

In [6]:
data_root="/workspace/raid/data/datasets/cifar10"
test_loader = get_cifar10_test_loader(batch_size=32,
                                      data_root=data_root,
                                      num_workers=1,
                                      pin_memory=False,
                                      shuffle=False,
                                      download=False)
len(test_loader)

312

# Evaluate the model

In [7]:
def one_hot(x, K):
    return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)

def accuracy(model, dataset_loader, device, solvers=None, solver_options=None, data_noise_std=None):
    model.eval()
    model.to(device)
    total_correct = 0

    for x, y in dataset_loader:
        x = x.to(device)
        y = one_hot(np.array(y.numpy()), 10)
        target_class = np.argmax(y, axis=1)

        with torch.no_grad():
            # Add noise:
            if (data_noise_std is not None) and (data_noise_std > 1e-12):
                x = x + data_noise_std * torch.randn_like(x)
                
            if solvers is not None:
                out = model(x, solvers, solver_options).cpu().detach().numpy()
            else:
                out = model(x).cpu().detach().numpy()
            predicted_class = np.argmax(out, axis=1)
            total_correct += np.sum(predicted_class == target_class)

    total = len(dataset_loader) * dataset_loader.batch_size
    torch.cuda.empty_cache()
    return total_correct / total

def adversarial_accuracy(model, dataset_loader, device, solvers=None, solver_options=None, args=None):
    global CONFIG_PGD_TEST
    global CONFIG_FGSM_TEST
    
    model.eval()
    model.to(device)
    total_correct = 0

    if args.a_adv_testing_mode == "clean":
        test_attack = Clean(model)
    elif args.a_adv_testing_mode == "fgsm":
        test_attack = FGSM(model, **CONFIG_FGSM_TEST)
    elif args.a_adv_testing_mode == "at":
        test_attack = PGD(model, **CONFIG_PGD_TEST)
    else:
        raise ValueError("Attack type not understood.")
    for x, y in dataset_loader:
        x, y = x.to(device), y.to(device)
        x, y = test_attack(x, y, {"solvers": solvers, "solver_options": solver_options})
        y = one_hot(np.array(y.cpu().numpy()), 10)
        target_class = np.argmax(y, axis=1)

        with torch.no_grad():
            if solvers is not None:
                out = model(x, solvers, solver_options).cpu().detach().numpy()
            else:
                out = model(x).cpu().detach().numpy()
            predicted_class = np.argmax(out, axis=1)
            total_correct += np.sum(predicted_class == target_class)

    total = len(dataset_loader) * dataset_loader.batch_size
    torch.cuda.empty_cache()
    return total_correct / total

In [8]:
cifar10_mean, cifar10_std = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)

CONFIG_FGSM_TEST = {"eps": 8/255., "mean": cifar10_mean, "std": cifar10_std}
CONFIG_PGD_TEST = {"eps": 8/255., "lr": 2/255., "n_iter": 7,
                   "mean": cifar10_mean, "std": cifar10_std}

device='cuda'
dtype=torch.float32

# Standalone
- Use one solver during  inference

In [9]:
# Create a solver
val_solvers = [create_solver(method='rk2',
                             parameterization='u',
                             n_steps=8,
                             step_size=-1,
                             u0=0.5,
                             v0=-1,
                             dtype=dtype,
                             device=device)]
# Freeze solver params
for solver in val_solvers:
    solver.freeze_params()

val_solver_options = Namespace(**{'solver_mode': 'standalone'})

In [10]:
# Compute standard accuracy
accuracy_test = accuracy(model, test_loader, device=device,
                         solvers=val_solvers, solver_options=val_solver_options)
accuracy_test

0.8279246794871795

In [11]:
# Compute robust accuracy
# a_adv_testing_mode = 'fgsm' for FGSM, or 'at' for PGD
a_adv_testing_mode = 'fgsm' 
adv_accuracy_test = adversarial_accuracy(model, test_loader, device,
                                         solvers=val_solvers, solver_options=val_solver_options,
                                         args = Namespace(**{'a_adv_testing_mode': 'fgsm'}))
adv_accuracy_test

0.41616586538461536

# Solver switching
For each batch one solver is chosen from a group of solvers with finite (in switching regime) or infinite (in smoothing regime) number of members.

In [23]:
device='cuda'
dtype=torch.float32
val_solvers = [create_solver(method='rk2', parameterization='u', n_steps=8, step_size=-1, u0=0.5, v0=-1,
                             dtype=dtype, device=device),
              create_solver(method='rk2', parameterization='u', n_steps=8, step_size=-1, u0=1., v0=-1,
                             dtype=dtype, device=device)]
for solver in val_solvers:
    solver.freeze_params()
    
val_solver_options = Namespace(**{'solver_mode': 'switch', 'switch_probs': [0.6, 0.4]})

In [24]:
accuracy_test = accuracy(model, test_loader, device=device,
                         solvers=val_solvers, solver_options=val_solver_options)
accuracy_test

0.8277243589743589

In [25]:
adv_accuracy_test = adversarial_accuracy(model, test_loader, device,
                                         solvers=val_solvers, solver_options=val_solver_options,
                                         args = Namespace(**{'a_adv_testing_mode': 'fgsm'})
                                        )
adv_accuracy_test

0.41616586538461536

# Solver Ensembling
- Use several solvers durung inference.

- Outputs of ODE Block (obtained with different solvers) are averaged before propagating through the next layer.

In [15]:
val_solvers = [create_solver(method='rk2', parameterization='u', n_steps=8, step_size=-1, u0=0.5, v0=-1,
                             dtype=dtype, device=device),
              create_solver(method='rk2', parameterization='u', n_steps=8, step_size=-1, u0=1., v0=-1,
                             dtype=dtype, device=device)]
for solver in val_solvers:
    solver.freeze_params()
    
val_solver_options = Namespace(**{'solver_mode': 'ensemble',
                                  'ensemble_prob':1, 'ensemble_weights': [0.6, 0.4]})

In [16]:
accuracy_test = accuracy(model, test_loader, device=device,
                         solvers=val_solvers, solver_options=val_solver_options)
accuracy_test

0.8278245192307693

In [17]:
adv_accuracy_test = adversarial_accuracy(model, test_loader, device,
                                         solvers=val_solvers, solver_options=val_solver_options,
                                         args = Namespace(**{'a_adv_testing_mode': 'fgsm'})
                                        )
adv_accuracy_test

0.41626602564102566

# Model Ensembling
- Use several solvers durung inference.

- Model probabilites obtained via propagation with different solvers are averaged to get the final result.

In [18]:
def accuracy_ensemble(models, dataset_loader, device, solvers_solver_options_arr=None, data_noise_std=None):
    for model in models:
        model.eval()
    total_correct = 0

    for x, y in dataset_loader:
        x = x.to(device)
        y = one_hot(np.array(y.numpy()), 10)
        target_class = np.argmax(y, axis=1)

        with torch.no_grad():
            # Add noise:
            if (data_noise_std is not None) and (data_noise_std > 1e-12):
                x = x + data_noise_std * torch.randn_like(x)

            probs_ensemble = 0

            if solvers_solver_options_arr is not None:
                for n, (model, solvers_solver_options) in enumerate(
                        itertools.zip_longest(models, solvers_solver_options_arr, fillvalue=models[0])):
                    logits = model(x, **solvers_solver_options)
                    probs = nn.Softmax(dim=1)(logits).cpu().detach().numpy()
                    probs_ensemble = probs_ensemble + probs

            else:
                for n, model in enumerate(models):
                    logits = model(x)
                    probs = nn.Softmax()(logits).cpu().detach().numpy()
                    probs_ensemble = probs_ensemble + probs

            probs_ensemble /= (n + 1)

            predicted_class = np.argmax(probs_ensemble, axis=1)
            total_correct += np.sum(predicted_class == target_class)

    total = len(dataset_loader) * dataset_loader.batch_size
    return total_correct / total


def adversarial_accuracy_ensemble(models, dataset_loader, device, solvers_solver_options_arr=None, args=None):
    global CONFIG_PGD_TEST
    global CONFIG_FGSM_TEST

    for model in models:
        model.eval()
    total_correct = 0

    if args.a_adv_testing_mode == "clean":
        test_attack = Clean2Ensemble(models)
    elif args.a_adv_testing_mode == "fgsm":
        test_attack = FGSM2Ensemble(models, **CONFIG_FGSM_TEST)
    else:
        raise ValueError("Attack type is not implemented for ensemble of models")

    for x, y in dataset_loader:
        x, y = x.to(device), y.to(device)
        x, y = test_attack(x, y, solvers_solver_options_arr)
        y = one_hot(np.array(y.cpu().numpy()), 10)
        target_class = np.argmax(y, axis=1)

        with torch.no_grad():

            probs_ensemble = 0

            if solvers_solver_options_arr is not None:
                for n, (model, solvers_solver_options) in enumerate(
                        itertools.zip_longest(models, solvers_solver_options_arr, fillvalue=models[0])):
                    logits = model(x, **solvers_solver_options)
                    probs = nn.Softmax(dim=1)(logits).cpu().detach().numpy()
                    probs_ensemble = probs_ensemble + probs

            else:
                for n, model in enumerate(models):
                    logits = model(x)
                    probs = nn.Softmax()(logits).cpu().detach().numpy()
                    probs_ensemble = probs_ensemble + probs

            probs_ensemble /= (n + 1)

            predicted_class = np.argmax(probs_ensemble, axis=1)
            total_correct += np.sum(predicted_class == target_class)

    total = len(dataset_loader) * dataset_loader.batch_size
    return total_correct /total

In [19]:
val_solvers = [create_solver(method='rk2',
                             parameterization='u',
                             n_steps=8,
                             step_size=-1,
                             u0=0.5,
                             v0=-1,
                             dtype=dtype,
                             device=device)]
for solver in val_solvers:
    solver.freeze_params()

val_solver_options = Namespace(**{'solver_mode': 'standalone'})

ensemble_size = 2
solver_ensemble = create_solver_ensemble_by_noising_params(val_solvers[0],
                                                           ensemble_size=ensemble_size,
                                                           kwargs_noise={'std': 0.2,
                                                                         'bernoulli_p': 1.,
                                                                         'noise_type': 'normal'})

solvers_solver_options_arr = [{'solvers': [solver], 'solver_options': val_solver_options}
                              for solver in solver_ensemble]

tensor([0.7345], device='cuda:0') None


In [20]:
accuracy_test = accuracy_ensemble([model], test_loader, device=device,
                                  solvers_solver_options_arr=solvers_solver_options_arr,)
accuracy_test

0.8279246794871795

In [21]:
adv_accuracy_test = adversarial_accuracy_ensemble([model], test_loader, device=device,
                                                  solvers_solver_options_arr=solvers_solver_options_arr,
                                                  args=Namespace(**{'a_adv_testing_mode': 'fgsm'}))
adv_accuracy_test

  probs_ensemble = probs_ensemble + nn.Softmax()(logits)


0.41626602564102566