Notebook useful to train using GPUs on Kaggle/Colab, contains the same code of the repo, just reformatted for imports between files

In [20]:
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.line_generators import sparse_line_generator
from utils.render import render_env
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
import numpy as np

def min_gt(seq, val):
    """
    Return smallest item in seq for which item > val applies.
    None is returned if seq was empty or all items in seq were >= val.
    """
    min = np.inf
    idx = len(seq) - 1
    while idx >= 0:
        if seq[idx] >= val and seq[idx] < min:
            min = seq[idx]
        idx -= 1
    return min

class OrderedAgent:

    def __init__(self):
        self.action_size = 5

    def act(self, state, eps=0):
        """
        :param state: input is the observation of the agent
        :return: returns an action
        """
        distance = []
        for direction in TreeObsForRailEnv.tree_explored_actions_char:
            try:
                distance.append(state.childs[direction].dist_min_to_target)
            except:
                distance.append(np.inf)
        distance = np.array(distance)
        min_dist = min_gt(distance, 0)
        min_direction = np.where(distance == min_dist)
        if len(min_direction[0]) > 1:
            return min_direction[0][-1] + 1
        return min_direction[0][0] + 1
    
def run_episode(env, agent, render = True):

    max_steps = 100 * (env.height + env.width)-1
    action_dict = dict()
    # Reset environment
    obs, info = env.reset(regenerate_rail = False,regenerate_schedule = False)
    env._max_episode_steps = 10000
    dones = env.dones
    frame_step = 0
    score = 0
    # Run episode
    for step in range(max_steps):

        # Action
        acting_agent = 0
        for a in range(env.get_num_agents()):
            if dones[a]:
                acting_agent += 1
            if a == acting_agent:
                action = agent.act(obs[a])
            else:
                action = 4
            action_dict.update({a: action})
        # Environment step

        obs, all_rewards, dones, _ = env.step(action_dict)
        for agent_handle in env.get_agent_handles():
            score += all_rewards[agent_handle]
        # print('Timestep {}, total score = {}'.format(step, score))

        if True in dones.values():
            print(dones)
            print({a: env.agents[a].position for a in range(env.get_num_agents())})
        # for a, done in dones.items():
        #     if a == '__all__':
        #         continue
        #     if done:
        #         print(f"Agent {a} is done at step {step}")
        #         print(f"Position: {env.agents[int(a)].position}")

        if dones['__all__']:
            print('All done')
            break

        # if step % 49 == 0:
        #     render_env(env)



sparse_env = RailEnv(
    width=30,
    height=30,
    rail_generator=sparse_rail_generator(
        # max_num_cities=3,  # Number of cities (= train stations)
        # grid_mode=False,  # Distribute the cities evenly in a grid
        # max_rails_between_cities=2,  # Max number of rails connecting to a city
        # max_rail_pairs_in_city=3  # Number of parallel tracks in cities
    ),
    line_generator=sparse_line_generator(),
    number_of_agents=3,
    obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
)

run_episode(sparse_env, OrderedAgent())

# from PIL import Image
# im = Image.open("happy.jpg")
# im.show()

{0: True, 1: False, 2: False, '__all__': False}
{0: None, 1: None, 2: None}
{0: True, 1: False, 2: False, '__all__': False}
{0: None, 1: (21, 19), 2: None}
{0: True, 1: False, 2: False, '__all__': False}
{0: None, 1: (22, 19), 2: None}
{0: True, 1: False, 2: False, '__all__': False}
{0: None, 1: (22, 20), 2: None}
{0: True, 1: False, 2: False, '__all__': False}
{0: None, 1: (23, 20), 2: None}
{0: True, 1: False, 2: False, '__all__': False}
{0: None, 1: (24, 20), 2: None}
{0: True, 1: False, 2: False, '__all__': False}
{0: None, 1: (25, 20), 2: None}
{0: True, 1: False, 2: False, '__all__': False}
{0: None, 1: (26, 20), 2: None}
{0: True, 1: False, 2: False, '__all__': False}
{0: None, 1: (26, 19), 2: None}
{0: True, 1: False, 2: False, '__all__': False}
{0: None, 1: (26, 18), 2: None}
{0: True, 1: False, 2: False, '__all__': False}
{0: None, 1: (26, 17), 2: None}
{0: True, 1: False, 2: False, '__all__': False}
{0: None, 1: (26, 16), 2: None}
{0: True, 1: False, 2: False, '__all__': Fal

In [1]:
from flatland.envs.persistence import RailEnvPersister
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator


# env, _env_dict = RailEnvPersister.load_new("./envs_config/train_envs/small_envs_50/Level_1.pkl")

In [11]:
env = RailEnv(width=40, height=40, rail_generator=sparse_rail_generator(), number_of_agents=8, random_seed=99)

In [14]:
import pickle

f = open('data', 'wb')
pickle.dump(env,f)
f.close()

f = open('data', 'rb')
y = pickle.load(f)
f.close()

In [16]:
y.random_seed

99

In [2]:
env._seed

<bound method RailEnv._seed of <flatland.envs.rail_env.RailEnv object at 0x7feaff1bc710>>

In [9]:
import pickle

file = open("./envs_config/train_envs/small_envs_50/Level_1.pkl", 'rb')
data = pickle.load(file)
file.close()

data.keys()


dict_keys(['grid', 'agents', 'malfunction', 'max_episode_steps', 'distance_map'])

In [3]:
_env_dict

{'grid': [[0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0],
  [0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0],
  [0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   16386,
   1025,
   1025,
   1025,
   1025,
   4608,
   0,
   0,
   0,
   0,
   0],
  [0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   32800,
   0,
   0,
   0,
   0,
   32800,
   0,
   0,
   0,
   0,
   0],
  [0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   32800,
   0,
   0,
   0,
   0,
   32872,
   4608,
   0,
   0,
   0,
   0],
  [0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
 

In [1]:
from flatland.contrib.interface import flatland_env

ModuleNotFoundError: No module named 'flatland.contrib'