In [65]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import gymnasium as gym

In [2]:
env = gym.make("Pendulum-v1")

In [45]:
obs, info = env.reset()
obs, reward, truncated, terminated, info = env.step([1.0])
obs = torch.tensor(obs)

actor = Actor(env)
critic = Critic(env)

action = actor(obs)
value = critic(obs, action)


In [15]:
a = torch.rand(32).unsqueeze(1)
action = torch.rand(32, 400)
# torch.cat((a, action)).shape
print(a.shape, action.shape)
torch.cat((a, action), dim=1).shape


torch.Size([32, 1]) torch.Size([32, 400])


torch.Size([32, 401])

In [75]:
a = torch.ones(64)
print(a.shape)
b = torch.ones(64, 1)
print(b.shape)
c = b.mul(a)
c.shape

torch.Size([64])
torch.Size([64, 1])


torch.Size([64, 64])

In [4]:
class Actor(nn.Module):
    def __init__(self, env=None, input_dim=None, output_dim=None):
        super().__init__()
        input_dim = input_dim or env.observation_space.shape[0]
        output_dim = output_dim or env.action_space.shape[0]
        self.act_min = env.action_space.low[0]
        self.act_max = env.action_space.high[0]
        
        self.l1 = nn.Linear(input_dim, 400)
        self.l2 = nn.Linear(400, 300)
        self.l3 = nn.Linear(300, output_dim)
        
        nn.init.orthogonal_(self.l1.weight)
        nn.init.orthogonal_(self.l2.weight)
        nn.init.orthogonal_(self.l3.weight)
                
    def forward(self, x):
        a = F.relu(self.l1(x))
        a = F.relu(self.l2(a))
        a = F.tanh(self.l3(a))
        
        # Scale the action
        a = self.act_max * (1+a)/2 + self.act_min * (1-a)/2
        return a

In [34]:
class Critic(nn.Module):
    def __init__(self, env=None, state_dim=None, action_dim=None):
        super().__init__()
        state_dim = state_dim or env.observation_space.shape[0]
        action_dim = action_dim or env.action_space.shape[0]
        self.act_min = env.action_space.low[0]
        self.act_max = env.action_space.high[0]
        
        self.l1 = nn.Linear(state_dim, 400)
        self.l2 = nn.Linear(400+action_dim, 300)
        self.l3 = nn.Linear(300, 1)
        
        nn.init.orthogonal_(self.l1.weight)
        nn.init.orthogonal_(self.l2.weight)
        nn.init.orthogonal_(self.l3.weight)
                
    def forward(self, obs, action):
        a = F.relu(self.l1(obs))
        a = torch.cat((a, action))
        a = F.relu(self.l2(a))
        a = self.l3(a)
        return a