In [1]:
!pip install pettingzoo[mpe] tensorflow matplotlib imageio ipython pygame

Collecting pettingzoo[mpe]
  Downloading pettingzoo-1.24.3-py3-none-any.whl.metadata (8.5 kB)
Collecting pygame
  Downloading pygame-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting jedi>=0.16 (from ipython)
  Downloading jedi-0.19.2-py2.py3-none-any.whl.metadata (22 kB)
Downloading pygame-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m21.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jedi-0.19.2-py2.py3-none-any.whl (1.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m45.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pettingzoo-1.24.3-py3-none-any.whl (847 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m847.8/847.8 kB[0m [31m31.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pygame, jedi, pettingzoo
  Attempting uninstall: pygame
    Found existing

In [None]:
# -*- coding: utf-8 -*-
# --- Installation ---
# !pip install pettingzoo[mpe] tensorflow matplotlib imageio ipython pygame

import os
import sys
import time
import io
import base64
from collections import deque, defaultdict
import warnings

# Suppress DeprecationWarnings from PettingZoo MPE C API (optional)
warnings.filterwarnings("ignore", category=DeprecationWarning)

import tensorflow as tf
import numpy as np
from pettingzoo.mpe import simple_spread_v3
import matplotlib.pyplot as plt
from IPython.display import HTML, display
import imageio
import pygame

# Check for GPU availability
if tf.config.list_physical_devices('GPU'):
    print("GPU is available and will be used.")
    physical_devices = tf.config.list_physical_devices('GPU')
    try:
        for gpu in physical_devices:
            tf.config.experimental.set_memory_growth(gpu, True)
        print("Enabled GPU memory growth.")
    except RuntimeError as e:
        print(e)
    device = '/GPU:0'
else:
    print("GPU is not available. CPU will be used.")
    device = '/CPU:0'

# --- TensorBoard Setup ---
log_dir = "masac_v2_tensorflow_logs_simple_spread_fixed_alpha_beta"
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
summary_writer = tf.summary.create_file_writer(log_dir)
print(f"TensorBoard logs will be saved to: {log_dir}")

# --- Environment Setup ---
def make_env(env_name="simple_spread_v3", continuous_actions=True, max_cycles=25, render_mode=None):
    """Creates the PettingZoo environment."""
    if env_name == "simple_spread_v3":
        env = simple_spread_v3.parallel_env(
            N=3,
            local_ratio=0.5,
            max_cycles=max_cycles,
            continuous_actions=continuous_actions,
            render_mode=render_mode
        )
    else:
        raise ValueError(f"Unsupported MPE environment name: {env_name}")
    try:
        _ = env.reset()
    except Exception as e:
        print(f"Warning: env.reset() raised an exception: {e}. Attempting reset without options.")
        _ = env.reset()
    return env

# --- Replay Buffer ---
class MultiAgentReplayBuffer:
    def __init__(self, capacity, agent_ids, obs_dims, action_dims):
        self.capacity = capacity
        self.agent_ids = list(agent_ids)
        self.num_agents = len(agent_ids)
        self.obs_dims = obs_dims
        self.action_dims = action_dims
        self.total_obs_dim = sum(self.obs_dims[agent_id] for agent_id in self.agent_ids)
        self.total_action_dim = sum(self.action_dims[agent_id] for agent_id in self.agent_ids)
        self.obs_buffer = np.zeros((capacity, self.total_obs_dim), dtype=np.float32)
        self.action_buffer = np.zeros((capacity, self.total_action_dim), dtype=np.float32)
        self.reward_buffer = np.zeros((capacity, self.num_agents), dtype=np.float32)
        self.next_obs_buffer = np.zeros((capacity, self.total_obs_dim), dtype=np.float32)
        self.done_buffer = np.zeros((capacity, self.num_agents), dtype=np.float32)
        self.ptr = 0
        self.size = 0

    def _dict_to_concatenated(self, data_dict, dim_map):
        """Concatenates data from agent dicts into a single array based on self.agent_ids order."""
        default_shape_val = np.zeros(dim_map[self.agent_ids[0]], dtype=np.float32)
        concatenated_list = []
        for agent_id in self.agent_ids:
            agent_data = data_dict.get(agent_id, default_shape_val)
            if not isinstance(agent_data, np.ndarray):
                try:
                    agent_data = np.array(agent_data, dtype=np.float32)
                except Exception as e:
                    agent_data = default_shape_val
            concatenated_list.append(agent_data.flatten())
        if not concatenated_list:
            return np.zeros(sum(dim_map.values()), dtype=np.float32)
        try:
            result = np.concatenate(concatenated_list, axis=0).astype(np.float32)
            return result
        except ValueError as e:
            print(f"Error during final concatenation: {e}")
            print("Shapes in list for concatenation:")
            for i, arr in enumerate(concatenated_list):
                print(f"  Item {i}: {arr.shape}")
            raise e

    def push(self, obs_dict, action_dict, reward_dict, next_obs_dict, done_dict):
        """Stores a transition for all agents."""
        try:
            if not all(agent_id in obs_dict and obs_dict[agent_id] is not None for agent_id in self.agent_ids) or \
               not all(agent_id in next_obs_dict and next_obs_dict[agent_id] is not None for agent_id in self.agent_ids):
                return
            obs_concat = self._dict_to_concatenated(obs_dict, self.obs_dims)
            next_obs_concat = self._dict_to_concatenated(next_obs_dict, self.obs_dims)
            action_concat = self._dict_to_concatenated(action_dict, self.action_dims)
            rewards = np.array([reward_dict.get(agent_id, 0.0) for agent_id in self.agent_ids], dtype=np.float32)
            dones = np.array([done_dict.get(agent_id, False) for agent_id in self.agent_ids], dtype=np.float32)
            expected_obs_shape = (self.total_obs_dim,)
            expected_action_shape = (self.total_action_dim,)
            if obs_concat.shape != expected_obs_shape or \
               next_obs_concat.shape != expected_obs_shape or \
               action_concat.shape != expected_action_shape:
                return
            self.obs_buffer[self.ptr] = obs_concat
            self.action_buffer[self.ptr] = action_concat
            self.reward_buffer[self.ptr] = rewards
            self.next_obs_buffer[self.ptr] = next_obs_concat
            self.done_buffer[self.ptr] = dones
            self.ptr = (self.ptr + 1) % self.capacity
            self.size = min(self.size + 1, self.capacity)
        except KeyError as e:
            pass
        except ValueError as e:
            pass

    def sample(self, batch_size):
        """Samples a batch of transitions."""
        actual_batch_size = min(self.size, batch_size)
        if actual_batch_size == 0:
            return None
        indices = np.random.choice(self.size, actual_batch_size, replace=False)
        return (
            self.obs_buffer[indices],
            self.action_buffer[indices],
            self.reward_buffer[indices],
            self.next_obs_buffer[indices],
            self.done_buffer[indices]
        )

    def __len__(self):
        """Returns the current size of the buffer."""
        return self.size

# --- Network Architectures ---
def create_actor_network(state_dim, action_dim, hidden_units):
    inputs = tf.keras.layers.Input(shape=(state_dim,))
    x = tf.keras.layers.Dense(hidden_units, activation='relu', kernel_initializer=tf.keras.initializers.HeUniform())(inputs)
    x = tf.keras.layers.Dense(hidden_units, activation='relu', kernel_initializer=tf.keras.initializers.HeUniform())(x)
    mu = tf.keras.layers.Dense(action_dim, kernel_initializer=tf.keras.initializers.GlorotUniform())(x)
    log_std = tf.keras.layers.Dense(action_dim, kernel_initializer=tf.keras.initializers.GlorotUniform())(x)
    log_std = tf.keras.layers.Lambda(lambda t: tf.clip_by_value(t, -20, 2))(log_std)
    return tf.keras.Model(inputs=inputs, outputs=[mu, log_std])

def create_critic_network(total_obs_dim, total_action_dim, hidden_units):
    obs_input = tf.keras.layers.Input(shape=(total_obs_dim,))
    action_input = tf.keras.layers.Input(shape=(total_action_dim,))
    concat = tf.keras.layers.Concatenate()([obs_input, action_input])
    x = tf.keras.layers.Dense(hidden_units, activation='relu', kernel_initializer=tf.keras.initializers.HeUniform())(concat)
    x = tf.keras.layers.Dense(hidden_units, activation='relu', kernel_initializer=tf.keras.initializers.HeUniform())(x)
    q_value = tf.keras.layers.Dense(1, kernel_initializer=tf.keras.initializers.GlorotUniform())(x)
    return tf.keras.Model(inputs=[obs_input, action_input], outputs=q_value)

# --- Log Probability Helper ---
@tf.function(reduce_retracing=True)
def gaussian_log_prob(x, mu, log_std):
    """Calculates log probability log P(x|N(mu, std)) for diagonal Gaussian."""
    std = tf.exp(log_std)
    std_safe = std + 1e-7
    var_safe = tf.square(std_safe)
    log_unnormalized = -0.5 * tf.square(x - mu) / var_safe
    var_safe_clipped = tf.maximum(var_safe, 1e-14)
    log_normalization = 0.5 * tf.math.log(2.0 * np.pi * var_safe_clipped)
    log_prob_per_dim = log_unnormalized - log_normalization
    return tf.reduce_sum(log_prob_per_dim, axis=1, keepdims=True)

# --- Multi-Agent SAC v2 (MA-SAC) Agent ---
class MASACAgentV2:
    def __init__(self, env, agent_ids, obs_dims, action_dims, action_spaces,
                 hidden_units=256, actor_lr=3e-4, critic_lr=3e-4,
                 tau=0.005, gamma=0.99, buffer_capacity=1000000, batch_size=256,
                 initial_alpha=0.2, initial_beta=0.1,
                 target_entropy_scale=1.0, target_ce_scale=0.1,
                 gradient_clip_norm=1.0):
        self.agent_ids = list(agent_ids)
        self.num_agents = len(agent_ids)
        self.obs_dims = obs_dims
        self.action_dims = action_dims
        self.action_spaces = action_spaces
        self.tau = tau
        self.gamma = gamma
        self.batch_size = batch_size
        self.gradient_clip_norm = gradient_clip_norm

        self.total_obs_dim = sum(self.obs_dims[agent_id] for agent_id in self.agent_ids)
        self.total_action_dim = sum(self.action_dims[agent_id] for agent_id in self.agent_ids)

        # --- Actors (Decentralized) ---
        self.actors = {}
        self.actor_optimizers = {}
        self.target_entropies = {}
        self.target_cross_entropies = {}
        for agent_id in self.agent_ids:
            action_dim = self.action_dims[agent_id]
            self.actors[agent_id] = create_actor_network(obs_dims[agent_id], action_dim, hidden_units)
            self.actors[agent_id]._name = f"actor_{agent_id}"
            self.actor_optimizers[agent_id] = tf.keras.optimizers.Adam(learning_rate=actor_lr)
            self.target_entropies[agent_id] = tf.constant(-float(action_dim) * target_entropy_scale, dtype=tf.float32)
            self.target_cross_entropies[agent_id] = tf.constant(-float(action_dim) * target_ce_scale, dtype=tf.float32)

        # --- Critic (Centralized) ---
        self.critic_1 = create_critic_network(self.total_obs_dim, self.total_action_dim, hidden_units)
        self.critic_2 = create_critic_network(self.total_obs_dim, self.total_action_dim, hidden_units)
        self.target_critic_1 = create_critic_network(self.total_obs_dim, self.total_action_dim, hidden_units)
        self.target_critic_2 = create_critic_network(self.total_obs_dim, self.total_action_dim, hidden_units)
        self.target_critic_1.set_weights(self.critic_1.get_weights())
        self.target_critic_2.set_weights(self.critic_2.get_weights())
        self.critic_1_optimizer = tf.keras.optimizers.Adam(learning_rate=critic_lr)
        self.critic_2_optimizer = tf.keras.optimizers.Adam(learning_rate=critic_lr)

        # --- Fixed Alpha and Beta ---
        self.alpha = tf.constant(initial_alpha, dtype=tf.float32, name="alpha")
        self.beta = tf.constant(initial_beta, dtype=tf.float32, name="beta")

        # --- Replay Buffer ---
        self.replay_buffer = MultiAgentReplayBuffer(buffer_capacity, agent_ids, obs_dims, action_dims)

        # --- Helper for splitting concatenated data ---
        self.agent_obs_indices = {}
        self.agent_action_indices = {}
        start_obs = 0
        start_action = 0
        for i, agent_id in enumerate(self.agent_ids):
            end_obs = start_obs + self.obs_dims[agent_id]
            end_action = start_action + self.action_dims[agent_id]
            self.agent_obs_indices[agent_id] = (start_obs, end_obs)
            self.agent_action_indices[agent_id] = (start_action, end_action)
            start_obs = end_obs
            start_action = end_action

    def _get_agent_obs(self, concatenated_obs, agent_id):
        start, end = self.agent_obs_indices[agent_id]
        return concatenated_obs[:, start:end]

    def _get_agent_action(self, concatenated_action, agent_id):
        start, end = self.agent_action_indices[agent_id]
        return concatenated_action[:, start:end]

    @tf.function
    def _get_action_from_actor(self, actor_model, state_tensor, evaluate):
        """Helper function to get action from actor model."""
        mu, log_std = actor_model(state_tensor, training=not evaluate)
        tf.debugging.check_numerics(mu, f"Actor mu output for {actor_model.name}")
        tf.debugging.check_numerics(log_std, f"Actor log_std output for {actor_model.name}")
        if evaluate:
            raw_action = mu
        else:
            std = tf.exp(log_std)
            epsilon = tf.random.normal(shape=tf.shape(mu))
            raw_action = mu + std * epsilon
        tf.debugging.check_numerics(raw_action, f"Raw action before tanh for {actor_model.name}")
        action_tanh = tf.tanh(raw_action)
        tf.debugging.check_numerics(action_tanh, f"Action after tanh for {actor_model.name}")
        return action_tanh

    def get_actions(self, obs_dict, evaluate=False):
        actions_dict = {}
        with tf.device(device):
            for i, agent_id in enumerate(self.agent_ids):
                if agent_id not in obs_dict or obs_dict[agent_id] is None:
                    actions_dict[agent_id] = np.zeros(self.action_dims[agent_id], dtype=np.float32)
                    continue
                state = obs_dict[agent_id][None, ...]
                try:
                    state_tensor = tf.convert_to_tensor(state, dtype=tf.float32)
                    if np.any(np.isnan(state)):
                        print(f"Warning: NaN detected in input observation for agent {agent_id}")
                        actions_dict[agent_id] = np.zeros(self.action_dims[agent_id], dtype=np.float32)
                        continue
                except Exception as e:
                    actions_dict[agent_id] = np.zeros(self.action_dims[agent_id], dtype=np.float32)
                    continue
                actor_model = self.actors[agent_id]
                try:
                    action_tanh = self._get_action_from_actor(actor_model, state_tensor, evaluate)
                    action_tanh_np = action_tanh.numpy()[0]
                    if np.any(np.isnan(action_tanh_np)):
                        print(f"Warning: NaN detected in action_tanh_np for agent {agent_id}. Replacing with zeros.")
                        actions_dict[agent_id] = np.zeros(self.action_dims[agent_id], dtype=np.float32)
                        continue
                except tf.errors.InvalidArgumentError as e:
                    print(f"!!! Numerical stability error during action selection for agent {agent_id}: {e} !!!")
                    actions_dict[agent_id] = np.zeros(self.action_dims[agent_id], dtype=np.float32)
                    continue
                action_space = self.action_spaces[agent_id]
                action_low = action_space.low.astype(np.float32)
                action_high = action_space.high.astype(np.float32)
                scaled_action = action_low + (action_tanh_np + 1.0) * 0.5 * (action_high - action_low)
                scaled_action = np.clip(scaled_action, action_low, action_high)
                if np.any(np.isnan(scaled_action)):
                    print(f"Warning: NaN detected in final scaled_action for agent {agent_id}. Replacing with zeros.")
                    actions_dict[agent_id] = np.zeros(self.action_dims[agent_id], dtype=np.float32)
                else:
                    actions_dict[agent_id] = scaled_action
        return actions_dict

    @tf.function(reduce_retracing=True)
    def _get_sampled_actions_and_log_probs(self, actor_model, agent_obs):
        """Samples an action and calculates its log probability."""
        mu, log_std = actor_model(agent_obs, training=True)
        std = tf.exp(log_std)
        epsilon = tf.random.normal(shape=tf.shape(mu))
        action_raw = mu + std * epsilon
        action_tanh = tf.tanh(action_raw)
        log_prob_raw = gaussian_log_prob(action_raw, mu, log_std)
        log_prob_tanh = log_prob_raw - tf.reduce_sum(tf.math.log(1.0 - tf.square(action_tanh) + 1e-7), axis=1, keepdims=True)
        tf.debugging.check_numerics(action_tanh, "Sampled action tanh")
        tf.debugging.check_numerics(log_prob_tanh, "Sampled action log_prob")
        return action_tanh, log_prob_tanh

    @tf.function(reduce_retracing=True)
    def _get_log_prob_under_policy(self, policy_actor_model, eval_agent_obs, action_tanh):
        """Calculates log prob of action_tanh under policy_actor_model given eval_agent_obs."""
        mu, log_std = policy_actor_model(eval_agent_obs, training=True)
        action_tanh_clipped = tf.clip_by_value(action_tanh, -1.0 + 1e-7, 1.0 - 1e-7)
        action_raw = tf.atanh(action_tanh_clipped)
        log_prob_raw = gaussian_log_prob(action_raw, mu, log_std)
        log_prob_tanh = log_prob_raw - tf.reduce_sum(tf.math.log(1.0 - tf.square(action_tanh) + 1e-7), axis=1, keepdims=True)
        tf.debugging.check_numerics(log_prob_tanh, "Cross-policy log_prob")
        return log_prob_tanh

    @tf.function
    def _update_networks(self, batch_obs, batch_actions, batch_rewards, batch_next_obs, batch_dones):
        """Performs one update step for critics and actors."""
        # ---------------------- Critic Update ----------------------
        with tf.GradientTape(persistent=True) as critic_tape:
            next_actions_list = []
            for i, agent_id in enumerate(self.agent_ids):
                agent_next_obs = self._get_agent_obs(batch_next_obs, agent_id)
                action_tanh_next, _ = self._get_sampled_actions_and_log_probs(
                    self.actors[agent_id], agent_next_obs
                )
                next_actions_list.append(action_tanh_next)
            next_actions_concat = tf.concat(next_actions_list, axis=1)
            tf.debugging.check_numerics(next_actions_concat, "Critic next actions")
            target_q1 = self.target_critic_1([batch_next_obs, next_actions_concat], training=False)
            target_q2 = self.target_critic_2([batch_next_obs, next_actions_concat], training=False)
            target_q = tf.minimum(target_q1, target_q2)
            tf.debugging.check_numerics(target_q, "Critic target Q")
            mean_reward = tf.reduce_mean(batch_rewards, axis=1, keepdims=True)
            shared_done = tf.reduce_max(batch_dones, axis=1, keepdims=True)
            target_y = mean_reward + self.gamma * (1.0 - shared_done) * tf.stop_gradient(target_q)
            tf.debugging.check_numerics(target_y, "Critic target y")
            q1 = self.critic_1([batch_obs, batch_actions], training=True)
            q2 = self.critic_2([batch_obs, batch_actions], training=True)
            tf.debugging.check_numerics(q1, "Critic Q1")
            tf.debugging.check_numerics(q2, "Critic Q2")
            critic_1_loss = tf.reduce_mean(tf.square(q1 - target_y))
            critic_2_loss = tf.reduce_mean(tf.square(q2 - target_y))
            critic_loss = critic_1_loss + critic_2_loss
            tf.debugging.check_numerics(critic_loss, "Critic total loss")
        critic_1_grads = critic_tape.gradient(critic_1_loss, self.critic_1.trainable_variables)
        critic_2_grads = critic_tape.gradient(critic_2_loss, self.critic_2.trainable_variables)
        del critic_tape
        if self.gradient_clip_norm is not None:
            if critic_1_grads is not None and all(g is not None for g in critic_1_grads):
                critic_1_grads, _ = tf.clip_by_global_norm(critic_1_grads, self.gradient_clip_norm)
            if critic_2_grads is not None and all(g is not None for g in critic_2_grads):
                critic_2_grads, _ = tf.clip_by_global_norm(critic_2_grads, self.gradient_clip_norm)
        if critic_1_grads is not None and all(g is not None for g in critic_1_grads):
            self.critic_1_optimizer.apply_gradients(zip(critic_1_grads, self.critic_1.trainable_variables))
        if critic_2_grads is not None and all(g is not None for g in critic_2_grads):
            self.critic_2_optimizer.apply_gradients(zip(critic_2_grads, self.critic_2.trainable_variables))

        # ---------------------- Actor Update ----------------------
        actor_loss_list = []
        with tf.GradientTape(persistent=True) as actor_tape:
            agent_actions_tanh_list = []
            agent_log_probs_tanh_list = []
            for i, agent_id in enumerate(self.agent_ids):
                agent_obs = self._get_agent_obs(batch_obs, agent_id)
                action_tanh, log_prob_tanh = self._get_sampled_actions_and_log_probs(
                    self.actors[agent_id], agent_obs
                )
                agent_actions_tanh_list.append(action_tanh)
                agent_log_probs_tanh_list.append(log_prob_tanh)
            current_actions_concat = tf.concat(agent_actions_tanh_list, axis=1)
            q1_new = self.critic_1([batch_obs, current_actions_concat], training=True)
            q_new = q1_new
            tf.debugging.check_numerics(q_new, "Actor update Q_new")
            for i, agent_id_i in enumerate(self.agent_ids):
                log_prob_i = agent_log_probs_tanh_list[i]
                action_i = agent_actions_tanh_list[i]
                obs_i = self._get_agent_obs(batch_obs, agent_id_i)
                log_prob_j_sum = tf.zeros_like(log_prob_i)
                num_others = 0
                for j, agent_id_j in enumerate(self.agent_ids):
                    if i == j:
                        continue
                    log_prob_j_given_i = self._get_log_prob_under_policy(
                        policy_actor_model=self.actors[agent_id_j],
                        eval_agent_obs=obs_i,
                        action_tanh=action_i
                    )
                    log_prob_j_sum += log_prob_j_given_i
                    num_others += 1
                tf.debugging.check_numerics(log_prob_i, f"Actor log_prob_i {i}")
                tf.debugging.check_numerics(log_prob_j_sum, f"Actor log_prob_j_sum {i}")
                actor_loss_i_batch = -q_new + self.alpha * log_prob_i + self.beta * log_prob_j_sum
                actor_loss_i = tf.reduce_mean(actor_loss_i_batch)
                tf.debugging.check_numerics(actor_loss_i, f"Actor loss {i}")
                actor_loss_list.append(actor_loss_i)
        actor_losses_dict_for_return = {}
        for i, agent_id in enumerate(self.agent_ids):
            actor_loss_tensor = actor_loss_list[i]
            actor_losses_dict_for_return[agent_id] = actor_loss_tensor
            actor_vars = self.actors[agent_id].trainable_variables
            grads = actor_tape.gradient(actor_loss_tensor, actor_vars)
            if grads is not None and all(g is not None for g in grads):
                if self.gradient_clip_norm is not None:
                    grads, _ = tf.clip_by_global_norm(grads, self.gradient_clip_norm)
                self.actor_optimizers[agent_id].apply_gradients(zip(grads, actor_vars))
        del actor_tape
        return critic_loss, actor_losses_dict_for_return, tf.constant(0.0), tf.constant(0.0)

    def update(self):
        """Samples batch, performs network updates, and soft-updates targets."""
        if len(self.replay_buffer) < self.batch_size:
            return None, None, None, None
        sample_result = self.replay_buffer.sample(self.batch_size)
        if sample_result is None:
            return None, None, None, None
        batch_obs, batch_actions, batch_rewards, batch_next_obs, batch_dones = sample_result
        try:
            update_result = self._update_networks(
                batch_obs, batch_actions, batch_rewards, batch_next_obs, batch_dones
            )
        except tf.errors.InvalidArgumentError as e:
            print(f"\n!!! Numerical stability error during agent update: {e} !!!")
            return None, None, None, None
        if update_result is None:
            return None, None, None, None
        critic_loss, actor_losses, alpha_loss, beta_loss = update_result
        self.update_target_networks()
        return critic_loss, actor_losses, alpha_loss, beta_loss

    @tf.function(reduce_retracing=True)
    def update_target_networks(self):
        """Performs soft update of target critic networks."""
        for target_var, source_var in zip(self.target_critic_1.trainable_variables, self.critic_1.trainable_variables):
            target_var.assign(self.tau * source_var + (1.0 - self.tau) * target_var)
        for target_var, source_var in zip(self.target_critic_2.trainable_variables, self.critic_2.trainable_variables):
            target_var.assign(self.tau * source_var + (1.0 - self.tau) * target_var)

    def save_weights(self, prefix):
        """Saves weights for actors and critics."""
        try:
            os.makedirs(prefix, exist_ok=True)
            for agent_id, actor in self.actors.items():
                actor.save_weights(os.path.join(prefix, f"actor_{agent_id}.weights.h5"))
            self.critic_1.save_weights(os.path.join(prefix, "critic1.weights.h5"))
            self.critic_2.save_weights(os.path.join(prefix, "critic2.weights.h5"))
            print(f"Weights saved successfully to prefix: {prefix}")
        except Exception as e:
            print(f"Error saving weights with prefix {prefix}: {e}")

    def load_weights(self, prefix):
        """Loads weights for actors and critics."""
        try:
            for agent_id, actor in self.actors.items():
                actor.load_weights(os.path.join(prefix, f"actor_{agent_id}.weights.h5"))
            self.critic_1.load_weights(os.path.join(prefix, "critic1.weights.h5"))
            self.critic_2.load_weights(os.path.join(prefix, "critic2.weights.h5"))
            self.target_critic_1.set_weights(self.critic_1.get_weights())
            self.target_critic_2.set_weights(self.critic_2.get_weights())
            print(f"Weights loaded successfully from prefix: {prefix}")
        except Exception as e:
            print(f"Error loading weights from prefix {prefix}: {e}")

# --- Training Function ---
def train_masac_v2(env_name="simple_spread_v3", hidden_units=128, actor_lr=1e-4, critic_lr=1e-3,
                   tau=0.01, gamma=0.99, buffer_capacity=100000, batch_size=512,
                   num_episodes=10000, max_steps_per_episode=100,
                   initial_alpha=0.2, initial_beta=0.1, target_entropy_scale=1.0, target_ce_scale=0.1,
                   start_steps=5000, update_every=1, log_interval=50,
                   gradient_clip_norm=1.0):
    """Trains the MASACAgentV2 with fixed alpha and beta."""
    env = make_env(env_name=env_name, continuous_actions=True, max_cycles=max_steps_per_episode)
    agent_ids = env.possible_agents
    obs_spaces = env.observation_spaces
    action_spaces = env.action_spaces
    obs_dims = {agent_id: obs_spaces[agent_id].shape[0] for agent_id in agent_ids}
    action_dims = {agent_id: action_spaces[agent_id].shape[0] for agent_id in agent_ids}
    agent = MASACAgentV2(env, agent_ids, obs_dims, action_dims, action_spaces,
                         hidden_units, actor_lr, critic_lr,
                         tau, gamma, buffer_capacity, batch_size,
                         initial_alpha, initial_beta, target_entropy_scale, target_ce_scale,
                         gradient_clip_norm=gradient_clip_norm)
    episode_rewards_history = deque(maxlen=100)
    total_steps = 0
    updates_performed = 0
    start_time = time.time()
    print(f"Starting training (V2 Objective) for {num_episodes} episodes...")
    print(f"Settings: batch_size={batch_size}, start_steps={start_steps}, update_every={update_every}")
    print(f"LRs: actor={actor_lr}, critic={critic_lr}")
    print(f"Fixed alpha={initial_alpha}, beta={initial_beta}, tau={tau}, gamma={gamma}, grad_clip={gradient_clip_norm}")
    for episode in range(num_episodes):
        try:
            observations, infos = env.reset()
            episode_reward_sum = 0
            terminations = {agent_id: False for agent_id in agent_ids}
            truncations = {agent_id: False for agent_id in agent_ids}
            steps_in_ep = 0
            while not (any(terminations.values()) or any(truncations.values())):
                if total_steps < start_steps:
                    actions = {agent_id: action_spaces[agent_id].sample() for agent_id in agent.agent_ids}
                else:
                    valid_obs = {k: v for k, v in observations.items() if k in agent.agent_ids and v is not None}
                    if len(valid_obs) < len(agent.agent_ids):
                        pass
                    actions = agent.get_actions(valid_obs, evaluate=False)
                next_observations, rewards, terminations, truncations, infos = env.step(actions)
                dones_dict = {agent_id: terminations.get(agent_id, False) or truncations.get(agent_id, False) for agent_id in agent.agent_ids}
                reward_dict_ordered = {agent_id: rewards.get(agent_id, 0.0) for agent_id in agent.agent_ids}
                action_dict_ordered = {agent_id: actions.get(agent_id, np.zeros(agent.action_dims[agent_id])) for agent_id in agent.agent_ids}
                obs_dict_ordered = {agent_id: observations.get(agent_id) for agent_id in agent.agent_ids}
                next_obs_dict_ordered = {agent_id: next_observations.get(agent_id) for agent_id in agent.agent_ids}
                agent.replay_buffer.push(obs_dict_ordered, action_dict_ordered, reward_dict_ordered, next_obs_dict_ordered, dones_dict)
                observations = next_observations
                episode_reward_sum += sum(rewards.values())
                steps_in_ep += 1
                total_steps += 1
                if total_steps >= start_steps and total_steps % update_every == 0:
                    if len(agent.replay_buffer) >= batch_size:
                        num_updates = 1
                        for _ in range(num_updates):
                            update_result = agent.update()
                            if update_result is not None:
                                updates_performed += 1
                                if updates_performed % log_interval == 0:
                                    critic_loss, actor_losses, _, _ = update_result
                                    with summary_writer.as_default(step=total_steps):
                                        tf.summary.scalar('Loss/Critic_Loss', critic_loss)
                                        if isinstance(actor_losses, dict):
                                            for ag_id, loss in actor_losses.items():
                                                loss_val = loss.numpy() if hasattr(loss, 'numpy') else loss
                                                if np.isscalar(loss_val):
                                                    tf.summary.scalar(f'Loss/Actor_{ag_id}', loss_val)
                                        if episode_rewards_history:
                                            avg_rew = np.mean(episode_rewards_history)
                                            tf.summary.scalar('Reward/Avg_Episode_Reward_100', avg_rew)
                if steps_in_ep >= max_steps_per_episode:
                    truncations = {agent_id: True for agent_id in agent.agent_ids}
                    break
            episode_rewards_history.append(episode_reward_sum)
            avg_reward_100 = np.mean(episode_rewards_history) if episode_rewards_history else 0.0
            if (episode + 1) % 10 == 0:
                elapsed_time = time.time() - start_time
                print(f"Ep {episode+1}: Steps={steps_in_ep}, Avg R (100)={avg_reward_100:.2f}, "
                      f"Buffer={len(agent.replay_buffer)}, Tot Steps={total_steps}, Time={elapsed_time:.1f}s")
        except Exception as e:
            print(f"\n!!! Error during episode {episode+1}: {e} !!!")
            import traceback
            traceback.print_exc()
            continue
    total_training_time = time.time() - start_time
    print(f"\nTraining finished after {num_episodes} episodes in {total_training_time:.2f} seconds.")
    final_avg_reward = np.mean(episode_rewards_history) if episode_rewards_history else 0.0
    print(f"Final Avg Reward (last 100 episodes): {final_avg_reward:.2f}")
    env.close()
    return episode_rewards_history, deque(maxlen=100), deque(maxlen=100), total_training_time, agent, log_dir


if __name__ == "__main__":
    print(f"TensorFlow Version: {tf.__version__}")
    # --- Hyperparameters ---
    ENV_NAME = "simple_spread_v3"
    HIDDEN_UNITS = 128
    ACTOR_LR = 3e-4
    CRITIC_LR = 1e-3
    TAU = 0.01
    GAMMA = 0.99
    BUFFER_CAPACITY = 200000
    BATCH_SIZE = 512
    NUM_EPISODES = 5000
    MAX_STEPS_PER_EPISODE = 100
    INITIAL_ALPHA = 0.2
    INITIAL_BETA = 0.1
    TARGET_ENTROPY_SCALE = 1.0
    TARGET_CE_SCALE = 0.1
    START_STEPS = 2000
    UPDATE_EVERY = 1
    LOG_INTERVAL = 100
    GRADIENT_CLIP_NORM = 1.0
    # --- Start Training ---
    tf.get_logger().setLevel('ERROR')
    results = None
    try:
        results = train_masac_v2(
            env_name=ENV_NAME, hidden_units=HIDDEN_UNITS, actor_lr=ACTOR_LR, critic_lr=CRITIC_LR,
            tau=TAU, gamma=GAMMA, buffer_capacity=BUFFER_CAPACITY, batch_size=BATCH_SIZE,
            num_episodes=NUM_EPISODES, max_steps_per_episode=MAX_STEPS_PER_EPISODE,
            initial_alpha=INITIAL_ALPHA, initial_beta=INITIAL_BETA, target_entropy_scale=TARGET_ENTROPY_SCALE, target_ce_scale=TARGET_CE_SCALE,
            start_steps=START_STEPS, update_every=UPDATE_EVERY, log_interval=LOG_INTERVAL,
            gradient_clip_norm=GRADIENT_CLIP_NORM
        )
    except Exception as main_exception:
        print(f"\n!!! Critical Error during training: {main_exception} !!!")
        import traceback
        traceback.print_exc()
    if results:
        episode_rewards_hist, _, _, training_time, trained_agent, save_log_dir = results
        print("Training run completed.")
        trained_agent.save_weights(os.path.join(save_log_dir, "final_model"))
    else:
        print("Training did not complete successfully or was interrupted.")

GPU is available and will be used.
Enabled GPU memory growth.
TensorBoard logs will be saved to: masac_v2_tensorflow_logs_simple_spread_fixed_alpha_beta
TensorFlow Version: 2.18.0




Starting training (V2 Objective) for 5000 episodes...
Settings: batch_size=512, start_steps=2000, update_every=1
LRs: actor=0.0003, critic=0.001
Fixed alpha=0.2, beta=0.1, tau=0.01, gamma=0.99, grad_clip=1.0
Ep 10: Steps=100, Avg R (100)=-282.89, Buffer=1000, Tot Steps=1000, Time=0.9s
Ep 20: Steps=100, Avg R (100)=-299.63, Buffer=2000, Tot Steps=2000, Time=11.6s
Ep 30: Steps=100, Avg R (100)=-622.04, Buffer=3000, Tot Steps=3000, Time=35.1s
Ep 40: Steps=100, Avg R (100)=-664.35, Buffer=4000, Tot Steps=4000, Time=57.9s
Ep 50: Steps=100, Avg R (100)=-660.60, Buffer=5000, Tot Steps=5000, Time=80.8s
Ep 60: Steps=100, Avg R (100)=-656.95, Buffer=6000, Tot Steps=6000, Time=104.2s
Ep 70: Steps=100, Avg R (100)=-663.90, Buffer=7000, Tot Steps=7000, Time=127.3s
Ep 80: Steps=100, Avg R (100)=-672.27, Buffer=8000, Tot Steps=8000, Time=150.2s
Ep 90: Steps=100, Avg R (100)=-670.02, Buffer=9000, Tot Steps=9000, Time=173.4s
Ep 100: Steps=100, Avg R (100)=-656.37, Buffer=10000, Tot Steps=10000, Time=19