# Deep Reinforcement Learning Agent

Hilfreiche Erklärungen am Beispiel CartPole:
- https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html
- https://www.tensorflow.org/agents/tutorials/1_dqn_tutorial

## Abhängigkeiten installieren

In [None]:
%pip install typing-extensions==4.5.0
%pip install ipython==8.23.0
# this is needed to resolve the dependency conflict between tf-agents and ipython
# tf-agents requires typing-extensions==4.5.0
# ipython 8.25.0 requires typing-extensions>=4.6, which causes a conflict
# downgrading ipython to version 8.23.0 ensures compatibility with typing-extensions 4.5.0

#%pip install tensorflow-cpu
%pip install tf-agents[reverb]
%pip install tf-keras

In [None]:
import os
# Keep using keras-2 (tf-keras) rather than keras-3 (keras).
os.environ['TF_USE_LEGACY_KERAS'] = '1'

In [None]:
import tensorflow as tf
from tf_agents.environments import py_environment
from tf_agents.environments import tf_py_environment
from tf_agents.networks import q_network
from tf_agents.agents.dqn import dqn_agent


## Umgebung definieren

In [None]:
class LoginEnv(py_environment.PyEnvironment):
    def __init__(self):
        
        # Zustandseigenschaften: Richtiges Passwort (boolean), Zeit zwischen Loginversuchen (date), Falsches Passwort Zähler (int), letzte Aktion (int)
        self.observation_space = spaces.Box(low=0, high=1, shape=(4,), dtype=np.float32)
        
        # Aktionen: 0 = Nicht sperren, 1 = 30s sperren, 2 = 1m sperren, 3 = 3min sperren, 4 = Dauerhaft sperren
        self.action_space = spaces.Discrete(5)
        
        # Interne Zustandsvariablen
        self.incorrect_password = False
        self.time_between_attempts = np.random.randint(0, 3600) # 1 sec to 1h (in seconds)
        self.incorrect_password_count = 0 
        self.last_action = 0
    
    def _reset(self):
        self.incorrect_password = np.random.choice([True, False])
        self.time_between_attempts = np.random.randint(0, 3600) # 1 sec bis 1h (in Sekunden)
        self.incorrect_password_count = np.random.randint(0, 11)
        self.last_action = np.random.randint(1, 3) if self.incorrect_password_count > 0 else 0
        return np.array([
            self.incorrect_password,
            self.time_between_attempts,
            self.incorrect_password_count,
            self.last_action
        ])

    def _step(self, action):
        reward = 0
        done = False
        
        if action == 0:  # Nicht sperren
            if not self.incorrect_password:
                reward = 1
                done = True
            elif self.time_between_attempts <= 3 or self.incorrect_password_count >= 10:
                reward = -1
                done = True
            else:
                reward = 0
        elif action == 1:  # 30s sperren
            if not self.incorrect_password:
                reward = -1
                done = True
            elif self.time_between_attempts <= 3 or (3 < self.incorrect_password_count <= 6):
                reward = 1
                done = True
            else:
                reward = 0
        elif action == 2:  # 1m sperren
            if not self.incorrect_password:
                reward = -1
                done = True
            elif (6 < self.incorrect_password_count <= 9):
                reward = 1
                done = True
            else:
                reward = 0
        elif action == 3:  # 3min sperren
            if not self.incorrect_password:
                reward = -1
                done = True
            elif (9 < self.incorrect_password_count < 10):
                reward = 1
                done = True
            else:
                reward = 0
        elif action == 4:  # Dauerhaft sperren
            if self.incorrect_password_count >= 10:
                reward = 1
                done = True
            else:
                reward = -1
                done = True

        return reward, done

    def action_spec(self):
        return self._action_spec

    def observation_spec(self):
        return self._observation_spec


In [None]:
py_env = LoginEnv()
env = tf_py_environment.TFPyEnvironment(py_env)

## Deep Reinforcement Learning Agent (DQN-Agent) definieren

In [None]:
# Erstelle das Q-Network
fc_layer_params = (64, 64)
q_net = q_network.QNetwork(
    env.observation_spec(),
    env.action_spec(),
    fc_layer_params=fc_layer_params)

# Konfiguriere den DQN-Agenten
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=0.001)
train_step_counter = tf.Variable(0)
agent = dqn_agent.DqnAgent(
    env.time_step_spec(),
    env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=common.element_wise_squared_loss,
    train_step_counter=train_step_counter)

# Initialisiere und kompiliere den Agenten
agent.initialize()

## Policy definieren

In [None]:
policy = boltzmann_policy.BoltzmannPolicy(agent.policy)

## Metriken und Auswertung

In [None]:
def compute_average_return(environment, policy, num_episodes=10):
    total_return = 0.0
    for _ in range(num_episodes):
        time_step = environment.reset()
        episode_return = 0.0
        while not time_step.is_last():
            action_step = policy.action(time_step)
            time_step = environment.step(action_step.action)
            episode_return += time_step.reward
        total_return += episode_return
    average_return = total_return / num_episodes
    return average_return # oder return avg_return.numpy()[0]?

## Wiederholungspuffer

In [None]:
import reverb

# Create a replay buffer table
table_name = 'replay_buffer'
table = reverb.Table(
    table_name,
    max_size=10000,
    sampler=reverb.selectors.Uniform(),
    remover=reverb.selectors.Fifo(),
    rate_limiter=reverb.rate_limiters.MinSize(1))

# Create a replay buffer server
server = reverb.Server([table])

# Create a replay buffer client
client = reverb.Client(f'localhost:{server.port}')

# Get the data specs from the agent
data_spec = agent.collect_data_spec

# Create a replay buffer dataset
replay_buffer_dataset = reverb.ReplayDataset(
    table_name,
    client=client,
    sequence_length=2,
    num_parallel_calls=1,
    max_in_flight_samples_per_worker=2)

# Print the data specs
print(data_spec)

## Datensammlung

## Training des Agenten

## Visualisierung