In [138]:
from selenium import webdriver
from selenium.webdriver.common.keys import Keys
from selenium.common.exceptions import WebDriverException
import matplotlib.pyplot as plt
import time
from os import listdir
import random
from PIL import Image, ImageOps

#### Parameters

In [109]:
STILL_ALIVE_REWARD = 1
DEAD_REWARD = -1

CROP_SHAPE = (750, 539, 1)
RESIZE_WIDTH = 180
RESIZE_HEIGHT = 137

#### Browser functions

In [119]:
def selectLevel(l):
    xpath = '/html/body/section/div[2]/nav/p[' + str(l+1) + ']'
    
    level = browser.find_element_by_xpath(xpath)
    level.click()
    
    
def clickBoard():        
    xpath = '/html/body/section/div[2]/div'
    board = browser.find_element_by_xpath(xpath)

    try:
        board.click()
    except WebDriverException:
        return
    
    
def getScreen():
    global index
    fileName = 'data/shots/screen_' + str(index) + '.png'
    
    state = getState()
    if state == 'playing' or state == 'paused':
    
        ss = browser.get_screenshot_as_file(fileName)

        index += 1

    return fileName
    
    
def getState():
    xpath = '/html/body/section/div[2]'
    state = browser.find_element_by_xpath(xpath)
    c = state.get_attribute('class')
    
    return c.split(' ')[-1]


def getScore():
    state = getState()
    
    if state == 'playing' or state == 'paused':
        
        xpath = '/html/body/section/div[2]/p[1]/span'
        score = browser.find_element_by_xpath(xpath)

        if not score.text.isnumeric():
            return 0
        return int(score.text)
    
    return 0


def makeMove(m):
    xpath = '/html/body/section/div[2]/div'
    board = browser.find_element_by_xpath(xpath)

    if m == 1:
        browser.find_element_by_tag_name('body').send_keys(Keys.ARROW_RIGHT)
    elif m == 2:
        browser.find_element_by_tag_name('body').send_keys(Keys.ARROW_DOWN)
    elif m == 3:
        browser.find_element_by_tag_name('body').send_keys(Keys.ARROW_LEFT)
    elif m == 4:
        browser.find_element_by_tag_name('body').send_keys(Keys.ARROW_UP)
        

def makeRandomMove():
    r = random.randint(0, 4)
    makeMove(r)
    
    return r

#### Preprocess functions

In [61]:
folder = 'data/shots/'


def cropImgOld(fileName):
    ss = plt.imread(fileName)

    ss = ss[120:659, 273:1023, :]
    
    plt.imsave(fileName, ss)
        

def preprocImg(fileName):
    '''Preprocess a screenshot.'''
    im = Image.open(fileName)    
    
    # Crop board
    im = im.crop((273, 120, 1023, 659))
    
    # Grayscale
    im = ImageOps.grayscale(im)
    
    # Binarization
    t = 127
    im = im.point(lambda x: 255 if x > t else 0)
    
    # Resize
    im = im.resize((RESIZE_WIDTH, RESIZE_HEIGHT))
    
    return im
    
    
def preprocAll():
    '''Preprocess all screenshots in the data folder.'''
    files = listdir(folder)
    for f in files:
        fileName = folder + f
        im = preprocImg(fileName)
        im.save(fileName)

In [112]:
preprocAll()

#### Play game

In [141]:
url = 'https://playsnake.org/'

browser = webdriver.Chrome(executable_path='D:/Libraries/Drivers/chromedriver_win32/chromedriver.exe')
# browser = webdriver.Chrome(executable_path='D:/Biblioteci/Python/chromedriver_win32/chromedriver.exe')
browser.get(url) 


index = 0
history = dict()

maxGames = 10

for g in range(maxGames):
    
    # Select the level
    selectLevel(1)
    time.sleep(2)
    
    score = 0
    
    while getState() == 'playing':
        clickBoard()
        s = getScreen()

        clickBoard()
        m = makeRandomMove()

        clickBoard()
        s2 = getScreen()
        r = max(STILL_ALIVE_REWARD, getScore() - score)
        history['screen_'+str(index-1)+'.png'] = [g, m, r, score]
        score = getScore()
        
        clickBoard()
        
    history['screen_'+str(index-1)+'.png'][2] = DEAD_REWARD
    clickBoard()
    time.sleep(1)
    
print(history)

browser.quit()

preprocAll()

{'screen_61.png': [2, 1, 1, 0], 'screen_59.png': [2, 3, 1, 0], 'screen_13.png': [1, 1, 1, 0], 'screen_127.png': [5, 4, 1, 0], 'screen_178.png': [8, 4, 1, 0], 'screen_133.png': [5, 3, 1, 0], 'screen_85.png': [2, 4, 1, 0], 'screen_1.png': [0, 4, 1, 0], 'screen_11.png': [0, 2, -1, 0], 'screen_105.png': [3, 3, -1, 0], 'screen_19.png': [1, 2, 1, 0], 'screen_180.png': [8, 4, -1, 0], 'screen_176.png': [8, 3, 1, 0], 'screen_194.png': [9, 3, -1, 0], 'screen_51.png': [2, 1, 1, 0], 'screen_101.png': [3, 0, 1, 0], 'screen_159.png': [7, 4, 1, 0], 'screen_109.png': [4, 4, 1, 0], 'screen_23.png': [1, 2, 1, 0], 'screen_31.png': [2, 2, 1, 0], 'screen_35.png': [2, 3, 1, 0], 'screen_186.png': [9, 3, 1, 0], 'screen_41.png': [2, 1, 1, 0], 'screen_99.png': [3, 3, 1, 0], 'screen_17.png': [1, 3, 1, 0], 'screen_141.png': [5, 2, -1, 0], 'screen_43.png': [2, 3, 1, 0], 'screen_117.png': [4, 1, 1, 0], 'screen_67.png': [2, 3, 1, 0], 'screen_57.png': [2, 1, 1, 0], 'screen_49.png': [2, 1, 1, 0], 'screen_93.png': [3, 

#### Q model

In [97]:
from keras.models import Sequential
from keras.layers import Conv2D, MaxPool2D, Flatten, Dense
from keras.regularizers import l2
from keras.optimizers import Adam, Adadelta, RMSprop
import keras.losses as losses
from keras.backend import set_image_data_format

Intai fac un model care doar prezice daca un board reprezinta un joc care s-a terminat (lovit de zid) sau nu

In [149]:
import numpy as np

set_image_data_format('channels_first')

def readData():
    files = listdir(folder)
    
    labels = np.zeros(len(files))    
    data = []
    
    for i, f in enumerate(files):
        im = Image.open(folder + f)
        
        if f in history:
            labels[i] = history[f][2]
        else:
            labels[i] = STILL_ALIVE_REWARD
            
        data.append([np.asarray(im.convert("L"))])
        
    data = np.asarray(data)
    return data, labels

        
def nn_model():
    model = Sequential()
    
    model.add(Conv2D(10, (5, 5), input_shape=(1, RESIZE_HEIGHT, RESIZE_WIDTH), activation='relu', name='conv1'))
    model.add(MaxPool2D(pool_size=(2,2), name='pool1'))
    
    model.add(Conv2D(5, (5, 5), activation='relu', name='conv2'))
    model.add(MaxPool2D(pool_size=(2,2), name='pool2'))
    
    model.add(Flatten(name='flat'))
    model.add(Dense(100, activation='relu', name='dense1'))
    model.add(Dense(1, activation='softmax', name='dense2'))
    
    model.compile(loss='binary_crossentropy',
                 optimizer = 'adam',
                 metrics=['binary_accuracy'])
    
    return model
    
model = nn_model()

data, labels = readData()
print(data.shape)
print(labels)

model.fit(data, labels, epochs=1, batch_size=8, shuffle=True)

(195, 1, 137, 180)
[ 1.  1.  1.  1.  1.  1.  1.  1. -1.  1.  1.  1.  1. -1.  1.  1.  1.  1.
  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1. -1.  1.  1.  1.  1.  1.
  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1. -1.  1.  1.  1.  1.  1.
  1.  1.  1.  1.  1.  1.  1. -1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.
  1.  1.  1.  1.  1.  1.  1.  1. -1.  1.  1.  1.  1.  1.  1.  1.  1.  1.
  1. -1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1. -1.  1.
  1.  1.  1.  1.  1.  1.  1.  1.  1. -1.  1.  1.  1.  1.  1.  1.  1.  1.
  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.
  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.
  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.
  1. -1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.]
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x1417ba11390>

In [156]:
model.evaluate(data[:10], labels[:10])



[3.188477039337158, 0.8999999761581421]

In [157]:
print(model.predict(data[:30]))
print(labels[:10])

[[1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]]
[ 1.  1.  1.  1.  1.  1.  1.  1. -1.  1.]
