In [None]:
%matplotlib notebook

import gym
from logger import Plotter
import numpy as np
from replay import ReplayBuffer
import torch
from torch import nn

plotter = Plotter('Return', 'Explained Variance', 'Policy Loss', 'Values Loss')

# create environment
env = gym.make('Pendulum-v0')
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]

REPLAY_SIZE = 1000000
BATCH_SIZE = 100
replay_buffer = ReplayBuffer(REPLAY_SIZE, obs_dim, act_dim)

HIDDEN_SIZE = 32
def mlp(in_dim, out_dim, activation=nn.ReLU, output_activation=nn.Identity):
    return nn.Sequential(nn.Linear(in_dim, HIDDEN_SIZE),
                         activation(),
                         nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE),
                         activation(),
                         nn.Linear(HIDDEN_SIZE, out_dim),
                         output_activation())


class Actor(nn.Module):
    def __init__(self, obs_dim, act_dim, lo, hi):
        super().__init__()
        self.pi = mlp(obs_dim, 2 * act_dim)
        self.hi = hi

    def act(self, batch, noise=True):
        if batch.dim() == 1:
            batch = batch.unsqueeze(0)
        output = self.pi(batch)
        means = output[:, ::2]
        if not noise:
            return means, 1

        std_devs = torch.exp(output[:, 1::2])
        gaussian = torch.distributions.Normal(means, std_devs)
        u = gaussian.rsample()
        actions = torch.tanh(u)
        logp = gaussian.log_prob(u) - torch.log(1 - actions ** 2)
        return self.hi * actions, logp.squeeze()

    def log_prob(self, states, actions):
        means = self(states)
        noise = actions - means
        return self.noise.log_prob(noise)


class Critic(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.q = mlp(obs_dim + act_dim, 1)

    def forward(self, obs, acts):
        return self.q(torch.cat((obs, acts), dim=1)).squeeze(dim=1)

policy_net = Actor(obs_dim, act_dim, env.action_space.low.item(), env.action_space.high.item())
q_net_1 = Critic(obs_dim, act_dim)
q_net_2 = Critic(obs_dim, act_dim)
# values_net = mlp(obs_dim, 1)

target_q_net_1 = Critic(obs_dim, act_dim)
target_q_net_2 = Critic(obs_dim, act_dim)
for net in [target_q_net_1.q, target_q_net_2.q]:
    for p in net.parameters():
        p.requires_grad = False

def update_target_net(coeff):
    with torch.no_grad():
        for net, target_net in [(q_net_1.q, target_q_net_1.q),
                                (q_net_2.q, target_q_net_2.q)]:
            for p, p_targ in zip(net.parameters(), target_net.parameters()):
                p_targ.data.mul_(1 - coeff)
                p_targ.data.add_(coeff * p.data)         
update_target_net(1)

LR = 3e-4
policy_optimiser = torch.optim.Adam(policy_net.parameters(), lr=LR)
q_optimiser_1 = torch.optim.Adam(q_net_1.parameters(), lr=LR)
q_optimiser_2 = torch.optim.Adam(q_net_2.parameters(), lr=LR)
# values_optimiser = torch.optim.Adam(values_net.parameters(), lr=LR)

ALPHA = 0.2
DISCOUNT = 0.99
UPDATE_AFTER = 1000
UPDATE_EVERY = 50
SMOOTHING_COEFF = 0.005

START_STEPS = 1000
STEPS_PER_EPOCH = 4000
EPOCHS = 100
steps = 0
for epoch in range(EPOCHS):
    done = True
    for _ in range(STEPS_PER_EPOCH):
        steps += 1

        obs = torch.as_tensor(env.reset(), dtype=torch.float32) if done else next_obs
        if steps > START_STEPS:
            action = policy_net.act(obs)[0].detach()
        else:
            action = torch.as_tensor(env.action_space.sample(), dtype=torch.float32)

        next_obs, reward, done, _ = env.step(action)
        next_obs = torch.as_tensor(next_obs, dtype=torch.float32)

        replay_buffer.store((obs, action, reward, next_obs, done))

        if steps < UPDATE_AFTER or steps % UPDATE_EVERY:
            continue

        for _ in range(UPDATE_EVERY):
            states, actions, rewards, next_states, terminals = replay_buffer.sample(BATCH_SIZE)

            with torch.no_grad():
                sample_next_actions, sample_next_logps = policy_net.act(next_states)

            # values = values_net(states)
            # values_loss = ((values - (q_net_1(states, sample_actions) - sample_logps).mean())**2).mean()
            # values_optimiser.zero_grad()
            # values_loss.backwards()
            # values_optimiser.step()

            q_1 = target_q_net_1(next_states, sample_next_actions)
            q_2 = target_q_net_2(next_states, sample_next_actions)
            y = rewards + DISCOUNT * (1 - terminals) * (torch.min(q_1, q_2) - ALPHA * sample_next_logps)

            q_loss_1 = ((y - q_net_1(states, actions)) ** 2).mean()
            q_optimiser_1.zero_grad()
            q_loss_1.backward()
            q_optimiser_1.step()

            q_loss_2 = ((y - q_net_2(states, actions)) ** 2).mean()
            q_optimiser_2.zero_grad()
            q_loss_2.backward()
            q_optimiser_2.step()

            sample_actions, sample_logps = policy_net.act(states)

            q1 = q_net_1(states, sample_actions)
            q2 = q_net_2(states, sample_actions)

            policy_loss = (ALPHA * sample_logps - torch.min(q1, q2)).mean()
            policy_optimiser.zero_grad()
            policy_loss.backward()
            policy_optimiser.step()

            update_target_net(SMOOTHING_COEFF)

    NUM_TEST_EPISODES = 8
    lengths = []
    returns = []
    explained_variances = []
    for i in range(NUM_TEST_EPISODES):
        vvalues = []
        rrewards = []
        obs = env.reset()
        done = False
        while not done:
            obs = torch.as_tensor(obs, dtype=torch.float32)
            with torch.no_grad():
                action, _ = policy_net.act(obs, noise=False)
                vvalues.append(q_net_1(obs.unsqueeze(0), action).squeeze(dim=-1))

            obs, reward, done, _ = env.step(action)
            rrewards.append(reward)

        def cumulative(iter, discount):
            c = iter.clone().detach()
            for i in reversed(range(len(c) - 1)):
                c[i] += discount * c[i + 1]
            return c

        test_rewards_to_go = cumulative(torch.as_tensor(rrewards), DISCOUNT)
        explained_variance = 1 - (np.array(vvalues) - np.array(test_rewards_to_go)).var() / np.array(test_rewards_to_go).var()
        explained_variances.append(explained_variance)

        returns.append(sum(rrewards))

    plotter.update(epoch,
                   (np.mean(returns), min(returns), max(returns)),
                   np.mean(explained_variance),
                   policy_loss.item(),
                   (q_loss_1.item(), q_loss_2.item()))

In [None]:
obs = env.reset()
try:
    while True:
        env.render()
        with torch.no_grad():
            action, _ = policy_net.act(torch.as_tensor(obs, dtype=torch.float32), noise=False)
        obs, _, done, _ = env.step(action)

        if done:
            obs = env.reset()
except (Exception, KeyboardInterrupt):
    env.close()