In [11]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import gymnasium as gym
from gymnasium import spaces
import networkx as nx
import matplotlib.pyplot as plt
import os
from datetime import datetime
import random

# Hyperparameters
MAX_STEPS = 300
MAX_EPISODES = 200
GAMMA = 0.98
LEARNING_RATE = 0.0001
N_OBSERVATIONS = 4
N_ACTIONS = 2
print_interval = 10
ENERGY_COEFF = 1
ENTROPY_COEFF = 0.01
CRITIC_COEFF = 0.5
MAX_GRAD_NORM = 0.5
NODE_N = 5
arrival_rate = np.linspace(0, 1, NODE_N+2).tolist()[1:-1]


class MFRLFullEnv(gym.Env):
    def __init__(self, agent):
        super().__init__()
        self.n = agent.topology.n
        self.topology = agent.topology
        self.arrival_rate = agent.arrival_rate
        self.counter = 0
        self.age = np.zeros(self.n)
        self.max_aoi = np.zeros(self.n)
        self.states = np.zeros((self.n, 3))  # States for each device: [idle, success, collision]
        self.states[:, 0] = 1  # Initialize all devices to idle state
        
        # Define observation space
        self.observation_space = spaces.Dict({
            "devices": spaces.Tuple([
                spaces.Dict({
                    "state": spaces.MultiBinary(3),  # One-hot encoded {idle:0, success:1, collision:2}
                    "age": spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32)
                }) for _ in range(self.n)
            ]),
            "counter": spaces.Discrete(MAX_STEPS)
        })

        # Define action space as MultiBinary for direct binary actions
        self.action_space = spaces.MultiBinary(self.n)
        
    def reset(self, seed=None):
        super().reset(seed=seed)
        self.counter = 0
        self.age = np.zeros(self.n)
        self.max_aoi = np.zeros(self.n)
        # Reset all devices to idle state
        self.states = np.zeros((self.n, 3))
        self.states[:, 0] = 1
        
        observation = {
            "devices": tuple({
                "state": self.states[i],
                "age": np.array([self.age[i]], dtype=np.float32)
            } for i in range(self.n)),
            "counter": self.counter
        }

        info = {}
        return observation, info
    
    def set_all_actions(self, actions):
        self.all_actions = np.array(actions)
        
    def get_maxaoi(self):
        return self.max_aoi
    
    def set_max_aoi(self, max_aoi):
        self.max_aoi_set = max_aoi
        
    def idle_check(self):
        return all(self.all_actions[self.adj_ids] == 0)
        
    def get_adjacent_nodes(self, *args):
        if len(args) > 0:
            return np.where(self.topology.adjacency_matrix[args[0]] == 1)[0]
        else:
            return np.where(self.topology.adjacency_matrix[self.id] == 1)[0]
        
    def step(self, action):
        self.counter += 1
        self.age += 1 / MAX_STEPS
        
        # Track which devices attempt transmission
        transmitting_devices = []
        energy_reward = 0  # Only track energy reward during steps
        
        # First pass: identify all transmitting devices
        for ind, act in enumerate(action):
            if act == 1 and self.arrival_rate[ind] > np.random.rand():
                transmitting_devices.append(ind)
                energy_reward -= ENERGY_COEFF  # Energy cost for transmission attempt
        
        # Second pass: determine transmission outcomes and update states
        new_states = np.zeros((self.n, 3))
        
        for ind in range(self.n):
            if ind in transmitting_devices:
                # Check if this is the only device transmitting
                if len(transmitting_devices) == 1:
                    # Successful transmission
                    new_states[ind, 1] = 1  # Success state
                    self.age[ind] = 0  # Reset age
                else:
                    # Collision occurred
                    new_states[ind, 2] = 1  # Collision state
            else:
                # Device remained idle
                new_states[ind, 0] = 1  # Idle state
        
        self.states = new_states
        self.max_aoi = np.maximum(self.age, self.max_aoi)
        
        # Construct new observation
        observation = {
            "devices": tuple({
                "state": self.states[i],
                "age": np.array([self.age[i]], dtype=np.float32)
            } for i in range(self.n)),
            "counter": self.counter
        }

        # Check termination condition
        terminated = self.counter == MAX_STEPS
        
        # Calculate final reward
        if terminated:
            # Calculate average AoI reward across all devices at episode end
            aoi_reward = np.mean((1 - self.max_aoi)) * MAX_STEPS
            total_reward = energy_reward/self.n + aoi_reward
        else:
            # During episode, only return energy reward
            total_reward = energy_reward/self.n
            
        truncated = False
        info = {}
        
        return observation, total_reward, terminated, truncated, info

In [12]:
import numpy as np

# Define a dummy agent with a topology
class DummyAgent:
    class Topology:
        def __init__(self, n):
            self.n = n
            self.adjacency_matrix = np.eye(n)  # Example: No connections (identity matrix)

    def __init__(self, n):
        self.topology = self.Topology(n)
        self.arrival_rate = np.random.uniform(0.2, 0.8, size=n)  # Random arrival rates for each device

# Define MAX_STEPS and ENERGY_COEFF
MAX_STEPS = 10  # Example: limit steps per episode
ENERGY_COEFF = 0.1  # Example: energy penalty

# Create environment
n_devices = 3  # Example: 3 devices
dummy_agent = DummyAgent(n_devices)
env = MFRLFullEnv(dummy_agent)

# Reset environment
obs, info = env.reset()
print("\n[Initial Observation]")
print(obs)

# Step through the environment
done = False
step_count = 0

while not done:
    action = np.random.randint(2, size=n_devices)  # Choose a random action
    obs, reward, done, _, info = env.step(action)  # Take a step

    print(f"\n[Step {step_count}]")
    print(f"Action Taken: {action}, Arrival Rates: {dummy_agent.arrival_rate}")
    print(f"Observation: {obs}")
    print(f"Reward: {reward}")
    print(f"Terminated: {done}")

    step_count += 1

print("\n[Episode Finished]")



[Initial Observation]
{'devices': ({'state': array([1., 0., 0.]), 'age': array([0.], dtype=float32)}, {'state': array([1., 0., 0.]), 'age': array([0.], dtype=float32)}, {'state': array([1., 0., 0.]), 'age': array([0.], dtype=float32)}), 'counter': 0}

[Step 0]
Action Taken: [0 1 0], Arrival Rates: [0.41302092 0.71561128 0.77463762]
Observation: {'devices': ({'state': array([1., 0., 0.]), 'age': array([0.1], dtype=float32)}, {'state': array([0., 1., 0.]), 'age': array([0.], dtype=float32)}, {'state': array([1., 0., 0.]), 'age': array([0.1], dtype=float32)}), 'counter': 1}
Reward: -0.03333333333333333
Terminated: False

[Step 1]
Action Taken: [0 1 0], Arrival Rates: [0.41302092 0.71561128 0.77463762]
Observation: {'devices': ({'state': array([1., 0., 0.]), 'age': array([0.2], dtype=float32)}, {'state': array([1., 0., 0.]), 'age': array([0.1], dtype=float32)}, {'state': array([1., 0., 0.]), 'age': array([0.2], dtype=float32)}), 'counter': 2}
Reward: 0.0
Terminated: False

[Step 2]
Action