In [2]:
import torch
import torch.nn as nn

In [3]:
import gymnasium as gym

In [4]:
import numpy as np

In [5]:
env_ids = ["HalfCheetah-v4", "CartPole-v1", "BreakoutNoFrameskip-v4"]
env_list = []
for env_id in env_ids:
    env = gym.make(env_id)
    if env_id == "BreakoutNoFrameskip-v4":
        env.observation_space.dtype = np.float32 
        env = gym.wrappers.ResizeObservation(env, (84, 84))
        env = gym.wrappers.GrayScaleObservation(env)
        env = gym.wrappers.FrameStack(env, 4)
        
    env_list.append(env)

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


In [6]:
env_list[2].observation_space.shape

(4, 84, 84)

In [7]:
from gymnasium.spaces import Box, Discrete

In [8]:
import numpy as np

In [9]:
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    # torch.nn.init.normal_(layer.weight, std)
    # torch.nn.init.normal_(layer.bias, std)
    return layer

import torch.nn.functional as F

class CNN(nn.Module):
    def __init__(self, ):
        super().__init__()
        self.cnn1 = layer_init(nn.Conv2d(4, 32, 8, stride=4, padding=1))
        self.cnn2 = layer_init(nn.Conv2d(32, 64, 4, stride=2, padding=1))
        self.cnn3 = layer_init(nn.Conv2d(64, 64, 3, stride=1, padding=1))
        self.cnn4 = layer_init(nn.Conv2d(64, 256, 3, stride=1, padding=1))
        self.pooling = nn.AdaptiveAvgPool2d(output_size=1)
    
    def forward(self, x):
        x = F.relu(self.cnn1(x))
        x = F.relu(self.cnn2(x))
        x = F.relu(self.cnn3(x))
        x = F.relu(self.cnn4(x))
        x = self.pooling(x)
        x = x.view(-1, 256)
        # x = x.squeeze()
        
        return x

In [10]:
env_list[2]

<FrameStack<GrayScaleObservation<ResizeObservation<OrderEnforcing<PassiveEnvChecker<AtariEnv<BreakoutNoFrameskip-v4>>>>>>>

In [11]:
img_obs, info = env_list[2].reset()

  logger.warn(


In [12]:
img_obs.shape

(4, 84, 84)

In [13]:
info

{'lives': 5, 'episode_frame_number': 0, 'frame_number': 0}

In [14]:
cnn= CNN()

In [15]:
img_obs.shape

(4, 84, 84)

In [16]:
img_ten = torch.tensor(img_obs).unsqueeze(0).to(torch.float)
img_ten.shape

  img_ten = torch.tensor(img_obs).unsqueeze(0).to(torch.float)


torch.Size([1, 4, 84, 84])

In [17]:
out = cnn(img_ten)
out.shape

torch.Size([1, 256])

In [18]:
class MultiHeadNet(nn.Module):
    def __init__(self, env_ids, env_list) -> None:
        super().__init__()
        self.env_ids = env_ids
        self.env_list = env_list
        encoder_dict = dict()
        decoder_dict = dict()
        for env_id, env in zip(env_ids, env_list):
            obs_dim = env.observation_space.shape
            if len(obs_dim) < 2:
                obs_dim = np.prod(obs_dim)
                obs_encoder = nn.Linear(obs_dim, 256)
            elif len(obs_dim) == 3:
                obs_encoder = CNN()
            if isinstance(env.action_space, Box):
                act_dim = np.prod(env.action_space.shape)
            elif isinstance(env.action_space, Discrete):
                act_dim = env.action_space.n
            act_decoder = nn.Linear(256, act_dim)
            encoder_dict[env_id] = obs_encoder
            decoder_dict[env_id] = act_decoder
        self.encoder_dict = nn.ModuleDict(encoder_dict)
        self.mlp = nn.Linear(256, 256)
        self.decoder_dict = nn.ModuleDict(decoder_dict)
    
    def forward(self, env_id, action_space, x):
        h = self.encoder_dict[env_id](x) 
        h = self.mlp(h)
        out = self.decoder_dict[env_id](h)
        return out

In [27]:
policy = MultiHeadNet(env_ids, env_list).to(torch.float32)

In [28]:
observations = [env.reset()[0] for env in env_list]

In [30]:
actions = [policy(env_id, None, torch.tensor(obs).to(torch.float32)) for env_id, obs in zip(env_ids, observations)]

In [31]:
for a in actions:
    print(a.shape)

torch.Size([6])
torch.Size([2])
torch.Size([1, 4])
