In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
%cd /content/drive/My\ Drive/adversarial_attacks_DRL

/content/drive/My Drive/adversarial_attacks_DRL


In [None]:
!pip install advertorch
!pip install tianshou

In [4]:
from advertorch.attacks import *
from atari_wrapper import wrap_deepmind
import copy
import torch
from drl_attacks.uniform_attack import uniform_attack_collector
from drl_attacks.strategically_timed_attack import strategically_timed_attack_collector
from utils import A2CPPONetAdapter

In [5]:
from advertorch.attacks.base import Attack
from drl_attacks.base_attack import base_attack_collector
import random as rd
import gym
import time
import torch
import numpy as np
from typing import Any, Dict, List, Union, Optional, Callable
from tianshou.policy import BasePolicy


class targeted_uniform_attack_collector(base_attack_collector):
    """
    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
        class.
    :param env: a ``gym.Env`` environment or an instance of the
        :class:`~tianshou.env.BaseVectorEnv` class.
    :param obs_adv_atk: an instance of the :class:`~advertorch.attacks.base.Attack`
        class implementing an image adversarial attack.
    :param perfect_attack: force adversarial attacks on observations to be
        always effective (ignore the ``adv`` param).
    :param atk_frequency: float, how frequently attacking env observations
    """
    def __init__(self,
                 policy: BasePolicy,
                 env: gym.Env,
                 obs_adv_atk: Attack,
                 perfect_attack: bool = False,
                 device: str = 'cuda' if torch.cuda.is_available() else 'cpu',
                 atk_frequency: float = 1.
                 ):
        super().__init__(
            policy, env, obs_adv_atk, perfect_attack, device)

        self.atk_frequency = atk_frequency
        if self.obs_adv_atk is not None:
            self.obs_adv_atk.targeted = True

    def collect(self,
                n_step: int = 0,
                n_episode: int = 0,
                render: Optional[float] = None
                ) -> Dict[str, float]:
        assert (n_step and not n_episode) or (not n_step and n_episode), \
            "One and only one collection number specification is permitted!"
        self.reset_env()
        self.reset_attack()
        while True:
            if render:
                self.render()
                time.sleep(render)
            self.show_warning()
            self.predict_next_action()

            # START ADVERSARIAL ATTACK
            x = rd.uniform(0, 1)
            if x < self.atk_frequency:
                des_act = [rd.randint(0, self.action_space-1)]
                while des_act == self.data.act:
                  des_act = [rd.randint(0, self.action_space-1)]
                if not self.perfect_attack:
                    self.obs_attacks(des_act)
                else:
                    self.data.act = des_act
                if self.data.act == des_act:
                    self.succ_attacks += 1
                self.n_attacks += 1
            self.frames_count += 1
            # END ADVERSARIAL ATTACK

            self.perform_step()
            if self.check_end_attack(n_step, n_episode):
                break

        return self.get_attack_stats()

In [6]:
def make_atari_env_watch(env_name):
    return wrap_deepmind(env_name, frame_stack=4,
                         episode_life=False, clip_rewards=False)

In [7]:
def init_attack(env, policy_type="ppo", device="cpu", attack_type="gsm", eps=0.01, targeted=False):
    # load pretrained Pong-PPO policy 
    model_path = "log/" + env + "/" + policy_type + "/policy.pth"
    env = make_atari_env_watch(env)
    state_shape = env.observation_space.shape or env.observation_space.n
    action_shape = env.action_space.shape or env.action_space.n
    policy, _ = torch.load(model_path, map_location=device)
    policy.to(device).init(device)

    # adapt PPO policy to Advertorch library
    if policy_type in ["a2c", "ppo"]:
        adv_net = A2CPPONetAdapter(copy.deepcopy(policy)).to(device)
        adv_net.eval()

    # define image adversarial attack
    if attack_type == "gsm":
        obs_adv_atk = GradientSignAttack(adv_net, eps=eps*255,
                                        clip_min=0, clip_max=255, targeted=targeted)
    elif attack_type == "cw":
        obs_adv_atk = CarliniWagnerL2Attack(adv_net, np.prod(action_shape),
                                            confidence=1, max_iterations=100,
                                            clip_min=0, clip_max=255,
                                            binary_search_steps=8,
                                            targeted=targeted)
    elif attack_type == "pgda":
        obs_adv_atk = PGDAttack(adv_net, eps=eps, targeted=targeted,
                                clip_min=0, clip_max=255, nb_iter=100,
                                eps_iter=0.01)
    return policy, obs_adv_atk

In [8]:
list_env = [ "PongNoFrameskip-v4", "BreakoutNoFrameskip-v4", "MsPacmanNoFrameskip-v4",
            "SeaquestNoFrameskip-v4", "QbertNoFrameskip-v4", "EnduroNoFrameskip-v4",
            "SpaceInvadersNoFrameskip-v4"]
list_atk = ["gsm", "cw"]
list_pol = ["a2c", "ppo"]

def test_untargeted(env, img_atk, policy, eps=0.01, episodes=10):
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    eps = 255*eps
    targeted = False
    policy, obs_adv_atk = init_attack(env, policy_type=policy, device=device, attack_type=img_atk, eps=eps, targeted=targeted)
    env = make_atari_env_watch(env)
    collector = uniform_attack_collector(policy, env, obs_adv_atk, atk_frequency=0.5, device=device)
    test_adversarial_policy = collector.collect(n_episode=episodes)
    return test_adversarial_policy['succ_atks(%)']
  
def test_targeted(env, img_atk, policy, eps=0.01, episodes=10):
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    eps = 255*eps
    targeted = True
    policy, obs_adv_atk = init_attack(env, policy_type=policy, device=device, attack_type=img_atk, eps=eps, targeted=targeted)
    env = make_atari_env_watch(env)
    collector = targeted_uniform_attack_collector(policy, env, obs_adv_atk, atk_frequency=0.2, device=device)
    test_adversarial_policy = collector.collect(n_episode=episodes)
    return test_adversarial_policy['succ_atks(%)']

# GSM - Targeted

In [None]:
print("PPO - Pong - GSM - Untargeted: ", test_untargeted("PongNoFrameskip-v4", "gsm", "ppo", eps=0.01, episodes=10))

PPO - Pong - GSM - Untargeted:  1.0


In [None]:
print("PPO - Breakout - GSM - Untargeted: ", test_untargeted("BreakoutNoFrameskip-v4", "gsm", "ppo", eps=0.01, episodes=10))

PPO - Breakout - GSM - Untargeted:  0.9785714285714285


In [None]:
print("PPO - MsPacman - GSM - Untargeted: ", test_untargeted("MsPacmanNoFrameskip-v4", "gsm", "ppo", eps=0.01, episodes=10))

PPO - MsPacman - GSM - Untargeted:  1.0


In [None]:
print("PPO - Seaquest - GSM - Untargeted: ", test_untargeted("SeaquestNoFrameskip-v4", "gsm", "ppo", eps=0.01, episodes=10))

PPO - Seaquest - GSM - Untargeted:  1.0


In [None]:
print("PPO - Enduro - GSM - Untargeted: ", test_untargeted("EnduroNoFrameskip-v4", "gsm", "ppo", eps=0.01, episodes=10))

PPO - Enduro - GSM - Untargeted:  1.0


In [None]:
print("PPO - SpaceInvaders - GSM - Untargeted: ", test_untargeted("SpaceInvadersNoFrameskip-v4", "gsm", "ppo", eps=0.01, episodes=10))

PPO - SpaceInvaders - GSM - Untargeted:  0.9951282884053264


In [None]:
print("PPO - Qbert - GSM - Untargeted: ", test_untargeted("QbertNoFrameskip-v4", "gsm", "ppo", eps=0.01, episodes=10))

PPO - Qbert - GSM - Untargeted:  1.0


# GSM - Targeted

In [None]:
print("PPO - Pong - GSM - Targeted: ", test_targeted("PongNoFrameskip-v4", "gsm", "ppo", eps=0.01, episodes=10))

PPO - Pong - GSM - Targeted:  0.5013093901982791


In [None]:
print("PPO - Breakout - GSM - Targeted: ", test_targeted("BreakoutNoFrameskip-v4", "gsm", "ppo", eps=0.01, episodes=10))

PPO - Breakout - GSM - Targeted:  0.3951979234263465


In [None]:
print("PPO - MsPacman - GSM - Targeted: ", test_targeted("MsPacmanNoFrameskip-v4", "gsm", "ppo", eps=0.01, episodes=10))

PPO - MsPacman - GSM - Targeted:  0.3375870069605568


In [None]:
print("PPO - Seaquest - GSM - Targeted: ", test_targeted("SeaquestNoFrameskip-v4", "gsm", "ppo", eps=0.01, episodes=10))

PPO - Seaquest - GSM - Targeted:  0.339137422984195


In [None]:
print("PPO - Enduro - GSM - Targeted: ", test_targeted("EnduroNoFrameskip-v4", "gsm", "ppo", eps=0.01, episodes=10))

PPO - Enduro - GSM - Targeted:  0.4492622704004818


In [None]:
print("PPO - SpaceInvaders - GSM - Targeted: ", test_targeted("SpaceInvadersNoFrameskip-v4", "gsm", "ppo", eps=0.01, episodes=10))

PPO - SpaceInvaders - GSM - Targeted:  0.5375947995666306


In [None]:
print("PPO - Qbert - GSM - Targeted: ", test_targeted("QbertNoFrameskip-v4", "gsm", "ppo", eps=0.01, episodes=10))

PPO - Qbert - GSM - Targeted:  0.8012307692307692


# CW - Targeted

In [None]:
print("PPO - Pong - CW - Targeted: ", test_targeted("PongNoFrameskip-v4", "cw", "ppo", eps=0.01, episodes=2))

PPO - Pong - CW - Targeted:  0.7198515769944341


In [None]:
print("PPO - Breakout - CW - Targeted: ", test_targeted("BreakoutNoFrameskip-v4", "cw", "ppo", eps=0.01, episodes=1))

PPO - Breakout - CW - Targeted:  0.4701195219123506


In [None]:
print("PPO - MsPacman - CW - Targeted: ", test_targeted("MsPacmanNoFrameskip-v4", "cw", "ppo", eps=0.01, episodes=1))

PPO - MsPacman - CW - Targeted:  0.29949238578680204


In [None]:
print("PPO - Seaquest - CW - Targeted: ", test_targeted("SeaquestNoFrameskip-v4", "cw", "ppo", eps=0.01, episodes=1))

PPO - Seaquest - CW - Targeted:  0.4681081081081081


In [None]:
print("PPO - Enduro - CW - Targeted: ", test_targeted("EnduroNoFrameskip-v4", "cw", "ppo", eps=0.01, episodes=1))

PPO - Enduro - CW - Targeted:  0.3360131010867947


In [10]:
print("PPO - SpaceInvaders - CW - Targeted: ", test_targeted("SpaceInvadersNoFrameskip-v4", "cw", "ppo", eps=0.01, episodes=1))

PPO - SpaceInvaders - CW - Targeted:  0.2544642857142857


In [9]:
print("PPO - Qbert - CW - Targeted: ", test_targeted("QbertNoFrameskip-v4", "cw", "ppo", eps=0.01, episodes=1))

PPO - Qbert - CW - Targeted:  0.40425531914893614


# CW - Untargeted

In [None]:
print("PPO - Pong - CW - Untargeted: ", test_untargeted("PongNoFrameskip-v4", "cw", "ppo", eps=0.01, episodes=1))

PPO - Pong - CW - Untargeted:  0.8546895640686922


In [None]:
print("PPO - Breakout - CW - Untargeted: ", test_untargeted("BreakoutNoFrameskip-v4", "cw", "ppo", eps=0.01, episodes=2))

PPO - Breakout - CW - Untargeted:  0.6911764705882353


In [None]:
print("PPO - MsPacman - CW - Untargeted: ", test_untargeted("MsPacmanNoFrameskip-v4", "cw", "ppo", eps=0.01, episodes=2))

PPO - MsPacman - CW - Untargeted:  0.8028391167192429


In [None]:
print("PPO - Seaquest - CW - Untargeted: ", test_untargeted("SeaquestNoFrameskip-v4", "cw", "ppo", eps=0.01, episodes=2))

PPO - Seaquest - CW - Untargeted:  0.926110833749376


In [None]:
print("PPO - Enduro - CW - Untargeted: ", test_untargeted("EnduroNoFrameskip-v4", "cw", "ppo", eps=0.01, episodes=1))

In [None]:
print("PPO - SpaceInvaders - CW - Untargeted: ", test_untargeted("SpaceInvadersNoFrameskip-v4", "cw", "ppo", eps=0.01, episodes=2))

PPO - SpaceInvaders - CW - Untargeted:  0.7832669322709164


In [None]:
print("PPO - Qbert - CW - Untargeted: ", test_untargeted("QbertNoFrameskip-v4", "cw", "ppo", eps=0.01, episodes=2))

PPO - Qbert - CW - Untargeted:  0.8595271210013908


# PGDA - Targeted

In [14]:
print("PPO - Pong - PGDA - Targeted: ", test_targeted("PongNoFrameskip-v4", "pgda", "ppo", eps=0.01, episodes=2))

PPO - Pong - PGDA - Targeted:  0.9857142857142858


In [13]:
print("PPO - Breakout - PGDA - Targeted: ", test_targeted("BreakoutNoFrameskip-v4", "pgda", "ppo", eps=0.01, episodes=2))

PPO - Breakout - PGDA - Targeted:  0.8297872340425532


In [12]:
print("PPO - MsPacman - PGDA - Targeted: ", test_targeted("MsPacmanNoFrameskip-v4", "pgda", "ppo", eps=0.01, episodes=2))

PPO - MsPacman - PGDA - Targeted:  0.36645962732919257


In [None]:
print("PPO - Seaquest - PGDA - Targeted: ", test_targeted("SeaquestNoFrameskip-v4", "pgda", "ppo", eps=0.01, episodes=2))

PPO - Seaquest - PGDA - Targeted:  0.49728555917481


In [None]:
print("PPO - Enduro - PGDA - Targeted: ", test_targeted("EnduroNoFrameskip-v4", "pgda", "ppo", eps=0.01, episodes=2))

PPO - Enduro - PGDA - Targeted:  0.35488611315209406


In [None]:
print("PPO - SpaceInvaders - PGDA - Targeted: ", test_targeted("SpaceInvadersNoFrameskip-v4", "pgda", "ppo", eps=0.01, episodes=1))

PPO - SpaceInvaders - PGDA - Targeted:  0.6666666666666666


In [None]:
print("PPO - Qbert - PGDA - Targeted: ", test_targeted("QbertNoFrameskip-v4", "pgda", "ppo", eps=0.01, episodes=1))

PPO - Qbert - PGDA - Targeted:  0.5
