In [1]:
import os
import gym
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import A2C
from stable_baselines3.common.callbacks import BaseCallback
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten, Dense
from tensorflow.keras.optimizers import Adam

### Creating Initial Enviornment

In [2]:
# Creates new tiny benchmark scenario in fully observable mode with vectorized action space with
# and a flat observations
env = gym.make('nasim:Tiny-v2')

### Testing Envirnment (Untrained)

In [3]:
#Runs ten episodes in the enviornment to test its functionality
episodes = 10
for episode in range(episodes):
    state = env.reset()
    done = False
    score = 0
    #While running episode displays state after each action
    while not done:
        env.render_state()
        action = env.action_space.sample()
        print(type(action))
        state, reward, done, info = env.step(action)
        score += reward
    #Prints out the episode number and the associated score for the episode
    print(f'Episode:{episode + 1} Score:{score}')
env.close()

State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Access | linux | ssh  | tomcat |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
|  (1, 0) |    False    |    True   |    True    |  0.0  |       0.0       |  0.0   |  True | True |  True  |
|  (2, 0) |    False    |   False   |   False    | 100.0 |       0.0       |  0.0   |  True | True |  True  |
|  (3, 0) |    False    |   False   |   False    | 100.0 |       0.0       |  0.0   |  True | True |  True  |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
<class 'numpy.ndarray'>
State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Acces

State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Access | linux | ssh  | tomcat |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
|  (1, 0) |     True    |    True   |    True    |  0.0  |       0.0       |  2.0   |  True | True |  True  |
|  (2, 0) |    False    |    True   |    True    | 100.0 |       0.0       |  0.0   |  True | True |  True  |
|  (3, 0) |     True    |    True   |    True    | 100.0 |       0.0       |  2.0   |  True | True |  True  |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
<class 'numpy.ndarray'>
State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Acces

<class 'numpy.ndarray'>
State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Access | linux | ssh  | tomcat |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
|  (1, 0) |     True    |    True   |    True    |  0.0  |       0.0       |  1.0   |  True | True |  True  |
|  (2, 0) |    False    |    True   |    True    | 100.0 |       0.0       |  0.0   |  True | True |  True  |
|  (3, 0) |    False    |    True   |    True    | 100.0 |       0.0       |  0.0   |  True | True |  True  |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
<class 'numpy.ndarray'>
State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value |

State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Access | linux | ssh  | tomcat |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
|  (1, 0) |     True    |    True   |    True    |  0.0  |       0.0       |  2.0   |  True | True |  True  |
|  (2, 0) |    False    |    True   |    True    | 100.0 |       0.0       |  0.0   |  True | True |  True  |
|  (3, 0) |     True    |    True   |    True    | 100.0 |       0.0       |  2.0   |  True | True |  True  |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
<class 'numpy.ndarray'>
State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Acces

State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Access | linux | ssh  | tomcat |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
|  (1, 0) |     True    |    True   |    True    |  0.0  |       0.0       |  2.0   |  True | True |  True  |
|  (2, 0) |     True    |    True   |    True    | 100.0 |       0.0       |  1.0   |  True | True |  True  |
|  (3, 0) |     True    |    True   |    True    | 100.0 |       0.0       |  2.0   |  True | True |  True  |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
<class 'numpy.ndarray'>
State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Acces

State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Access | linux | ssh  | tomcat |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
|  (1, 0) |     True    |    True   |    True    |  0.0  |       0.0       |  1.0   |  True | True |  True  |
|  (2, 0) |    False    |    True   |    True    | 100.0 |       0.0       |  0.0   |  True | True |  True  |
|  (3, 0) |    False    |    True   |    True    | 100.0 |       0.0       |  0.0   |  True | True |  True  |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
<class 'numpy.ndarray'>
State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Acces

State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Access | linux | ssh  | tomcat |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
|  (1, 0) |     True    |    True   |    True    |  0.0  |       0.0       |  2.0   |  True | True |  True  |
|  (2, 0) |     True    |    True   |    True    | 100.0 |       0.0       |  2.0   |  True | True |  True  |
|  (3, 0) |     True    |    True   |    True    | 100.0 |       0.0       |  1.0   |  True | True |  True  |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
<class 'numpy.ndarray'>
State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Acces

State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Access | linux | ssh  | tomcat |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
|  (1, 0) |     True    |    True   |    True    |  0.0  |       0.0       |  1.0   |  True | True |  True  |
|  (2, 0) |    False    |    True   |    True    | 100.0 |       0.0       |  0.0   |  True | True |  True  |
|  (3, 0) |    False    |    True   |    True    | 100.0 |       0.0       |  0.0   |  True | True |  True  |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
<class 'numpy.ndarray'>
State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Acces

State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Access | linux | ssh  | tomcat |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
|  (1, 0) |     True    |    True   |    True    |  0.0  |       0.0       |  2.0   |  True | True |  True  |
|  (2, 0) |    False    |    True   |    True    | 100.0 |       0.0       |  0.0   |  True | True |  True  |
|  (3, 0) |    False    |    True   |    True    | 100.0 |       0.0       |  0.0   |  True | True |  True  |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
<class 'numpy.ndarray'>
State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Acces

State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Access | linux | ssh  | tomcat |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
|  (1, 0) |     True    |    True   |    True    |  0.0  |       0.0       |  2.0   |  True | True |  True  |
|  (2, 0) |    False    |    True   |    True    | 100.0 |       0.0       |  0.0   |  True | True |  True  |
|  (3, 0) |     True    |    True   |    True    | 100.0 |       0.0       |  2.0   |  True | True |  True  |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
<class 'numpy.ndarray'>
State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Acces

<class 'numpy.ndarray'>
State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Access | linux | ssh  | tomcat |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
|  (1, 0) |     True    |    True   |    True    |  0.0  |       0.0       |  2.0   |  True | True |  True  |
|  (2, 0) |     True    |    True   |    True    | 100.0 |       0.0       |  1.0   |  True | True |  True  |
|  (3, 0) |     True    |    True   |    True    | 100.0 |       0.0       |  2.0   |  True | True |  True  |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
<class 'numpy.ndarray'>
State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value |

State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Access | linux | ssh  | tomcat |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
|  (1, 0) |     True    |    True   |    True    |  0.0  |       0.0       |  2.0   |  True | True |  True  |
|  (2, 0) |    False    |    True   |    True    | 100.0 |       0.0       |  0.0   |  True | True |  True  |
|  (3, 0) |     True    |    True   |    True    | 100.0 |       0.0       |  2.0   |  True | True |  True  |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
<class 'numpy.ndarray'>
State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Acces

State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Access | linux | ssh  | tomcat |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
|  (1, 0) |    False    |    True   |    True    |  0.0  |       0.0       |  0.0   |  True | True |  True  |
|  (2, 0) |    False    |   False   |   False    | 100.0 |       0.0       |  0.0   |  True | True |  True  |
|  (3, 0) |    False    |   False   |   False    | 100.0 |       0.0       |  0.0   |  True | True |  True  |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
<class 'numpy.ndarray'>
State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Acces

State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Access | linux | ssh  | tomcat |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
|  (1, 0) |     True    |    True   |    True    |  0.0  |       0.0       |  1.0   |  True | True |  True  |
|  (2, 0) |    False    |    True   |   False    | 100.0 |       0.0       |  0.0   |  True | True |  True  |
|  (3, 0) |    False    |    True   |   False    | 100.0 |       0.0       |  0.0   |  True | True |  True  |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
<class 'numpy.ndarray'>
State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Acces

State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Access | linux | ssh  | tomcat |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
|  (1, 0) |     True    |    True   |    True    |  0.0  |       0.0       |  2.0   |  True | True |  True  |
|  (2, 0) |    False    |    True   |    True    | 100.0 |       0.0       |  0.0   |  True | True |  True  |
|  (3, 0) |    False    |    True   |    True    | 100.0 |       0.0       |  0.0   |  True | True |  True  |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
<class 'numpy.ndarray'>
State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Acces

State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Access | linux | ssh  | tomcat |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
|  (1, 0) |     True    |    True   |    True    |  0.0  |       0.0       |  2.0   |  True | True |  True  |
|  (2, 0) |     True    |    True   |    True    | 100.0 |       0.0       |  1.0   |  True | True |  True  |
|  (3, 0) |     True    |    True   |    True    | 100.0 |       0.0       |  2.0   |  True | True |  True  |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
<class 'numpy.ndarray'>
State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Acces

In [4]:
env.reset()

array([  0.,   1.,   0.,   0.,   1.,   0.,   1.,   1.,   0.,   0.,   0.,
         1.,   1.,   1.,   0.,   0.,   1.,   0.,   1.,   0.,   0.,   0.,
       100.,   0.,   0.,   1.,   1.,   1.,   0.,   0.,   0.,   1.,   1.,
         0.,   0.,   0., 100.,   0.,   0.,   1.,   1.,   1.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.], dtype=float32)

### Allowing For Models To Be Saved (Code Found From https://github.com/nicknochnack/MarioRL)
    

In [5]:
#Allowing-For-Models-To-Be-Saved# Class allowing for models to be 
# automatically saved at a set number of steps
class TrainAndLoggingCallback(BaseCallback):

    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)

        return True

In [6]:
# Create directories to save our models and logs to
save_dir = './training/'
log_dir = './logs/'

In [7]:
callback = TrainAndLoggingCallback(check_freq = 5000, save_path = save_dir)

### Agent Creation and Training

#### PPO Agent Creation and Training

In [8]:
# Creates a PPO agent that uses an MLP policy within the enviornment passed to it
ppo_agent = PPO('MlpPolicy', env, verbose = 1, tensorboard_log = log_dir, learning_rate = 0.000001,
               n_steps = 512)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [9]:
ppo_agent.learn(total_timesteps = 10000, callback = callback)

Logging to ./logs/PPO_10
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 222      |
|    ep_rew_mean     | 15.5     |
| time/              |          |
|    fps             | 268      |
|    iterations      | 1        |
|    time_elapsed    | 1        |
|    total_timesteps | 512      |
---------------------------------
-------------------------------------------
| rollout/                |               |
|    ep_len_mean          | 240           |
|    ep_rew_mean          | -2.5          |
| time/                   |               |
|    fps                  | 241           |
|    iterations           | 2             |
|    time_elapsed         | 4             |
|    total_timesteps      | 1024          |
| train/                  |               |
|    approx_kl            | 1.0660151e-06 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -3.58         |
|    explained_vari

-------------------------------------------
| rollout/                |               |
|    ep_len_mean          | 187           |
|    ep_rew_mean          | 45.9          |
| time/                   |               |
|    fps                  | 206           |
|    iterations           | 11            |
|    time_elapsed         | 27            |
|    total_timesteps      | 5632          |
| train/                  |               |
|    approx_kl            | 5.3670956e-06 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -3.58         |
|    explained_variance   | 0.000278      |
|    learning_rate        | 1e-06         |
|    loss                 | 281           |
|    n_updates            | 100           |
|    policy_gradient_loss | -0.000234     |
|    value_loss           | 571           |
-------------------------------------------
------------------------------------------
| rollout/                |      

------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 187          |
|    ep_rew_mean          | 44.7         |
| time/                   |              |
|    fps                  | 210          |
|    iterations           | 20           |
|    time_elapsed         | 48           |
|    total_timesteps      | 10240        |
| train/                  |              |
|    approx_kl            | 9.825453e-08 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -3.58        |
|    explained_variance   | 0.000431     |
|    learning_rate        | 1e-06        |
|    loss                 | 520          |
|    n_updates            | 190          |
|    policy_gradient_loss | -1.4e-05     |
|    value_loss           | 1.06e+03     |
------------------------------------------


<stable_baselines3.ppo.ppo.PPO at 0x2306e340e50>

#### PPO Model Testing

In [10]:
# Running the enviornment for 10 episodes similar to earlier however now
# the action produces is a prediction from the trained model based on the state provided to it
episodes = 10
for episode in range(episodes):
    state = env.reset()
    done = False
    score = 0
    #While running episode displays state after each action
    while not done:
        env.render_state()
        action = ppo_agent.predict(state)
        action = np.array(action)
        state, reward, done, info = env.step(action)
        score += reward
    #Prints out the episode number and the associated score for the episode
    print(f'Episode:{episode + 1} Score:{score}')
env.close()

State:
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
| Address | Compromised | Reachable | Discovered | Value | Discovery Value | Access | linux | ssh  | tomcat |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+
|  (1, 0) |    False    |    True   |    True    |  0.0  |       0.0       |  0.0   |  True | True |  True  |
|  (2, 0) |    False    |   False   |   False    | 100.0 |       0.0       |  0.0   |  True | True |  True  |
|  (3, 0) |    False    |   False   |   False    | 100.0 |       0.0       |  0.0   |  True | True |  True  |
+---------+-------------+-----------+------------+-------+-----------------+--------+-------+------+--------+


  action = np.array(action)


TypeError: only integer scalar arrays can be converted to a scalar index

#### A2C Agent Creation and Training

In [None]:
a2c_agent = A2C('MlpPolicy', env, verbose = 1, tensorboard_log = log_dir, learning_rate = 0.000001,
               n_steps = 512)

In [None]:
a2c_agent.learn(total_timesteps = 1000000, callback = callback)