In [7]:
import setup
import argparse
import os
import random
import time
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical
from typing import List, Tuple, Literal, Any, Optional, cast, Callable, Union, Iterable
from utils.env import PreprocessObservation, FrameStack, ToTensorEnv

from utils.env_sb3 import WarpFrame, MaxAndSkipEnv, NoopResetEnv, EpisodicLifeEnv

In [8]:
env = gym.make('PongDeterministic-v4')
env = WarpFrame(env)
env = ToTensorEnv(env)
env = FrameStack(env, num_stack=4)

In [9]:
def make_env():
    def thunk():
        env = gym.make('PongDeterministic-v4')
        env.seed(0)
        env = WarpFrame(env)
        env = ToTensorEnv(env)
        env = FrameStack(env, num_stack=4)
        return env

    return thunk


In [10]:
# envs = gym.vector.make('PongDeterministic-v4', 3, False, )
envs = gym.vector.SyncVectorEnv([make_env() for _ in range(4)])
envs


SyncVectorEnv(4)

In [12]:
(envs.single_action_space.shape, envs.single_observation_space.shape)

((), (4, 1, 84, 84))

In [68]:
State = torch.Tensor

In [69]:
class ActorCritic(nn.Module):
    def __init__(self, n_actions: int):
        super().__init__()

        self.n_actions = n_actions

        self.base = nn.Sequential(
            nn.Conv2d(4, 32, (8, 8), 4),
            nn.ReLU(),
            nn.Conv2d(32, 64, (4, 4), 2),
            nn.ReLU(),
            nn.Conv2d(64, 64, (3, 3), 1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(7 * 7 * 64, 512),
            nn.ReLU(),
            # nn.Linear(256, n_actions),
            # nn.Softmax(dim=1),
        )

        self.actor = nn.Sequential(
            nn.Linear(512, n_actions), nn.Softmax(dim=1))
        self.critic = nn.Linear(512, 1)

    def forward(self, s: State) -> Tuple[torch.Tensor, torch.Tensor]:
        base = self.base(s)
        action = self.actor(base)
        value = self.critic(base)

        assert action.shape == (s.size(0), self.n_actions)
        assert value.shape == (s.size(0), 1)

        return (action, value)

In [71]:
os = (envs.reset())
print(os.shape)
os[0]

(3, 4, 1, 84, 84)


array([[[[0.2509804 , 0.2509804 , 0.2509804 , ..., 0.41960785,
          0.41960785, 0.41960785],
         [0.41960785, 0.41960785, 0.41960785, ..., 0.41960785,
          0.41960785, 0.41960785],
         [0.41960785, 0.41960785, 0.41960785, ..., 0.41960785,
          0.41960785, 0.41960785],
         ...,
         [0.2901961 , 0.2901961 , 0.2901961 , ..., 0.2901961 ,
          0.2901961 , 0.2901961 ],
         [0.2901961 , 0.2901961 , 0.2901961 , ..., 0.2901961 ,
          0.2901961 , 0.2901961 ],
         [0.2901961 , 0.2901961 , 0.2901961 , ..., 0.2901961 ,
          0.2901961 , 0.2901961 ]]],


       [[[0.2509804 , 0.2509804 , 0.2509804 , ..., 0.41960785,
          0.41960785, 0.41960785],
         [0.41960785, 0.41960785, 0.41960785, ..., 0.41960785,
          0.41960785, 0.41960785],
         [0.41960785, 0.41960785, 0.41960785, ..., 0.41960785,
          0.41960785, 0.41960785],
         ...,
         [0.2901961 , 0.2901961 , 0.2901961 , ..., 0.2901961 ,
          0.2901961 , 0

In [58]:
env.reset()[0]

tensor([[[0.2510, 0.2510, 0.2510,  ..., 0.4196, 0.4196, 0.4196],
         [0.4196, 0.4196, 0.4196,  ..., 0.4196, 0.4196, 0.4196],
         [0.4196, 0.4196, 0.4196,  ..., 0.4196, 0.4196, 0.4196],
         ...,
         [0.2902, 0.2902, 0.2902,  ..., 0.2902, 0.2902, 0.2902],
         [0.2902, 0.2902, 0.2902,  ..., 0.2902, 0.2902, 0.2902],
         [0.2902, 0.2902, 0.2902,  ..., 0.2902, 0.2902, 0.2902]]])