### Installation

In [None]:
!pip install git+https://github.com/DLR-RM/stable-baselines3.git@40e0b9d
!pip install gymnasium
!git clone https://github.com/muhd-umer/rl-wireless.git

# Colab path
%cd /content/rl-wireless

### Necessary Imports

In [None]:
import numpy as np
import gymnasium as gym
from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistics
from agents import DQNAgent
import time

### Registering and Testing the Environment

In [None]:
# Set the parameters
global N, M, K, Ns, asd_degs, min_P, max_P, num_P, num_episodes, dtype, seed
N = 7
M = 32
K = 10
Ns = 500
asd_degs = [
    30,
]
min_P = -20
max_P = 23
num_P = 10
dtype = np.float32
seed = 0

# Register and create the environment
gym.register(id="MassiveMIMO-v0", entry_point="network:MassiveMIMOEnv")

env = gym.make(
    "MassiveMIMO-v0",
    N=N,
    M=M,
    K=K,
    Ns=Ns,
    min_P=min_P,
    max_P=max_P,
    num_P=num_P,
    dtype=dtype,
)


In [None]:
def make_env(env_id, seed):
    def thunk():
        env = gym.make(
            env_id,
            N=N,
            M=M,
            K=K,
            Ns=Ns,
            min_P=min_P,
            max_P=max_P,
            num_P=num_P,
            dtype=dtype,
        )
        env = RecordEpisodeStatistics(env)
        env.action_space.seed(seed)
        env.observation_space.seed(seed)
        return env

    return thunk


In [None]:
envs = gym.vector.SyncVectorEnv([make_env("MassiveMIMO-v0", seed)])
assert isinstance(
    envs.single_action_space, gym.spaces.Discrete
), "Only discrete action space is supported."

### Training DQN Agent

In [None]:
total_timesteps = 500000
target_freq = Ns

agent = DQNAgent(envs, target_freq=target_freq)


In [None]:
start_time = time.time()

for global_step in range(total_timesteps):
    agent.get_actions(global_step)
    next_obs, b_reward, b_terminated, b_truncated, b_info = agent.envs.step(
        agent.actions
    )

    b_done = [b_terminated[i] or b_truncated[i] for i in range(agent.envs.num_envs)]

    real_next_obs = next_obs.copy()
    for idx, d in enumerate(b_done):
        if d:
            real_next_obs[idx], _ = agent.envs.envs[idx].reset()
    agent.replay_buffer.add(
        agent.obs, real_next_obs, agent.actions, b_reward, b_done, b_info
    )

    agent.obs = next_obs
    agent.train(global_step, start_time)
