In [1]:
import gym

import os

import numpy as np

from datetime import datetime

import json


# This wrapper extracts Hand Crafted Features from gym observations
class HCFgymWrapper(gym.ObservationWrapper):
    def __init__(self, env, FuncList, DstDir=os.getcwd()):
      super().__init__(env)
      self.FuncList = FuncList
      self.DstDir = DstDir
      self.resultsDir=os.path.join(self.DstDir, "results")
      self.Outputs = dict()
      if not os.path.exists(self.DstDir):
          os.mkdir(self.DstDir)
          os.mkdir(os.path.join(self.resultsDir))
      elif not os.path.exists(os.path.join(self.resultsDir)):
         os.mkdir(os.path.join(self.resultsDir))
      for func in self.FuncList:
          self.Outputs[func.__name__]=[]

    def observation(self, obs):
      # modify obs
      for func in self.FuncList:
          self.Outputs[func.__name__].append(func(obs).tolist())
      return obs

    def close(self):
      now=str(datetime.now().strftime("%m.%d.%Y_%H:%M:%S"))
      results_file = os.path.join(self.resultsDir, "HandCraftedFeatures"+now+".json")
      with open(results_file, "w") as write_file:
        json.dump(self.Outputs, write_file)
      return super().close()

In [2]:
import cv2

def get_padel_position(obs=None, method=cv2.TM_CCOEFF_NORMED):
    """
    return the avg. location of the padel over 4 consequtive observations
    obs: np.ndarray.shape = (3,74). obs is the buttom part of the image, where the padel is.
                                    we assume only the last frame in each observation as the
                                    relevant for the current location (the others are there to give the
                                    network a sense of motion).
    method: string in [cv2.TM_CCOEFF, cv2.TM_CCOEFF_NORMED, cv2.TM_CCORR,
                       cv2.TM_CCORR_NORMED, cv2.TM_SQDIFF, cv2.TM_SQDIFF_NORMED]
    """
    obs=obs[75:78,5:79,0]
    assert obs is not None
    assert obs.shape[0] == 3 and obs.shape[1] == 74
    assert obs.dtype in [np.uint8, np.float32]
    assert method in [cv2.TM_CCOEFF, cv2.TM_CCOEFF_NORMED, cv2.TM_CCORR,
                      cv2.TM_CCORR_NORMED, cv2.TM_SQDIFF, cv2.TM_SQDIFF_NORMED]
    
    padel_filter = np.array(([44, 44, 44, 44, 44],
                             [110, 110, 110, 110, 110],
                             [22, 22, 22, 22, 22]),dtype=np.uint8)
    w, h = padel_filter.shape[::-1]
    
    res = cv2.matchTemplate(obs, padel_filter, method)
    min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res)
        # If the method is TM_SQDIFF or TM_SQDIFF_NORMED, take minimum
    if method in [cv2.TM_SQDIFF, cv2.TM_SQDIFF_NORMED]:
        center = min_loc + np.array([w/2, h/2])
    else:
        center = max_loc + np.array([w/2, h/2])
     
    center += [5, 75]
    return center

In [3]:
FuncList= []
FuncList.append(get_padel_position)

In [4]:
FuncList

[<function __main__.get_padel_position>]

In [5]:
env=gym.make("Breakout-v4")
Wrap=HCFgymWrapper(env,FuncList=FuncList)

In [6]:
T = 10
s_t = Wrap.reset()
for t in range(T):
    a_t = Wrap.action_space.sample()
    observations, rewards, dones, infos = Wrap.step(a_t)
    if dones:
        s_t = Wrap.reset()

In [7]:
Wrap.close()