In [19]:
from pystk2_gymnasium import AgentSpec
import gymnasium as gym
import time
import matplotlib.pyplot as plt
from sac_torch import AgentSac
from utils import plot_learning_curve

In [38]:
class SelectKeysWrapper(gym.ObservationWrapper):
    def __init__(self, env, selected_keys):
        super().__init__(env)
        self.selected_keys = selected_keys

        # Filter the original observation space
        new_obs_space = {key: env.observation_space[key] for key in selected_keys}
        self.observation_space = gym.spaces.Dict(new_obs_space)
        self.obs_len = ???

    def observation(self, observation):
        """Filter the observation to only return selected keys."""
        return {key: observation[key] for key in self.selected_keys}

In [42]:
# (1) Setup the environment

player_name = "smail_gaetan_kart"
env_name = "supertuxkart/flattened_continuous_actions-v0"
n_envs = 1
n_steps = 3000
max_episode_steps = 200

def create_env(render_mode):
    env = gym.make(
        env_name,
        render_mode=render_mode, # human for video, else None 
        agent=AgentSpec(use_ai=False, name=player_name), # use_ai=False for using the "action" line of the workspace
    )
    # env = ContinuousObservationWrapper(env)
    selected_keys = {'center_path', 'center_path_distance', 'front', 'velocity'}
    env = SelectKeysWrapper(env, selected_keys)
    # env = gym.wrappers.TimeLimit(env, max_episode_steps=max_episode_steps)
    return env

In [43]:
envs = [create_env(None) for _ in range(n_envs-1)]+[create_env("human")]
agent = AgentSac(input_dims=envs[0].observation_space.shape, env=envs[0],
                n_actions=envs[0].action_space.shape[0])

score_history = []
load_checkpoint = False

# if load_checkpoint:
#     agent.load_models()
#     envs[0].render(mode='human')

observations = [envs[i].reset()[0] for i in range(n_envs)]
scores = [0]*n_envs
scores_history = [[] for _ in range(n_envs)]
center_path_distances = []

t_choose = 0
t_step = 0
t_learn = 0
print("beginning of training")

for k in range(n_steps):
    for i in range(n_envs):
        t = time.time()
        action = agent.choose_action(observations[i])
        t_choose += time.time()-t

        t = time.time()
        observation_, reward, terminated, truncated, info = envs[i].step(action)
        print(observation_)
        import sys; sys.exit()
        center_path_distance = observation_[5]
        reward -= abs(center_path_distance)
        t_step += time.time() - t

        done = terminated or truncated
        scores[i] += reward
        agent.remember(observations[i], action, reward, observation_, done)

        if done:
            envs[i] = create_env(envs[i].render_mode)
            observations[i] = envs[i].reset()[0]

        t = time.time()
        if not load_checkpoint:
            agent.learn()
        t_learn += time.time()-t

        observations[i] = observation_
        scores_history[i].append(scores[i])

        if k%50==49:
            plot_learning_curve(list(range(k+1)), scores_history[i], "plots/stk_scores.png")
        if k==n_steps-1:
            envs[i].close()

print(f"{t_choose = } | {t_step = } | {t_learn = }")

TypeError: 'NoneType' object is not subscriptable

..:: Antarctica Rendering Engine 2.0 ::..


In [44]:
envs[0].selected_keys

{'center_path', 'center_path_distance', 'front', 'velocity'}

In [24]:
envs[0].observation_space, envs[0].action_space.sample()

(Dict('attachment': Discrete(10), 'attachment_time_left': Box(0.0, inf, (1,), float32), 'aux_ticks': Box(0.0, inf, (1,), float32), 'center_path': Box(-inf, inf, (3,), float32), 'center_path_distance': Box(-inf, inf, (1,), float32), 'distance_down_track': Box(-inf, inf, (1,), float32), 'energy': Box(0.0, inf, (1,), float32), 'front': Box(-inf, inf, (3,), float32), 'items_position': Box(-inf, inf, (5, 3), float32), 'items_type': MultiDiscrete([7 7 7 7 7]), 'jumping': Discrete(2), 'karts_position': Box(-inf, inf, (5, 3), float32), 'max_steer_angle': Box(-1.0, 1.0, (1,), float32), 'paths_distance': Box(0.0, inf, (5, 2), float32), 'paths_end': Box(-inf, inf, (5, 3), float32), 'paths_start': Box(-inf, inf, (5, 3), float32), 'paths_width': Box(0.0, inf, (5, 1), float32), 'phase': Discrete(4), 'powerup': Discrete(11), 'shield_time': Box(0.0, inf, (1,), float32), 'skeed_factor': Box(0.0, inf, (1,), float32), 'velocity': Box(-inf, inf, (3,), float32)),
 {'acceleration': array([0.16754398], dtype

In [41]:
for env in envs:
    env.close()