In [None]:
import torch
import torch.nn as nn
from torch.distributions import MultivariateNormal
from torch.distributions import Categorical
from torch import optim

import numpy as np


class RolloutBuffer:
    def __init__(self):
        self.actions = []
        self.states = []
        self.logprobs = []
        self.rewards = []
        self.state_values = []
        self.dones = []

    def clear(self):
        del self.actions[:]
        del self.states[:]
        del self.logprobs[:]
        del self.rewards[:]
        del self.state_values[:]
        del self.dones[:]

class ActorCriticPPO(nn.Module):
    def __init__(self, state_size, action_size, hidden_size):
        super(ActorCriticPPO, self).__init__()

        # actor
        self.actor = nn.Sequential(
                        nn.Linear(state_size, hidden_size),
                        nn.Tanh(),
                        nn.Linear(hidden_size, hidden_size),
                        nn.Tanh(),
                        nn.Linear(hidden_size, action_size),
                        nn.Softmax(dim=-1)
                    )
        # critic
        self.critic = nn.Sequential(
                        nn.Linear(state_size, hidden_size),
                        nn.Tanh(),
                        nn.Linear(hidden_size, hidden_size),
                        nn.Tanh(),
                        nn.Linear(hidden_size, 1)
                    )
        
    
    def forward(self):
        raise NotImplementedError
    

    def act(self, state):
      action_probs = self.actor(state)
      dist = Categorical(action_probs)

      action = dist.sample()
      action_logprob = dist.log_prob(action)
      state_val = self.critic(state)

      return action.detach(), action_logprob.detach(), state_val.detach()
    

    def evaluate(self, state, action):
        action_probs = self.actor(state)
        dist = Categorical(action_probs)

        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()
        state_values = self.critic(state)
        
        return action_logprobs, state_values, dist_entropy


class PPO:
    def __init__(self, env, obs_type, max_steps):
        self.env = env
        
        if obs_type == 'vector':
          self.state_size = env.observation_space.shape[0]
        elif obs_type == 'pixel':
          self.state_size = env.observation_space.shape
          self.state_size = self.state_size[0]*self.state_size[1]*self.state_size[2]

        self.action_size = env.action_space.n
        self.hidden_size = 64

        self.gamma = 0.9
        self.update_timestep = max_steps * 5   # update policy every n timesteps
        self.epochs = 40                     # update policy for epochs
        self.eps_clip = 0.1                    # clip parameter for PPO
        self.lr_actor = 0.001                  # learning rate for actor network
        self.lr_critic = 0.01                 # learning rate for critic network

        self.buffer = RolloutBuffer()

        self.policy = ActorCriticPPO(self.state_size, self.action_size, self.hidden_size)
        self.optimizer = torch.optim.Adam([
                        {'params': self.policy.actor.parameters(), 'lr': self.lr_actor},
                        {'params': self.policy.critic.parameters(), 'lr': self.lr_critic}
                    ])

        self.policy_old = ActorCriticPPO(self.state_size, self.action_size, self.hidden_size)
        self.policy_old.load_state_dict(self.policy.state_dict())
        
        self.MseLoss = nn.MSELoss()


    def select_action(self, state):

        with torch.no_grad():
            state = torch.FloatTensor(state)
            action, action_logprob, state_val = self.policy_old.act(state)
        
        self.buffer.states.append(state)
        self.buffer.actions.append(action)
        self.buffer.logprobs.append(action_logprob)
        self.buffer.state_values.append(state_val)

        return action.item()

    def compute_returns(self):
        returns = []
        R = 0
        for reward, is_terminal in zip(reversed(self.buffer.rewards), reversed(self.buffer.dones)):
            if is_terminal:
                R = 0
            R = reward + (self.gamma * R)
            returns.insert(0, R)
            
        returns = torch.tensor(returns, dtype=torch.float32)
        returns = (returns - returns.mean()) / (returns.std() + 1e-7)

        return returns

    def update(self):
        returns = self.compute_returns()
        weight = 0.01

        old_states = torch.squeeze(torch.stack(self.buffer.states, dim=0)).detach()
        old_actions = torch.squeeze(torch.stack(self.buffer.actions, dim=0)).detach()
        old_logprobs = torch.squeeze(torch.stack(self.buffer.logprobs, dim=0)).detach()
        old_state_values = torch.squeeze(torch.stack(self.buffer.state_values, dim=0)).detach()

        advantages = returns.detach() - old_state_values.detach()
        
        for _ in range(self.epochs):
            logprobs, state_values, entropy = self.policy.evaluate(old_states, old_actions) # Evaluating old actions and values
            state_values = torch.squeeze(state_values)     # match state_values tensor dimensions with rewards tensor
            ratios = torch.exp(logprobs - old_logprobs.detach())  # Finding the ratio (pi_theta / pi_theta__old)

            # Finding Surrogate Loss   
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages

            # final loss of clipped objective PPO
            loss = -torch.min(surr1, surr2) + 0.5 * self.MseLoss(state_values, returns) - weight * entropy
            
            # take gradient step
            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()
            
        # Copy new weights into old policy
        self.policy_old.load_state_dict(self.policy.state_dict())

        # clear buffer
        self.buffer.clear()
    
    def train(self, model, episodes, max_training_timesteps):
        episode = 0
        self.episode_rewards=[]
        time_step=0
        # training loop
        while time_step <= max_training_timesteps:
            
            state = torch.tensor(self.env.reset(), dtype=torch.float).flatten()
            total_rewards = 0
            for t in range(1, episodes+1):
                
                # select action with policy
                action = model.select_action(state)
                next_state, reward, done = env.step(action)
                
                state = torch.tensor(next_state, dtype=torch.float).flatten()
                
                model.buffer.rewards.append(reward)
                model.buffer.dones.append(done)
                
                
                total_rewards += reward

                if done:
                  break

                time_step +=1
                
                if time_step % self.update_timestep == 0:
                    model.update()

                if time_step % 1000 == 0:
                    print("Episode : {} \t\t Timestep : {} \t\t Total Reward : {}".format(episode, time_step, total_rewards))

            
            episode += 1


       

  and should_run_async(code)


In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import matplotlib
import torch

# matplotlib.use('TkAgg')  # 'Qt5Agg') # 'TkAgg'
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import numpy as np
from gym import spaces

ACTION_EFFECTS = (-1, 0, 1)  # left, idle right.
OBSERVATION_TYPES = ['pixel', 'vector']


class Catch():
    

    def __init__(self, rows: int = 7, columns: int = 7, speed: float = 1.0,
                 max_steps: int = 250, max_misses: int = 10,
                 observation_type: str = 'pixel', seed=None,
                 ):
        """ Arguments: 
        rows: the number of rows in the environment grid.
        columns: number of columns in the environment grid.
        speed: speed of dropping new balls. At 1.0 (default), we drop a new ball whenever the last one drops from the bottom. 
        max_steps: number of steps after which the environment terminates.
        max_misses: number of missed balls after which the environment terminates (when this happens before 'max_steps' is reached).
        observation_type: type of observation, either 'vector' or 'pixel'. 
              - 'vector': observation is a vector of length 3:  [x_paddle,x_lowest_ball,y_lowest_ball]
              - 'pixel': observation is an array of size [rows x columns x 2], with one hot indicator for the paddle location in the first channel,
              and one-hot indicator for every present ball in the second channel. 
        seed: environment seed. 
        """
        if observation_type not in OBSERVATION_TYPES:
            raise ValueError('Invalid "observation_type". Needs to be in  {}'.format(OBSERVATION_TYPES))
        if speed <= 0.0:
            raise ValueError('Dropping "speed" should be larger than 0.0')

        # store arguments
        self._rng = np.random.RandomState(seed)
        self.rows = rows
        self.columns = columns
        self.speed = speed
        self.max_steps = max_steps
        self.max_misses = max_misses
        self.observation_type = observation_type

        # compute the drop interval 
        self.drop_interval = max(1, rows // speed)  # compute the interval towards the next drop, can never drop below 1
        if speed != 1.0 and observation_type == 'vector':
            print(
                'Warning: You use speed > 1.0, which means there may be multiple balls in the screen at the same time.' +
                'However, with observation_type = vector, only the xy location of *lowest* ball is visible to the agent' +
                ' (to ensure a fixed length observation vector')

        # Initialize counter
        self.total_timesteps = None
        self.fig = None
        self.action_space = spaces.Discrete(3, )
        if self.observation_type == 'vector':
            self.observation_space = spaces.Box(low=np.array((0, 0, 0)),
                                                high=np.array((self.columns, self.columns, self.rows)), dtype=int)
        elif self.observation_type == 'pixel':
            self.observation_space = spaces.Box(low=np.zeros((self.rows, self.columns, 2)),
                                                high=np.ones((self.rows, self.columns, 2)), dtype=int)

    def reset(self):
        ''' Reset the problem to empty board with paddle in the middle bottom and a first ball on a random location in the top row '''
        # reset all counters
        self.total_timesteps = 0
        self.total_reward = 0
        self.r = '-'
        self.missed_balls = 0
        self.time_till_next_drop = self.drop_interval
        self.terminal = False

        # initialized problem
        self.paddle_xy = [self.columns // 2, 0]  # paddle in the bottom middle
        self.balls_xy = []  # empty the current balls
        self._drop_new_ball()  # add the first ball
        s0 = self._get_state()  # get first state
        return s0

    def step(self, a):
        ''' Forward the environment one step based on provided action a '''

        # Check whether step is even possible
        if self.total_timesteps is None:
            ValueError("You need to reset() the environment before you can call step()")
        elif self.terminal:
            ValueError("Environment has terminated, you need to call reset() first")

        # Move the paddle based on the chosen action
        self.paddle_xy[0] = np.clip(self.paddle_xy[0] + ACTION_EFFECTS[a], 0, self.columns - 1)

        # Drop all balls one step down
        for ball in self.balls_xy:
            ball[1] -= 1

        # Check whether lowest ball dropped from the bottom
        if len(self.balls_xy) > 0:  # there is a ball present
            if self.balls_xy[0][1] < 0:  # the lowest ball reached below the bottom
                del self.balls_xy[0]

        # Check whether we need to drop a new ball
        self.time_till_next_drop -= 1
        if self.time_till_next_drop == 0:
            self._drop_new_ball()
            self.time_till_next_drop = self.drop_interval

            # Compute rewards
        if (len(self.balls_xy) == 0) or (self.balls_xy[0][1] != 0):  # no ball present at bottom row
            r = 0.0
        elif self.balls_xy[0][0] == self.paddle_xy[0]:  # ball and paddle location match, caught a ball
            r = 1.0
        else:  # missed the ball
            r = -1.0
            self.missed_balls += 1

        # Compute termination
        self.total_timesteps += 1
        if (self.total_timesteps == self.max_steps) or (self.missed_balls == self.max_misses):
            self.terminal = True
        else:
            self.terminal = False

        self.r = r
        self.total_reward += r
        return self._get_state(), r, self.terminal #, {}

    def render(self, step_pause=0.3):
        ''' Render the current environment situation '''
        if self.total_timesteps is None:
            ValueError("You need to reset() the environment before you render it")

        # In first call initialize figure
        if self.fig == None:
            self._initialize_plot()

        # Set all colors to white
        for x in range(self.columns):
            for y in range(self.rows):
                if self.paddle_xy == [x, y]:  # hit the agent location
                    if [x, y] in self.balls_xy:  # agent caught a ball
                        self.patches[x][y].set_color('g')
                    else:
                        self.patches[x][y].set_color('y')
                elif [x, y] in self.balls_xy:  # hit a ball location without agent
                    if y == 0:  # missed the ball
                        self.patches[x][y].set_color('r')
                    else:  # just a ball
                        self.patches[x][y].set_color('w')
                else:  # empty spot
                    self.patches[x][y].set_color('k')
        # plt.axis('off')

        self.label.set_text(
            'Reward:  {:<5}            Total reward:  {:<5}     \nTotal misses: {:>2}/{:<2}     Timestep: {:>3}/{:<3}'.format(
                self.r, self.total_reward, self.missed_balls, self.max_misses, self.total_timesteps, self.max_steps))

        # Draw figure
        plt.pause(step_pause)

    def _initialize_plot(self):
        ''' initializes the catch environment figure '''
        self.fig, self.ax = plt.subplots()
        self.fig.set_figheight(self.rows)
        self.fig.set_figwidth(self.columns)
        self.ax.set_aspect('equal', adjustable='box')
        self.ax.set_xlim([0, self.columns])
        self.ax.set_ylim([0, self.rows])
        self.ax.axes.xaxis.set_visible(False)
        self.ax.axes.yaxis.set_visible(False)

        self.patches = [[[] for x in range(self.rows)] for y in range(self.columns)]
        for x in range(self.columns):
            for y in range(self.rows):
                self.patches[x][y] = Rectangle((x, y), 1, 1, linewidth=0.0, color='k')
                self.ax.add_patch(self.patches[x][y])

        self.label = self.ax.text(0.01, self.rows + 0.2, '', fontsize=20, c='k')

    def _drop_new_ball(self):
        ''' drops a new ball from the top '''
        self.balls_xy.append([self._rng.randint(self.columns), self.rows - 1])  # 0])

    def _get_state(self):
        ''' Returns the current agent observation '''
        if self.observation_type == 'vector':
            if len(self.balls_xy) > 0:  # balls present
                s = np.append(self.paddle_xy[0], self.balls_xy[0]).astype('float32')  # paddle xy and ball xy
            else:
                s = np.append(self.paddle_xy[0], [-1, -1]).astype(
                    'float32')  # no balls, impute (-1,-1) in state for no ball present
        elif self.observation_type == 'pixel':
            s = np.zeros((self.columns, self.rows, 2), dtype=np.float32)
            s[self.paddle_xy[0], self.paddle_xy[1], 0] = 1.0  # set paddle indicator in first slice
            for ball in self.balls_xy:
                s[ball[0], ball[1], 1] = 1.0  # set ball indicator(s) in second slice
        else:
            raise ValueError('observation_type not recognized, needs to be in {}'.format(OBSERVATION_TYPES))
        return s


if __name__ == '__main__':

    obs_type = 'pixel'
    env = Catch(rows=7, columns=7, speed=1.0, max_steps=300, max_misses=10, observation_type='pixel', seed=None)
   
    episodes = 2000
    max_steps = 300

    max_training_timesteps = int(1e5)   # break training loop if timeteps > max_training_timesteps
    
    ppo_agent = PPO(env, obs_type, max_steps)
    ppo_agent.train(ppo_agent, episodes, max_training_timesteps)




Episode : 11 		 Timestep : 1000 		 Total Reward : -6.0
Episode : 24 		 Timestep : 2000 		 Total Reward : 0.0
Episode : 36 		 Timestep : 3000 		 Total Reward : 0.0
Episode : 48 		 Timestep : 4000 		 Total Reward : -4.0
Episode : 60 		 Timestep : 5000 		 Total Reward : -3.0
Episode : 72 		 Timestep : 6000 		 Total Reward : -3.0
Episode : 94 		 Timestep : 8000 		 Total Reward : -1.0
Episode : 105 		 Timestep : 9000 		 Total Reward : -4.0
Episode : 116 		 Timestep : 10000 		 Total Reward : -7.0
Episode : 128 		 Timestep : 11000 		 Total Reward : -4.0
Episode : 140 		 Timestep : 12000 		 Total Reward : -2.0
Episode : 149 		 Timestep : 13000 		 Total Reward : -6.0
Episode : 160 		 Timestep : 14000 		 Total Reward : -2.0
Episode : 171 		 Timestep : 15000 		 Total Reward : -1.0
Episode : 181 		 Timestep : 16000 		 Total Reward : -7.0
Episode : 191 		 Timestep : 17000 		 Total Reward : -1.0
Episode : 202 		 Timestep : 18000 		 Total Reward : -6.0
Episode : 212 		 Timestep : 19000 		 Total Rewar