In [None]:
import os
os.environ["WANDB_NOTEBOOK_NAME"] = "training_agent.ipynb"

import gymnasium as gym
from gymnasium import spaces
from gymnasium.wrappers import NormalizeObservation

import wandb

from keras.models import Model, Sequential
from keras.layers import Dense, Flatten
from keras.optimizers.legacy import Adam
from rl.agents.dqn import DQNAgent
from rl.memory import SequentialMemory
from rl.policy import EpsGreedyQPolicy, LinearAnnealedPolicy
from rl.callbacks import WandbLogger

import game

import importlib


In [None]:
wandb.login()

In [None]:
importlib.reload(game)

env = game.Game()
# env = NormalizeObservation(env)

print(f"{env.observation_space.shape}")
print(f"{env.action_space.n}")

def build_model(input_shape, output_shape):
    model = Sequential([
        Flatten(input_shape=(1,) + input_shape), 
        # Dense(64, activation='relu'),
        Dense(32, activation='relu'),
        Dense(32, activation='relu'),
        Dense(32, activation='relu'),
        Dense(output_shape, activation='linear')
    ])
    return model

input_shape = env.observation_space.shape
output_shape = env.action_space.n

model = build_model(input_shape, output_shape)
print(model.summary())
print(f"{model.input_shape}")
print(f"{model.output_shape}")

# Create a memory buffer
memory = SequentialMemory(limit=20000, window_length=1)
policy = LinearAnnealedPolicy(
    EpsGreedyQPolicy(),
    attr='eps', 
    value_max=1,
    value_min=0.1,
    value_test=0.05, 
    nb_steps=1000000
    )
# policy = EpsGreedyQPolicy(eps=0.1)

# Define the Deep Q-Network agent
dqn = DQNAgent(
    model=model, 
    policy=policy,  
    nb_actions=env.action_space.n, 
    memory=memory, 
    nb_steps_warmup=5000,
    gamma=0.99,
    target_model_update=10000,
    enable_double_dqn=True, 
    )

# Compile the model
dqn.compile(
    Adam(learning_rate=1e-3), 
    metrics=['mae']
    )

callbacks = [WandbLogger()]

# Train the agent
fit_history = dqn.fit(
    env, 
    nb_steps=10000000, 
    action_repetition=1, 
    callbacks=callbacks,
    verbose=1, 
    visualize=False, 
    nb_max_start_steps=0, 
    start_step_policy=None, 
    nb_max_episode_steps=None
    )


In [None]:
# dqn.load_weights('dqn_weights.h5f')

In [None]:
# importlib.reload(game)

# env = game.Game()
# env = NormalizeObservation(env)

# Evaluate the agent
test_history = dqn.test(
    env, 
    nb_episodes=1, 
    action_repetition=1, 
    callbacks=None, 
    visualize=True, 
    nb_max_episode_steps=None, 
    nb_max_start_steps=0, 
    start_step_policy=None, 
    verbose=2
    )

In [None]:
dqn.save_weights('dqn_weights_run2.h5f', overwrite=True)