# Reinforcement Learning Model



In [1]:
pip install tensorflow


Defaulting to user installation because normal site-packages is not writeable
Collecting tensorflow
  Downloading tensorflow-2.19.0-cp312-cp312-win_amd64.whl.metadata (4.1 kB)
Collecting astunparse>=1.6.0 (from tensorflow)
  Downloading astunparse-1.6.3-py2.py3-none-any.whl.metadata (4.4 kB)
Collecting flatbuffers>=24.3.25 (from tensorflow)
  Downloading flatbuffers-25.2.10-py2.py3-none-any.whl.metadata (875 bytes)
Collecting gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 (from tensorflow)
  Downloading gast-0.6.0-py3-none-any.whl.metadata (1.3 kB)
Collecting google-pasta>=0.1.1 (from tensorflow)
  Downloading google_pasta-0.2.0-py3-none-any.whl.metadata (814 bytes)
Collecting libclang>=13.0.0 (from tensorflow)
  Downloading libclang-18.1.1-py2.py3-none-win_amd64.whl.metadata (5.3 kB)
Collecting opt-einsum>=2.3.2 (from tensorflow)
  Downloading opt_einsum-3.4.0-py3-none-any.whl.metadata (6.3 kB)
Collecting termcolor>=1.1.0 (from tensorflow)
  Using cached termcolor-2.5.0-py3-none-any.whl.metadata



In [None]:
pip install stable-baselines3

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [None]:
pip install "shimmy>=2.0"

In [None]:
pip install gymnasium stable-baselines3 joblib numpy

In [None]:
pip install tensorboard

In [5]:
import os
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback
from stable_baselines3.common.monitor import Monitor
import matplotlib.pyplot as plt
from security_env import SecurityEnv

# Check if tensorboard is installed
try:
    from torch.utils.tensorboard import SummaryWriter
    TENSORBOARD_AVAILABLE = True
except ImportError:
    TENSORBOARD_AVAILABLE = False
    print("Tensorboard is not installed. Training logs will not be saved.")
    print("To install tensorboard, run: pip install tensorboard")

def create_env(render_mode=None):
    """Create and wrap the environment."""
    env = SecurityEnv(render_mode=render_mode)
    env = Monitor(env)
    env = DummyVecEnv([lambda: env])
    env = VecNormalize(env, norm_obs=True, norm_reward=True)
    return env

def train_agent():
    """Train the PPO agent on the SecurityEnv."""
    print("Starting training process...")

    # Create directories for saving models and logs
    os.makedirs("models", exist_ok=True)
    os.makedirs("logs", exist_ok=True)

    # Create training and evaluation environments
    train_env = create_env()
    eval_env = create_env()

    # Create the agent with optimized hyperparameters
    model = PPO(
        "MlpPolicy",
        train_env,
        learning_rate=0.0003,
        n_steps=2048,
        batch_size=64,
        n_epochs=10,
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        clip_range_vf=None,
        ent_coef=0.00,
        vf_coef=0.5,
        max_grad_norm=0.5,
        tensorboard_log="./logs/" if TENSORBOARD_AVAILABLE else None,
        policy_kwargs=dict(
            net_arch=dict(
                pi=[64, 64],
                vf=[64, 64]
            )
        ),
        verbose=0
    )

    # Create callbacks
    eval_callback = EvalCallback(
        eval_env,
        best_model_save_path="./models/best_model",
        log_path="./logs",
        eval_freq=1000,
        deterministic=True,
        render=False
    )

    checkpoint_callback = CheckpointCallback(
        save_freq=5000,
        save_path="./models/",
        name_prefix="security_model"
    )

    # Train the agent
    total_timesteps = 500  # Adjust based on your needs
    model.learn(
        total_timesteps=total_timesteps,
        callback=[eval_callback, checkpoint_callback],
        progress_bar=True
    )

    # Save the final model and normalization stats
    model.save("models/final_security_model")
    train_env.save("models/vec_normalize.pkl")

    print("\nTraining completed!")
    return model, train_env

def evaluate_agent(model, env, num_episodes=5):
    """Evaluate the trained agent."""
    print("\nEvaluating agent...")

    for episode in range(num_episodes):
        obs, _ = env.reset()
        done = False
        truncated = False
        total_reward = 0
        episode_steps = 0

        while not (done or truncated):
            action, _ = model.predict(obs, deterministic=True)

            result = env.step(action)
            if len(result) == 4:
                obs, reward, done, info = result
                truncated = False
            else:
                obs, reward, done, truncated, info = result

            if isinstance(reward, np.ndarray):
                reward = reward.item()

            total_reward += reward
            episode_steps += 1

            # Debugging: Check the type of info
            print(f"Type of info: {type(info)}")
            print(f"Info content: {info}")

            # Print step information
            print(f"\nEpisode {episode + 1}, Step {episode_steps}")
            print(f"Action: {action}")
            print(f"Reward: {reward:.2f}")
            print(f"Security Score: {info['security_score']:.2f}")
            print(f"Fatigue Score: {info['fatigue_score']:.2f}")
            print("Feature Values:")
            for name, value in info['feature_values'].items():
                print(f"  {name}: {value}")

        print(f"\nEpisode {episode + 1} completed:")
        print(f"Total Steps: {episode_steps}")
        print(f"Total Reward: {total_reward:.2f}")
        print("-" * 50) 

def plot_training_results(log_dir="./logs"):
    """Plot the training results."""
    import pandas as pd

    monitor_file = os.path.join(log_dir, "monitor.csv")
    if not os.path.exists(monitor_file):
        print(f"No log file found at {monitor_file}. Skipping plotting.")
        return

    # Read the training logs
    df = pd.read_csv(monitor_file, skiprows=1)

    # Plot episode rewards
    plt.figure(figsize=(12, 5))
    plt.plot(df['r'], label="Episode Reward")
    plt.xlabel("Episode")
    plt.ylabel("Total Reward")
    plt.title("Learning Curve")
    plt.legend()
    plt.savefig("logs/learning_curve.png")
    plt.close()

    # Plot moving average reward
    window_size = 10
    moving_avg = df['r'].rolling(window=window_size).mean()

    plt.figure(figsize=(12, 5))
    plt.plot(moving_avg, label=f"Moving Average Reward (window={window_size})", color='orange')
    plt.xlabel("Episode")
    plt.ylabel("Average Reward")
    plt.title("Moving Average of Episode Rewards")
    plt.legend()
    plt.savefig("logs/moving_average.png")
    plt.close()


def create_eval_env(render_mode=None):
    """Create a raw Gymnasium environment for evaluation."""
    env = SecurityEnv(render_mode=render_mode)
    return env

if __name__ == "__main__":
    # Train the agent
    model, env = train_agent()

    # Evaluate the agent using the raw environment to force Gymnasium API
    eval_env = create_eval_env(render_mode="human")
    evaluate_agent(model, eval_env)

    # Plot training results
    plot_training_results()

Starting training process...


Output()


Training completed!

Evaluating agent...
Type of info: <class 'dict'>

Episode 1, Step 1
Action: [0 0 0 0 3 2 1 4 3 2 4 1 2 0 1 0 0 1 1 0 1 0 1]
Reward: 3.62
Security Score: 14.00
Fatigue Score: 8.92
Feature Values:
  Level of familiarity with cybersecurity practices: 0.0
  Frequency of Password Changes: 1.0
  Difficulty Level: 1.0
  Effort Required: 1.0
  Perceived Importance: 4.0
  Frequency of MFA prompts: 3.0
  Difficulty Level MFA: 2.0
  Effort Required MFA: 5.0
  Perceived Importance of MFA: 4.0
  Which types of MFA do you encounter most often? (Select all that apply)_Authentication app (e.g., Google Authenticator, Microsoft Authenticator): 0.0
  Which types of MFA do you encounter most often? (Select all that apply)_Biometric verification (fingerprint, facial recognition): 1.0
  Which types of MFA do you encounter most often? (Select all that apply)_I do not use MFA: 0.0
  Which types of MFA do you encounter most often? (Select all that apply)_One-time passwords (OTP) via SMS o