In [13]:
import gym
import numpy as np
import numpy.typing as npt

from typing import List, Tuple, Literal, Any, Optional, cast
from utils.agent import Agent
from tqdm.autonotebook import tqdm
from utils.algorithm import AlgorithmInterface
from utils.preprocess import PreprocessInterface
import torch
from torchvision import transforms
from torch import nn

In [14]:
RANDOM_SEED = 0
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)



<torch._C.Generator at 0x7f755a7d6f10>

In [15]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')


Using cuda device


In [16]:
env = gym.make('StarGunner-v0')
env.seed(RANDOM_SEED)
env.reset()
print(env.action_space)
env._max_episode_steps = 1_8000
env.observation_space

Discrete(18)


Box([[[0 0 0]
  [0 0 0]
  [0 0 0]
  ...
  [0 0 0]
  [0 0 0]
  [0 0 0]]

 [[0 0 0]
  [0 0 0]
  [0 0 0]
  ...
  [0 0 0]
  [0 0 0]
  [0 0 0]]

 [[0 0 0]
  [0 0 0]
  [0 0 0]
  ...
  [0 0 0]
  [0 0 0]
  [0 0 0]]

 ...

 [[0 0 0]
  [0 0 0]
  [0 0 0]
  ...
  [0 0 0]
  [0 0 0]
  [0 0 0]]

 [[0 0 0]
  [0 0 0]
  [0 0 0]
  ...
  [0 0 0]
  [0 0 0]
  [0 0 0]]

 [[0 0 0]
  [0 0 0]
  [0 0 0]
  ...
  [0 0 0]
  [0 0 0]
  [0 0 0]]], [[[255 255 255]
  [255 255 255]
  [255 255 255]
  ...
  [255 255 255]
  [255 255 255]
  [255 255 255]]

 [[255 255 255]
  [255 255 255]
  [255 255 255]
  ...
  [255 255 255]
  [255 255 255]
  [255 255 255]]

 [[255 255 255]
  [255 255 255]
  [255 255 255]
  ...
  [255 255 255]
  [255 255 255]
  [255 255 255]]

 ...

 [[255 255 255]
  [255 255 255]
  [255 255 255]
  ...
  [255 255 255]
  [255 255 255]
  [255 255 255]]

 [[255 255 255]
  [255 255 255]
  [255 255 255]
  ...
  [255 255 255]
  [255 255 255]
  [255 255 255]]

 [[255 255 255]
  [255 255 255]
  [255 255 255]
  ...
 

In [17]:
# shape is (210, 160, 3)
Observation = npt.NDArray[np.uint8]
Action = int

# shape is (4, 210, 160, 3)
State = torch.Tensor
Reward = int

Episode = List[Tuple[Observation, Optional[Action], Optional[Reward]]]


In [18]:
class DQN(nn.Module):
    def __init__(self):
        super(DQN, self).__init__()
        # self.first = nn.Sequential(nn.Conv2d(4, 32, (8, 8), 4), nn.ReLU()) self.second = nn.Sequential()
        self.net = nn.Sequential(
            nn.Conv2d(4, 32, (8, 8), 4), nn.ReLU(), nn.Conv2d(32, 64,  (4, 4), 2), nn.ReLU(), nn.Conv2d(64, 64, (3, 3), 1), nn.ReLU(), nn.Flatten(),  nn.Linear(7*7*64, 512), nn.Linear(512, 18)).to(device)

    def forward(self, x: State) -> torch.Tensor:
        rlt: torch.Tensor = self.net(x.to(device))
        return rlt.cpu()


In [19]:
class RandomAlgorithm(AlgorithmInterface[State, Action]):
    def __init__(self):

        self.times = 0
        self.last_action = None

    def allowed_actions(self, state: State) -> List[Action]:
        return list(range(18))

    def take_action(self, state: State) -> Action:
        self.times += 1

        if self.times % 10 == 0:
            act = np.random.choice(self.allowed_actions(state))
            self.last_action = act
            return act

        if self.last_action is not None:
            return self.last_action

        act = np.random.choice(self.allowed_actions(state))
        self.last_action = act
        return act

    def after_step(
        self,
        sa: Tuple[State, Action],
        episode: List[Tuple[State, Optional[Action], Optional[Reward]]],
    ):
        pass

    def on_termination(
        self, episode: List[Tuple[State, Optional[Action], Optional[Reward]]]
    ):
        pass


class NNAlgorithm(AlgorithmInterface[State, Action]):
    def __init__(self, nn: DQN, sigma: float):
        self.network = nn
        self.sigma = sigma

        self.times: int = 0

    def allowed_actions(self, state: State) -> List[Action]:
        return list(range(18))

    def take_action(self, state: State) -> Action:
        self.times += 1
        rand = np.random.random()
        if rand < self.sigma:
            return np.random.choice(self.allowed_actions(state))
        else:
            act_vals: torch.Tensor = self.network(state)
            maxi = torch.argmax(act_vals)
            return cast(Action, maxi)

    def after_step(
        self,
        sa: Tuple[State, Action],
        episode: List[Tuple[State, Optional[Action], Optional[Reward]]],
    ):
        pass

    def on_termination(
        self, episode: List[Tuple[State, Optional[Action], Optional[Reward]]]
    ):
        pass


class Preprocess(PreprocessInterface[Observation, Action, State]):
    def __init__(self):
        self.trfm = transforms.Compose(
            [transforms.ToTensor(), transforms.Grayscale(),
                transforms.Resize((84, 84))])

    def transform_one(
        self, h: List[Tuple[Observation, Optional[Action], Optional[Reward]]]
    ) -> State:

        # last_4 is (1-4, 210, 160, 3)
        last_4 = np.asarray([np.asarray(o)
                            for (o, _, _) in h[-4:]])

        while last_4.shape[0] < 4:
            last_4 = np.insert(last_4, 0, last_4[-1:], axis=0)

        assert last_4.shape == (4, 210, 160, 3)

        rlt = torch.stack([self.trfm(i)
                           for i in last_4]).squeeze(1).unsqueeze(0)
        assert rlt.shape == (1, 4,  84, 84)
        return rlt

    def transform_many(
        self, h: List[Tuple[Observation, Optional[Action], Optional[Reward]]]
    ) -> List[Tuple[State, Optional[Action], Optional[Reward]]]:
        return cast(Any, None)


In [20]:
agent = Agent(env, NNAlgorithm(DQN(), 1e-3), Preprocess())


In [21]:
EVALUATION_TIMES = 30

rwds: List[int] = []

for _ in tqdm(range(EVALUATION_TIMES)):
    agent.reset()

    end = False

    while not end:
        (o, end, episode) = agent.step()
        # env.render()
        if end:
            rwds.append(np.sum([r if r is not None else 0 for (_,
                                                               _, r) in cast(Episode, episode)]))


100%|██████████| 30/30 [02:45<00:00,  5.52s/it]


In [22]:
np.mean(rwds)

480.0