In [31]:
from machin.frame.algorithms import DDPGPer
from machin.utils.logging import default_logger as logger
import torch as t
import torch.nn as nn
import gym

from drl4dypm.env import *

In [6]:
# configurations
env = gym.make("Pendulum-v0")
observe_dim = 3
action_dim = 1
action_range = 2
max_episodes = 100
max_steps = 200
noise_param = (0, 0.2)
noise_mode = "normal"
solved_reward = -150
solved_repeat = 5

In [3]:
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, action_range):
        super(Actor, self).__init__()

        self.fc1 = nn.Linear(state_dim, 16)
        self.fc2 = nn.Linear(16, 16)
        self.fc3 = nn.Linear(16, action_dim)
        self.action_range = action_range

    def forward(self, state):
        a = t.relu(self.fc1(state))
        a = t.relu(self.fc2(a))
        a = t.tanh(self.fc3(a)) * self.action_range
        return a


class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()

        self.fc1 = nn.Linear(state_dim + action_dim, 16)
        self.fc2 = nn.Linear(16, 16)
        self.fc3 = nn.Linear(16, 1)

    def forward(self, state, action):
        state_action = t.cat([state, action], 1)
        q = t.relu(self.fc1(state_action))
        q = t.relu(self.fc2(q))
        q = self.fc3(q)
        return q

In [4]:
actor = Actor(observe_dim, action_dim, action_range)
actor_t = Actor(observe_dim, action_dim, action_range)
critic = Critic(observe_dim, action_dim)
critic_t = Critic(observe_dim, action_dim)

ddpg_per = DDPGPer(actor, actor_t, critic, critic_t,
                   t.optim.Adam,
                   nn.MSELoss(reduction='sum'))




In [5]:
episode, step, reward_fulfilled = 0, 0, 0
smoothed_total_reward = 0

In [8]:
cols = ['episode','reward','critic_loss','actor_loss']
line = '|'.join([f'{col:<12}' for col in cols])
print(line)


while episode < max_episodes:
    episode += 1
    total_reward = 0
    terminal = False
    step = 0
    state = t.tensor(env.reset(), dtype=t.float32).view(1, observe_dim)

    while not terminal and step <= max_steps:
        step += 1
        with t.no_grad():
            old_state = state
            # agent model inference
            action = ddpg_per.act_with_noise(
                        {"state": old_state},
                        noise_param=noise_param,
                        mode=noise_mode
                    )
            state, reward, terminal, _ = env.step(action.numpy())
            state = t.tensor(state, dtype=t.float32).view(1, observe_dim)
            total_reward += reward[0]

            ddpg_per.store_transition({
                "state": {"state": old_state},
                "action": {"action": action},
                "next_state": {"state": state},
                "reward": reward[0],
                "terminal": terminal or step == max_steps
            })
    
    
    # update, update more if episode is longer, else less
    for _ in range(step):
        actor_loss, critic_loss = ddpg_per.update()
    
    
    smoothed_total_reward = (smoothed_total_reward * 0.9 +
                         total_reward * 0.1)
    
    line = f'{episode:<12}|' + '|'.join([f'{col:<12.4f}' for col in [smoothed_total_reward, critic_loss, actor_loss,]])
    
    print(line)

The framework is not responsible for any un-matching device issues caused by this operation.[0m
The framework is not responsible for any un-matching device issues caused by this operation.[0m
The framework is not responsible for any un-matching device issues caused by this operation.[0m


episode     |reward      |critic_loss |actor_loss  
2           |-281.1770   |5.7275      |-3.2711     
3           |-376.9814   |1.2157      |-8.4578     
4           |-415.6200   |0.1798      |-13.3615    
5           |-518.8913   |0.7134      |-19.1177    
6           |-622.6979   |0.3976      |-24.6760    
7           |-715.9733   |1.3380      |-31.8608    
8           |-799.7304   |2.0065      |-35.4326    
9           |-875.5328   |1.8675      |-41.7894    
10          |-932.5437   |2.8006      |-48.7506    
11          |-949.0209   |1.1675      |-53.2349    
12          |-980.1068   |3.5369      |-58.1922    
13          |-983.2745   |3.1872      |-65.7600    
14          |-989.6856   |2.7915      |-69.6690    
15          |-980.1767   |5.6098      |-70.2010    
16          |-986.2443   |11.5356     |-72.8416    
17          |-1031.0578  |4.2731      |-70.2140    
18          |-1042.6977  |1.3655      |-71.8187    
19          |-1043.5948  |1.6224      |-84.8745    
20          