In [1]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"


In [2]:
import numpy as np
import torch
import gymnasium as gym
from stable_baselines3 import A2C
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.vec_env import DummyVecEnv

In [3]:
import numpy as np
import gymnasium as gym
from gymnasium.spaces import MultiDiscrete, Dict, Box

class CogSatDSAEnv(gym.Env):
    def __init__(self, env_config=None, render_mode=None):
        super(CogSatDSAEnv, self).__init__()
        
        # Set static LEO/GEO counts, or pull from env_config if using Ray
        self.n_leo = env_config.get("n_leo", 3) if env_config else 3
        self.n_leo_users = env_config.get("n_leo_users", 2) if env_config else 2
        self.n_geo = env_config.get("n_geo", 2) if env_config else 2
        self.n_geo_users = env_config.get("n_geo_users", 1) if env_config else 1

        # Register env spec if missing (useful for SB3 compatibility)
        if not hasattr(self, 'spec') or self.spec is None:
            self.spec = gym.envs.registration.EnvSpec("CogSatDSAEnv-v1")

        # Action space: 4 devices picking from 11 channels (0 = no transmission)
        self.action_space = MultiDiscrete([11] * 4)

        # Observation space structure
        self.observation_space = Dict({
            "utc_time": Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.int64),
            "leo_pos": Box(low=-np.inf, high=np.inf, shape=(self.n_leo,), dtype=np.float64),
            "leo_rssi": Box(low=-np.inf, high=np.inf, shape=(self.n_leo * self.n_leo_users,), dtype=np.float64),
            "leo_sinr": Box(low=-np.inf, high=np.inf, shape=(self.n_leo * self.n_leo_users,), dtype=np.float64),
            "geo_rssi": Box(low=-np.inf, high=np.inf, shape=(self.n_geo * self.n_geo_users,), dtype=np.float64),
            "geo_sinr": Box(low=-np.inf, high=np.inf, shape=(self.n_geo * self.n_geo_users,), dtype=np.float64),
        })

        self.terminated = False
        self.current_step = 0

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.current_step = 0
        self.terminated = False

        obs = {
            "utc_time": np.array([0], dtype=np.int64),
            "leo_pos": np.random.randn(self.n_leo).astype(np.float32),
            "leo_rssi": np.random.randn(self.n_leo * self.n_leo_users).astype(np.float32),
            "leo_sinr": np.random.randn(self.n_leo * self.n_leo_users).astype(np.float32),
            "geo_rssi": np.random.randn(self.n_geo * self.n_geo_users).astype(np.float32),
            "geo_sinr": np.random.randn(self.n_geo * self.n_geo_users).astype(np.float32),
        }
        return obs, {}

    def step(self, action):
        self.current_step += 1
        reward = 0.0  # Placeholder
        self.terminated = self.current_step >= 100

        obs = {
            "utc_time": np.array([self.current_step], dtype=np.int64),
            "leo_pos": np.random.randn(self.n_leo).astype(np.float32),
            "leo_rssi": np.random.randn(self.n_leo * self.n_leo_users).astype(np.float32),
            "leo_sinr": np.random.randn(self.n_leo * self.n_leo_users).astype(np.float32),
            "geo_rssi": np.random.randn(self.n_geo * self.n_geo_users).astype(np.float32),
            "geo_sinr": np.random.randn(self.n_geo * self.n_geo_users).astype(np.float32),
        }

        return obs, reward, self.terminated, False, {}

    def render(self):
        pass  # Optional: plot satellite movement, SINR, channel use, etc.

    def close(self):
        pass


In [4]:

env = CogSatDSAEnv()
obs, _ = env.reset()
print("Initial observation:", obs)

action = env.action_space.sample()
obs, reward, done, truncated, _ = env.step(action)
print("Step observation:", obs)


Initial observation: {'utc_time': array([0], dtype=int64), 'leo_pos': array([-1.8378863, -0.3772603, -1.2108759], dtype=float32), 'leo_rssi': array([ 1.0761946 , -0.13622265, -2.2043018 , -1.771516  , -1.3955628 ,
       -1.1850308 ], dtype=float32), 'leo_sinr': array([-0.87400585, -1.1427749 ,  0.78125614,  0.33952934,  0.6669274 ,
        0.9472205 ], dtype=float32), 'geo_rssi': array([-0.32214814,  0.7790566 ], dtype=float32), 'geo_sinr': array([-1.4544635 ,  0.55104905], dtype=float32)}
Step observation: {'utc_time': array([1], dtype=int64), 'leo_pos': array([0.81929195, 0.6205581 , 0.8094295 ], dtype=float32), 'leo_rssi': array([-0.40166116,  1.5716503 , -1.2595519 ,  0.6821898 , -1.1111343 ,
        1.5264995 ], dtype=float32), 'leo_sinr': array([-0.84844214, -0.77274036,  0.49185777, -1.1529508 ,  0.02108317,
       -0.45947865], dtype=float32), 'geo_rssi': array([1.0121659 , 0.25623253], dtype=float32), 'geo_sinr': array([ 0.42437437, -1.2079464 ], dtype=float32)}


In [5]:


dummy_env = DummyVecEnv([lambda: env])
# Extract the shape of the observation space, which is a Dict
obs_shape = {key: space.shape for key, space in env.observation_space.spaces.items()}
model = A2C("MultiInputPolicy", dummy_env, verbose=0, device="cpu")

n_steps = 5


In [6]:
import torch
import numpy as np
from stable_baselines3 import A2C
from stable_baselines3.common.buffers import DictRolloutBuffer

# === Globals ===
step_count = 0
obs_last = None
action_last = None
value_last = None
log_prob_last = None

# You need to define these before use
# env = YourCustomEnv()
# dummy_env = YourCustomEnv()
# model = A2C("MultiInputPolicy", dummy_env, ...)
# n_steps = 5  # or any number you choose

buffer = DictRolloutBuffer(
    buffer_size=n_steps,
    observation_space=env.observation_space,
    action_space=env.action_space,
    device=model.device,
    gamma=model.gamma,
    gae_lambda=model.gae_lambda
)


def reset_env():
    global obs_last
    obs_last = dummy_env.reset()
    return obs_last


def preprocess_obs(obs):
    """Convert dict observation to torch tensor dict with proper dimensions."""
    tensor_dict = {}
    for key, value in obs.items():
        # Convert numpy array to tensor and ensure it's float32
        tensor_value = torch.as_tensor(value).float().to(model.device)
        # Add batch dimension if needed (SB3 expects batch dimension)
        if len(tensor_value.shape) == 1:
            tensor_value = tensor_value.unsqueeze(0)
        tensor_dict[key] = tensor_value
    return tensor_dict

def get_action(obs):
    global obs_last, action_last, value_last, log_prob_last
    obs_last = obs

    with torch.no_grad():
        obs_tensor = preprocess_obs(obs)
        # SB3 expects flattened observations for MultiInputPolicy
        flattened_obs = model.policy.obs_to_tensor(obs_tensor)[0]
        action_tensor, value_tensor, log_prob_tensor = model.policy.forward(flattened_obs)

    action_last = action_tensor
    value_last = value_tensor
    log_prob_last = log_prob_tensor

    return action_tensor.cpu().numpy().squeeze()  # Remove batch dimension for env


def my_step(action):
    next_obs, reward, terminated, truncated, _ = env.step(action)
    done = terminated or truncated
    return next_obs, reward, done


def compute_a2c_loss(policy, rollout_data, value_coef=0.5, entropy_coef=0.01):
    observations = rollout_data.observations
    actions = rollout_data.actions
    returns = rollout_data.returns
    advantages = rollout_data.advantages
    old_log_probs = rollout_data.old_log_prob

    # Get action distribution and value predictions
    dist = policy.get_distribution(observations)
    value_preds = policy.predict_values(observations)

    # Log probs and entropy from the current policy
    new_log_probs = dist.log_prob(actions)
    entropy = dist.entropy().mean()

    # Actor loss
    policy_loss = -(advantages * new_log_probs).mean()

    # Critic loss
    value_loss = torch.nn.functional.mse_loss(returns, value_preds)

    # Total loss
    total_loss = policy_loss + value_coef * value_loss - entropy_coef * entropy
    return total_loss


def store_transition(reward, done, next_obs):
    global step_count, obs_last, action_last, value_last, log_prob_last

    reward = np.array([reward], dtype=np.float32)
    done = np.array([done], dtype=bool)

    buffer.add(
        obs_last,
        action_last,
        reward,
        done,
        value_last,
        log_prob_last
    )

    step_count += 1
    obs_last = next_obs

    if step_count % n_steps == 0:
        with torch.no_grad():
            next_obs_tensor = preprocess_obs(next_obs)
            last_val = model.policy.predict_values(next_obs_tensor)

        buffer.compute_returns_and_advantage(last_val, dones=done)

        model.policy.train()
        model.policy.optimizer.zero_grad()

        for rollout_data in buffer.get(batch_size=None):
            loss = compute_a2c_loss(model.policy, rollout_data)
            loss.backward()

        model.policy.optimizer.step()
        buffer.reset()


def save_model(path="a2c_satellite"):
    model.save(path)


def load_model(path="a2c_satellite"):
    global model
    model = A2C.load(path)
    model.set_env(dummy_env)
    return True


In [7]:
def train_multiple_episodes(n_episodes=100):

    for episode in range(n_episodes):
        global step_count
        obs = reset_env()
        done = False
        episode_reward = 0
        while not done:
            action = get_action(obs)
            next_obs, reward, done = my_step(action)
            store_transition(reward, done, next_obs)
            obs = next_obs
            print(f"Episode {episode + 1}: Reward = {episode_reward:.2f}, Total Steps = {step_count}")
            episode_reward += reward
            if done:
                break  # Optional, since the loop exits on `done` anyway
        step_count = 0  # Reset step count for the next episode
        print(f"Episode {episode + 1} finished with total reward: {episode_reward:.2f}")


In [8]:
train_multiple_episodes(10)

Episode 1: Reward = 0.00, Total Steps = 1
Episode 1: Reward = 0.00, Total Steps = 2
Episode 1: Reward = 0.00, Total Steps = 3
Episode 1: Reward = 0.00, Total Steps = 4
Episode 1: Reward = 0.00, Total Steps = 5
Episode 1: Reward = 0.00, Total Steps = 6
Episode 1: Reward = 0.00, Total Steps = 7
Episode 1: Reward = 0.00, Total Steps = 8
Episode 1: Reward = 0.00, Total Steps = 9
Episode 1: Reward = 0.00, Total Steps = 10
Episode 1: Reward = 0.00, Total Steps = 11
Episode 1: Reward = 0.00, Total Steps = 12
Episode 1: Reward = 0.00, Total Steps = 13
Episode 1: Reward = 0.00, Total Steps = 14
Episode 1: Reward = 0.00, Total Steps = 15
Episode 1: Reward = 0.00, Total Steps = 16
Episode 1: Reward = 0.00, Total Steps = 17
Episode 1: Reward = 0.00, Total Steps = 18
Episode 1: Reward = 0.00, Total Steps = 19
Episode 1: Reward = 0.00, Total Steps = 20
Episode 1: Reward = 0.00, Total Steps = 21
Episode 1: Reward = 0.00, Total Steps = 22
Episode 1: Reward = 0.00, Total Steps = 23
Episode 1: Reward = 

  value_loss = torch.nn.functional.mse_loss(returns, value_preds)


Episode 1: Reward = 0.00, Total Steps = 60
Episode 1: Reward = 0.00, Total Steps = 61
Episode 1: Reward = 0.00, Total Steps = 62
Episode 1: Reward = 0.00, Total Steps = 63
Episode 1: Reward = 0.00, Total Steps = 64
Episode 1: Reward = 0.00, Total Steps = 65
Episode 1: Reward = 0.00, Total Steps = 66
Episode 1: Reward = 0.00, Total Steps = 67
Episode 1: Reward = 0.00, Total Steps = 68
Episode 1: Reward = 0.00, Total Steps = 69
Episode 1: Reward = 0.00, Total Steps = 70
Episode 1: Reward = 0.00, Total Steps = 71
Episode 1: Reward = 0.00, Total Steps = 72
Episode 1: Reward = 0.00, Total Steps = 73
Episode 1: Reward = 0.00, Total Steps = 74
Episode 1: Reward = 0.00, Total Steps = 75
Episode 1: Reward = 0.00, Total Steps = 76
Episode 1: Reward = 0.00, Total Steps = 77
Episode 1: Reward = 0.00, Total Steps = 78
Episode 1: Reward = 0.00, Total Steps = 79
Episode 1: Reward = 0.00, Total Steps = 80
Episode 1: Reward = 0.00, Total Steps = 81
Episode 1: Reward = 0.00, Total Steps = 82
Episode 1: 