<a href="https://colab.research.google.com/github/darkraithromb/next-platform-starter/blob/main/uav.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!pip install gymnasium[all]
!pip install stable-baselines3[extra]
import shimmy
import gym
from gym import spaces
import numpy as np
from stable_baselines3 import PPO

# Custom Environment for UAV Handover
class UAVHandoverEnv(gym.Env):
    def __init__(self):
        super(UAVHandoverEnv, self).__init__()

        # Action space: 0 -> Stay, 1 -> Handover
        self.action_space = spaces.Discrete(2)

        # Observation space: [x, y, z, velocity, current_base_station_id, RSSI]
        self.observation_space = spaces.Box(low=np.array([-np.inf, -np.inf, -np.inf, 0, 0, -np.inf]),
                                            high=np.array([np.inf, np.inf, np.inf, 10, 45, 0]),
                                            dtype=np.float32)

        # Initialize UAV state (position, velocity, base station, RSSI)
        self.state = None

        # Define base station positions (x, y, z)
        self.base_stations = self.initialize_base_stations()

    def initialize_base_stations(self):
        # Assume 45 base stations distributed in a 6x6x0.3 km³ area
        bs_positions = []
        for i in range(45):
            bs_positions.append([np.random.uniform(0, 6), np.random.uniform(0, 6), 0.3])
        return np.array(bs_positions)

    def reset(self):
        # Initialize the UAV state [x, y, z, velocity, base station ID, RSSI]
        self.state = np.array([0, 0, 0, np.random.uniform(1, 3), 0, -70])  # RSSI starts at -70 dBm
        return self.state

    def compute_rssi(self, position, base_station_id):
        # Example RSSI calculation using distance from base station (simple path loss model)
        bs_position = self.base_stations[base_station_id]
        distance = np.linalg.norm(position - bs_position)
        rssi = -20 * np.log10(distance + 1) - 70  # Simplified path loss model
        return rssi

    def step(self, action):
        x, y, z, velocity, current_bs_id, rssi = self.state

        # Apply action: Handover or Stay
        if action == 1:  # Handover to a different base station
            new_bs_id = np.random.randint(0, len(self.base_stations))
            rssi = self.compute_rssi(np.array([x, y, z]), new_bs_id)
            penalty = -1  # Penalty for handover
        else:
            # Stay connected, RSSI remains based on current base station
            rssi = self.compute_rssi(np.array([x, y, z]), int(current_bs_id))
            new_bs_id = current_bs_id
            penalty = 0  # No penalty for staying connected

        # Update UAV's position (moving in random directions for simplicity)
        x += np.random.uniform(-0.1, 0.1) * velocity
        y += np.random.uniform(-0.1, 0.1) * velocity

        # Keep the UAV within bounds
        x = np.clip(x, 0, 6)
        y = np.clip(y, 0, 6)

        # Reward function: balance stable connection and handover penalty
        reward = rssi / -100  # Normalize RSSI for reward
        reward += penalty

        # Update state
        self.state = np.array([x, y, z, velocity, new_bs_id, rssi])

        # Check if done (for example, if RSSI falls below a threshold)
        done = rssi < -90  # Consider connection lost if RSSI is too weak

        return self.state, reward, done, {}

    def render(self, mode='human'):
        # Optional: implement rendering logic for visualizing the UAV and base stations
        pass

# Instantiate environment
env = UAVHandoverEnv()

# Define and train the PPO agent
model = PPO("MlpPolicy", env, verbose=1)

# Train the model for a defined number of timesteps
model.learn(total_timesteps=10000)

# Save the trained model
model.save("ppo_uav_handover")

# Load the model (if needed)
# model = PPO.load("ppo_uav_handover")

# Test the trained model
obs = env.reset()
for _ in range(1000):  # Test for 1000 timesteps
    action, _ = model.predict(obs)
    obs, reward, done, info = env.step(action)
    if done:
        obs = env.reset()
    env.render()  # Optional: visualize the environment


Collecting shimmy<1.0,>=0.1.0 (from shimmy[atari]<1.0,>=0.1.0; extra == "all"->gymnasium[all])
  Downloading Shimmy-0.2.1-py3-none-any.whl.metadata (2.3 kB)
Collecting box2d-py==2.3.5 (from gymnasium[all])
  Using cached box2d-py-2.3.5.tar.gz (374 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting swig==4.* (from gymnasium[all])
  Using cached swig-4.2.1-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl.metadata (3.6 kB)
Collecting mujoco-py<2.2,>=2.1 (from gymnasium[all])
  Using cached mujoco_py-2.1.2.14-py3-none-any.whl.metadata (669 bytes)
Collecting cython<3 (from gymnasium[all])
  Using cached Cython-0.29.37-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl.metadata (3.1 kB)
Collecting mujoco>=2.3.3 (from gymnasium[all])
  Using cached mujoco-3.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (44 kB)
Collecting lz4>=3.1.0 (from gymnasium[all])
  Downloading lz4-4.3.3-cp310-cp310-manylinux_2_17_x86_64.man



-----------------------------
| time/              |      |
|    fps             | 1148 |
|    iterations      | 1    |
|    time_elapsed    | 1    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 854         |
|    iterations           | 2           |
|    time_elapsed         | 4           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.020376246 |
|    clip_fraction        | 0.538       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.674      |
|    explained_variance   | -0.0467     |
|    learning_rate        | 0.0003      |
|    loss                 | 1.22        |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0618     |
|    value_loss           | 4.44        |
-----------------------------------------
----------------------------------

AttributeError: module 'stable_baselines3' has no attribute '__version__'