<a href="https://colab.research.google.com/github/darkraithromb/next-platform-starter/blob/main/uav.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install gymnasium[all]
!pip install stable-baselines3[extra]
!pip install --upgrade stable-baselines3
import shimmy
import gym
from gym import spaces
import numpy as np
from stable_baselines3 import PPO

class UAVHandoverEnv(gym.Env):
    def __init__(self):
        super(UAVHandoverEnv, self).__init__()


        self.action_space = spaces.Discrete(2)

        self.observation_space = spaces.Box(low=np.array([-np.inf, -np.inf, -np.inf, 0, 0, -np.inf]),
                                            high=np.array([np.inf, np.inf, np.inf, 10, 45, 0]),
                                            dtype=np.float32)

        self.state = None


        self.base_stations = self.initialize_base_stations()

    def initialize_base_stations(self):

        bs_positions = []
        for i in range(45):
            bs_positions.append([np.random.uniform(0, 6), np.random.uniform(0, 6), 0.3])
        return np.array(bs_positions)

    def reset(self):

        self.state = np.array([0, 0, 0, np.random.uniform(1, 3), 0, -70])
        return self.state

    def compute_rssi(self, position, base_station_id):

        bs_position = self.base_stations[base_station_id]
        distance = np.linalg.norm(position - bs_position)
        rssi = -20 * np.log10(distance + 1) - 70
        return rssi

    def step(self, action):
        x, y, z, velocity, current_bs_id, rssi = self.state


        if action == 1:
            new_bs_id = np.random.randint(0, len(self.base_stations))
            rssi = self.compute_rssi(np.array([x, y, z]), new_bs_id)
            penalty = -1
        else:

            rssi = self.compute_rssi(np.array([x, y, z]), int(current_bs_id))
            new_bs_id = current_bs_id
            penalty = 0


        x += np.random.uniform(-0.1, 0.1) * velocity
        y += np.random.uniform(-0.1, 0.1) * velocity

        x = np.clip(x, 0, 6)
        y = np.clip(y, 0, 6)


        reward = rssi / -100
        reward += penalty


        self.state = np.array([x, y, z, velocity, new_bs_id, rssi])


        done = rssi < -90

        return self.state, reward, done, {}

    def render(self, mode='human'):

        pass


env = UAVHandoverEnv()


model = PPO("MlpPolicy", env, verbose=1)

model.learn(total_timesteps=10000)

model.save("ppo_uav_handover")


obs = env.reset()
for _ in range(1000):
    action, _ = model.predict(obs)
    obs, reward, done, info = env.step(action)
    if done:
        obs = env.reset()
    env.render()


Collecting gymnasium[all]
  Downloading gymnasium-0.29.1-py3-none-any.whl.metadata (10 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium[all])
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Collecting shimmy<1.0,>=0.1.0 (from shimmy[atari]<1.0,>=0.1.0; extra == "all"->gymnasium[all])
  Downloading Shimmy-0.2.1-py3-none-any.whl.metadata (2.3 kB)
Collecting box2d-py==2.3.5 (from gymnasium[all])
  Downloading box2d-py-2.3.5.tar.gz (374 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m374.4/374.4 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting swig==4.* (from gymnasium[all])
  Downloading swig-4.2.1-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl.metadata (3.6 kB)
Collecting mujoco-py<2.2,>=2.1 (from gymnasium[all])
  Downloading mujoco_py-2.1.2.14-py3-none-any.whl.metadata (669 bytes)
Collecting cython<3 (from gymnasium[all])
  Downloading Cython-0.29.37-

  from jax import xla_computation as _xla_computation


Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.




-----------------------------
| time/              |      |
|    fps             | 1039 |
|    iterations      | 1    |
|    time_elapsed    | 1    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 716         |
|    iterations           | 2           |
|    time_elapsed         | 5           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.020455036 |
|    clip_fraction        | 0.53        |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.674      |
|    explained_variance   | -0.0239     |
|    learning_rate        | 0.0003      |
|    loss                 | 1.03        |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0604     |
|    value_loss           | 5.09        |
-----------------------------------------
----------------------------------