<a href="https://colab.research.google.com/github/nimamt/machine_learning/blob/master/pytorch/reinforcement/TD3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import gym
import random
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from collections import deque
from torch.autograd import Variable

# Hyperparameters
BATCH_SIZE = 256
LR_ACTOR = 1e-4
LR_CRITIC = 1e-3
STD_NOISE = 0.3
GAMMA = 0.99
TAU = 1e-3
BUFFER_SIZE = int(1e6)
STEPS = 1000
POLICY_NOISE = 0.2
NOISE_CLIP = 0.5
DELAY_STEPS = 2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ActorNet(torch.nn.Module):
    def __init__(self, state_size, action_size):
        super(ActorNet, self).__init__()

        self.fc1 = torch.nn.Linear(state_size, 256, device=device)
        self.fc2 = torch.nn.Linear(256, 128, device=device)
        self.fc3 = torch.nn.Linear(128, action_size, device=device)

        self.reset_parameters()

    def reset_parameters(self):
        x = self.fc1_init()
        self.fc1.weight.data.uniform_(x[0],x[1])
        x = self.fc2_init()
        self.fc2.weight.data.uniform_(x[0],x[1])
        self.fc3.weight.data.uniform_(-3e-3, 3e-3)

    def fc1_init(self):
        lim = 1. / np.sqrt(self.fc1.weight.data.size()[0])
        return (-lim, lim)

    def fc2_init(self):
        lim = 1. / np.sqrt(self.fc2.weight.data.size()[0])
        return (-lim, lim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.tanh(self.fc3(x))
        return x


class CriticNet(torch.nn.Module):
    def __init__(self, state_size, action_size):
        super(CriticNet, self).__init__()

        self.fc1 = torch.nn.Linear(state_size + action_size, 256, device=device)
        self.fc2 = torch.nn.Linear(256, 128, device=device)
        self.fc3 = torch.nn.Linear(128, 1, device=device)

        self.fc4 = torch.nn.Linear(state_size + action_size, 256, device=device)
        self.fc5 = torch.nn.Linear(256, 128, device=device)
        self.fc6 = torch.nn.Linear(128, 1, device=device)

        self.reset_parameters()

    def reset_parameters(self):
        self.fc1.weight.data.uniform_(*self.fc1_init())
        self.fc2.weight.data.uniform_(*self.fc2_init())
        self.fc3.weight.data.uniform_(-3e-3, 3e-3)

        self.fc4.weight.data.uniform_(*self.fc1_init())
        self.fc5.weight.data.uniform_(*self.fc2_init())
        self.fc6.weight.data.uniform_(-3e-3, 3e-3)

    def fc1_init(self):
        lim = 1. / np.sqrt(self.fc1.weight.data.size()[0])
        return (-lim, lim)

    def fc2_init(self):
        lim = 1. / np.sqrt(self.fc2.weight.data.size()[0])
        return (-lim, lim)

    def forward(self, state, action):
        state_action = torch.cat([state, action], 1).to(device)
        x = F.relu(self.fc1(state_action))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        y = F.relu(self.fc4(state_action))
        y = F.relu(self.fc5(y))
        y = self.fc6(y)
        return x,y

    def Q1(self, state, action):
        state_action = torch.cat([state, action], 1).to(device)
        x = F.relu(self.fc1(state_action))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class DDPG:
    def __init__(self, state_dim, action_dim, action_high):
        self.actor = ActorNet(state_dim, action_dim)
        self.actor_target = ActorNet(state_dim, action_dim)
        self.critic = CriticNet(state_dim, action_dim)
        self.critic_target = CriticNet(state_dim, action_dim)
        self.memory = deque(maxlen=BUFFER_SIZE)
        self.optimizer_actor = optim.Adam(self.actor.parameters(), lr=LR_ACTOR)
        self.optimizer_critic = optim.Adam(self.critic.parameters(), lr=LR_CRITIC)
        self.action_high = action_high
        self.step = 0

    def act(self, state):
        state = Variable(torch.from_numpy(state).float().to(device).unsqueeze(0))
        self.actor.eval()
        with torch.no_grad():
            action = self.actor(state.to(device))
        self.actor.train()
        return np.clip(action.cpu().numpy()[0] * self.action_high, -self.action_high, self.action_high)

    def memorize(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def learn(self):
        if len(self.memory) < BATCH_SIZE:
            return

        self.step += 1

        batch = random.sample(self.memory, BATCH_SIZE)
        state_batch = torch.from_numpy(np.array([arr[0] for arr in batch])).float().to(device)
        action_batch = torch.from_numpy(np.array([arr[1] for arr in batch])).float().to(device)
        reward_batch = torch.from_numpy(np.array([arr[2] for arr in batch]).reshape(BATCH_SIZE, 1)).float().to(device)
        next_state_batch = torch.from_numpy(np.array([arr[3] for arr in batch])).float().to(device)
        done_batch = torch.from_numpy(np.array([arr[4] for arr in batch], dtype=np.uint8).reshape(BATCH_SIZE, 1)).float().to(device)

        next_actions = self.actor_target(next_state_batch)
        # noise = torch.zeros_like(next_actions).to(device)
        # noise.normal_(std=STD_NOISE)
        # Paper code
        noise = (
				  torch.randn_like(next_actions) * POLICY_NOISE
			  ).clamp(-NOISE_CLIP, NOISE_CLIP)
        next_actions = (next_actions + noise).clamp(-self.action_high, self.action_high)
        q_next1, q_next2 = self.critic_target(next_state_batch, next_actions)
        q_targets = reward_batch + GAMMA * torch.min(q_next1,q_next2) * (1 - done_batch)

        # Update critic
        self.optimizer_critic.zero_grad()
        q_current1, q_current2 = self.critic(state_batch, action_batch)
        critic_loss = F.mse_loss(q_current1, q_targets) + F.mse_loss(q_current2, q_targets)
        critic_loss.backward()
        self.optimizer_critic.step()

        if self.step % DELAY_STEPS == 0:
          # Update actor
          self.optimizer_actor.zero_grad()
          actor_loss = -self.critic.Q1(state_batch, self.actor(state_batch)).mean()
          actor_loss.backward()
          self.optimizer_actor.step()

          # Update target networks
          self.update_targets()

    def update_targets(self):
        for target_param, param in zip(self.actor_target.parameters(), self.actor.parameters()):
            target_param.data.copy_(TAU * param.data + (1 - TAU) * target_param.data)
        for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):
            target_param.data.copy_(TAU * param.data + (1 - TAU) * target_param.data)

if __name__ == "__main__":
    env = gym.make("Pendulum-v1")
    agent = DDPG(env.observation_space.shape[0], env.action_space.shape[0], env.action_space.high[0])
    scores = []
    for i_episode in range(1, 150):
        state = env.reset()
        score = 0
        for t in range(1000):
            action = agent.act(state)
            next_state, reward, done, info = env.step(action)
            agent.memorize(state, action, reward, next_state, done)
            agent.learn()
            state = next_state
            score += reward
            if done:
                break
        scores.append(score)
        print("Episode {}: Score = {}".format(i_episode, score))

Episode 1: Score = -915.4854227131614
Episode 2: Score = -1371.2139301632176
Episode 3: Score = -1321.0586690929742
Episode 4: Score = -944.6033472877201
Episode 5: Score = -1601.682867510128
Episode 6: Score = -1517.3193491452625
Episode 7: Score = -1669.9879787643602
Episode 8: Score = -1678.0348837244017
Episode 9: Score = -1826.342537554729
Episode 10: Score = -1831.9836459000537
Episode 11: Score = -1599.1568174339661
Episode 12: Score = -1225.8292465213945
Episode 13: Score = -1609.0338192872075
Episode 14: Score = -1399.749496148025
Episode 15: Score = -1382.6626497437396
Episode 16: Score = -1353.101447095879
Episode 17: Score = -1550.1940531328553
Episode 18: Score = -1555.4553992842655
Episode 19: Score = -1433.084477257876
Episode 20: Score = -1479.9490777096285
Episode 21: Score = -1286.4736544913544
Episode 22: Score = -1396.6915768863096
Episode 23: Score = -1379.4575795217206
Episode 24: Score = -1211.79418358925
Episode 25: Score = -1476.2376071970546
Episode 26: Score 