In [3]:
!pip install gymnasium[mujoco]
!pip install stable-baselines3



In [11]:
import gymnasium as gym
import numpy as np
from multiprocessing import Process, Queue
from stable_baselines3 import PPO
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import os

os.environ["MUJOCO_GL"] = "glfw" # or 'egl'

env_name = 'Ant-v5'
env = gym.make(env_name, render_mode='rgb_array')

# set your custom architecture
policy_kwargs = {'activation_fn': torch.nn.ReLU, 'net_arch': {'pi': [64, 128], 'vf': [32, 64]}}
model = PPO('MlpPolicy', env, policy_kwargs=policy_kwargs)

In [12]:
model.policy_kwargs

{'activation_fn': torch.nn.modules.activation.ReLU,
 'net_arch': {'pi': [64, 128], 'vf': [32, 64]}}

In [13]:
model.policy.state_dict()

OrderedDict([('log_std', tensor([0., 0., 0., 0., 0., 0., 0., 0.])),
             ('mlp_extractor.policy_net.0.weight',
              tensor([[ 0.0751,  0.1961, -0.1624,  ...,  0.0180, -0.0207, -0.0952],
                      [-0.0009,  0.0155, -0.0358,  ...,  0.1907,  0.0068,  0.1008],
                      [ 0.0327, -0.0926, -0.1695,  ...,  0.0517, -0.0656,  0.2301],
                      ...,
                      [-0.0640, -0.0491, -0.1461,  ...,  0.0579, -0.0094, -0.0487],
                      [-0.1815,  0.0672,  0.0532,  ...,  0.1817, -0.0266, -0.0990],
                      [-0.1356, -0.0981,  0.0645,  ...,  0.1444,  0.1475, -0.0584]])),
             ('mlp_extractor.policy_net.0.bias',
              tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [14]:
for key, val in model.policy.state_dict().items():
    print(key, val.shape)

log_std torch.Size([8])
mlp_extractor.policy_net.0.weight torch.Size([64, 105])
mlp_extractor.policy_net.0.bias torch.Size([64])
mlp_extractor.policy_net.2.weight torch.Size([128, 64])
mlp_extractor.policy_net.2.bias torch.Size([128])
mlp_extractor.value_net.0.weight torch.Size([32, 105])
mlp_extractor.value_net.0.bias torch.Size([32])
mlp_extractor.value_net.2.weight torch.Size([64, 32])
mlp_extractor.value_net.2.bias torch.Size([64])
action_net.weight torch.Size([8, 128])
action_net.bias torch.Size([8])
value_net.weight torch.Size([1, 64])
value_net.bias torch.Size([1])


In [15]:
model.learn(total_timesteps=1000)

<stable_baselines3.ppo.ppo.PPO at 0x2061058ba90>

In [16]:
# evaluation
observation, info = env.reset()
for _ in range(1000):
    action = model.predict(observation, deterministic=True)[0]
    observation, reward, terminated, truncated, info = env.step(action)
    if terminated or truncated:
        env.reset()
env.close()