In [4]:
import warnings
warnings.filterwarnings('ignore')
import gym
import matplotlib.pyplot as plt
from rl.core import Processor
import numpy as np
from PIL import Image
import time
import cv2

In [7]:
def envCheck(env_name):
    """
    Checks the current environment.
    
    Parameters:
    env_name - (str) name of the environment
    """
    env = gym.make(env_name)
    env.reset()
    action = 1
    #retrieve information
    observation, reward, done, info = env.step(action)
    plt.imshow(observation)
    plt.imsave()
    plt.show()
    nb_actions = env.action_space.n
    act_means = env.unwrapped.get_action_meanings()
    #print information
    print(f"Number of Action: {nb_actions}")
    print()
    print(f"Action Meanings:\n{act_means}")
    env.close()
    print()
    print(f"Reward: {reward}\nDone: {done}\nInfo:{info}")

In [3]:
class AtariProcessor(Processor):
    """
    Standard processor class. This class doesn´t capture any images.
    """
    
    def __init__(self):
        """
        Initializes the standard processor.
        """
        self.INPUT_SHAPE = (84,84)
        self.WINDOW_LENGTH = 4
        
    def process_observation(self, observation):
        """
        This functions processes the observation.
        
        parameter:
        
        observation - image 
        """
        # (height, width, channel)
        assert observation.ndim == 3  
        img = Image.fromarray(observation)
        
        # resize and convert to grayscale
        img = img.resize(self.INPUT_SHAPE).convert(
            'L')  
        #save image as array
        processed_observation = np.array(img)
        assert processed_observation.shape == self.INPUT_SHAPE
        return processed_observation.astype('uint8')

    def process_state_batch(self, batch):
        """
        This function processes the state batch.
        """
        processed_batch = batch.astype('float32') / 255.
        return processed_batch

    def process_reward(self, reward):
        """
        This function clips the reward.
        """
        return np.clip(reward, -1., 1.)

In [4]:
class AtariProcessorCapture(Processor):
    """
    This processor extension captures images.
    """

    def __init__(self, filepath, env_name):
        """
        Initializes the 'AtariProcessor' class.

        Parameter:
        filepath(str): path where the processor saves the image
        save_image(bool): If True, the processor saves the image
        """
        self.save_image = True
        self.filepath = filepath
        self.env_name = env_name
        
        self.INPUT_SHAPE = (84,84)
        self.WINDOW_LENGTH = 4

    def process_observation(self, observation):
        """
        This functions processes the observation.
        
        parameter:
        
        observation - image 
        """
        assert observation.ndim == 3  # (height, width, channel)
        img = Image.fromarray(observation)
        img = img.resize(self.INPUT_SHAPE).convert(
            'L')  # resize and convert to grayscale
        processed_observation = np.array(img)

        if self.save_image:
            path = self.filepath+self.env_name +'_{}.png'.format(int(time.time()))
            cv2.imwrite(path, processed_observation)

        assert processed_observation.shape == self.INPUT_SHAPE
        # saves storage in experience memory
        return processed_observation.astype('uint8')

    def process_state_batch(self, batch):
        """
        This function processes the state batch.
        """
        processed_batch = batch.astype('float32') / 255.
        return processed_batch

    def process_reward(self, reward):
        """
        This function clips the reward.
        """
        return np.clip(reward, -1., 1.)