### Installation

In [None]:
import os
!git clone --branch=main https://github.com/muhd-umer/rl-wireless.git rl-wireless
assert os.path.exists('./rl-wireless'), "No rl-wireless folder found."
%cd ./rl-wireless

!pip install -r requirements.txt

### Necessary Imports

In [None]:
import warnings
import numpy as np
import gymnasium as gym
from network import MassiveMIMOEnv
import ray
from ray import air, tune
from ray.rllib.utils.framework import try_import_torch
from ray.tune.registry import get_trainable_cls

# disable warnings
warnings.filterwarnings("ignore")

### Registering the Environment

In [None]:
# Set the parameters
global N, M, K, Ns, asd_degs, min_P, max_P, num_P, num_episodes, dtype, seed
N = 7
M = 32
K = 10
Ns = 10
asd_degs = [
    30,
]
min_P = -20
max_P = 23
num_P = 10
dtype = np.float32
seed = 0

# Register and create the environment
gym.register(id="MassiveMIMO-v0", entry_point=MassiveMIMOEnv)

env = gym.make(
    "MassiveMIMO-v0",
    N=N,
    M=M,
    K=K,
    Ns=Ns,
    min_P=min_P,
    max_P=max_P,
    num_P=num_P,
    dtype=dtype,
)

In [None]:
from ray.tune.registry import register_env

# register the predefined scenario with RLlib
register_env("MassiveMIMO-v0", lambda cfg: env)

### Training with PPO Agent

In [None]:
# init ray with available CPUs (and GPUs)
ray.init(
    num_cpus=4,
    num_gpus=1,
    include_dashboard=False,
    ignore_reinit_error=True,
    log_to_driver=False,
)

In [None]:
config = (
    get_trainable_cls("PPO")  # RLlib algorithm to use
    .get_default_config()
    .environment("MassiveMIMO-v0")
    .framework("torch")
    .resources(
        num_gpus=0.5,
        num_gpus_per_worker=0.0,
    )
    .rollouts(
        num_rollout_workers=1,
        num_envs_per_worker=1,
    )
    .training(lr=tune.grid_search([0.005, 0.003, 0.001, 0.0001]))
)

stop = {
    "timesteps_total": 100000,
}

In [None]:
results = tune.Tuner(
    "PPO",
    param_space=config.to_dict(),
    run_config=air.RunConfig(stop=stop, local_dir="./results"),
).fit()

In [None]:
%tensorboard --logdir results --port 6006