# Trial with Graph Convolutional Network

In [1]:
import os
os.environ["OMP_NUM_THREADS"] = "1"

## 方針

### 教師あり学習

まずは教師あり学習により強いポリシーを獲得。

### 自己対戦による訓練

* Q関数が相手によって変わってしまう。
    * On-policyで自己対戦にすればOK。Off-policyにしたい場合は、Q関数が大きく違わない過去エピソードとすべき。
        * Rainbow
            * PFRLに実装あり。まずはこれ？
        * MuZero
            * 実装少し大変かも。だが自己対戦による実績あるため性能は出るかも？
    * Policyを持つアルゴリズムならOK。相手ごとにQ関数を学習すれば良い。この場合も自己対戦かつOn-policy。
        * AC3、PPOなど
            * PFRLに実装あり。これもトライ？
            
### 方策／Q関数モデル

* Graph Neural Network
    * あたかもそれぞれの選手が行動判断／価値判断しているようなモデルにする。選択されたアクションは、その時Activeな選手の最善手とする。
    * 初めは、特徴量は、絶対位置座標で、完全グラフを用いる。

## Actions

### Default action set

The default action set consists of 19 actions:

*   Idle actions

    *   `action_idle` = 0, a no-op action, stickly actions are not affected (player maintains his directional movement etc.).

*   Movement actions

    *   `action_left` = 1, run to the left, sticky action.
    *   `action_top_left` = 2, run to the top-left, sticky action.
    *   `action_top` = 3, run to the top, sticky action.
    *   `action_top_right` = 4, run to the top-right, sticky action.
    *   `action_right` = 5, run to the right, sticky action.
    *   `action_bottom_right` = 6, run to the bottom-right, sticky action.
    *   `action_bottom` = 7, run to the bottom, sticky action.
    *   `action_bottom_left` = 8, run to the bottom-left, sticky action.

*   Passing / Shooting

    *   `action_long_pass` = 9, perform a long pass to the player on your team. Player to pass the ball to is auto-determined based on the movement direction.
    *   `action_high_pass` = 10, perform a high pass, similar to `action_long_pass`.
    *   `action_short_pass` = 11, perform a short pass, similar to `action_long_pass`.
    *   `action_shot` = 12, perform a shot, always in the direction of the opponent's goal.

*   Other actions

    *   `action_sprint` = 13, start sprinting, sticky action. Player moves faster, but has worse ball handling.
    *   `action_release_direction` = 14, reset current movement direction.
    *   `action_release_sprint` = 15, stop sprinting.
    *   `action_sliding` = 16, perform a slide (effective when not having a ball).
    *   `action_dribble` = 17, start dribbling (effective when having a ball), sticky action. Player moves slower, but it is harder to take over the ball from him.
    *   `action_release_dribble` = 18, stop dribbling.

### V2 action set

It is an extension of the default action set:

*   `action_builtin_ai` = 19, let game's built-in AI generate an action

In [2]:
# Install:
# Kaggle environments.
#!git clone https://github.com/Kaggle/kaggle-environments.git
#!cd kaggle-environments && pip install .

# GFootball environment.
#!apt-get update -y
#!apt-get install -y libsdl2-gfx-dev libsdl2-ttf-dev

# Make sure that the Branch in git clone and in wget call matches !!
#!git clone -b v2.3 https://github.com/google-research/football.git
#!mkdir -p football/third_party/gfootball_engine/lib

#!wget https://storage.googleapis.com/gfootball/prebuilt_gameplayfootball_v2.3.so -O football/third_party/gfootball_engine/lib/prebuilt_gameplayfootball.so
#!cd football && GFOOTBALL_USE_PREBUILT_SO=1 pip3 install .

## Install

In [3]:
# ------------------ install torch_geometric begin -----------------
try:
    import torch_geometric
except:
    import subprocess
    import torch

    nvcc_stdout = str(subprocess.check_output(['nvcc', '-V']))
    tmp = nvcc_stdout[nvcc_stdout.rfind('release') + len('release') + 1:]
    cuda_version = tmp[:tmp.find(',')]
    cuda = {
            '9.2': 'cu92',
            '10.1': 'cu101',
            '10.2': 'cu102',
            }

    CUDA = cuda[cuda_version]
    TORCH = torch.__version__.split('.')
    TORCH[-1] = '0'
    TORCH = '.'.join(TORCH)

    install1 = 'pip install torch-scatter==latest+' + CUDA + ' -f https://pytorch-geometric.com/whl/torch-' + TORCH + '.html'
    install2 = 'pip install torch-sparse==latest+' + CUDA + ' -f https://pytorch-geometric.com/whl/torch-' + TORCH + '.html'
    install3 = 'pip install torch-cluster==latest+' + CUDA + ' -f https://pytorch-geometric.com/whl/torch-' + TORCH + '.html'
    install4 = 'pip install torch-spline-conv==latest+' + CUDA + ' -f https://pytorch-geometric.com/whl/torch-' + TORCH + '.html'
    install5 = 'pip install torch-geometric'

    subprocess.run(install1.split())
    subprocess.run(install2.split())
    subprocess.run(install3.split())
    subprocess.run(install4.split())
    subprocess.run(install5.split())
# ------------------ install torch_geometric end -----------------

In [4]:
import os
import cv2
import sys
import glob 
import random
import json
import pickle
import copy
import imageio
import pathlib
import collections
from collections import deque, namedtuple
import numpy as np
import pandas as pd
import argparse
from IPython.display import clear_output
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from joblib import Parallel, delayed
sns.set()
%matplotlib inline

from gym import spaces
from tqdm import tqdm
from logging import getLogger, StreamHandler, FileHandler, DEBUG, INFO
from typing import Union, Callable, List, Tuple, Iterable, Any, Dict
from dataclasses import dataclass
from IPython.display import Image, display
sns.set()

# PyTorch
import torch
from torch import nn
import torch.multiprocessing as mp
import torch.nn.functional as F
import torch.distributions as D

# PyTorch geometric
from torch_geometric.data import Data, DataLoader, Batch
from torch_geometric.nn import RGCNConv
from torch_geometric.data import InMemoryDataset

from torch_scatter import scatter_max, scatter_sum, scatter_mean

# Env
import gym
import gfootball
import gfootball.env as football_env
from gfootball.env import observation_preprocessing
from gfootball.env.wrappers import Simple115StateWrapper

## Config

In [5]:
# Check we can use GPU
print(torch.cuda.is_available())

# set gpu id
if torch.cuda.is_available(): 
    # NOTE: it is not number of gpu but id which start from 0
    gpu = 0
else:
    # cpu=>-1
    gpu = -1

True


In [6]:
# set logger
def logger_config():
    logger = getLogger(__name__)
    handler = StreamHandler()
    handler.setLevel("DEBUG")
    logger.setLevel("DEBUG")
    logger.addHandler(handler)
    logger.propagate = False

    filepath = './result.log'
    file_handler = FileHandler(filepath)
    logger.addHandler(file_handler)
    return logger

logger = logger_config()

In [7]:
# fixed random seed
# but this is NOT enough to fix the result of rewards.Please tell me the reason.
def seed_everything(seed=1234):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
# Set a random seed used in PFRL.
seed = 5046
seed_everything(seed)

# Set different random seeds for train and test envs.
train_seed = seed
test_seed = 2 ** 31 - 1 - seed

## Environment

In [8]:
env = football_env.create_environment(
    env_name='11_vs_11_kaggle',  # easy mode
    stacked=False,
    representation='simple115v2',           # SMM
    rewards='scoring, checkpoints',
    write_goal_dumps=False,
    write_full_episode_dumps=False,
    render=False,
    write_video=False,
    dump_frequency=1,
    logdir='./',
    extra_players=None,
    number_of_left_players_agent_controls=1,
    number_of_right_players_agent_controls=1,
)

In [9]:
array = env.reset()

In [10]:
obs, _, _, info = env.step([1,1])

## GCN-Policy

The flag to distinguish left players, right players, and the ball.

* left players: 0
* right players: 1
* ball: 2

In [11]:
relations_dict = {
    (0., 0.): 0, # left player -> left player
    (0., 1.): 1, # left player -> right player
    (0., 2.): 2, # left player -> ball
    (1., 0.): 3, # right player -> left player
    (1., 1.): 4, # right player -> right player
    (1., 2.): 5, # right player -> ball
    (2., 0.): 6, # ball -> left player
    (2., 1.): 7, # ball -> right player
}

In [12]:
def create_graph_from_observation(state, action=None, reverse=False, reverse_y=False):
    array = state
    
    left_coordinations = np.concatenate([array[:22].reshape(11, 2), np.zeros((11, 1))], axis=-1)
    left_directions = np.concatenate([array[22:44].reshape(11, 2), np.zeros((11, 1))], axis=-1)
    right_coordinations = np.concatenate([array[44:66].reshape(11, 2), np.zeros((11, 1))], axis=-1)
    right_directions = np.concatenate([array[66:88].reshape(11, 2), np.zeros((11, 1))], axis=-1)
    ball_coordination = array[88:91].reshape([1, 3])
    ball_direction = array[91:94].reshape([1, 3])
    ball_ownership = array[94:97] # none, left, right
    active_player = array[97:108].reshape([11, 1])
    game_mode = array[108:]
    
    if reverse:
        left_coordinations[:, 0] *= -1.0
        left_directions[:, 0] *= -1.0
        right_coordinations[:, 0] *= -1.0
        right_directions[:, 0] *= -1.0
        ball_coordination[:, 0] *= -1.0
        ball_direction[:, 0] *= -1.0
        ball_ownership = ball_ownership[[0, 2, 1]]
        active_player = active_player
        game_mode = game_mode

    if reverse_y:
        left_coordinations[:, 1] *= -1.0
        left_directions[:, 1] *= -1.0
        right_coordinations[:, 1] *= -1.0
        right_directions[:, 1] *= -1.0
        ball_coordination[:, 1] *= -1.0
        ball_direction[:, 1] *= -1.0
        ball_ownership = ball_ownership
        active_player = active_player
        game_mode = game_mode

    # Node features
    left_features = np.concatenate([
        0*np.ones((11, 1)),
        left_coordinations,
        left_directions,
        ball_ownership[1]*np.ones((11, 1)),
        active_player,
    ], axis=-1)
    right_features = np.concatenate([
        1*np.ones((11, 1)),
        right_coordinations,
        right_directions,
        ball_ownership[2]*np.ones((11, 1)),
        np.zeros((11, 1)),
    ], axis=-1)
    ball_features = np.concatenate([
        2*np.ones((1, 1)),
        ball_coordination,
        ball_direction,
        np.zeros((1, 1)),
        np.zeros((1, 1)),
    ], axis=-1)

    features = np.concatenate([left_features, right_features, ball_features], axis=0)

    # Edges and relations
    X, Y = np.meshgrid(np.arange(len(features)), np.arange(len(features)))
    all_combinations = np.vstack([X.flatten(), Y.flatten()]).T
    edge_index = np.array(
        [combination for combination in all_combinations if not combination[0] == combination[1]]
    ).T
    types_for_edge_index = features[edge_index][:,:,0]
    relations = [relations_dict[tuple(types)] for types in types_for_edge_index.T]

    # numpy array to torch tensor
    features = torch.tensor(features, dtype=torch.float).contiguous()
    edge_index = torch.tensor(edge_index, dtype=torch.long).contiguous()
    relations = torch.tensor(relations, dtype=torch.long).contiguous()

    if action is None:
        graph = Data(x=features, edge_index=edge_index, edge_type=relations)

    else:
        graph = Data(x=features, edge_index=edge_index, edge_type=relations, y=torch.tensor(action, dtype=torch.long))

    return graph

In [13]:
class Policy(torch.nn.Module):
    def __init__(self, num_features, num_relations):
        super().__init__()
        self.conv1 = RGCNConv(num_features - 2, 256, num_relations=num_relations)
        #self.conv2 = RGCNConv(128, 256, num_relations=num_relations)
        self.fc1 = nn.Linear(256, 256)
        self.fc2 = nn.Linear(256, 19)

    def forward(self, data):
        x, edge_index, edge_type = data.x, data.edge_index, data.edge_type
        flag = x[:,0]
        is_active = x[:,-1]
        x = x[:,1:-1]

        x = self.conv1(x, edge_index, edge_type)
        x = F.relu(x)

        #x = self.conv2(x, edge_index, edge_type)
        #x = F.relu(x)
        
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        
        return D.Categorical(logits=x[is_active.bool()])

In [14]:
class Value(torch.nn.Module):
    def __init__(self, num_features, num_relations):
        super().__init__()
        self.conv1 = RGCNConv(num_features - 2, 256, num_relations=num_relations)
        #self.conv2 = RGCNConv(128, 256, num_relations=num_relations)
        self.fc1 = nn.Linear(256, 256)
        self.fc2 = nn.Linear(256, 1)

    def forward(self, data):
        x, edge_index, edge_type = data.x, data.edge_index, data.edge_type
        flag = x[:,0]
        is_active = x[:,-1]
        x = x[:,1:-1]

        x = self.conv1(x, edge_index, edge_type)
        x = F.relu(x)

        #x = self.conv2(x, edge_index, edge_type)
        #x = F.relu(x)
        
        x = self.fc1(x)
        x = F.relu(x)
        
        if hasattr(data, 'batch'):
            x, argmax = scatter_max(x, data.batch, dim=0)
        else:
            x = torch.max(x, dim=0).values

        x = self.fc2(x)

        return x

## Load initial policy (Supervised)

In [15]:
policy = Policy(num_features=9, num_relations=len(relations_dict))

In [16]:
policy_path = 'policy.pt'
policy.load_state_dict(torch.load(policy_path))

<All keys matched successfully>

## Reinforcement Learning classes and functions

In [17]:
class Episodes:
    def __init__(
        self,
        **kwargs,
    ):
        self.episodes = {
            **kwargs,
        }
        for key in ('state', 'action', 'next_state', 'reward', 'done'):
            if key not in kwargs.keys():
                self.episodes[key] = []
        for key, value in self.episodes.items():
            if len(value) != len(self.episodes['state']):
                raise Exception('The length of {} is invalid.'.format(key))
        
    def append(self, **kwargs):
        for key, value in kwargs.items():
            self.episodes[key].append(value)
        
    def add_new_key(self, key, value):
        if len(value) != len(self):
            raise Exception('The length of {} is invalid.'.format(key))
        elif key in self.episodes.keys():
            raise Exception('The key {} is already defined.'.format(key))
        self.episodes[key] = value
        
    def update(self, key, value):
        if len(value) != len(self):
            raise Exception('The length of {} is invalid.'.format(key))
        self.episodes[key] = value
        
    def __len__(self):
        return len(self.episodes['state'])
    
    @property
    def num_episodes(self):
        return np.sum(self.episodes['done'])
    
    @property
    def total_rewards(self):
        rewards = []
        total_reward = 0
        for reward, done in zip(self.episodes['reward'], self.episodes['done']):
            total_reward += reward
            if done:
                rewards.append(total_reward)
                total_reward = 0
        return rewards
    
    def __repr__(self):
        return 'Episodes\n* number of episodes: {}\n* number of steps: {}\n* total rewards: {}\n* keys: {}'.format(
            np.sum(self.episodes['done']),
            len(self),
            self.total_rewards,
            ', '.join(self.episodes.keys()),
        )
    
    def __getitem__(self, key):
        return self.episodes[key]

    def sample(self, size, sequential=True):
        if sequential:
            id_from = random.randint(0, len(self) - size)
            id_to = id_from + size
            sampled_dict = {
                key: value[id_from:id_to]
                for key, value
                in self.episodes.items()
            }
            sampled_episode = Episodes(
                **sampled_dict
            )
            return sampled_episode
        else:
            raise NotImplementedError

#### Episodic functions

In [18]:
def compute_returns(episodes, gamma, infinite_horizon=True, update=False):
    rewards = episodes['reward']
    states = episodes['state']
    dones = episodes['done']
    values = episodes['value']

    returns = torch.empty(len(rewards), dtype=torch.float32)
    with torch.no_grad():
        last_reward = infinite_horizon * float(values[-1])
        for t in reversed(range(len(rewards))):
            returns[t] = last_reward = rewards[t] + (1 - dones[t]) * gamma * last_reward
            if infinite_horizon and dones[t]:
                reward_to_go = float(values[t])
                returns[t] = reward_to_go
                last_reward = reward_to_go
    if update:
        episodes.update('return', returns)
    else:
        episodes.add_new_key('return', returns)

In [19]:
def compute_values(episodes, value_net, device='cpu', update=False):
    batch = Batch.from_data_list(episodes.episodes['state'])
    
    batch.to(device)
    value_net.to(device)
    
    values = value_net(batch).flatten()
    if update:
        episodes.update('value', values)
    else:
        episodes.add_new_key('value', values)

In [20]:
def compute_advantages(episodes, gamma, lambda_=1, device='cpu', update=False):
    """Generalized Advantage Estimation"""
    rewards = episodes['reward']
    dones = episodes['done']
    zero = torch.tensor([0]).to(device)
    values = torch.cat((episodes['value'], zero), dim=0)
    advantages = torch.empty(len(rewards), dtype=torch.float32).to(device)
    last_advantage = 0
    for t in reversed(range(len(rewards))):
        delta = rewards[t] + (1 - dones[t]) * gamma * values[t + 1] - values[t]
        advantages[t] = last_advantage = delta + gamma * lambda_ * last_advantage
    if update:
        episodes.update('advantage', advantages)
    else:
        episodes.add_new_key('advantage', advantages)

In [21]:
def concat_episodes(episodes_list):
    keys = episodes_list[0].episodes.keys()
    episodes_dict = {key: [] for key in keys}
    for episodes in episodes_list:
        for key in keys:
            episodes_dict[key] += episodes.episodes[key]
    return Episodes(**episodes_dict)

In [22]:
def get_value_loss(episodes, device='cpu'):
    return F.mse_loss(episodes['value'].to(device), episodes['return'].detach().to(device))

In [23]:
def get_ppo_losses(
    episodes,
    policy,
    policy_old,
    clip=False,
    clip_param=0.2,
    compute_kl=True,
    compute_entropy=True,
    device='cpu',
):
    states = episodes['state']
    actions = torch.tensor(episodes['action']).to(device)
    advantages = episodes['advantage'].detach()
    batch_states = Batch.from_data_list(states)
    
    #actions.to(device)
    batch_states.to(device)
    policy.to(device)
    policy_old.to(device)
    
    distribs = policy(batch_states)
    log_probs = distribs.log_prob(actions)
    with torch.no_grad():
        distribs_old = policy_old(batch_states)
        log_probs_old = distribs_old.log_prob(actions)
    prob_ratios = torch.exp(log_probs - log_probs_old)
    
    if clip:
        clipped_prob_ratios = torch.clamp(prob_ratios, 1.0 - clip_param, 1.0 + clip_param) * advantages
        policy_loss = -torch.mean(torch.min(prob_ratios, clipped_prob_ratios))
    else:
        policy_loss = -torch.mean(prob_ratios * advantages)
    
    losses = {'policy_loss': policy_loss}
    
    if compute_kl:
        kl = D.kl_divergence(distribs_old, distribs).mean()
        losses['kl'] = kl
    
    if compute_entropy:
        entropy = distribs.entropy().mean()
        losses['entropy'] = entropy
    
    return losses

In [None]:
REVERSE_ACTION = {i: i for i in range(20)}
REVERSE_ACTION[1] = 5
REVERSE_ACTION[5] = 1
REVERSE_ACTION[2] = 4
REVERSE_ACTION[4] = 2
REVERSE_ACTION[6] = 8
REVERSE_ACTION[8] = 6
def reverse_action(action):
    return REVERSE_ACTION[action]

In [None]:
env = football_env.create_environment(
    env_name='11_vs_11_kaggle',
    stacked=False,
    representation='simple115v2',
    rewards='scoring, checkpoints',
    write_goal_dumps=False,
    write_full_episode_dumps=False,
    render=False,
    write_video=False,
    dump_frequency=1,
    logdir='./',
    extra_players=None,
    number_of_left_players_agent_controls=1,
    number_of_right_players_agent_controls=1
)

In [None]:
env_eval = football_env.create_environment(
    env_name='11_vs_11_hard_stochastic',  # hard mode
    stacked=False,
    representation='simple115v2',
    rewards='scoring, checkpoints',
    write_goal_dumps=False,
    write_full_episode_dumps=False,
    render=False,
    write_video=False,
    dump_frequency=1,
    logdir='./',
    extra_players=None,
    number_of_left_players_agent_controls=1,
    number_of_right_players_agent_controls=0
)

In [None]:
def get_episodes(policy):
    
    episode_l = Episodes()
    episode_r = Episodes()

    obs_l, obs_r = env.reset()
    next_state_l = create_graph_from_observation(obs_l)
    next_state_r = create_graph_from_observation(obs_r)
    done = False
    
    while not done:
        state_l, state_r = next_state_l, next_state_r
        states = torch_geometric.data.Batch.from_data_list([
            state_l, state_r
        ])

        action_l, action_r = policy(states).sample().detach().numpy().tolist()
        #action_r = reverse_action(action_r_rev)

        (obs_l, obs_r), (reward_l, reward_r), done, info = env.step([action_l, action_r])

        next_state_l = create_graph_from_observation(obs_l)
        next_state_r = create_graph_from_observation(obs_r)

        episode_l.append(
            state=state_l,
            action=action_l,
            reward=reward_l,
            next_state=next_state_l,
            done=done,
        )
        episode_r.append(
            state=state_r,
            action=action_r,
            reward=reward_r,
            next_state=next_state_r,
            done=done,
        )
        
    return episode_l, episode_r

In [None]:
def agent(obs):
    global policy
    
    # Get observations for the first (and only one) player we control.
    #obs = Simple115StateWrapper.convert_observation(obs['players_raw'], True)
    graph = create_graph_from_observation(obs)
    actions = policy(graph)
    action = int(np.argmax(actions.detach().numpy()))
    return [action]

In [None]:
def evaluate_agent(policy):
    obs = env_eval.reset()
    rewards = []
    while True:
        graph = create_graph_from_observation(obs)
        with torch.no_grad():
            distrib = policy(graph)
        action = [int(distrib.logits.argmax())]
        obs, reward, done, info = env_eval.step(action)
        rewards.append(reward)
        if done:
            break
    score = np.sum(rewards)
    return score

In [None]:
gamma = 0.997
target_kld = 0.001
entropy_reg = 0.0001
lr_value = 1e-4
lr_policy = 1e-4
rollout = 100
batchsize = 500
num_parallel_episodes = 16
num_effective_episodes = 5001
num_iter=10
eval_num_parallel = 16
gae_lambda = 0.999
device = 'cuda'

In [None]:
beta = 1.0

In [None]:
value = Value(num_features=9, num_relations=len(relations_dict))

In [None]:
policy_optimizer = torch.optim.Adam(policy.parameters(), lr=lr_policy)
value_optimizer = torch.optim.Adam(value.parameters(), lr=lr_value)

In [None]:
log = {
    'policy_loss': [],
    'kl_divergence': [],
    'entropy': [],
    'beta': [],
    'value_loss': [],
}
progress = []

In [None]:
for effective_episode_num in tqdm(range(num_effective_episodes)):

    policy.eval()
    policy.to('cpu')
    print('Collecting episodes...')
    _episodes = Parallel(n_jobs=-1, backend='multiprocessing')(
        [delayed(get_episodes)(policy) for _ in range(num_parallel_episodes)]
    )
    episodes = []
    for episode in _episodes:
        episodes.extend(episode)
    episodes = concat_episodes(episodes)

    #with open('episodes_for_debugging.pickle', 'rb') as f:
    #    episodes = pickle.load(f)
    
    print('Sampling episodes...')
    sampled_episodes = [episodes.sample(rollout) for _ in range(batchsize)]
    for sampled_episode in sampled_episodes:
        sampled_episode['done'][-1] = True
    sampled_episodes = concat_episodes(sampled_episodes)
    
    policy.train()
    value.train()
    
    print('Computing values...')
    compute_values(sampled_episodes, value, device=device)
    print('Computing returns...')
    compute_returns(sampled_episodes, gamma)
    print('Computing advantages...')
    compute_advantages(sampled_episodes, gamma, lambda_=gae_lambda, device=device)

    policy_old = copy.deepcopy(policy)

    print('Started training...')
    policy.train()
    for iter_num in range(num_iter):
        print('Policy iteration {} started...'.format(iter_num))
        ppo_losses = get_ppo_losses(
            sampled_episodes,
            policy,
            policy_old,
            clip=False,
            compute_kl=True,
            compute_entropy=True,
            device=device,
        )

        policy_loss = ppo_losses['policy_loss']
        kl_divergence = ppo_losses['kl']
        entropy = ppo_losses['entropy']

        loss = policy_loss + beta * kl_divergence - entropy_reg * entropy

        if kl_divergence > 1.5 * target_kld:
            print("Early stopping because of high KL-divergence.")
            break

        policy.zero_grad()
        loss.backward()
        policy_optimizer.step()

    if kl_divergence > 1.3 * target_kld:
        beta *= 1.5
    elif kl_divergence < 0.7 * target_kld:
        beta /= 1.5

    del sampled_episodes.episodes['value']
    del sampled_episodes.episodes['return']
    del sampled_episodes.episodes['advantage']
    torch.cuda.empty_cache()
        
    for iter_num in range(num_iter):
        print('Value iteration {} started...'.format(iter_num))
        print('Computing advantages...')
        compute_values(sampled_episodes, value, device=device, update=True)
        print('Computing returns...')
        compute_returns(sampled_episodes, gamma, update=True)
        print('Computing value loss...')
        value_loss = get_value_loss(sampled_episodes, device=device)
        value_optimizer.zero_grad()
        value_loss.backward()
        value_optimizer.step()
        
    del sampled_episodes.episodes['value']
    del sampled_episodes.episodes['return']
    torch.cuda.empty_cache()
        
    policy_loss.to('cpu')
    kl_divergence.to('cpu')
    entropy.to('cpu')
    value_loss.to('cpu')
    log['policy_loss'].append(float(policy_loss.detach()))
    log['kl_divergence'].append(float(kl_divergence.detach()))
    log['entropy'].append(float(entropy.detach()))
    log['beta'].append(beta)
    log['value_loss'].append(float(value_loss.detach()))
    
    del ppo_losses, policy_loss, loss, kl_divergence, entropy, value_loss
    torch.cuda.empty_cache()
    
    policy.to('cpu')
    policy.eval()
    print('Training ended.')
    #print(log)
    with open('training_log.json', 'wt') as f:
        json.dump(log, f)
    clear_output()
    if effective_episode_num % 10 == 0:
        print('Evaluating...')
        rewards_eval = Parallel(n_jobs=-1, backend='multiprocessing')(
            [delayed(evaluate_agent)(policy) for _ in range(eval_num_parallel)]
        )
        progress.append(float(np.mean(rewards_eval)))
        with open('reward_curve.json', 'wt') as f:
            json.dump(progress, f)
        plt.close()
        plt.cla()
        plt.clf()
        clear_output()
        plt.plot([i*10 for i in range(len(progress))], progress)
        plt.grid()
        plt.xlim(0, None)
        plt.xlabel('number of effective episodes')
        plt.ylabel('mean reward')
        plt.show()
    if effective_episode_num % 50 == 0:
        torch.save(policy.to('cpu').state_dict(), './policies/{}.pth'.format(effective_episode_num))
        torch.save(value.to('cpu').state_dict(), './values/{}.pth'.format(effective_episode_num))

    del sampled_episodes
    del episodes
    del _episodes
    torch.cuda.empty_cache()

  5%|▍         | 245/5001 [21:49:20<422:01:07, 319.44s/it]

Collecting episodes...
Sampling episodes...
Computing values...
Computing returns...
Computing advantages...
Started training...
Policy iteration 0 started...
Policy iteration 1 started...
Policy iteration 2 started...
Policy iteration 3 started...
Policy iteration 4 started...
Policy iteration 5 started...
Policy iteration 6 started...
Policy iteration 7 started...
Policy iteration 8 started...
Policy iteration 9 started...
Value iteration 0 started...
Computing advantages...
Computing returns...
Computing value loss...
Value iteration 1 started...
Computing advantages...
Computing returns...
Computing value loss...
Value iteration 2 started...
Computing advantages...
Computing returns...
Computing value loss...
Value iteration 3 started...
Computing advantages...
Computing returns...
Computing value loss...
Value iteration 4 started...
Computing advantages...
Computing returns...
Computing value loss...
Value iteration 5 started...
Computing advantages...
Computing returns...
Computi