In [1]:
#import tensorflow as tf
from pyinfusion import TransitionDynamic
import pandas as pd
import numpy as np
import json


#import torch
#from collections import deque
#import random

#import QPLEXBuild
#import SingleAgentRL

## Preparing the Data based on Action

In [2]:
def prepare_data(infusion_data, maintenance_data, num_actions = 2):
    
    ## insert repair label
    maintenance_data['repair'] = np.where(maintenance_data.PartCost.notnull(), 1, 0)

    # 
    maint_df = maintenance_data.copy()

    ## group the maintenance df
    maint_df = maint_df.groupby('WO_WO#').agg({'repair': 'first'}).reset_index()

    ## left join the infusion data and maintenance data
    infusion_data = infusion_data.merge(maint_df, how='left',
                                       left_on = 'WO_WO#', right_on = 'WO_WO#')

    ## split the infusion_data into scenarios of repair and non_repair
    split_data = {}

    for i in range(num_actions):
        split_data[i] = {'infusion_log' : infusion_data[infusion_data.repair == i],
                         'maintenance_log' : maintenance_data[maintenance_data.repair == i]
                        }

    return split_data

# Single Agent

## Computing Transition Matrix - Single Agent

In [3]:
def action_transition_matrix(prepped_failure_info, num_actions, states, col_name = 'PCUSerialNumber'):
    trans_matrices = {}

    for i in range(num_actions):
        action_infusion_information = prepped_failure_info[i]['infusion_log']
        action_maint_information = prepped_failure_info[i]['maintenance_log']

        # instantitate the Transition Dynamic
        transition_dynamic = TransitionDynamic(infusion_log = action_infusion_information,
                                               maintenance_log = action_maint_information,
                                               col_name = col_name,
                                               states = states)

        # calculate the the DTMC matrix
        trans_matrices[i] = transition_dynamic.system_environment(group_assets=True)

    return trans_matrices

# Multi Agent 

## Compute Transition Matrix - Multi Agent

In [4]:
def select_agent(transition_matrix_, agents_id, num_actions):
    agents_trans_env = [] # store agents transition information
    transition_evironment = {} # store the transition matrix for each action

    for agent in agents_id:
        transition_environment = {}

        for action in range(num_actions):
            transition_environment[action] = transition_matrix_[action][agent]['dtmc_matrix_pandas'].to_json(orient='records')

        agents_trans_env.append(transition_environment)

    return agents_trans_env

## Compute Cost - Multi Agent
- Assume Fixed Cost for Transition State based on Generalized Cost Info

In [5]:
def cost_multi_agent(costs_single, num_agents):
    return [costs_single for agent in range(num_agents)]

# Obtain Agent Information

In [6]:
# importing the data
pcu_failure_info = pd.read_csv('/Users/mobolajishobanke/Desktop/Fall Research/NN_Class_Project/pcu_failure_information.csv')

## check sample
pcu_failure_info.head(2)

Unnamed: 0,PCUSerialNumber,WO_Requested,WO_WO#,WO_Type,ActiveStartTime,ActiveStopTime,TotalInfusionTime,TotalEqActiveTime
0,12828160,2020-08-19,682305,CEIN,2020-01-02 18:14:58,2020-08-15 09:07:06,40707750,52291662
1,12828160,2021-08-19,742583,CEIN,2020-08-18 22:37:09,2020-12-31 20:39:32,27605763,32650026


In [7]:
## reading in the maintenance data
maintenance_data = pd.read_excel('/Users/mobolajishobanke/Desktop/Fall Research/NN_Class_Project/maintenance_data_2005_2022.xlsx')

## filter based on CEIN, CECM, HZARD
maintenance_data = maintenance_data[maintenance_data.WO_Type.isin(['CECM', 'CEIN', 'HZARC'])]

maintenance_data.columns

Index(['Asset_Status', 'Asset_AssetID', 'Asset_Serial', 'Asset_AssetPK',
       'Asset_Classification', 'Asset_Model', 'Asset_Manufacturer',
       'Asset_InstallDate', 'WO_WO#', 'WO_Requested', 'WO_Closed', 'WO_Type',
       'WO_Type_Desc', 'WO_Substatus', 'WO_Problem', 'WO_Failure',
       'WO_Solution', 'PartID', 'PartName', 'PartCost', 'WO_Reason',
       'WO_LaborReport'],
      dtype='object')

In [8]:
# compute prepped data from infusion and maintenance data
failure_info = prepare_data(infusion_data = pcu_failure_info,
                           maintenance_data = maintenance_data)

print(failure_info.keys())
print(failure_info[0].keys())

dict_keys([0, 1])
dict_keys(['infusion_log', 'maintenance_log'])


In [9]:
## obtain maintenance cost per state
maint_costs =  maintenance_data.groupby('WO_Type').agg({'PartCost':'mean'})
maint_costs

Unnamed: 0_level_0,PartCost
WO_Type,Unnamed: 1_level_1
CECM,60.731572
CEIN,28.667771
HZARC,18.190441


In [10]:
# all states in the system
states_ = ['infusing'] + maintenance_data.WO_Type.unique().tolist()
states_

['infusing', 'CEIN', 'CECM', 'HZARC']

In [11]:
# create cost for all states. It is assumed that it is more expensive to not carry out a repair if a system breaks down
costs = {
    0:{0: 0.0, 1: 0, 2: 160.73, 3: 118.19},
    1: {0: 0.0, 1: 28.67, 2: 60.73, 3: 18.19}
}

In [12]:
# compute the action transition matrices for all agents
transition_matrices_all_agents = action_transition_matrix(prepped_failure_info = failure_info, 
                                                         num_actions = 2,
                                                         states = states_)

Transtion Matrix Computation: 100%|██████████| 1016/1016 [00:02<00:00, 376.38it/s]
Transtion Matrix Computation: 100%|██████████| 1381/1381 [00:03<00:00, 435.13it/s]


### Selecting Sample Agent To Be Used ValueIteration and DQN Implementation
 - use agent with the most number of occurrence in the infusion data

In [13]:
occurrence_count = pcu_failure_info.groupby('PCUSerialNumber').agg({'ActiveStartTime': 'count'}).sort_values(by= 'ActiveStartTime', ascending = False).reset_index()
agents_serial = occurrence_count.PCUSerialNumber.values.tolist()[:5]
occurrence_count.head(5)

Unnamed: 0,PCUSerialNumber,ActiveStartTime
0,12992579,7
1,13923991,6
2,13923356,6
3,14154795,5
4,14157251,5


In [14]:
# view agents serial
agents_serial

[12992579, 13923991, 13923356, 14154795, 14157251]

### Obtain Transition Matrix and Cost of Selected Agent

In [15]:
# chose agent
chosen_agent = agents_serial[:1]

agent_transition_matrix = select_agent(transition_matrix_ = transition_matrices_all_agents,
                                      agents_id = chosen_agent,
                                      num_actions = 2)

# print transition matrices for each action
num_actions = 2

for agent in range(len(chosen_agent)):
    for action in range(num_actions):
        print(f'action: {action}:\n{agent_transition_matrix[agent][action]} \n')
    

action: 0:
[{"0.0":0.0,"1.0":0.3333333333,"2.0":0.6666666667},{"0.0":0.0,"1.0":0.0,"2.0":1.0},{"0.0":0.6666666667,"1.0":0.0,"2.0":0.3333333333}] 

action: 1:
[{"0.0":0.0,"1.0":0.0,"2.0":1.0},{"0.0":1.0,"1.0":0.0,"2.0":0.0},{"0.0":0.2,"1.0":0.4,"2.0":0.4}] 



### Store Agent Interation Information in a json file for future import

In [16]:
pd.DataFrame(json.loads(agent_transition_matrix[0][0]))

Unnamed: 0,0.0,1.0,2.0
0,0.0,0.333333,0.666667
1,0.0,0.0,1.0
2,0.666667,0.0,0.333333


In [17]:
# store information
#agent_information = (chosen_agent, agent_transition_matrix, costs)
agent_information = (agent_transition_matrix, costs)

# serializing infrmation for json
agent_information = json.dumps(agent_information)

with open('single_agent_historical_information.json', 'w') as file:
    file.write(agent_information)

### Selecting Multiple Agents

In [17]:
# chose agent
chosen_agent = agents_serial[:5]

agent_transition_matrix = select_agent(transition_matrix_ = transition_matrices_all_agents,
                                      agents_id = chosen_agent,
                                      num_actions = 2)

# print transition matrices for each action
num_actions = 2

for agent in range(len(chosen_agent)):
    print('Agent ID: {}'.format(chosen_agent[agent]) )
    for action in range(num_actions):
        print(f'action: {action}:\n{agent_transition_matrix[agent][action]} \n')
    

Agent ID: 12992579
action: 0:
[{"0.0":0.0,"1.0":0.3333333333,"2.0":0.6666666667},{"0.0":0.0,"1.0":0.0,"2.0":1.0},{"0.0":0.6666666667,"1.0":0.0,"2.0":0.3333333333}] 

action: 1:
[{"0.0":0.0,"1.0":0.0,"2.0":1.0},{"0.0":1.0,"1.0":0.0,"2.0":0.0},{"0.0":0.2,"1.0":0.4,"2.0":0.4}] 

Agent ID: 13923991
action: 0:
[{"0.0":0.0,"1.0":0.25,"2.0":0.75},{"0.0":1.0,"1.0":0.0,"2.0":0.0},{"0.0":0.4,"1.0":0.2,"2.0":0.4}] 

action: 1:
[{"0.0":0.0,"1.0":0.5,"2.0":0.5},{"0.0":1.0,"1.0":0.0,"2.0":0.0},{"0.0":0.0,"1.0":1.0,"2.0":0.0}] 

Agent ID: 13923356
action: 0:
[{"0.0":0.0,"2.0":1.0},{"0.0":1.0,"2.0":0.0}] 

action: 1:
[{"0.0":0.0,"1.0":0.5,"2.0":0.5},{"0.0":0.0,"1.0":0.0,"2.0":1.0},{"0.0":0.6,"1.0":0.2,"2.0":0.2}] 

Agent ID: 14154795
action: 0:
[{"0.0":0.0,"1.0":0.3333333333,"2.0":0.6666666667},{"0.0":0.0,"1.0":0.0,"2.0":1.0},{"0.0":0.6666666667,"1.0":0.0,"2.0":0.3333333333}] 

action: 1:
[{"0.0":0.0,"1.0":0.5,"2.0":0.5},{"0.0":1.0,"1.0":0.0,"2.0":0.0},{"0.0":0.0,"1.0":1.0,"2.0":0.0}] 

Agent ID: 1415

In [18]:
multiagent_cost = cost_multi_agent(costs_single =costs,
                                   num_agents = len(chosen_agent))

In [30]:
idx_agent_idx = [0, 1, 3]

selected_multiagent_choice = [agent_transition_matrix[i] for i in idx_agent_idx]
costs_selected = [multiagent_cost[i] for i in idx_agent_idx]

In [29]:
json.loads(agent_transition_matrix[1][1])

[{'0.0': 0.0, '1.0': 0.5, '2.0': 0.5},
 {'0.0': 1.0, '1.0': 0.0, '2.0': 0.0},
 {'0.0': 0.0, '1.0': 1.0, '2.0': 0.0}]

In [20]:
multiagent_cost[0][1]

{0: 0.0, 1: 28.67, 2: 60.73, 3: 18.19}

In [31]:
 # store information
agent_information = (selected_multiagent_choice, costs_selected)

# serializing infrmation for json
agent_information = json.dumps(agent_information)

with open('multi_agent_3___new.json', 'w') as file:
    file.write(agent_information)

In [None]:
# change pandas dataframe to json file


In [None]:
agent_information = (chosen_agent, agent_transition_matrix, costs)

In [None]:
# save the agent information as a json file
with open('single_agent_information.json', 'w') as file:
    json.dump(agent_information, file)

In [None]:
## maintenance_dat
maint_costs =  maintenance_data.groupby('WO_Type').agg({'PartCost':'mean'})
maint_costs

In [None]:
costs = {
    0:{0: 0.0, 1: 0, 2: 160.73, 3: 118.19},
    1: {0: 0.0, 1: 28.67, 2: 60.73, 3: 18.19}
}

In [None]:
def cost_compute(maint_cost, encoded_states, num_actions, states):
    



    
    num_states = len(states)
    
    action_costs = {} # costs per action
    costs = {}

    for action in range(num_actions):
        for state in range(num_state):
            if states[state] == 0:
                costs[state] = 0

            if action == 0 and states[state] == 'CEIN': # CEIN : Planned Maintenance
                costs[state] = 0
                
            
            
        

In [None]:
states_ = ['infu']

In [None]:
# action = 0
infusion_0 = failure_info[0]['infusion_log']
maint_0 = failure_info[0]['maintenance_log']

# action = 1
infusion_1 = failure_info[1]['infusion_log']
maint_1 = failure_info[1]['maintenance_log']

In [None]:
wo_types = maintenance_data.WO_Type.unique().tolist()
states_ = ['infusing'] + wo_types
states_

In [None]:
# action = 0
transition_dynamic_0 = TransitionDynamic(infusion_log = infusion_0,
                                        maintenance_log = maint_0,
                                        col_name = 'PCUSerialNumber',
                                        states = states_)

## obtain the dtmc transition information
transition_matrix_0 = transition_dynamic_0.system_environment(group_assets = True)

# action = 1
transition_dynamic_1 = TransitionDynamic(infusion_log = infusion_1,
                                        maintenance_log = maint_1,
                                        col_name = 'PCUSerialNumber',
                                        states = states_)

## obtain the dtmc transition information
transition_matrix_1 = transition_dynamic_1.system_environment(group_assets = True)

In [None]:
pcu_failure_info.groupby('PCUSerialNumber').agg({'ActiveStartTime': 'count'}).sort_values(by= 'ActiveStartTime', ascending = False)

## Transition Dynamic for Sample Agent/PCU = 12992579

In [None]:
## extracting information for sample pcu
transition_matrix_0[12992579].keys()

In [None]:
transition_environment = {
    0 : transition_matrix_0[12992579]['dtmc_matrix_pandas'],
    1: transition_matrix_1[12992579]['dtmc_matrix_pandas']
}

In [None]:
transition_environment[0].loc[2].values

In [None]:
## maintenance_dat
maint_costs =  maintenance_data.groupby('WO_Type').agg({'PartCost':'mean'})
maint_costs

In [None]:
c1 = [0]

for i in states_[1:]:
    cost_i = round(maint_1[maint_1.WO_Type == i].PartCost.mean(), 2)
    c1.append(cost_i)

c0 = [0, 0, c1[2]+100]

print('c0 : {}'.format(c0))
print('c1 : {}'.format(c1))

In [None]:
transition_environment = {
    0 : transition_matrix_0[12992579]['dtmc_matrix'],
    1: transition_matrix_1[12992579]['dtmc_matrix']
}

cost = {0 : c0, 1 : c1}

## Value Determination for Sample Agent - Run With Pyinfusion

In [None]:
# determine the optimal policy using the Value Iteration Formula

optimal_policy = pyinfusion.value_iteration(environment = transition_environment,
                                               cost = cost,
                                               n_actions = 2,
                                               tol = 1e-16,
                                               states = states_)

optimal_policy

In [None]:
transition_environment

## Testing Online DQN

In [None]:
dqn = pyinfusion.OnlineDQN(environment = transition_environment,
                          cost = cost,
                          num_states = len(states_),
                          num_actions = 2)

In [None]:
dqn.train()

In [None]:
mod = dqn.q_model

In [None]:
n = len(states_)
data = np.eye(n)

In [None]:
data

In [None]:
np.argmin(mod.predict(data), axis=1)

In [None]:
xer = pyinfusion.OfflineDQN(environment = transition_environment,
                          cost = cost,
                          num_states = len(states_),
                          num_actions = 2)

xer.train()

In [None]:
mod = xer.action_predictor
np.argmin(mod.predict(data), axis=1)

In [None]:
import tensorflow as tf
from tensorflow.keras import layers

# Agent DQN for Q-value estimation
class AgentDQN(tf.keras.Model):
    def __init__(self, state_dim, action_dim):
        super(AgentDQN, self).__init__()
        self.fc1 = layers.Dense(64, activation='relu')
        self.fc2 = layers.Dense(64, activation='relu')
        self.q_out = layers.Dense(action_dim)  # Output Q-values

    def call(self, state):
        x = self.fc1(state)
        x = self.fc2(x)
        q_values = self.q_out(x)
        return q_values

# Transformation Network with Softplus activation to ensure positive weight outputs
class TransformationNetwork(tf.keras.Model):
    def __init__(self, hidden_dim=64):
        super(TransformationNetwork, self).__init__()
        self.fc1 = layers.Dense(hidden_dim, activation='relu')
        self.fc2 = layers.Dense(hidden_dim, activation='relu')
        self.w = layers.Dense(1, activation='softplus')  # Softplus ensures positive output
        self.b = layers.Dense(1)

    def call(self, global_state, local_value):
        x = self.fc1(global_state)
        x = self.fc2(x)
        weight = self.w(x)
        bias = self.b(x)
        transformed_value = weight * local_value + bias
        return transformed_value

# QPLEX Model
class QPLEX(tf.keras.Model):
    def __init__(self, state_dim, action_dim, num_agents, hidden_dim=64):
        super(QPLEX, self).__init__()
        self.num_agents = num_agents
        self.agents = [AgentDQN(state_dim, action_dim) for _ in range(num_agents)]
        self.value_transformers = [TransformationNetwork(hidden_dim) for _ in range(num_agents)]
        self.advantage_transformers = [TransformationNetwork(hidden_dim) for _ in range(num_agents)]
        self.lambda_weights = layers.Dense(num_agents, activation='softplus')  # Positive weights

    def call(self, states, global_state, actions):
        q_values = [self.agents[i](states[:, i]) for i in range(self.num_agents)]
        q_values = tf.stack(q_values, axis=1)  # Shape: (batch_size, num_agents, action_dim)

        batch_indices = tf.range(tf.shape(q_values)[0])[:, None]
        selected_q_values = tf.gather_nd(q_values, tf.concat([batch_indices, actions], axis=-1))
        values = tf.reduce_max(selected_q_values, axis=-1, keepdims=True)
        advantages = selected_q_values - values

        transformed_values = [self.value_transformers[i](global_state, values[:, i]) for i in range(self.num_agents)]
        transformed_advantages = [self.advantage_transformers[i](global_state, advantages[:, i]) for i in range(self.num_agents)]

        transformed_values = tf.stack(transformed_values, axis=1)
        transformed_advantages = tf.stack(transformed_advantages, axis=1)

        lambda_weights = self.lambda_weights(global_state)
        lambda_weights = tf.expand_dims(lambda_weights, axis=-1)
        joint_q_value = tf.reduce_sum(transformed_values + lambda_weights * transformed_advantages, axis=1)
        return joint_q_value  # Shape: (batch_size, 1)

In [None]:
# Initialize environment and replay buffers
num_agents = 2
state_dim = 4
action_dim = 3
transition_matrices = [np.random.rand(3, 3, 3) for _ in range(num_agents)]  # Random transition matrices

env = MultiAgentEnvironment(num_agents=num_agents, transition_matrices=transition_matrices)
buffers = [ReplayBuffer(max_size=1000) for _ in range(num_agents)]
qplex_model = QPLEX(state_dim=state_dim, action_dim=action_dim, num_agents=num_agents)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

# Training function
def train_step(states, actions, rewards, next_states, global_state):
    with tf.GradientTape() as tape:
        current_q_values = qplex_model(states, global_state, actions)
        target_q_values = rewards + 0.99 * qplex_model(next_states, global_state, actions)
        loss = tf.reduce_mean(tf.square(target_q_values - current_q_values))

    grads = tape.gradient(loss, qplex_model.trainable_variables)
    optimizer.apply_gradients(zip(grads, qplex_model.trainable_variables))
    return loss

# Main training loop
for episode in range(100):
    states = env.reset()
    for t in range(50):
        actions = [np.random.randint(action_dim) for _ in range(num_agents)]
        next_states, rewards = env.step(actions)
        
        for i in range(num_agents):
            buffers[i].add((states[i], actions[i], rewards[i], next_states[i]))
        
        # Sample from buffer and train
        if len(buffers[0]) >= 32:
            batch_states = np.array([buffer.sample(32)[0] for buffer in buffers])
            batch_actions = np.array([buffer.sample(32)[1] for buffer in buffers])
            batch_rewards = np.array([buffer.sample(32)[2] for buffer in buffers])
            batch_next_states = np.array([buffer.sample(32)[3] for buffer in buffers])
            global_state = np.random.random((32, state_dim))  # Simplified global state

            loss = train_step(batch_states, batch_actions, batch_rewards, batch_next_states, global_state)
            print(f"Episode {episode}, Step {t}, Loss: {loss.numpy()}")

        states = next_states

In [None]:
# Initialize the loss function
mse_loss = tf.keras.losses.MeanSquaredError()

# Training step using the built-in MSE loss function
def train_step(states, actions, rewards, next_states, global_state):
    with tf.GradientTape() as tape:
        current_q_values = qplex_model(states, global_state, actions)
        target_q_values = rewards + 0.99 * qplex_model(next_states, global_state, actions)  # Discounted future rewards
        
        # Use MeanSquaredError directly
        loss = mse_loss(target_q_values, current_q_values)
    
    grads = tape.gradient(loss, qplex_model.trainable_variables)
    optimizer.apply_gradients(zip(grads, qplex_model.trainable_variables))
    return loss

In [None]:
global_state = np.concatenate([obs for obs in agent_observations])

In [None]:
import numpy as np

# Set epsilon parameters
initial_epsilon = 1.0  # Starting epsilon (full exploration)
min_epsilon = 0.01  # Minimum epsilon (almost full exploitation)
epsilon_decay = 0.995  # Decay rate for epsilon per episode

# Initialize epsilon for each agent
epsilon_values = [initial_epsilon for _ in range(num_agents)]

def select_action(agent_index, state, epsilon):
    """
    Epsilon-greedy action selection for a single agent.
    """
    if np.random.rand() < epsilon:
        # Exploration: choose a random action
        action = np.random.randint(action_dim)
    else:
        # Exploitation: choose the action with the highest Q-value
        q_values = qplex_model.agents[agent_index](tf.convert_to_tensor([state], dtype=tf.float32))
        action = tf.argmax(q_values, axis=1).numpy()[0]
    return action

In [None]:
# Training loop with epsilon-greedy policy for exploration
for episode in range(total_episodes):
    # Reset environment and initialize rewards
    states = env.reset()
    episode_rewards = [0 for _ in range(num_agents)]

    for t in range(steps_per_episode):
        actions = []
        for i in range(num_agents):
            # Select action for each agent with epsilon-greedy policy
            action = select_action(i, states[i], epsilon_values[i])
            actions.append(action)

        # Take a step in the environment with the selected actions
        next_states, rewards = env.step(actions)

        # Accumulate rewards for tracking
        for i in range(num_agents):
            episode_rewards[i] += rewards[i]

        # Store experiences in the replay buffer
        for i in range(num_agents):
            buffers[i].add((states[i], actions[i], rewards[i], next_states[i]))

        # Sample from buffer and train if buffer has enough samples
        if len(buffers[0]) >= batch_size:
            batch_states = np.array([buffer.sample(batch_size)[0] for buffer in buffers])
            batch_actions = np.array([buffer.sample(batch_size)[1] for buffer in buffers])
            batch_rewards = np.array([buffer.sample(batch_size)[2] for buffer in buffers])
            batch_next_states = np.array([buffer.sample(batch_size)[3] for buffer in buffers])
            global_state = np.random.random((batch_size, state_dim))  # Simplified global state for training

            # Train the QPLEX model
            loss = train_step(batch_states, batch_actions, batch_rewards, batch_next_states, global_state)

        # Move to next state
        states = next_states

    # Decay epsilon for each agent after the episode
    for i in range(num_agents):
        epsilon_values[i] = max(min_epsilon, epsilon_values[i] * epsilon_decay)

    # Optionally, print the epsilon values and episode rewards
    if episode % 10 == 0:
        print(f"Episode {episode}, Epsilon: {epsilon_values}, Episode Rewards: {episode_rewards}")

In [None]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np

# Agent DQN for Q-value estimation
class AgentDQN(tf.keras.Model):
    def __init__(self, state_dim, action_dim):
        super(AgentDQN, self).__init__()
        self.fc1 = layers.Dense(64, activation='relu')
        self.fc2 = layers.Dense(64, activation='relu')
        self.q_out = layers.Dense(action_dim)  # Output Q-values

    def call(self, state):
        x = self.fc1(state)
        x = self.fc2(x)
        q_values = self.q_out(x)
        return q_values

# Transformation Network with Softplus activation to ensure positive weights
class TransformationNetwork(tf.keras.Model):
    def __init__(self, hidden_dim=64):
        super(TransformationNetwork, self).__init__()
        self.fc1 = layers.Dense(hidden_dim, activation='relu')
        self.fc2 = layers.Dense.Dense(hidden_dim, activation='relu')
        self.w = layers.Dense(1, activation='softplus')  # Ensures positive output for weight
        self.b = layers.Dense(1)  # Bias term

    def call(self, global_state, local_value):
        # Use the global state as input to transform local Q-values
        x = self.fc1(global_state)
        x = self.fc2(x)
        weight = self.w(x)  # Softplus ensures weight is positive
        bias = self.b(x)
        transformed_value = weight * local_value + bias
        return transformed_value

# QPLEX Model with Global State Concatenation
class QPLEX(tf.keras.Model):
    def __init__(self, state_dim, action_dim, num_agents, hidden_dim=64):
        super(QPLEX, self).__init__()
        self.num_agents = num_agents
        self.agents = [AgentDQN(state_dim, action_dim) for _ in range(num_agents)]
        self.value_transformers = [TransformationNetwork(hidden_dim) for _ in range(num_agents)]
        self.advantage_transformers = [TransformationNetwork(hidden_dim) for _ in range(num_agents)]
        self.lambda_weights = layers.Dense(num_agents, activation='softplus')  # Positive weights for monotonicity

    def call(self, states, actions):
        # Concatenate all agents' states to form the global state
        global_state = tf.concat([states[:, i] for i in range(self.num_agents)], axis=-1)  # Shape: (batch_size, total_state_dim)

        # Calculate individual Q-values for each agent
        q_values = [self.agents[i](states[:, i]) for i in range(self.num_agents)]
        q_values = tf.stack(q_values, axis=1)  # Shape: (batch_size, num_agents, action_dim)

        # Extract values and advantages for the selected actions
        batch_indices = tf.range(tf.shape(q_values)[0])[:, None]
        selected_q_values = tf.gather_nd(q_values, tf.concat([batch_indices, actions], axis=-1))
        values = tf.reduce_max(selected_q_values, axis=-1, keepdims=True)
        advantages = selected_q_values - values

        # Transform values and advantages using the global state
        transformed_values = [self.value_transformers[i](global_state, values[:, i]) for i in range(self.num_agents)]
        transformed_advantages = [self.advantage_transformers[i](global_state, advantages[:, i]) for i in range(self.num_agents)]

        # Stack and aggregate transformed values and advantages
        transformed_values = tf.stack(transformed_values, axis=1)
        transformed_advantages = tf.stack(transformed_advantages, axis=1)

        lambda_weights = self.lambda_weights(global_state)
        lambda_weights = tf.expand_dims(lambda_weights, axis=-1)
        joint_q_value = tf.reduce_sum(transformed_values + lambda_weights * transformed_advantages, axis=1)
        return joint_q_value  # Shape: (batch_size, 1)

In [None]:
# Training loop with epsilon-greedy policy and global state concatenation
for episode in range(total_episodes):
    # Reset environment and initialize rewards
    states = env.reset()
    episode_rewards = [0 for _ in range(num_agents)]

    for t in range(steps_per_episode):
        actions = []
        for i in range(num_agents):
            # Select action for each agent with epsilon-greedy policy
            action = select_action(i, states[i], epsilon_values[i])
            actions.append(action)

        # Take a step in the environment
        next_states, rewards = env.step(actions)

        # Concatenate next states to form the next global state
        global_state = np.concatenate([states[i] for i in range(num_agents)])

        # Store experiences in the replay buffer
        for i in range(num_agents):
            buffers[i].add((states[i], actions[i], rewards[i], next_states[i]))

        # Train the QPLEX model with the global state concatenation
        if len(buffers[0]) >= batch_size:
            batch_states = np.array([buffer.sample(batch_size)[0] for buffer in buffers])
            batch_actions = np.array([buffer.sample(batch_size)[1] for buffer in buffers])
            batch_rewards = np.array([buffer.sample(batch_size)[2] for buffer in buffers])
            batch_next_states = np.array([buffer.sample(batch_size)[3] for buffer in buffers])

            # Concatenate batch of states for global state input
            batch_global_state = np.concatenate([batch_states[:, i] for i in range(num_agents)], axis=-1)

            # Train step with QPLEX
            loss = train_step(batch_states, batch_actions, batch_rewards, batch_next_states, batch_global_state)

        # Move to next state
        states = next_states

    # Decay epsilon and track rewards
    for i in range(num_agents):
        epsilon_values[i] = max(min_epsilon, epsilon_values[i] * epsilon_decay)