In [None]:
from cyberbattle.simulation.model import *
from cyberbattle.simulation.model import Precondition
import networkx as nx
import sys, logging
import cyberbattle.simulation.commandcontrol as commandcontrol
import cyberbattle.samples.toyctf.toy_ctf as ctf
import plotly.offline as plo
import random
import numpy as np
import torch

SEED = 100

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

plo.init_notebook_mode(connected=True)
logging.basicConfig(stream=sys.stdout, level=logging.INFO, format="%(levelname)s: %(message)s")
%matplotlib inline

In [None]:
# --- ws_entry ---
vuln_admin = VulnerabilityInfo(
    description="Weak Admin Password",
    type=VulnerabilityType.LOCAL,
    outcome=LeakedNodesId(["web1"]),    
    precondition=Precondition("true"),
    cost=1.0,
    reward_string="Found web1 host"
)

# --- web1 ---
vuln_web = VulnerabilityInfo(
    description="Web exploit",
    type=VulnerabilityType.REMOTE,
    outcome=LeakedNodesId(["app1"]),   
    precondition=Precondition("service:http"),
    cost=1.0,
    reward_string="App server discovered"
)
vuln_web_creds = VulnerabilityInfo(
    description="Web admin creds in config",
    type=VulnerabilityType.REMOTE,
    outcome=LeakedCredentials(credentials=[CachedCredential(node="app1", port="SSH", credential="admin:supersecret")
    ]),
    precondition=Precondition("service:http"),
    cost=1.2,
    reward_string="App server admin credentials leaked"
)

# --- web1 (or file1) ---
vuln_git = VulnerabilityInfo(
    description="Git history leak",
    type=VulnerabilityType.REMOTE,
    outcome=LeakedNodesId(["file1"]),  
    precondition=Precondition("service:git"),
    cost=1.0,
    reward_string="File server discovered"
)
vuln_git_creds = VulnerabilityInfo(
    description="Git leak of SMB root password",
    type=VulnerabilityType.REMOTE,
    outcome=LeakedCredentials(credentials=[
        CachedCredential(node="file1", port="SMB", credential="root:toor")
    ]),
    precondition=Precondition("service:git"),
    cost=1.2,
    reward_string="File server root password found"
)

# --- file1 ---
vuln_smb = VulnerabilityInfo(
    description="SMB Weak Share",
    type=VulnerabilityType.REMOTE,
    outcome=LeakedNodesId(["backup1"]),    
    precondition=Precondition("service:smb"),
    cost=1.2,
    reward_string="Backup server discovered"
)
vuln_smb_creds = VulnerabilityInfo(
    description="SMB share creds leak",
    type=VulnerabilityType.REMOTE,
    outcome=LeakedCredentials(credentials=[
        CachedCredential(node="backup1", port="SMB", credential="root:toor")
    ]),
    precondition=Precondition("service:smb"),
    cost=1.4,
    reward_string="Backup server SMB root password leaked"
)

# --- backup1 ---
vuln_backup = VulnerabilityInfo(
    description="Backup Exploit",
    type=VulnerabilityType.REMOTE,
    outcome=LeakedNodesId(["sql1"]),   
    precondition=Precondition("service:smb"),
    cost=1.3,
    reward_string="SQL server found via backup"
)
vuln_backup_creds = VulnerabilityInfo(
    description="Backup reveals SQL creds",
    type=VulnerabilityType.REMOTE,
    outcome=LeakedCredentials(credentials=[
        CachedCredential(node="sql1", port="SQL", credential="dbuser:sqldbpw")
    ]),
    precondition=Precondition("service:smb"),
    cost=1.5,
    reward_string="SQL DB credentials extracted from backup"
)

# --- sql1 ---
vuln_sql = VulnerabilityInfo(
    description="SQL Injection",
    type=VulnerabilityType.REMOTE,
    outcome=LeakedNodesId(["dc1","web3"]),    
    precondition=Precondition("service:sql"),
    cost=2.0,
    reward_string="Domain controller found"
)
vuln_sql_creds = VulnerabilityInfo(
    description="SQLi leaks DC creds",
    type=VulnerabilityType.REMOTE,
    outcome=LeakedCredentials(credentials=[
        CachedCredential(node="dc1", port="RDP", credential="admin:supersecret")
    ]),
    precondition=Precondition("service:sql"),
    cost=2.2,
    reward_string="Domain controller admin creds leaked"
)

# --- dc1 ---
vuln_rdp = VulnerabilityInfo(
    description="RDP Brute Force",
    type=VulnerabilityType.REMOTE,
    outcome=LeakedNodesId(["mgmt1"]),
    precondition=Precondition("service:rdp"),
    cost=2.0,
    reward_string="Management server discovered"
)
vuln_rdp_creds = VulnerabilityInfo(
    description="RDP brute-forced creds",
    type=VulnerabilityType.REMOTE,
    outcome=LeakedCredentials(credentials=[
        CachedCredential(node="mgmt1", port="SSH", credential="admin:supersecret")
    ]),
    precondition=Precondition("service:rdp"),
    cost=2.2,
    reward_string="Management SSH creds leaked"
)

# --- mgmt1 ---
vuln_mgmt = VulnerabilityInfo(
    description="Mgmt panel pivot",
    type=VulnerabilityType.REMOTE,
    outcome=LeakedNodesId(["web2", "mail1", "app2"]),
    precondition=Precondition("service:ssh"),
    cost=2.0,
    reward_string="Core services discovered via mgmt"
)
vuln_mgmt_creds = VulnerabilityInfo(
    description="Mgmt panel leaks creds",
    type=VulnerabilityType.REMOTE,
    outcome=LeakedCredentials(credentials=[
        CachedCredential(node="web2", port="HTTP", credential="admin:supersecret"),
        CachedCredential(node="mail1", port="SMTP", credential="mailadmin:mailpw"),
        CachedCredential(node="app2", port="SSH", credential="root:toor")
    ]),
    precondition=Precondition("service:ssh"),
    cost=2.2,
    reward_string="Core services creds leaked via mgmt"
)

# --- mail1 ---
vuln_mail = VulnerabilityInfo(
    description="Mail leak",
    type=VulnerabilityType.REMOTE,
    outcome=LeakedNodesId(["sql1", "app1"]),  
    precondition=Precondition("service:smtp"),
    cost=1.0,
    reward_string="Mail system leaks database/app paths"
)
vuln_mail_creds = VulnerabilityInfo(
    description="Mail leaks DB/app creds",
    type=VulnerabilityType.REMOTE,
    outcome=LeakedCredentials(credentials=[
        CachedCredential(node="sql1", port="SQL", credential="dbuser:sqldbpw"),
        CachedCredential(node="app1", port="SSH", credential="admin:supersecret")
    ]),
    precondition=Precondition("service:smtp"),
    cost=1.2,
    reward_string="Mail system leaks DB/app credentials"
)

vuln_web2 = VulnerabilityInfo(
    description="Web2 RCE",
    type=VulnerabilityType.REMOTE,
    outcome=LeakedNodesId(["app2"]),
    precondition=Precondition("service:http"),
    cost=1.3,
    reward_string="Remote code execution on web2"
)
vuln_web2_creds = VulnerabilityInfo(
    description="Web2 config leaks creds",
    type=VulnerabilityType.REMOTE,
    outcome=LeakedCredentials(credentials=[
        CachedCredential(node="app2", port="SSH", credential="root:toor")
    ]),
    precondition=Precondition("service:http"),
    cost=1.5,
    reward_string="Leaked app2 SSH credentials from web2"
)

vuln_web3 = VulnerabilityInfo(
    description="Web3 SSRF",
    type=VulnerabilityType.REMOTE,
    outcome=LeakedNodesId(["mail1"]),
    precondition=Precondition("service:http"),
    cost=1.4,
    reward_string="SSRF exploited web3, mail1 discovered"
)
vuln_web3_creds = VulnerabilityInfo(
    description="Web3 backup credentials leak",
    type=VulnerabilityType.REMOTE,
    outcome=LeakedCredentials(credentials=[
        CachedCredential(node="mail1", port="SMTP", credential="mailadmin:mailpw")
    ]),
    precondition=Precondition("service:http"),
    cost=1.6,
    reward_string="Leaked mail1 SMTP credentials from web3"
)

vuln_app1 = VulnerabilityInfo(
    description="App1 PrivEsc",
    type=VulnerabilityType.LOCAL,
    outcome=LeakedNodesId(["sql1"]),
    precondition=Precondition("service:ssh"),
    cost=1.5,
    reward_string="Privilege escalation on app1 led to SQL1"
)
vuln_app1_creds = VulnerabilityInfo(
    description="App1 database creds",
    type=VulnerabilityType.REMOTE,
    outcome=LeakedCredentials(credentials=[
        CachedCredential(node="sql1", port="SQL", credential="dbuser:sqldbpw")
    ]),
    precondition=Precondition("service:ssh"),
    cost=1.7,
    reward_string="Leaked SQL db creds from app1"
)

vuln_app2 = VulnerabilityInfo(
    description="App2 API exploit",
    type=VulnerabilityType.REMOTE,
    outcome=LeakedNodesId(["file1"]),
    precondition=Precondition("service:ssh"),
    cost=1.3,
    reward_string="API exploit on app2 revealed file server"
)
vuln_app2_creds = VulnerabilityInfo(
    description="App2 file creds leak",
    type=VulnerabilityType.REMOTE,
    outcome=LeakedCredentials(credentials=[
        CachedCredential(node="file1", port="SMB", credential="root:toor")
    ]),
    precondition=Precondition("service:ssh"),
    cost=1.5,
    reward_string="Leaked file server SMB creds from app2"
)

#Extra ones

vuln_app2_privesc = VulnerabilityInfo(
    description="SUID binary privilege escalation",
    type=VulnerabilityType.LOCAL,
    outcome=AdminEscalation(),  
    precondition=Precondition("os:Linux"),
    cost=1.7,
    reward_string="Local admin gained on app1"
)
vuln_sql_databreach = VulnerabilityInfo(
    description="Exfiltrate customer DB data",
    type=VulnerabilityType.REMOTE,
    outcome=CustomerData(),
    precondition=Precondition("service:sql"),
    cost=2.5,
    reward_string="Customer database exfiltrated"
)
vuln_ws3_lateral = VulnerabilityInfo(
    description="Lateral move via credential reuse",
    type=VulnerabilityType.REMOTE,
    outcome=LateralMove(),
    precondition=Precondition("service:ssh"),
    cost=1.3,
    reward_string="Lateral movement succeeded"
)
vuln_web2_probe = VulnerabilityInfo(
    description="HTTP banner probe",
    type=VulnerabilityType.REMOTE,
    outcome=ProbeSucceeded(discovered_properties=["os:Windows"]),
    precondition=Precondition("service:http"),
    cost=0.5,
    reward_string="Discovered OS of web2"
)




In [None]:
ALLOW = RulePermission.ALLOW

default_allow_rules = [
    FirewallRule("RDP", ALLOW),
    FirewallRule("SSH", ALLOW),
    FirewallRule("HTTPS", ALLOW),
    FirewallRule("HTTP", ALLOW),
    FirewallRule("SMB", ALLOW),
    FirewallRule("GIT", ALLOW),
    FirewallRule("SMTP", ALLOW),
    FirewallRule("SQL", ALLOW),
]

nodes = {
    # ENTRY NODE
    "ws_entry": NodeInfo(
        services=[],
        firewall=FirewallConfiguration(
            incoming=default_allow_rules,
            outgoing=default_allow_rules
        ),
        value=0,
        properties=["os:Linux", "role:entry", "user:user"],
        vulnerabilities={
            "WeakAdmin": vuln_admin,
        },
        agent_installed=True,
        reimagable=False,
        owned_string="Accesso iniziale alla rete stabilito. Utente con privilegi limitati."
    ),

    # WORKSTATIONS
    "ws2": NodeInfo(
        services=[ListeningService("RDP", allowedCredentials=["admin:supersecret"])],
        firewall=FirewallConfiguration(
            incoming=default_allow_rules,
            outgoing=default_allow_rules
        ),
        value=1.0,
        properties=["os:Windows", "role:workstation", "service:rdp"],
        vulnerabilities={
            "RDPBrute": vuln_rdp,
            "RDPBruteCreds": vuln_rdp_creds,
        },
        owned_string="Accesso RDP ottenuto. Nessun privilegio elevato disponibile."
    ),
    "ws3": NodeInfo(
        services=[ListeningService("SSH", allowedCredentials=["user:password1"])],
        firewall=FirewallConfiguration(
            incoming=default_allow_rules,
            outgoing=default_allow_rules
        ),
        value=1.0,
        properties=["os:Linux", "role:workstation", "user:user"],
        vulnerabilities={
            "WeakAdmin": vuln_admin,
            "LateralMov": vuln_ws3_lateral
        },
        owned_string="Compromesso account SSH. Possibilità di movimento laterale."
    ),
    "ws4": NodeInfo(
        services=[ListeningService("SMB", allowedCredentials=["root:toor"])],
        firewall=FirewallConfiguration(
            incoming=default_allow_rules,
            outgoing=default_allow_rules
        ),
        value=1.0,
        properties=["os:Windows", "role:workstation", "service:smb"],
        vulnerabilities={
            "SMBWeakShare": vuln_smb,
            "SMBCreds": vuln_smb_creds,
        },
        owned_string="Accesso SMB. Credenziali archiviate potenzialmente recuperabili."
    ),

    # WEB SERVERS
    "web1": NodeInfo(
        services=[ListeningService("HTTP"), ListeningService("GIT")],
        firewall=FirewallConfiguration(
            incoming=default_allow_rules,
            outgoing=default_allow_rules
        ),
        value=2.0,
        properties=["os:Linux", "role:webserver", "service:http", "service:git"],
        vulnerabilities={
            "WebExploit": vuln_web,
            "WebAdminCreds": vuln_web_creds,
            "GitHistory": vuln_git,
            "GitSMBCreds": vuln_git_creds,
        },
        owned_string="Accesso HTTP e GIT acquisito. Storico commit e credenziali potenzialmente esposte."
    ),
    "web2": NodeInfo(
        services=[ListeningService("HTTP", allowedCredentials=["user:password1"])],
        firewall=FirewallConfiguration(
            incoming=default_allow_rules,
            outgoing=default_allow_rules
        ),
        value=2.0,
        properties=["os:Windows", "role:webserver", "service:http"],
        vulnerabilities={
            "Web2RCE": vuln_web2,
            "Web2ConfigLeak": vuln_web2_creds,
            "Web2Probing": vuln_web2_probe
        },
        owned_string="Accesso HTTP con credenziali deboli. Possibile RCE rilevato."
    ),
    "web3": NodeInfo(
        services=[ListeningService("HTTP", allowedCredentials=["admin:supersecret"])],
        firewall=FirewallConfiguration(
            incoming=default_allow_rules,
            outgoing=default_allow_rules
        ),
        value=2.0,
        properties=["os:Windows", "role:webserver", "service:http"],
        vulnerabilities={
            "Web3SSRF": vuln_web3,
            "Web3BackupCreds": vuln_web3_creds
        },
        owned_string="HTTP compromesso. Backup accessibile, credenziali scoperte."
    ),

    # APP SERVERS
    "app1": NodeInfo(
        services=[ListeningService("SSH", allowedCredentials=["admin:supersecret"])],
        firewall=FirewallConfiguration(
            incoming=default_allow_rules,
            outgoing=default_allow_rules
        ),
        value=3.0,
        properties=["os:Linux", "role:appserver", "user:admin"],
        vulnerabilities={
            "App1PrivEsc": vuln_app1,
            "App1DBCreds": vuln_app1_creds
        },
        owned_string="Accesso root SSH. Escalation amministrativa completata."
    ),
    "app2": NodeInfo(
        services=[ListeningService("SSH", allowedCredentials=["root:toor"])],
        firewall=FirewallConfiguration(
            incoming=default_allow_rules,
            outgoing=default_allow_rules
        ),
        value=3.0,
        properties=["os:Linux", "role:appserver", "user:admin"],
        vulnerabilities={
            "App2APIExploit": vuln_app2,
            "App2FileCreds": vuln_app2_creds,
            "App2PrivEsc": vuln_app2_privesc
        },
        owned_string="Privilegi SYSTEM acquisiti. Backdoor persistente ipoteticamente installata."
    ),

    # FILE SERVER
    "file1": NodeInfo(
        services=[ListeningService("SMB", allowedCredentials=["root:toor"])],
        firewall=FirewallConfiguration(
            incoming=default_allow_rules,
            outgoing=default_allow_rules
        ),
        value=3.0,
        properties=["os:Windows", "role:fileserver", "service:smb"],
        vulnerabilities={
            "SMBWeakShare": vuln_smb,
            "SMBCreds": vuln_smb_creds,
        },
        owned_string="Accesso SMB come root. Dati di rete e backup accessibili."
    ),

    # MAIL SERVER
    "mail1": NodeInfo(
        services=[ListeningService("SMTP")],
        firewall=FirewallConfiguration(
            incoming=default_allow_rules,
            outgoing=default_allow_rules
        ),
        value=2.0,
        properties=["os:Linux", "role:mailserver", "service:smtp"],
        vulnerabilities={
            "MailLeak": vuln_mail,
            "MailLeakCreds": vuln_mail_creds,
        },
        owned_string="Accesso SMTP. Comunicazioni interne potenzialmente intercettate."
    ),

    # SQL SERVER
    "sql1": NodeInfo(
        services=[ListeningService("SQL", allowedCredentials=["dbuser:sqldbpw", "admin:supersecret"])],
        firewall=FirewallConfiguration(
            incoming=[FirewallRule("SQL", ALLOW)],
            outgoing=default_allow_rules
        ),
        value=7.0,
        properties=["os:Windows", "role:sqlserver", "service:sql"],
        vulnerabilities={
            "SQLi": vuln_sql,
            "SQLiCreds": vuln_sql_creds,
            "SQLiDataBreach": vuln_sql_databreach
        },
        owned_string="Credenziali DB usate per accedere al SQL Server. Dati sensibili estratti."
    ),

    # DOMAIN CONTROLLER
    "dc1": NodeInfo(
        services=[ListeningService("RDP", allowedCredentials=["admin:supersecret", "shared:sharedpw"])],
        firewall=FirewallConfiguration(
            incoming=[FirewallRule("RDP", ALLOW)],
            outgoing=default_allow_rules
        ),
        value=10.0,
        properties=["os:Windows", "role:domaincontroller", "service:rdp", "user:admin"],
        vulnerabilities={
            "RDPBrute": vuln_rdp,
            "RDPBruteCreds": vuln_rdp_creds,
            "WeakAdmin": vuln_admin
        },
        owned_string="Accesso Domain Controller con privilegio SYSTEM. Controllo totale del dominio acquisito."
    ),

    # BACKUP SERVER
    "backup1": NodeInfo(
        services=[ListeningService("SMB", allowedCredentials=["root:toor"])],
        firewall=FirewallConfiguration(
            incoming=default_allow_rules,
            outgoing=default_allow_rules
        ),
        value=3.0,
        properties=["os:Windows", "role:backupserver", "service:smb"],
        vulnerabilities={
            "BackupExploit": vuln_backup,
            "BackupSQLCreds": vuln_backup_creds,
        },
        owned_string="Backup SMB compromesso. Dump SQL e configurazioni esposte."
    ),

    # MANAGEMENT SERVER
    "mgmt1": NodeInfo(
        services=[ListeningService("SSH", allowedCredentials=["admin:supersecret"])],
        firewall=FirewallConfiguration(
            incoming=[FirewallRule("SSH", ALLOW)],
            outgoing=default_allow_rules
        ),
        value=10.0,
        properties=["os:Linux", "role:mgmt", "user:admin"],
        vulnerabilities={
            "MgmtExploit": vuln_mgmt,
            "MgmtCreds": vuln_mgmt_creds,
        },
        owned_string="Accesso root ottenuto su server di gestione. Presa completa sul network."
    ),
}


import boolean


for node_name, node in nodes.items():
    for vuln_name, vuln in node.vulnerabilities.items():
        expr = vuln.precondition.expression
        # Check for 'Symbol' objects (simple service precondition)
        if isinstance(expr, boolean.Symbol):
            symbol = expr.obj
            if symbol.lower().startswith("service:"):
                port = symbol.split(":")[1].lower()
                allowed_ports = [rule.port.lower() for rule in node.firewall.incoming]
                if port not in allowed_ports:
                    print(f"INCOHERENCY: Node {node_name} has vulnerability '{vuln_name}' requiring port '{port}', but firewall does not allow it in!")


In [None]:
edges = [
    ("ws_entry", "ws2"), ("ws_entry", "web1"), ("ws_entry", "web2"),
    ("ws2", "ws3"), ("ws3", "ws4"), ("ws4", "file1"), ("web1", "app1"), ("web1", "sql1"),
    ("web2", "app2"), ("web3", "sql1"),
    ("file1", "backup1"), ("file1", "mail1"),
    ("app1", "mgmt1"), ("app2", "mgmt1"),
    ("sql1", "dc1"), ("backup1", "dc1"),
    ("mail1", "mgmt1"), ("dc1", "mgmt1"),("sql1", "web3")
]

In [None]:
graph = nx.DiGraph()
graph.add_nodes_from([(k, {"data": v}) for (k, v) in nodes.items()])
graph.add_edges_from(edges)
network = create_network(nodes)

MY_IDENTIFIERS = Identifiers(
    ports=["SSH", "RDP", "HTTP", "GIT", "SMB", "SQL", "SMTP"],
    properties=[
        "os:Linux", "os:Windows",
        "role:entry", "role:workstation", "role:webserver", "role:appserver", "role:mgmt",
        "role:sqlserver", "role:domaincontroller", "role:fileserver", "role:mailserver", "role:backupserver",
        "service:http", "service:git", "service:smb", "service:rdp", "service:sql", "service:smtp",
        "user:user", "user:admin"
    ],
    local_vulnerabilities=[
    "WeakAdmin", "App1PrivEsc","App2PrivEsc"
    ],
    remote_vulnerabilities=[
    "WebExploit", "WebAdminCreds",
    "GitHistory", "GitSMBCreds",
    "SMBWeakShare", "SMBCreds",
    "SQLi", "SQLiCreds",
    "RDPBrute", "RDPBruteCreds",
    "MailLeak", "MailLeakCreds",
    "BackupExploit", "BackupSQLCreds",
    "MgmtExploit", "MgmtCreds",
    "Web2RCE", "Web2ConfigLeak",
    "Web3SSRF", "Web3BackupCreds",
    "App1DBCreds","App2APIExploit","App2FileCreds",
    "SQLiDataBreach",  "LateralMov", "Web2Probing",
    ]
)



env = Environment(
    network=graph,
    vulnerability_library = {
    "WebExploit": vuln_web,
    "WebAdminCreds": vuln_web_creds,
    "GitHistory": vuln_git,
    "GitSMBCreds": vuln_git_creds,
    "SMBWeakShare": vuln_smb,
    "SMBCreds": vuln_smb_creds,
    "SQLi": vuln_sql,
    "SQLiCreds": vuln_sql_creds,
    "RDPBrute": vuln_rdp,
    "RDPBruteCreds": vuln_rdp_creds,
    "MailLeak": vuln_mail,
    "MailLeakCreds": vuln_mail_creds,
    "BackupExploit": vuln_backup,
    "BackupSQLCreds": vuln_backup_creds,
    "MgmtExploit": vuln_mgmt,
    "MgmtCreds": vuln_mgmt_creds,
    "WeakAdmin": vuln_admin,
    "Web2RCE": vuln_web2,
    "Web2ConfigLeak": vuln_web2_creds,
    "Web3SSRF": vuln_web3,
    "Web3BackupCreds": vuln_web3_creds,
    "App1PrivEsc": vuln_app1,
    "App1DBCreds": vuln_app1_creds,
    "App2APIExploit": vuln_app2,
    "App2FileCreds": vuln_app2_creds,
    },
    identifiers=MY_IDENTIFIERS
)


In [None]:
env.plot_environment_graph()

In [None]:
from cyberbattle.agents.baseline.learner import *
from cyberbattle.agents.baseline.agent_tabularqlearning import *
import cyberbattle.agents.baseline.agent_wrapper as w
from cyberbattle._env.cyberbattle_env import *
from cyberbattle.agents.baseline.agent_dql import *
from cyberbattle.agents.baseline.learner import *
from cyberbattle._env.defender import *
from cyberbattle.agents.baseline.plotting import *

env_gym = CyberBattleEnv(
    initial_environment=env,
    attacker_goal=AttackerGoal(own_atleast_percent=1.0),
    defender_agent=ScanAndReimageCompromisedMachines(probability=0.7, scan_capacity=3, scan_frequency=7),
    #defender_agent=ExternalRandomEvents(),
    #defender_constraint=DefenderConstraint(maintain_sla=0.80),
    maximum_node_count=15,           
    maximum_total_credentials=20     
)

env_gym.action_space.seed(SEED)

ep = w.EnvironmentBounds.of_identifiers(
    maximum_node_count=15,
    maximum_total_credentials=1000,
    identifiers=MY_IDENTIFIERS
)


In [None]:
policy = DeepQLearnerPolicy(
    ep=ep,
    gamma=0.95,
    replay_memory_size=10000,
    target_update=100,     
    batch_size=64,         
    learning_rate=0.001,   
)


results = epsilon_greedy_search(
    cyberbattle_gym_env=env_gym,   
    environment_properties=ep,
    learner=policy,                
    episode_count=50,
    iteration_count=1000,
    epsilon=0.9,
    render=False,
    epsilon_exponential_decay=5000,   
    epsilon_minimum=0.10,
    verbosity=w.Verbosity.Quiet,
    title="DQN su Custom Network"
)

plot_averaged_cummulative_rewards("Cumulative Reward (Averaged)", [results], show=True)
plot_all_episodes(results)
#plot_episodes_length([results])