In [2]:
import sys
!{sys.executable} -m pip install tqdm

Collecting tqdm
  Downloading https://files.pythonhosted.org/packages/91/55/8cb23a97301b177e9c8e3226dba45bb454411de2cbd25746763267f226c2/tqdm-4.28.1-py2.py3-none-any.whl (45kB)
[K    100% |████████████████████████████████| 51kB 1.3MB/s ta 0:00:011
[?25hInstalling collected packages: tqdm
Successfully installed tqdm-4.28.1


In [6]:
import gym
import imageio
from itertools import chain
import math
from threading import Thread
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import numpy as np
from queue import Queue
import warnings
warnings.filterwarnings('ignore')



class RLEnvironment(object):
    """An RL Environment, used for wrapping environments to run PPO on."""

    def __init__(self):
        super(RLEnvironment, self).__init__()

    def step(self, x):
        """Takes an action x, which is the same format as the output from a policy network.
        Returns observation (np.ndarray), reward (float), terminal (boolean)
        """
        raise NotImplementedError()

    def reset(self):
        """Resets the environment.
        Returns observation (np.ndarray)
        """
        raise NotImplementedError()


class EnvironmentFactory(object):
    """Creates new environment objects"""

    def __init__(self):
        super(EnvironmentFactory, self).__init__()

    def new(self):
        raise NotImplementedError()


class ExperienceDataset(Dataset):
    def __init__(self, experience):
        super(ExperienceDataset, self).__init__()
        self._exp = []
        for x in experience:
            self._exp.extend(x)
        self._length = len(self._exp)

    def __getitem__(self, index):
        return self._exp[index]

    def __len__(self):
        return self._length


def multinomial_likelihood(dist, idx):
    return dist[range(dist.shape[0]), idx.long()[:, 0]].unsqueeze(1)


def get_log_p(data, mu, sigma):
    """get negative log likelihood from normal distribution"""
    return -torch.log(torch.sqrt(2 * math.pi * sigma ** 2)) - (data - mu) ** 2 / (2 * sigma ** 2)


def ppo(env_factory, policy, value, likelihood_fn, embedding_net=None, epochs=1000, rollouts_per_epoch=100,
        max_episode_length=200, gamma=0.99, policy_epochs=5, batch_size=256, epsilon=0.2, environment_threads=1,
        data_loader_threads=1, device=torch.device('cpu'), lr=1e-3, betas=(0.9, 0.999), weight_decay=0.01, gif_name='',
        gif_epochs=0, csv_file='latest_run.csv'):
    # Clear the csv file
    with open(csv_file, 'w') as f:
        f.write('avg_reward, value_loss, policy_loss')

    # Move networks to the correct device
    policy = policy.to(device)
    value = value.to(device)

    # Collect parameters
    params = chain(policy.parameters(), value.parameters())
    if embedding_net:
        embedding_net = embedding_net.to(device)
        params = chain(params, embedding_net.parameters())

    # Set up optimization
    optimizer = optim.Adam(params, lr=lr, betas=betas, weight_decay=weight_decay)
    value_criteria = nn.MSELoss()

    # Calculate the upper and lower bound for PPO
    ppo_lower_bound = 1 - epsilon
    ppo_upper_bound = 1 + epsilon

    loop = tqdm(total=epochs, position=0, leave=False)

    # Prepare the environments
    environments = [env_factory.new() for _ in range(environment_threads)]
    rollouts_per_thread = rollouts_per_epoch // environment_threads
    remainder = rollouts_per_epoch % environment_threads
    rollout_nums = ([rollouts_per_thread + 1] * remainder) + ([rollouts_per_thread] * (environment_threads - remainder))

    for e in range(epochs):
        # Run the environments
        experience_queue = Queue()
        reward_queue = Queue()
        threads = [Thread(target=_run_envs, args=(environments[i],
                                                  embedding_net,
                                                  policy,
                                                  experience_queue,
                                                  reward_queue,
                                                  rollout_nums[i],
                                                  max_episode_length,
                                                  gamma,
                                                  device)) for i in range(environment_threads)]
        for x in threads:
            x.start()
        for x in threads:
            x.join()

        # Collect the experience
        rollouts = list(experience_queue.queue)
        avg_r = sum(reward_queue.queue) / reward_queue.qsize()
        loop.set_description('avg reward: % 6.2f' % (avg_r))

        # Make gifs
        if gif_epochs and e % gif_epochs == 0:
            _make_gif(rollouts[0], gif_name + '%d.gif' % e)

        # Update the policy
        experience_dataset = ExperienceDataset(rollouts)
        data_loader = DataLoader(experience_dataset, num_workers=data_loader_threads, batch_size=batch_size,
                                 shuffle=True,
                                 pin_memory=False)
        avg_policy_loss = 0
        avg_val_loss = 0
        for _ in range(policy_epochs):
            avg_policy_loss = 0
            avg_val_loss = 0
            for state, old_action_dist, old_action, reward, ret in data_loader:
                state = _prepare_tensor_batch(state, device)
                old_action_dist = _prepare_tensor_batch(old_action_dist, device)
                old_action = _prepare_tensor_batch(old_action, device)
                ret = _prepare_tensor_batch(ret, device).unsqueeze(1)
                
                optimizer.zero_grad()

                # If there is an embedding net, carry out the embedding
                if embedding_net:
                    state = embedding_net(state)

                # Calculate the ratio term
                current_action_dist = policy(state, False)
                current_likelihood = likelihood_fn(current_action_dist, old_action)
                old_likelihood = likelihood_fn(old_action_dist, old_action)
                ratio = (current_likelihood / old_likelihood)

                # Calculate the value loss
                expected_returns = value(state)
                val_loss = value_criteria(expected_returns, ret)

                # Calculate the policy loss
                advantage = ret - expected_returns.detach()
                lhs = ratio * advantage
                rhs = torch.clamp(ratio, ppo_lower_bound, ppo_upper_bound) * advantage
                policy_loss = -torch.mean(torch.min(lhs, rhs))

                # For logging
                avg_val_loss += val_loss.item()
                avg_policy_loss += policy_loss.item()

                # Backpropagate
                loss = policy_loss + val_loss
                loss.backward()
                optimizer.step()

            # Log info
            avg_val_loss /= len(data_loader)
            avg_policy_loss /= len(data_loader)
            loop.set_description(
                'avg reward: % 6.2f, value loss: % 6.2f, policy loss: % 6.2f' % (avg_r, avg_val_loss, avg_policy_loss))
        with open(csv_file, 'a+') as f:
            f.write('%6.2f, %6.2f, %6.2f\n' % (avg_r, avg_val_loss, avg_policy_loss))
        print()
        loop.update(1)


def _calculate_returns(trajectory, gamma):
    current_return = 0
    for i in reversed(range(len(trajectory))):
        state, action_dist, action, reward = trajectory[i]
        ret = reward + gamma * current_return
        trajectory[i] = (state, action_dist, action, reward, ret)
        current_return = ret


def _run_envs(env, embedding_net, policy, experience_queue, reward_queue, num_rollouts, max_episode_length,
              gamma, device):
    for _ in range(num_rollouts):
        current_rollout = []
        s = env.reset()
        episode_reward = 0
        for _ in range(max_episode_length):
            input_state = _prepare_numpy(s, device)
            if embedding_net:
                input_state = embedding_net(input_state)

            action_dist, action = policy(input_state)
            action_dist, action = action_dist[0], action[0]  # Remove the batch dimension
            s_prime, r, t = env.step(action)

            if type(r) != float:
                print('run envs:', r, type(r))

            current_rollout.append((s, action_dist.cpu().detach().numpy(), action, r))
            episode_reward += r
            if t:
                break
            s = s_prime
        _calculate_returns(current_rollout, gamma)
        experience_queue.put(current_rollout)
        reward_queue.put(episode_reward)


def _prepare_numpy(ndarray, device):
    return torch.from_numpy(ndarray).float().unsqueeze(0).to(device)


def _prepare_tensor_batch(tensor, device):
    return tensor.detach().float().to(device)


def _make_gif(rollout, filename):
    with imageio.get_writer(filename, mode='I', duration=1 / 30) as writer:
        for x in rollout:
            writer.append_data((x[0][:, :, 0] * 255).astype(np.uint8))

class CartPoleEnvironmentFactory(EnvironmentFactory):
    def __init__(self):
        super(CartPoleEnvironmentFactory, self).__init__()

    def new(self):
        return CartPoleEnvironment()


class CartPoleEnvironment(RLEnvironment):
    def __init__(self):
        super(CartPoleEnvironment, self).__init__()
        self._env = gym.make('CartPole-v0')

    def step(self, action):
        """action is type np.ndarray of shape [1] and type np.uint8.
        Returns observation (np.ndarray), r (float), t (boolean)
        """
        s, r, t, _ = self._env.step(action.item())
        return s, r, t

    def reset(self):
        """Returns observation (np.ndarray)"""
        return self._env.reset()


class CartPolePolicyNetwork(nn.Module):
    """Policy Network for CartPole."""

    def __init__(self, state_dim=4, action_dim=2):
        super(CartPolePolicyNetwork, self).__init__()
        self._net = nn.Sequential(
            nn.Linear(state_dim, 10),
            nn.ReLU(),
            nn.Linear(10, 10),
            nn.ReLU(),
            nn.Linear(10, 10),
            nn.ReLU(),
            nn.Linear(10, action_dim)
        )
        self._softmax = nn.Softmax(dim=1)

    def forward(self, x, get_action=True):
        """Receives input x of shape [batch, state_dim].
        Outputs action distribution (categorical distribution) of shape [batch, action_dim],
        as well as a sampled action (optional).
        """
        scores = self._net(x)
        probs = self._softmax(scores)

        if not get_action:
            return probs

        batch_size = x.shape[0]
        actions = np.empty((batch_size, 1), dtype=np.uint8)
        probs_np = probs.cpu().detach().numpy()
        for i in range(batch_size):
            action_one_hot = np.random.multinomial(1, probs_np[i])
            action_idx = np.argmax(action_one_hot)
            actions[i, 0] = action_idx
        return probs, actions


class CartPoleValueNetwork(nn.Module):
    """Approximates the value of a particular CartPole state."""

    def __init__(self, state_dim=4):
        super(CartPoleValueNetwork, self).__init__()
        self._net = nn.Sequential(
            nn.Linear(state_dim, 10),
            nn.ReLU(),
            nn.Linear(10, 10),
            nn.ReLU(),
            nn.Linear(10, 10),
            nn.ReLU(),
            nn.Linear(10, 1)
        )

    def forward(self, x):
        """Receives an observation of shape [batch, state_dim].
        Returns the value of each state, in shape [batch, 1]
        """
        return self._net(x)


def main():
    factory = CartPoleEnvironmentFactory()
    policy = CartPolePolicyNetwork()
    value = CartPoleValueNetwork()
    ppo(factory, policy, value, multinomial_likelihood, epochs=1000, rollouts_per_epoch=100, max_episode_length=200,
        gamma=0.99, policy_epochs=5, batch_size=256)


if __name__ == '__main__':
    main()

avg reward:  23.75, value loss:  308.64, policy loss: -14.39:   0%|          | 1/1000 [00:01<18:07,  1.09s/it]




avg reward:  23.57, value loss:  308.00, policy loss: -14.15:   0%|          | 2/1000 [00:02<17:42,  1.06s/it]




avg reward:  25.64, value loss:  390.12, policy loss: -15.49:   0%|          | 3/1000 [00:03<17:34,  1.06s/it]




avg reward:  26.98, value loss:  299.14, policy loss: -12.50:   0%|          | 4/1000 [00:04<17:38,  1.06s/it]




avg reward:  34.87, value loss:  320.50, policy loss:  -6.95:   0%|          | 5/1000 [00:05<19:14,  1.16s/it]




avg reward:  39.24, value loss:  332.84, policy loss:  -0.83:   1%|          | 6/1000 [00:07<21:11,  1.28s/it]




avg reward:  64.91, value loss:  419.26, policy loss:  -0.54:   1%|          | 7/1000 [00:09<26:41,  1.61s/it]




avg reward:  103.49, value loss:  393.65, policy loss:   0.27:   1%|          | 8/1000 [00:13<37:20,  2.26s/it]




avg reward:  146.47, value loss:  321.88, policy loss:  -0.09:   1%|          | 9/1000 [00:18<52:36,  3.19s/it]




avg reward:  177.74, value loss:  292.21, policy loss:   0.05:   1%|          | 10/1000 [00:25<1:08:51,  4.17s/it]




avg reward:  184.26, value loss:  335.05, policy loss:  -0.09:   1%|          | 11/1000 [00:31<1:20:48,  4.90s/it]




avg reward:  196.17, value loss:  315.48, policy loss:  -0.03:   1%|          | 12/1000 [00:38<1:31:00,  5.53s/it]




avg reward:  188.19, value loss:  297.25, policy loss:  -0.05:   1%|▏         | 13/1000 [00:45<1:37:23,  5.92s/it]




avg reward:  197.51, value loss:  271.16, policy loss:   0.01:   1%|▏         | 14/1000 [00:52<1:43:17,  6.29s/it]




avg reward:  183.51, value loss:  282.42, policy loss:   0.03:   2%|▏         | 15/1000 [00:59<1:44:54,  6.39s/it]




avg reward:  197.20, value loss:  253.54, policy loss:   0.17:   2%|▏         | 16/1000 [01:06<1:48:00,  6.59s/it]




avg reward:  198.87, value loss:  280.57, policy loss:   0.02:   2%|▏         | 17/1000 [01:13<1:51:45,  6.82s/it]




avg reward:  198.20, value loss:  258.06, policy loss:   0.07:   2%|▏         | 18/1000 [01:20<1:53:21,  6.93s/it]




avg reward:  195.69, value loss:  246.79, policy loss:   0.02:   2%|▏         | 19/1000 [01:27<1:53:32,  6.94s/it]




avg reward:  194.93, value loss:  202.84, policy loss:   0.10:   2%|▏         | 20/1000 [01:34<1:53:53,  6.97s/it]




avg reward:  194.60, value loss:  235.41, policy loss:  -0.06:   2%|▏         | 21/1000 [01:41<1:53:48,  6.98s/it]




avg reward:  200.00, value loss:  281.97, policy loss:  -0.07:   2%|▏         | 22/1000 [01:49<1:54:20,  7.02s/it]




avg reward:  199.19, value loss:  279.58, policy loss:   0.10:   2%|▏         | 23/1000 [01:56<1:54:34,  7.04s/it]




avg reward:  198.66, value loss:  221.64, policy loss:  -0.06:   2%|▏         | 24/1000 [02:03<1:55:06,  7.08s/it]




avg reward:  199.63, value loss:  283.45, policy loss:   0.02:   2%|▎         | 25/1000 [02:10<1:56:08,  7.15s/it]




avg reward:  197.87, value loss:  256.76, policy loss:  -0.08:   3%|▎         | 26/1000 [02:17<1:56:35,  7.18s/it]




avg reward:  192.35, value loss:  154.14, policy loss:  -0.02:   3%|▎         | 27/1000 [02:24<1:55:39,  7.13s/it]




avg reward:  198.02, value loss:  219.53, policy loss:   0.01:   3%|▎         | 28/1000 [02:32<1:55:42,  7.14s/it]




avg reward:  195.97, value loss:  176.44, policy loss:   0.02:   3%|▎         | 29/1000 [02:39<1:55:22,  7.13s/it]




avg reward:  197.34, value loss:  228.24, policy loss:  -0.04:   3%|▎         | 30/1000 [02:46<1:55:40,  7.16s/it]




avg reward:  197.61, value loss:  197.96, policy loss:   0.05:   3%|▎         | 31/1000 [02:53<1:55:29,  7.15s/it]




avg reward:  195.29, value loss:  290.91, policy loss:   0.06:   3%|▎         | 32/1000 [03:00<1:54:30,  7.10s/it]




avg reward:  189.65, value loss:  293.21, policy loss:  -0.09:   3%|▎         | 33/1000 [03:07<1:53:58,  7.07s/it]




avg reward:  196.38, value loss:  249.33, policy loss:   0.02:   3%|▎         | 34/1000 [03:14<1:54:20,  7.10s/it]




avg reward:  198.26, value loss:  219.99, policy loss:   0.02:   4%|▎         | 35/1000 [03:22<1:55:55,  7.21s/it]




avg reward:  198.55, value loss:  228.89, policy loss:  -0.01:   4%|▎         | 36/1000 [03:29<1:56:33,  7.25s/it]




avg reward:  199.01, value loss:  241.78, policy loss:   0.01:   4%|▎         | 37/1000 [03:39<2:07:38,  7.95s/it]




avg reward:  199.74, value loss:  261.69, policy loss:  -0.18:   4%|▍         | 38/1000 [03:48<2:14:58,  8.42s/it]




avg reward:  199.81, value loss:  278.59, policy loss:   0.07:   4%|▍         | 39/1000 [03:58<2:21:13,  8.82s/it]




avg reward:  199.49, value loss:  261.97, policy loss:  -0.00:   4%|▍         | 40/1000 [04:07<2:24:39,  9.04s/it]




avg reward:  196.38, value loss:  201.16, policy loss:  -0.07:   4%|▍         | 41/1000 [04:17<2:27:01,  9.20s/it]




avg reward:  199.20, value loss:  283.37, policy loss:  -0.03:   4%|▍         | 42/1000 [04:27<2:29:28,  9.36s/it]




avg reward:  200.00, value loss:  268.01, policy loss:   0.14:   4%|▍         | 43/1000 [04:36<2:30:58,  9.47s/it]




avg reward:  197.73, value loss:  291.46, policy loss:   0.09:   4%|▍         | 43/1000 [04:45<2:30:58,  9.47s/it]Process Process-261:
KeyboardInterrupt
Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()


KeyboardInterrupt: 

  File "/usr/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 52, in _worker_loop
    r = index_queue.get()
  File "/usr/lib/python3.6/multiprocessing/queues.py", line 335, in get
    res = self._reader.recv_bytes()
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 407, in _recv_bytes
    buf = self._recv(4)
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 379, in _recv
    chunk = read(handle, remaining)


In [4]:
!nvidia-smi

Sat Nov 24 13:13:48 2018       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 390.48                 Driver Version: 390.48                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  TITAN X (Pascal)    Off  | 00000000:01:00.0  On |                  N/A |
| 23%   38C    P8    17W / 250W |  12146MiB / 12192MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  TITAN X (Pascal)    Off  | 00000000:02:00.0 Off |                  N/A |
| 23%   35C    P8    16W / 250W |   6614MiB / 12196MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                            

In [3]:
torch.cuda.empty_cache()