<a href="https://colab.research.google.com/github/mgo-city/INM707/blob/main/notebooks/INM707_Lab_06.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install gymnasium

In [None]:
pip install -U "ray[rllib]"

In [None]:
!pip install git+https://github.com/mgo-city/INM707-lab06.git

In [None]:
from dungeon.dungeon import Dungeon, index_to_actions

In [None]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np

In [None]:
class DungeonEnv(gym.Env):

    def __init__(self, render_mode=None, size=5):
        self.size = size  # The size of the Dungeon

        # Observations are dictionaries with the agent's and the target's location.
        # Each location is encoded as an element of {0, ..., `size`}^2, i.e. MultiDiscrete([size, size]).
        self.observation_space = spaces.Dict(
            {
                "relative_coordinates": spaces.Box(-size, size , shape=(2,), dtype=int),
                "surroundings": spaces.Box(0, 4, shape=(5, 5), dtype=int)
            }
        )

        # We have 4 actions, corresponding to "right", "up", "left", "down"
        self.action_space = spaces.Discrete(4)

        self.render_mode = "ansi"

        self.clock = None

        self._dungeon_env = Dungeon(size)
        self._dungeon_env.reset()


    def reset(self, seed=None, options=None):
        # We need the following line to seed self.np_random
        super().reset(seed=seed)

        obs = self._dungeon_env.reset()

        return obs, {}

    def step(self, action):

        act = index_to_actions[action]

        observations, reward, done = self._dungeon_env.step(act.name)

        return observations, reward, done, False, {}

    def render(self):

        envir_with_agent = self._dungeon_env.dungeon.copy()
        envir_with_agent[self._dungeon_env.position_agent[0], self._dungeon_env.position_agent[1]] = 4
        
        full_repr = ""

        for r in range(self.size):
            
            line = ""
            
            for c in range(self.size):

                string_repr = self._dungeon_env.dict_map_display[ envir_with_agent[r,c] ]
                
                line += "{0:2}".format(string_repr)

            full_repr += line + "\n"

        return full_repr


In [None]:
dungeon_env = DungeonEnv(size=100)


In [None]:
dungeon_env.render()
dungeon_env.step(1)

({'relative_coordinates': array([80,  8]),
  'surroundings': array([[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]], dtype=int8)},
 -1,
 False,
 False,
 {})

In [None]:
import ray
from ray.rllib.algorithms import ppo, dqn
from ray.tune.logger import pretty_print

In [None]:
ray.shutdown()
ray.init()

algo = dqn.DQN(env=DungeonEnv, config={
    "env_config": {"size":40},  # config to pass to env class
})

while True:
    metrics = algo.train()
    mean_reward = metrics['episode_reward_mean']
    print(mean_reward)


2023-03-29 09:49:54,517	INFO worker.py:1553 -- Started a local Ray instance.
2023-03-29 09:49:55,882	INFO algorithm_config.py:2899 -- Your framework setting is 'tf', meaning you are using static-graph mode. Set framework='tf2' to enable eager execution with tf2.x. You may also then want to set eager_tracing=True in order to reach similar execution speed as with static-graph mode.
2023-03-29 09:49:55,916	INFO algorithm.py:506 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.


-129.22727272727272
-90.43
-53.1588785046729
-16.818791946308725
-8.548022598870057
3.975103734439834
7.466666666666667
12.91304347826087
14.160112359550562
16.214833759590793
17.646924829157175
16.68421052631579
17.707207207207208
16.004878048780487
17.611494252873563
17.71461187214612
17.439080459770114
18.45274725274725
17.53211009174312
16.953161592505854
17.72645739910314
16.656324582338904
17.301149425287356
18.334792122538293
16.546762589928058
16.546762589928058
17.762013729977117
17.35198135198135
16.785714285714285
18.59090909090909
17.931662870159453
16.45145631067961
17.550925925925927
17.62700228832952
17.54691075514874
17.143187066974598
17.904017857142858
17.208333333333332
16.89125295508274
17.210772833723652
17.172897196261683
18.250554323725055
17.97566371681416
17.041570438799077
17.494226327944574
17.807256235827666
16.962962962962962
16.95294117647059
17.99772727272727
17.76233183856502
17.483069977426638
17.185446009389672
17.348729792147807
16.406698564593302
18.