In [77]:
from pathlib import Path
from pystk2_gymnasium import AgentSpec
from functools import partial
import torch
import inspect
from bbrl.agents.gymnasium import ParallelGymAgent, make_env
from bbrl.agents import Agents, TemporalAgent
from bbrl.workspace import Workspace
from bbrl.agents import Agent
import gymnasium as gym
import time

import gym
import numpy as np
from sac_torch import Agent
from utils import plot_learning_curve

# Note the use of relative imports
from actors import Actor
from pystk_actor import env_name, get_wrappers, player_name

In [2]:
class SamplingActor(Agent):
    """Samples random actions"""

    def __init__(self, action_space: gym.Space):
        super().__init__()
        self.action_space = action_space

    def forward(self, t: int):
        self.set(("action", t), torch.tensor([[4, 1, 0, 1, 1, 0, 6]]))#torch.LongTensor([self.action_space.sample()]))

In [3]:
class ContinuousObservationWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = env.observation_space['continuous']
    
    def observation(self, observation):
        return observation['continuous']

In [78]:
# (1) Setup the environment

env_name = "supertuxkart/flattened_continuous_actions-v0"
n_envs = 4
n_steps = 150

def create_env():
    env = make_env(
        env_name,
        wrappers=get_wrappers(),
        render_mode=None, # human for video, else None 
        autoreset=True,
        agent=AgentSpec(use_ai=False, name=player_name), # use_ai=False for using the "action" line of the workspace
    )
    return ContinuousObservationWrapper(env)

envs = [create_env() for _ in range(n_envs)]

..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..
..:: Antarctica Rendering Engine 2.0 ::..


In [80]:
agent = Agent(input_dims=envs[0].observation_space.shape, env=envs[0],
                n_actions=envs[0].action_space.shape[0])

best_score = envs[0].reward_range[0]
score_history = []
load_checkpoint = False

if load_checkpoint:
    agent.load_models()
    envs[0].render(mode='human')

observations = [envs[i].reset()[0] for i in range(n_envs)]
scores = [0]*n_envs
scores_history = [[] for _ in range(n_envs)]

t_choose = 0
t_step = 0
t_learn = 0

print("beginning of training")

for t in range(n_steps):
    for i in range(n_envs):
        t = time.time()
        action = agent.choose_action(observations[i])
        t_choose += time.time()-t

        t = time.time()
        observation_, reward, terminated, truncated, info = envs[i].step(action)
        t_step += time.time() - t

        done = terminated or truncated
        scores[i] += reward
        agent.remember(observations[i], action, reward, observation_, done)

        t = time.time()
        if not load_checkpoint:
            agent.learn()
        t_learn += time.time()-t

        observations[i] = observation_
        scores_history[i].append(scores[i])

        # avg_score = np.mean(scores_history[i][-100:])
        # if avg_score > best_score:
        #     best_score = avg_score
        #     if not load_checkpoint:
        #         agent.save_models()

        if t%50==49:
            plot_learning_curve(list(range(t+1)), scores_history[i], "plots/stk_scores.png")
        if t==n_steps-1:
            envs[i].close()

print(f"{t_choose = } | {t_step = } | {t_learn = }")

beginning of training
t_choose = 0.33148837089538574 | t_step = 9.448336839675903 | t_learn = 4.844192743301392


In [None]:
# (2) Learn

actor = SamplingActor(env.action_space)
temporal_agent = TemporalAgent(Agents(env_agent, actor))

workspace = Workspace()
temporal_agent(workspace, t=0, n_steps=200)

# (3) Save the actor state
mod_path = Path(inspect.getfile(get_wrappers)).parent
torch.save(actor.state_dict(), mod_path / "pystk_actor.pth")

In [52]:
env.observation_space['continuous']

Box([  0.   0. -inf -inf -inf -inf -inf   0. -inf -inf -inf -inf -inf -inf
 -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf
 -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf  -1.
   0.   0.   0.   0.   0.   0.   0.   0.   0.   0. -inf -inf -inf -inf
 -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf
 -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf   0.   0.
   0.   0.   0.   0.   0. -inf -inf -inf], [inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf
 inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf
 inf inf inf inf inf  1. inf inf inf inf inf inf inf inf inf inf inf inf
 inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf
 inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf
 inf inf], (92,), float32)

In [39]:
# for key in workspace.keys():
#     print(key, workspace([key])
workspace["action"]

tensor([[[4, 1, 0,  ..., 1, 0, 6]],

        [[3, 1, 0,  ..., 1, 1, 3]],

        [[1, 0, 1,  ..., 0, 1, 5]],

        ...,

        [[1, 1, 1,  ..., 0, 0, 1]],

        [[1, 0, 0,  ..., 0, 1, 1]],

        [[1, 1, 1,  ..., 1, 0, 4]]])