In [None]:
import tensorflow as tf
from tf_agents.networks import q_network
from tf_agents.agents.dqn import dqn_agent
from tf_agents.utils import common
from tf_agents.environments import tf_py_environment
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.policies import random_tf_policy
from tf_agents.metrics import tf_metrics
from tf_agents.drivers import dynamic_step_driver
from tf_agents.eval import metric_utils

# Import your BallSortCraneEnvironment here
# from your_environment_file import BallSortCraneEnvironment

# 1. Setting up the environments for training and evaluation
train_py_env = BallSortCraneEnvironment()
eval_py_env = BallSortCraneEnvironment()

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

# 2. Creating the Q-Network, which will be used to create the DQN Agent
fc_layer_params = (100, 50)  # Size of fully connected layers in the Q-network

q_net = q_network.QNetwork(
    train_env.observation_spec(),
    train_env.action_spec(),
    fc_layer_params=fc_layer_params)

# 3. Instantiating the DQN Agent
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=1e-3)
train_step_counter = tf.Variable(0)

agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=common.element_wise_squared_loss,
    train_step_counter=train_step_counter)

agent.initialize()

# 4. Replay Buffer to store the experiences collected by the agent
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=train_env.batch_size,
    max_length=100000)

# 5. Function to collect data from the environment
def collect_step(environment, policy, buffer):
    time_step = environment.current_time_step()
    action_step = policy.action(time_step)
    next_time_step = environment.step(action_step.action)
    traj = trajectory.from_transition(time_step, action_step, next_time_step)
    buffer.add_batch(traj)

def collect_data(env, policy, buffer, steps):
    for _ in range(steps):
        collect_step(env, policy, buffer)

# Collect initial data for training
initial_collect_steps = 1000
collect_data(train_env, random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                        train_env.action_spec()), replay_buffer, initial_collect_steps)

# 6. Preparing the data for training the agent
dataset = replay_buffer.as_dataset(
    num_parallel_calls=3, 
    sample_batch_size=64, 
    num_steps=2).prefetch(3)

iterator = iter(dataset)

# 7. Training the agent
num_iterations = 20000

for _ in range(num_iterations):
    # Collect data from the environment
    collect_data(train_env, agent.collect_policy, replay_buffer, steps=1)

    # Sample a batch of data from the buffer and update the agent's network
    experience, unused_info = next(iterator)
    train_loss = agent.train(experience).loss

    step = agent.train_step_counter.numpy()

    if step % 1000 == 0:
        print('step = {0}: loss = {1}'.format(step, train_loss))

# 8. Evaluating the agent's performance
def compute_avg_return(environment, policy, num_episodes=10):
    total_return = 0.0
    for _ in range(num_episodes):
        time_step = environment.reset()
        episode_return = 0.0

        while not time_step.is_last():
            action_step = policy.action(time_step)
            time_step = environment.step(action_step.action)
            episode_return += time_step.reward
        total_return += episode_return

    avg_return = total_return / num_episodes
    return avg_return.numpy()[0]

num_eval_episodes = 10
avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
print('Average Return = ', avg_return)
