In [59]:
import gym
import numpy as np
import numpy.typing as npt

from typing import List, Tuple, Literal, Any, Optional, cast
from utils.agent import Agent
from utils.algorithm import AlgorithmInterface
from utils.preprocess import PreprocessInterface


In [60]:
env = gym.make('StarGunner-v0')
env.reset()
env.action_space
env._max_episode_steps = 3

In [61]:
# shape is (210, 160, 3)
Observation = npt.NDArray[np.uint8]
Action = int
State = Any
Reward = int


In [62]:
class RandomAlgorithm(AlgorithmInterface[State, Action]):
    def __init__(self, n_actions: int):
        self.n_actions = n_actions

    def allowed_actions(self, state: State) -> List[Action]:
        return list(range(self.n_actions))

    def take_action(self, state: State) -> int:
        return np.random.choice(self.allowed_actions(state))

    def after_step(
        self,
        sa: Tuple[State, Action],
        episode: List[Tuple[State, Optional[Action], Optional[Reward]]],
    ):
        pass

    def on_termination(
        self, episode: List[Tuple[State, Optional[Action], Optional[Reward]]]
    ):
        pass


class DummyPreprocess(PreprocessInterface[Observation, Action, State]):
    def transform_one(
        self, h: List[Tuple[Observation, Optional[Action], Optional[Reward]]]
    ) -> State:
        return cast(State, None)

    def transform_many(
        self, h: List[Tuple[Observation, Optional[Action], Optional[Reward]]]
    ) -> List[Tuple[State, Optional[Action], Optional[Reward]]]:
        return cast(State, None)


In [63]:
agent = Agent(env, RandomAlgorithm(18), DummyPreprocess())
agent.reset()


In [64]:
end = False
while not end:
  (o, end, episode) = agent.step()
  # env.render()
  if end:
    print([(o.shape,a,r) for (o,a,r) in cast(Any, episode)])



[((210, 160, 3), 10, 0.0), ((210, 160, 3), 4, 0.0), ((210, 160, 3), 5, 0.0), ((210, 160, 3), None, None)]
