In [2]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import sys, os
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from iot_env import CyberBattleIoT, is_unsafe_state
from cyberbattle.agents.baseline.agent_dql import DeepQLearnerPolicy
from cyberbattle.agents.baseline.agent_wrapper import train
from project import SACAgent  # your custom SAC agent

ModuleNotFoundError: No module named 'cyberbattle'

In [None]:
# Hyperparameters
EPISODES = 50
MAX_STEPS = 1000
SAFE_THRESHOLD = 0.5  # For SafeSAC filtering

In [None]:
def is_unsafe_state(state):
    # Basic heuristic: if state reveals access to a sensitive node, consider it unsafe
    risky_keywords = ["BabyMonitor", "DoorLock", "Thermostat"]
    state_str = str(state)
    return any(keyword in state_str for keyword in risky_keywords)

In [None]:
def run_dql():
    print("Training DQL...")
    env = CyberBattleIoT()
    agent = DeepQLearnerPolicy(
        neural_net=None,
        environment=env,
        gamma=0.9,
        learning_rate=0.01,
        replay_memory_size=10000,
        target_update=20,
    )
    stats = train(env, agent, episodes=EPISODES)
    return stats['rewards']

In [None]:
def run_sac():
    print("Training SAC...")
    env = CyberBattleIoT()
    agent = SACAgent(env)
    rewards = []
    for ep in range(EPISODES):
        state = env.reset()
        ep_reward = 0
        for step in range(MAX_STEPS):
            action = agent.select_action(state)
            next_state, reward, done, _ = env.step(action)
            agent.replay_buffer.push(state, action, reward, next_state, done)
            agent.update_parameters()
            state = next_state
            ep_reward += reward
            if done:
                break
        rewards.append(ep_reward)
    return rewards

In [None]:
def run_safe_sac():
    print("Pretraining SAC for SafeSAC...")
    env = CyberBattleIoT()
    agent = SACAgent(env)
    q_safe = {}  # simulate safety critic
    pretrain_rewards = []

    # Pretraining Phase
    for ep in range(EPISODES):
        state = env.reset()
        ep_reward = 0
        for step in range(MAX_STEPS):
            action = agent.select_action(state)
            next_state, reward, done, _ = env.step(action)
            unsafe = is_unsafe_state(next_state)
            q_safe[(str(state), action)] = 0.0 if unsafe else 1.0
            agent.replay_buffer.push(state, action, reward, next_state, done)
            agent.update_parameters()
            state = next_state
            ep_reward += reward
            if done:
                break
        pretrain_rewards.append(ep_reward)

    print("Finetuning with SafeSAC constraints...")
    finetune_rewards = []
    for ep in range(EPISODES):
        state = env.reset()
        ep_reward = 0
        for step in range(MAX_STEPS):
            unsafe_actions = [a for a in range(env.action_space.n)
                              if q_safe.get((str(state), a), 1.0) < SAFE_THRESHOLD]
            action = agent.select_action(state)
            while action in unsafe_actions:
                action = agent.select_action(state)  # reject unsafe
            next_state, reward, done, _ = env.step(action)
            agent.replay_buffer.push(state, action, reward, next_state, done)
            agent.update_parameters()
            state = next_state
            ep_reward += reward
            if done:
                break
        finetune_rewards.append(ep_reward)

    return pretrain_rewards, finetune_rewards

In [None]:
def plot_results(dql, sac, safe_pre, safe_fine):
    plt.figure()
    plt.plot(dql, label="DQL")
    plt.plot(sac, label="SAC")
    plt.plot(safe_pre, label="SafeSAC Pretrain")
    plt.plot(safe_fine, label="SafeSAC Finetune")
    plt.xlabel("Episode")
    plt.ylabel("Total Reward")
    plt.title("Agent Reward Comparison")
    plt.legend()
    plt.grid()
    plt.savefig("training_results.png")
    plt.show()

In [None]:
if __name__ == "__main__":
    dql_rewards = run_dql()
    sac_rewards = run_sac()
    safe_pre, safe_fine = run_safe_sac()
    plot_results(dql_rewards, sac_rewards, safe_pre, safe_fine)