In [1]:
from stable_baselines3 import A2C, DQN, PPO
from stable_baselines3.ppo import MlpPolicy as PPOMlp
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.ppo import MlpPolicy as PPOMlp

from yawning_titan.envs.generic.core.blue_interface import BlueInterface
from yawning_titan.envs.generic.core.network_interface import NetworkInterface
from yawning_titan.game_modes.game_mode import GameMode
from yawning_titan.networks.node import Node
from yawning_titan.envs.generic.core.action_loops import ActionLoop
from yawning_titan.yawning_titan_run import YawningTitanRun
from yawning_titan.envs.generic.helpers.eval_printout import EvalPrintout
from yawning_titan import AGENTS_DIR, PPO_TENSORBOARD_LOGS_DIR
from yawning_titan.game_modes.game_mode_db import default_game_mode, GameModeDB
from yawning_titan.networks.network import Network
from yawning_titan.networks.network_db import default_18_node_network, NetworkDB

from adaptive_red import AdaptiveRed
from multiagent_env import MultiAgentEnv

import torch
import torch.nn as nn
from torch.distributions import MultivariateNormal
from torch.distributions import Categorical

from typing import Tuple, Union, Optional, Dict, List
from logging import Logger, getLogger
from uuid import uuid4

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
gdb = GameModeDB()
ndb = NetworkDB()

In [3]:
gdb.show(True)

name                    author              locked    uuid
----------------------  ------------------  --------  ------------------------------------
DCBO Agent Config       dstl/YAWNING-TITAN  True      bac2cb9d-b24b-426c-88a5-5edd0c2de413
Default Game Mode       dstl/YAWNING-TITAN  True      900a704f-6271-4994-ade7-40b74d3199b1
Low skill red           dstl/YAWNING-TITAN  True      3ccd9988-8781-4c3e-9c75-44cc987ae6af
Ransomware_base                             False     26edf1f7-c71d-4564-89d8-0eeee1659afc
Ransomware_pessimistic                      False     c19751f4-577a-49dc-b697-52eded404309
Ransomware_naive                            False     034552b9-a971-42fc-8553-f50efc7a70a7
Ransomware_lsr                              False     3674f421-9450-4e1b-83ac-1dab0e597ed5


In [4]:
ndb.show(True)

name                     author              locked    uuid
-----------------------  ------------------  --------  ------------------------------------
Default 18-node network  dstl/YAWNING-TITAN  True      b3cd9dfd-b178-415d-93f0-c9e279b3c511
Dcbo base network        dstl/YAWNING-TITAN  True      47cb9f49-b53d-44f8-9a7b-3d74cf2ec1b0
Test Star Network        Erick Galinkin      False     8d6912ce-fc5a-4620-9cf0-2393b4e2b5ca
50 Node Mesh             Erick Galinkin      False     3b921390-cd7b-41c5-8120-5e9ac587d2f2


In [5]:
game_mode = gdb.get("26edf1f7-c71d-4564-89d8-0eeee1659afc")
network = ndb.get("3b921390-cd7b-41c5-8120-5e9ac587d2f2")

In [6]:
network_interface = NetworkInterface(game_mode, network)

In [7]:
red = AdaptiveRed(network_interface)
blue = BlueInterface(network_interface)

In [8]:
env = MultiAgentEnv(red_agent=red, 
                    blue_agent=blue, 
                    network_interface=network_interface)

In [9]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

In [10]:
class RolloutBuffer:
    def __init__(self):
        self.actions = []
        self.states = []
        self.logprobs = []
        self.rewards = []
        self.state_values = []
        self.is_terminals = []
    
    def clear(self):
        del self.actions[:]
        del self.states[:]
        del self.logprobs[:]
        del self.rewards[:]
        del self.state_values[:]
        del self.is_terminals[:]

In [11]:
class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, has_continuous_action_space, action_std_init):
        super(ActorCritic, self).__init__()

        self.has_continuous_action_space = has_continuous_action_space
        
        if has_continuous_action_space:
            self.action_dim = action_dim
            self.action_var = torch.full((action_dim,), action_std_init * action_std_init).to(device)
        # actor
        if has_continuous_action_space :
            self.actor = nn.Sequential(
                            nn.Linear(state_dim, 64),
                            nn.Tanh(),
                            nn.Linear(64, 64),
                            nn.Tanh(),
                            nn.Linear(64, action_dim),
                            nn.Tanh()
                        )
        else:
            self.actor = nn.Sequential(
                            nn.Linear(state_dim, 64),
                            nn.Tanh(),
                            nn.Linear(64, 64),
                            nn.Tanh(),
                            nn.Linear(64, action_dim),
                            nn.Softmax(dim=-1)
                        )
        # critic
        self.critic = nn.Sequential(
                        nn.Linear(state_dim, 64),
                        nn.Tanh(),
                        nn.Linear(64, 64),
                        nn.Tanh(),
                        nn.Linear(64, 1)
                    )
        
    def set_action_std(self, new_action_std):
        if self.has_continuous_action_space:
            self.action_var = torch.full((self.action_dim,), new_action_std * new_action_std).to(device)
        else:
            print("--------------------------------------------------------------------------------------------")
            print("WARNING : Calling ActorCritic::set_action_std() on discrete action space policy")
            print("--------------------------------------------------------------------------------------------")

    def forward(self):
        raise NotImplementedError
    
    def act(self, state):

        if self.has_continuous_action_space:
            action_mean = self.actor(state)
            cov_mat = torch.diag(self.action_var).unsqueeze(dim=0)
            dist = MultivariateNormal(action_mean, cov_mat)
        else:
            action_probs = self.actor(state)
            dist = Categorical(action_probs)

        action = dist.sample()
        action_logprob = dist.log_prob(action)
        state_val = self.critic(state)

        return action.detach(), action_logprob.detach(), state_val.detach()
    
    def evaluate(self, state, action):

        if self.has_continuous_action_space:
            action_mean = self.actor(state)
            
            action_var = self.action_var.expand_as(action_mean)
            cov_mat = torch.diag_embed(action_var).to(device)
            dist = MultivariateNormal(action_mean, cov_mat)
            
            # For Single Action Environments.
            if self.action_dim == 1:
                action = action.reshape(-1, self.action_dim)
        else:
            action_probs = self.actor(state)
            dist = Categorical(action_probs)
        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()
        state_values = self.critic(state)
        
        return action_logprobs, state_values, dist_entropy

In [12]:
class PPO:
    def __init__(self, state_dim, action_dim, K_epochs=20, lr_actor=0.0003, lr_critic=0.001, gamma=0.99, eps_clip=0.2, has_continuous_action_space=False, action_std_init=0.6):
        self.has_continuous_action_space = has_continuous_action_space

        if has_continuous_action_space:
            self.action_std = action_std_init

        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        
        self.buffer = RolloutBuffer()

        self.policy = ActorCritic(state_dim, action_dim, has_continuous_action_space, action_std_init).to(device)
        self.optimizer = torch.optim.Adam([
                        {'params': self.policy.actor.parameters(), 'lr': lr_actor},
                        {'params': self.policy.critic.parameters(), 'lr': lr_critic}
                    ])

        self.policy_old = ActorCritic(state_dim, action_dim, has_continuous_action_space, action_std_init).to(device)
        self.policy_old.load_state_dict(self.policy.state_dict())
        
        self.MseLoss = nn.MSELoss()

    def set_action_std(self, new_action_std):
        if self.has_continuous_action_space:
            self.action_std = new_action_std
            self.policy.set_action_std(new_action_std)
            self.policy_old.set_action_std(new_action_std)
        else:
            print("--------------------------------------------------------------------------------------------")
            print("WARNING : Calling PPO::set_action_std() on discrete action space policy")
            print("--------------------------------------------------------------------------------------------")

    def decay_action_std(self, action_std_decay_rate, min_action_std):
        print("--------------------------------------------------------------------------------------------")
        if self.has_continuous_action_space:
            self.action_std = self.action_std - action_std_decay_rate
            self.action_std = round(self.action_std, 4)
            if (self.action_std <= min_action_std):
                self.action_std = min_action_std
                print("setting actor output action_std to min_action_std : ", self.action_std)
            else:
                print("setting actor output action_std to : ", self.action_std)
            self.set_action_std(self.action_std)

        else:
            print("WARNING : Calling PPO::decay_action_std() on discrete action space policy")
        print("--------------------------------------------------------------------------------------------")

    def select_action(self, state):

        if self.has_continuous_action_space:
            with torch.no_grad():
                state = torch.FloatTensor(state).to(device)
                action, action_logprob, state_val = self.policy_old.act(state)

            self.buffer.states.append(state)
            self.buffer.actions.append(action)
            self.buffer.logprobs.append(action_logprob)
            self.buffer.state_values.append(state_val)

            return action.detach().cpu().numpy().flatten()
        else:
            with torch.no_grad():
                state = torch.FloatTensor(state).to(device)
                action, action_logprob, state_val = self.policy_old.act(state)
            
            self.buffer.states.append(state)
            self.buffer.actions.append(action)
            self.buffer.logprobs.append(action_logprob)
            self.buffer.state_values.append(state_val)

            return action.item()

    def update(self):
        # Monte Carlo estimate of returns
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(self.buffer.rewards), reversed(self.buffer.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)
            
        # Normalizing the rewards
        rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)

        # convert list to tensor
        old_states = torch.squeeze(torch.stack(self.buffer.states, dim=0)).detach().to(device)
        old_actions = torch.squeeze(torch.stack(self.buffer.actions, dim=0)).detach().to(device)
        old_logprobs = torch.squeeze(torch.stack(self.buffer.logprobs, dim=0)).detach().to(device)
        old_state_values = torch.squeeze(torch.stack(self.buffer.state_values, dim=0)).detach().to(device)

        # calculate advantages
        advantages = rewards.detach() - old_state_values.detach()

        # Optimize policy for K epochs
        for _ in range(self.K_epochs):

            # Evaluating old actions and values
            logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)

            # match state_values tensor dimensions with rewards tensor
            state_values = torch.squeeze(state_values)
            
            # Finding the ratio (pi_theta / pi_theta__old)
            ratios = torch.exp(logprobs - old_logprobs.detach())

            # Finding Surrogate Loss  
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages

            # final loss of clipped objective PPO
            loss = -torch.min(surr1, surr2) + 0.5 * self.MseLoss(state_values, rewards) - 0.01 * dist_entropy
            
            # take gradient step
            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()
            
        # Copy new weights into old policy
        self.policy_old.load_state_dict(self.policy.state_dict())

        # clear buffer
        self.buffer.clear()
    
    def save(self, checkpoint_path):
        torch.save(self.policy_old.state_dict(), checkpoint_path)
   
    def load(self, checkpoint_path):
        self.policy_old.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))
        self.policy.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))

In [13]:
def _get_new_ppo(env: MultiAgentEnv, interface) -> PPO:
    obs_space = env.observation_space.shape[0]
    if interface is None:
        raise Exception("Must supply interface")
    elif type(interface) == AdaptiveRed:
        action_space = len(interface.action_dict)
    else:
        action_space = env.action_space.n
    agent = PPO(obs_space, action_space)
    return agent

In [14]:
red_agent = _get_new_ppo(env, red)
blue_agent = _get_new_ppo(env, blue)

In [15]:
state = env.reset()

In [16]:
red_action = red_agent.select_action(state)
blue_action = blue_agent.select_action(state)

In [17]:
red_action

2

In [18]:
blue_action

60

In [19]:
env.step(red_action, blue_action)

(array([0., 1., 1., ..., 0., 0., 0.], dtype=float32),
 -4.99,
 -0.01,
 False,
 {'initial_state': {'3d208d21-c737-4165-b0a8-b3a38b9976e1': 0,
   'cf44d970-ff5c-4b1e-81f7-9727726395dc': 0,
   'fdcef68b-7f8e-4fdb-8beb-45468e95a52f': 0,
   '37cc3860-8223-4cd3-9ae2-b608c1dc9430': 0,
   '59e06fa4-2804-47b5-bae2-02fd335709e2': 0,
   'e92a8bd2-93ff-476d-9653-2dd64d24bc2a': 0,
   '2290a52f-b96d-4224-b38f-604036790ee9': 0,
   'cc6d3868-3dd5-413d-8174-51c238916056': 0,
   '14ee6488-5dd1-45b2-a2a1-88f95fcc218e': 0,
   '137c3b16-87b8-458d-a3ab-768a3e89bcdf': 0,
   '81100995-f972-40e7-b7c4-9ef2f9dd0c23': 0,
   'bfb4b33d-ac97-4915-b5f9-7b6a120c90bf': 0,
   '63477212-ad2b-490d-b109-b6049f03b663': 0,
   '4a4cd23d-53e8-43c9-aace-783d749a1980': 0,
   '564fdcba-2aec-4dd0-b369-df37a4efc7f4': 0,
   '38323875-ac0d-42dc-961f-c588715715f5': 0,
   '4acc637e-7ce9-417a-9762-8a4f4072ebeb': 0,
   'd246cc68-20f5-4b13-9388-7ae8f2627060': 0,
   '731526dd-e382-4a66-953a-dba9d0634d4c': 0,
   '29f3d803-65fd-4f8f-9941-e96