In [None]:
import gym
import random
from advertorch.attacks import *
from utils import *
from net.discrete_net import *
import random as rd
import numpy as np
import matplotlib.pyplot as plt

In [None]:
def test_adv_img_atk(env, adv_atk, n_eval=100):
    """Perform the adversarial attacks"""
    targeted = adv_atk.targeted
    action_shape = env.action_space.shape or env.action_space.n
    state_shape = env.observation_space.shape or env.observation_space.n
    succ = 0
    obs = env.reset()
    for i in range(n_eval):
        act = [random.randint(0, action_shape-1)]
        if targeted:
          des_act = act
          while act == des_act:
              act = [rd.randint(0, np.prod(action_shape)-1)]
        t_obs = torch.FloatTensor(obs[np.newaxis, :]).to(device)
        if not targeted:
            t_act = torch.tensor(act).to(device)
            adv_obs = adv_atk.perturb(t_obs, t_act)
            """adv_obs_ = adv_obs.cpu().detach().numpy().astype(int)[0]
            imgplot = plt.imshow(obs)
            plt.show()
            imgplot = plt.imshow(adv_obs_)
            plt.show()
            input()"""
        else:
            t_des_act = torch.tensor(des_act).to(device)
            adv_obs = adv_atk.perturb(t_obs, t_des_act)

        y = adv_atk.predict(adv_obs)
        _, adv_actions = torch.max(y, 1)
        adv_act = adv_actions.cpu().detach().numpy()[0]
        if not targeted and adv_act != act:
            succ += 1
        if targeted and adv_act == des_act:
            succ += 1
        obs, rew, done, info = env.step(adv_act)
        if done:
            obs = env.reset()
    return succ / n_eval

In [None]:
def init_attack(env, device="cpu", attack_type="gsm", eps=0.3, targeted=False):
    state_shape = env.observation_space.shape or env.observation_space.n
    action_shape = env.action_space.shape or env.action_space.n
    net = Net(2, state_shape, action_shape, device).to(device)
    net = NetAdapter(net).to(device)
    net.eval()
    if attack_type == "gsm":
        adv_atk = GradientSignAttack(net, targeted=targeted, eps=eps, clip=255.)
    elif attack_type == "cw":
        adv_atk = CarliniWagnerL2Attack(net, np.prod(action_shape), confidence=1, max_iterations=10, clip=255.)
    else:
        raise Exception("method not supported!")
    return adv_atk

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
list_env = [gym.make("CartPole-v0"), gym.make("Breakout-v4"), gym.make("MsPacman-v4"), gym.make("Pong-v4")]
list_atk = ["gsm", "cw"]

eps = 255*0.3
targeted = True
env = list_env[1]
atk = list_atk[1]

adv_atk = init_attack(env, device, attack_type=atk, eps=eps, targeted=targeted)

In [None]:
res = test_adv_img_atk(env, adv_atk)
print("Pong - GSM: ", res)

Pong - GSM:  0.9733333333333334


In [None]:
res = test_adv_img_atk(env, adv_atk)
print("Pacman - Targeted - GSM: ", res)

Pacman - Targeted - GSM:  0.87


In [None]:
res = test_adv_img_atk(env, adv_atk)
print("Breakout - Targeted - GSM: ", res)

Breakout - Targeted - GSM:  0.8333333333333334


In [None]:
res = test_adv_img_atk(env, adv_atk)
print("Breakout - Targeted - CW: ", res)

Breakout - Targeted - CW:  0.72
