In [1]:
# Import necessary libraries

import time
import random
import numpy as np
import copy
import json
from collections import Counter
from typing import Dict, Any, List, Tuple

import gym
from gym import spaces
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.callbacks import EvalCallback, ProgressBarCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy as PPOMlp
from baselines import logger

# Import Yawning Titan specific modules
from yawning_titan.envs.generic.core.blue_interface import BlueInterface
from yawning_titan.envs.generic.core.red_interface import RedInterface
from yawning_titan.envs.generic.generic_env import GenericNetworkEnv
from yawning_titan.envs.generic.core.action_loops import ActionLoop
from yawning_titan.envs.generic.core.network_interface import NetworkInterface
from yawning_titan.game_modes.game_mode import GameMode
from yawning_titan.networks.node import Node
from yawning_titan.networks.network import Network
from yawning_titan.envs.generic.core import reward_functions
from yawning_titan.envs.generic.helpers.eval_printout import EvalPrintout
from yawning_titan.envs.generic.helpers.graph2plot import CustomEnvGraph

# Import Reward Machine related modules
from reward_machines.reward_machine import RewardMachine
from reward_machines.rm_environment import RewardMachineEnv, RewardMachineWrapper
from reward_machines.reward_functions import ConstantRewardFunction
from reward_machines.reward_machine_utils import evaluate_dnf



In [2]:
# Load the game mode from a YAML file
game_mode = GameMode()
game_mode = game_mode.create_from_yaml(yaml='CONFIG YAML FILE', legacy=False, infer_legacy=True)
print(game_mode.game_rules.max_steps.value)

# Create a network representation
network = Network()

# Define network nodes and their positions

# External layer
router_1 = Node("Boundary router packet filter")
router_1.node_position = [0, 600]
network.add_node(router_1)

switch_1 = Node("External switch")
switch_1.node_position = [0, 500]
network.add_node(switch_1)

network_ids_1 = Node("Network IDS (DMZ)")
network_ids_1.node_position = [-200, 500]
network.add_node(network_ids_1)

dns_server_external = Node("DNS server esterno")
dns_server_external.node_position = [200, 500]
network.add_node(dns_server_external)

# DMZ layer
server_1 = Node("Main Firewall/VPN server/NAT")
server_1.node_position = [-100, 400]
network.add_node(server_1)

server_2 = Node("External Web Server/Host IDS")
server_2.node_position = [100, 400]
network.add_node(server_2)

switch_2 = Node("Internal Switch (DMZ)")
switch_2.node_position = [0, 300]
network.add_node(switch_2)

network_ids_2 = Node("Network IDS (DMZ Internal)")
network_ids_2.node_position = [-200, 300]
network.add_node(network_ids_2)

# Internal layer
internal_firewall = Node("Internal Firewall")
internal_firewall.node_position = [0, 200]
network.add_node(internal_firewall)

server_3 = Node("Database Server")
server_3.node_position = [200, 200]
network.add_node(server_3)

email_server = Node("Email Server/Host IDS")
email_server.node_position = [-200, 200]
network.add_node(email_server)

dns_server_internal = Node("DNS server interno")
dns_server_internal.node_position = [400, 200]
network.add_node(dns_server_internal)

web_proxy = Node("Web proxy server")
web_proxy.node_position = [-400, 200]
network.add_node(web_proxy)

switch_3 = Node("Internal Switch")
switch_3.node_position = [0, 100]
network.add_node(switch_3)

network_ids_3 = Node("Network IDS Internal")
network_ids_3.node_position = [-200, 100]
network.add_node(network_ids_3)

# Subnet layer
subnet1_router = Node("Router Subnet 1")
subnet1_router.node_position = [-300, 0]
network.add_node(subnet1_router)

subnet2_router = Node("Router Subnet 2")
subnet2_router.node_position = [-100, 0]
network.add_node(subnet2_router)

subnet3_router = Node("Router Subnet 3")
subnet3_router.node_position = [100, 0]
network.add_node(subnet3_router)

subnet4_router = Node("Router Subnet 4")
subnet4_router.node_position = [300, 0]
network.add_node(subnet4_router)

# Client/Server layer
client1_management = Node("Client1 Management")
client1_management.node_position = [-300, -100]
network.add_node(client1_management)

client1_hr = Node("Client1 HR")
client1_hr.node_position = [-100, -100]
network.add_node(client1_hr)

client1_it = Node("Client1 IT")
client1_it.node_position = [100, -100]
network.add_node(client1_it)

server_backup = Node("Server backup")
server_backup.node_position = [300, -100]
network.add_node(server_backup)

# Add edges to connect nodes
network.add_edge(router_1, switch_1)
network.add_edge(switch_1, server_1)
network.add_edge(switch_1, server_2)
network.add_edge(switch_1, network_ids_1)
network.add_edge(switch_1, dns_server_external)
network.add_edge(server_1, switch_2)
network.add_edge(switch_2, server_3)
network.add_edge(switch_2, internal_firewall)
network.add_edge(switch_2, email_server)
network.add_edge(switch_2, dns_server_internal)
network.add_edge(switch_2, web_proxy)
network.add_edge(switch_2, network_ids_2)
network.add_edge(internal_firewall, switch_3)
network.add_edge(switch_3, network_ids_3)
network.add_edge(switch_3, subnet1_router)
network.add_edge(switch_3, subnet2_router)
network.add_edge(switch_3, subnet3_router)
network.add_edge(switch_3, subnet4_router)
network.add_edge(subnet1_router, client1_management)
network.add_edge(subnet2_router, client1_hr)
network.add_edge(subnet3_router, client1_it)
network.add_edge(subnet4_router, server_backup)


# Set entry and high-value nodes
router_1.entry_node = True
server_3.high_value_node = True
server_backup.high_value_node = True


# Display the network details
network.show(verbose=True)

200
UUID                                  Name                           High Value Node    Entry Node      Vulnerability  Position (x,y)
------------------------------------  -----------------------------  -----------------  ------------  ---------------  ----------------
7922ca36-aa54-4a19-a1c9-ad11e686a69b  Boundary router packet filter  False              True                     0.01  0.00, 600.00
e67c6569-67b1-4b51-97c2-02e5c134eb00  External switch                False              False                    0.01  0.00, 500.00
e603dd07-9631-4a7f-ab99-0cc1c4a8bf9c  Network IDS (DMZ)              False              False                    0.01  -200.00, 500.00
a152cc0d-c88c-4976-81ae-048b8f0facec  DNS server esterno             False              False                    0.01  200.00, 500.00
c588726a-4c4e-4fc7-836c-83572bc0415c  Main Firewall/VPN server/NAT   False              False                    0.01  -100.00, 400.00
99bd503a-b3c9-4e54-85df-4a7188763d50  External Web Server/

In [3]:
# Create NetworkInterface, RedInterface, and BlueInterface objects
network_interface = NetworkInterface(game_mode=game_mode, network=network)
red = RedInterface(network_interface)
blue = BlueInterface(network_interface)

In [4]:
YTenv = GenericNetworkEnv(
    red,
    blue,
    network_interface,
    show_metrics_every=10
)




In [5]:
# Create and train PPO agent 
agent = PPO(PPOMlp, YTenv, verbose=1)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [6]:
agent.learn(total_timesteps=10000)

----------------------------------
| rollout/           |           |
|    ep_len_mean     | 200       |
|    ep_rew_mean     | -2.01e+03 |
| time/              |           |
|    fps             | 754       |
|    iterations      | 1         |
|    time_elapsed    | 2         |
|    total_timesteps | 2048      |
----------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 200          |
|    ep_rew_mean          | -2.04e+03    |
| time/                   |              |
|    fps                  | 630          |
|    iterations           | 2            |
|    time_elapsed         | 6            |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0007171217 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -4.26        |
|    explained_variance   | -0.000619    |
|    

<stable_baselines3.ppo.ppo.PPO at 0x15fbc4c70>

In [7]:
evaluate_policy(agent, YTenv, n_eval_episodes=10)




(-1106.9238873004913, 3.757377493381295)