In [None]:
from CybORG.Agents import *
from pprint import pprint
from ipaddress import *
from CybORG import CybORG
from CybORG.Simulator.Scenarios import EnterpriseScenarioGenerator
from CybORG.Simulator.Actions import *
from CybORG.Agents.Wrappers import *

class CAGE4CustomRed(FiniteStateRedAgent):
    def __init__(self, name=None, np_random=None, agent_subnets=None):
        super().__init__(name=name, np_random=np_random, agent_subnets=agent_subnets)
        self.print_action_output = True
        self.print_obs_output    = True
        self.prioritise_servers  = True

    def last_turn_summary(self, observation, action, success):
        if success is None:
            return
        super().last_turn_summary(observation, action, success)
    
    def set_host_state_priority_list(self):
        host_state_priority_list = {
            'K': 28,  'KD': 23,
            'S': 18,  'SD': 14,
            'U': 9,   'UD': 4,
            'R': 2,   'RD': 2
        }
        return host_state_priority_list

    def state_transitions_probability(self):
        map = {
            'K':  [0.3, 0.3, 0.2, None, None, None, 0.2, None, None],
            'KD': [None, 0.6, 0.4, None, None, None, None, None, None],
            'S':  [0.2, None, None, 0.2, 0.5, None, None, 0.1, None],
            'SD': [None, None, None, 0.7, 0.2, None, None, 0.1, None],
            'U':  [0.1, None, None, None, None, 0.7, None, 0.2, 0.0],
            'UD': [None, None, None, None, None, 1.0, None, None, 0.0],
            'R':  [0.3, None, None, None, None, None, 0.5, 0.2, 0.0],
            'RD': [None, None, None, None, None, None, 0.5, 0.5, 0.0],
            'F'  :[1.0, None, None, None, None, None, None, None, 0.0]
        }
        return map

from CybORG.Agents.SimpleAgents.BaseAgent import BaseAgent
from CybORG.Agents.SimpleAgents.BlueReactAgent import *
from CybORG.Shared import Results
from CybORG.Simulator.Actions import Monitor, Remove, Restore, Analyse, DeployDecoy
from CybORG.Simulator.Actions.ConcreteActions import ControlTraffic
from CybORG.Simulator.Actions import Action

from inspect import signature
from typing import Dict, Any
from ipaddress import IPv4Address

from CybORG.Agents.SimpleAgents.BaseAgent import BaseAgent
from CybORG.Shared import Results
from CybORG.Shared.Enums import TernaryEnum
from CybORG.Simulator.Actions import Monitor, Analyse, Remove, Restore, DeployDecoy, Sleep
from CybORG.Simulator.Actions.ConcreteActions.ControlTraffic import AllowTrafficZone, BlockTrafficZone

class BlueDefenderAgent(BaseAgent):
    def __init__(self, name='Blue'):
        super().__init__(name)
        self.host_list        = []
        self.analyse_queue    = []
        self.restore_queue    = []
        self.decoy_queue      = []
        self.control_queue    = []
        self.decoyed_hosts    = set()
        self.controlled_hosts = set()
        self.last_action      = None
        self.last_host        = None
        self.host_states: Dict[str, Dict[str, Any]] = {}
        self.action_params: Dict[Any, Any] = {}

    def set_initial_values(self, action_space, observation):
        if isinstance(action_space, dict) and 'action' in action_space:
            self.action_params = {
                cls: signature(cls).parameters
                for cls in action_space['action'].keys()
            }
        self._process_new_observations(observation)

    def train(self, results: Results):
        pass

    def end_episode(self):
        self.__init__(name=self.name)

    def _process_new_observations(self, observation: dict):
        for key, info in observation.items():
            if key == 'success' or not isinstance(info, dict):
                continue
            iface = info.get('Interface', [])
            if not iface:
                continue
            ip = str(iface[0].get('ip_address'))
            hostname = info.get('System info', {}).get('Hostname')
            if ip not in self.host_states:
                self.host_states[ip] = {'hostname': hostname, 'compromised': False}
            for sess in info.get('Sessions', []):
                if sess.get('agent', '').startswith('red'):
                    self.host_states[ip]['compromised'] = True
                    break
            else:
                if 'Processes' in info and any('PID' in p for p in info['Processes']):
                    self.host_states[ip]['compromised'] = True

    def _clean_queues(self, observation: dict):
        visible = {
            v.get('System info', {}).get('Hostname')
            for k, v in observation.items()
            if k != 'success' and isinstance(v, dict)
        }
        for Q in (self.host_list, self.analyse_queue,
                  self.restore_queue, self.decoy_queue,
                  self.control_queue):
            Q[:] = [h for h in Q if h in visible]
        self.decoyed_hosts    &= visible
        self.controlled_hosts &= visible

    def _choose_action(self, action_cls, hostname, action_space):
        def pick_session():
            return next(iter(action_space.get('session', {})))

        if action_cls in (BlockTrafficZone, AllowTrafficZone):
            subs = [s for s, ok in action_space.get('subnet', {}).items() if ok]
            if len(subs) >= 2:
                from_s, to_s = subs[0], subs[1]
            elif subs:
                from_s = to_s = subs[0]
            else:
                return Monitor(agent=self.name, session=pick_session())
            return action_cls(
                agent=self.name,
                session=pick_session(),
                from_subnet=from_s,
                to_subnet=to_s
            )

        params = {}
        for name in self.action_params.get(action_cls, []):
            opts = [v for v, ok in action_space.get(name, {}).items() if ok]
            if name == 'hostname':
                params[name] = self.host_states[hostname]['hostname']
            elif name in ('ip address', 'ip_address'):
                params[name] = IPv4Address(hostname)
            elif name == 'agent':
                params[name] = self.name
            elif name == 'session':
                params[name] = pick_session()
            elif opts:
                params[name] = self.np_random.choice(opts)
            else:
                return Sleep()

        return action_cls(**params)

    def get_action(self, observation: dict, action_space: dict):
        print(f"[BlueDefenderAgent] get_action called. observation type: {type(observation)}")
        print(f"observation: {observation}")
        success = None
        if 'success' in observation:
            success = observation.pop('success')
        observation.pop('action', None)

        if self.last_action is not None and isinstance(success, TernaryEnum) and success.name == 'IN_PROGRESS':
            return Sleep()

        self._clean_queues(observation)
        self._process_new_observations(observation)
        for ip, st in self.host_states.items():
            if st['compromised'] and ip not in self.analyse_queue and ip not in self.host_list:
                self.analyse_queue.append(ip)

        if self.last_action == 'Monitor':
            for key, info in observation.items():
                if key == 'success' or not isinstance(info, dict):
                    continue
                host = info['System info']['Hostname']
                ip = str(info['Interface'][0]['ip_address'])
                if host and ip not in self.analyse_queue and host != 'User0' and self.host_states[ip]['compromised']:
                    self.analyse_queue.append(ip)

        if self.last_action == 'Analyse' and self.last_host:
            if self.host_states[self.last_host]['compromised']:
                self.host_list.append(self.last_host)

        if self.last_action == 'Remove' and self.last_host:
            self.restore_queue.append(self.last_host)
            if self.last_host not in self.decoyed_hosts:
                self.decoy_queue.append(self.last_host)

        if self.last_action == 'Restore' and self.last_host:
            self.control_queue.append(self.last_host)

        if self.restore_queue:
            ip = self.restore_queue.pop(0)
            act = self._choose_action(Restore, ip, action_space)
            self.last_action, self.last_host = 'Restore', ip
            return act

        if self.decoy_queue:
            ip = self.decoy_queue.pop(0)
            act = self._choose_action(DeployDecoy, ip, action_space)
            self.decoyed_hosts.add(ip)
            self.last_action, self.last_host = 'DeployDecoy', ip
            return act

        if self.control_queue:
            ip = self.control_queue.pop(0)
            act = self._choose_action(BlockTrafficZone, ip, action_space)
            self.controlled_hosts.add(ip)
            self.last_action, self.last_host = 'BlockTrafficZone', ip
            return act

        if self.host_list:
            ip = self.host_list.pop(0)
            act = self._choose_action(Remove, ip, action_space)
            self.last_action, self.last_host = 'Remove', ip
            return act

        if self.analyse_queue:
            ip = self.analyse_queue.pop(0)
            act = self._choose_action(Analyse, ip, action_space)
            self.last_action, self.last_host = 'Analyse', ip
            return act

        act = self._choose_action(Monitor, None, action_space)
        self.last_action, self.last_host = 'Monitor', None
        return act


In [None]:
def get_current_red_agents(observation):
    red_agents = set()
    for host_data in observation.values():
        if isinstance(host_data, dict) and "Sessions" in host_data:
            for session in host_data["Sessions"]:
                agent_name = session.get("agent", "")
                if "red_agent" in agent_name:
                    red_agents.add(agent_name)
    return list(red_agents)

def print_active_red_agents(cyborg):
    agent_names = cyborg.agents 
    for agent in agent_names:
        obs = cyborg.get_observation(agent=agent)
        current_red_agents = get_current_red_agents(obs)
        if current_red_agents:
            print("Active red agents in the environment:", current_red_agents)

In [None]:
sg = EnterpriseScenarioGenerator(blue_agent_class=BlueDefenderAgent, 
                                green_agent_class=EnterpriseGreenAgent, 
                                red_agent_class=CAGE4CustomRed,
                                steps=200)
cyborg = CybORG(scenario_generator=sg, seed=1000)
red_agent_name = 'red_agent_0'
blue_agent_name = 'blue_agent_0'

reset = cyborg.reset(agent=red_agent_name)
action_space = cyborg.get_action_space(agent=red_agent_name)
initial_obs = reset.observation

reset = cyborg.reset(agent=blue_agent_name)
action_space_blue = cyborg.get_action_space(agent=blue_agent_name)
initial_obs_blue = reset.observation

#pprint(initial_obs)
#pprint(initial_obs_blue)
#pprint(action_space_blue)

for i in range(100):
    cyborg.parallel_step()