In [None]:
%run Dino.ipynb

In [None]:
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.common.keys import Keys
from datetime import datetime as dt
import numpy as np
import matplotlib.pyplot as plt

options = webdriver.ChromeOptions()

options.add_argument('headless')
options.add_argument('window-size=600x310')

class Agent:
    def __init__(self, debug = False):
        self.debug = debug
        self.dino = Dino()
        self.alive_since = None
        self.died_at = None
        self.screenshots = []
        
    def setup(self):
        self.driver = webdriver.Chrome(options=options)
        self.driver.implicitly_wait(10)
        self.driver.get('file:///home/kocur4d/projects/dinotrainer/trex/index.html')
        self.document = self.driver.find_element(By.XPATH, '//html')
        return None
        
    def start(self):
        self.alive_since = dt.now()
        self.document.send_keys(Keys.SPACE)
        
    def score(self):
        return (self.died_at - self.alive_since).total_seconds()
        
    def show_gallery(self, cols = 1, titles = None):
        """Display a list of images in a single figure with matplotlib.

        Parameters
        ---------
        images: List of np.arrays compatible with plt.imshow.

        cols (Default = 1): Number of columns in figure (number of rows is 
                            set to np.ceil(n_images/float(cols))).

        titles: List of titles corresponding to each image. Must have
                the same length as titles.
        """
        images = self.screenshots
        assert((titles is None) or (len(images) == len(titles)))
        n_images = len(images)
        if titles is None: titles = ['Image (%d)' % i for i in range(1,n_images + 1)]
        fig = plt.figure(figsize=(50, 4))
        for n, (image, title) in enumerate(zip(images, titles)):
            a = fig.add_subplot(cols, np.ceil(n_images/float(cols)), n + 1)
            if image.ndim == 2:
                plt.gray()
            plt.imshow(image)
            a.set_title(title)
        fig.set_size_inches(np.array(fig.get_size_inches()) * n_images)
        
    def isDead(self): 
        if self.died_at == None:
            if self.driver.execute_script("return Runner.instance_.crashed") == True:
                self.died_at = dt.now()
                return True
            else:
                return False
        else:
            return True
            
    def react(self):
        X = self.get_data()
        action = self.dino.react(X)
        self.__react(action)
        
    def get_data(self):
        screenshot = get_screenshot()
        plt.imshow(screenshot, interpolation='nearest')
        rhdata = screenshot[60:, 45:450]
        rhdata = np.where(rhdata > 0, 1, 0)
        up = rhdata[:25]
        down = rhdata[25:]
        up_sum = np.sum(up,axis=0)
        down_sum = np.sum(down,axis=0)
        return np.append(up_sum, down_sum).reshape((1,810))
    
    def __react(self, action):
        actions = {
            'down': lambda: self.document.send_keys(Keys.DOWN),
            'up': lambda: self.document.send_keys(Keys.SPACE),
        }
        print('action', action)
        callback = actions.get(action, lambda: None)
        callback()
        
    def get_screenshot(self):
        data = self.driver.execute_script('return document.getElementsByClassName("runner-canvas")[1].getContext("2d").getImageData(0,0,600,150);')['data']
        return data[0::4]