In [9]:
import gym
import numpy as np
import numpy.typing as npt

from typing import List, Tuple, Literal, Any, Optional, cast
from utils.agent import Agent
from tqdm.autonotebook import tqdm
from utils.algorithm import AlgorithmInterface
from utils.preprocess import PreprocessInterface
import torch
from torchvision import transforms


In [10]:
RANDOM_SEED = 0
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)



<torch._C.Generator at 0x7efeb09ab970>

In [11]:
env = gym.make('StarGunner-v0')
env.seed(RANDOM_SEED)
env.reset()
env.action_space
env._max_episode_steps = 1_8000

In [12]:
# shape is (210, 160, 3)
Observation = npt.NDArray[np.uint8]
Action = int

# shape is (4, 210, 160, 3)
State = torch.Tensor
Reward = int

Episode = List[Tuple[Observation, Optional[Action], Optional[Reward]]]


In [13]:
class RandomAlgorithm(AlgorithmInterface[State, Action]):
    def __init__(self, n_actions: int):
        self.n_actions = n_actions

        self.times = 0
        self.last_action = None

    def allowed_actions(self, state: State) -> List[Action]:
        return list(range(self.n_actions))

    def take_action(self, state: State) -> int:
        self.times += 1

        if self.times % 10 == 0:
            act = np.random.choice(self.allowed_actions(state))
            self.last_action = act
            return act

        if self.last_action is not None:
            return self.last_action

        act = np.random.choice(self.allowed_actions(state))
        self.last_action = act
        return act

    def after_step(
        self,
        sa: Tuple[State, Action],
        episode: List[Tuple[State, Optional[Action], Optional[Reward]]],
    ):
        pass

    def on_termination(
        self, episode: List[Tuple[State, Optional[Action], Optional[Reward]]]
    ):
        pass


class Preprocess(PreprocessInterface[Observation, Action, State]):
    def __init__(self):
        self.trfm = transforms.Compose(
            [transforms.Grayscale(),
                transforms.Resize((84, 84))])

    def transform_one(
        self, h: List[Tuple[Observation, Optional[Action], Optional[Reward]]]
    ) -> State:

        # last_4 is (1-4, 3, 160, 210)
        last_4 = np.asarray([np.asarray(np.transpose(o))
                            for (o, _, _) in h[-4:]])

        while last_4.shape[0] < 4:
            last_4 = np.insert(last_4, 0, last_4[-1:], axis=0)

        assert last_4.shape == (4, 3, 160, 210)

        # # now last_4 is (3, 160, 210)
        # last_4 = torch.from_numpy(np.transpose(last_4))

        # return shape is (4, 1, 84, 84)
        rlt = self.trfm(torch.from_numpy(last_4))
        assert rlt.shape == (4, 1, 84, 84)
        return rlt

    def transform_many(
        self, h: List[Tuple[Observation, Optional[Action], Optional[Reward]]]
    ) -> List[Tuple[State, Optional[Action], Optional[Reward]]]:
        return cast(Any, None)


In [14]:
agent = Agent(env, RandomAlgorithm(18), Preprocess())


In [15]:
EVALUATION_TIMES = 30

rwds: List[int] = []

for _ in tqdm(range(EVALUATION_TIMES)):
    agent.reset()

    end = False

    while not end:
        (o, end, episode) = agent.step()
        # env.render()
        if end:
            rwds.append(np.sum([r if r is not None else 0 for (_,
                                                               _, r) in cast(Episode, episode)]))


100%|██████████| 30/30 [01:41<00:00,  3.38s/it]


In [17]:
np.mean(rwds)

670.0