In [2]:
import random, os, numpy as np, tensorflow as tf
from tf_graph import FlappyGraph
from config import *

In [3]:
def add_jumps_to_training(training_images, last_jumps):
    print("Parsing data...")
    iter_counter = 0
    X_data = []
    # (game frames, height, width)
    for i, game in enumerate(training_images):
        X_data.append([])
        for j, image in enumerate(game):
            X_data[iter_counter].append(np.append(image.ravel(), last_jumps[i][j]))
        X_data[iter_counter] = np.array(X_data[iter_counter], dtype=np.float32)
        iter_counter += 1

    return np.array(X_data)

In [4]:
flappy_graph = FlappyGraph(int((CANVAS_WIDTH * IMG_SCALE_FACTOR) * round(CANVAS_HEIGHT * IMG_SCALE_FACTOR)) + 1)
init = tf.global_variables_initializer()
global sess
sess = tf.Session()
sess.run(init)

In [5]:
global saver
saver = tf.train.Saver()
def save_model():
    if not os.path.exists(MODEL_DIR):
        os.makedirs(MODEL_DIR)
    saver.save(sess, MODEL_PATH)

In [6]:
def train_iteration():
    print("Loading data...")
    training_images = np.load(os.path.join(DATA_DIR, "images.npy"))
    actions = np.load(os.path.join(DATA_DIR, "actions.npy"))
    rewards = np.load(os.path.join(DATA_DIR, "adjusted_rewards.npy"))
    last_jumps = np.load(os.path.join(DATA_DIR, "last_jumps.npy"))
    X_data = add_jumps_to_training(training_images = training_images, last_jumps = last_jumps)

    for i in range(5):
        for j in range(5):
            new_prob, _, train_loss = sess.run([flappy_graph.new_prob, flappy_graph.train_step, flappy_graph.loss], 
                        feed_dict={
                            flappy_graph.inputs: X_data[i], 
                            flappy_graph.actions: actions[i], 
                            flappy_graph.rewards: rewards[i], 
                            flappy_graph.lr: 1e-3}
                        )
            print("loss", train_loss, "new_prob: ", new_prob)

In [7]:
import run_agent

In [8]:
save_model()

In [None]:
for i in range(10):
    run_agent.run()
    train_iteration()
    save_model()

INFO:tensorflow:Restoring parameters from ./models/trained_flappy
[[-2.84942579]] [[ 0.054711]]
[[-2.84942579]] [[ 0.054711]]
[[-2.84942579]] [[ 0.054711]]
[[-2.84938288]] [[ 0.05471322]]
[[-2.84940338]] [[ 0.05471216]]
[[-2.84938073]] [[ 0.05471334]]
[[-2.84950399]] [[ 0.05470696]]
[[-2.84944797]] [[ 0.05470986]]
[[-2.84951448]] [[ 0.05470642]]
[[-2.84937906]] [[ 0.05471342]]
[[-2.84933996]] [[ 0.05471545]]
[[-2.84929013]] [[ 0.05471802]]
[[-2.84939551]] [[ 0.05471257]]
[[-2.84938169]] [[ 0.05471329]]
[[-2.84929109]] [[ 0.05471798]]
[[-2.8492434]] [[ 0.05472044]]
[[-2.84920216]] [[ 0.05472257]]
[[-2.84920979]] [[ 0.05472218]]
[[-2.84930015]] [[ 0.05471751]]
[[-2.84900379]] [[ 0.05473283]]
[[-2.8490057]] [[ 0.05473274]]
[[-2.84895301]] [[ 0.05473546]]
[[-2.84890032]] [[ 0.05473819]]
[[-2.84904337]] [[ 0.05473079]]
[[-2.84918976]] [[ 0.05472321]]
[[-2.84922266]] [[ 0.05472151]]
[[-2.84931254]] [[ 0.05471686]]
[[-2.84933472]] [[ 0.05471572]]
[[-2.84933567]] [[ 0.05471567]]
[[-2.84925485]

[[-2.84882903]] [[ 0.05474188]]
[[-2.84902716]] [[ 0.05473163]]
[[-2.84897923]] [[ 0.05473411]]
[[-2.84906673]] [[ 0.05472958]]
[[-2.8491106]] [[ 0.05472731]]
[[-2.84895325]] [[ 0.05473545]]
[[-2.84877491]] [[ 0.05474468]]
[[-2.84885335]] [[ 0.05474062]]
[[-2.84901977]] [[ 0.05473201]]
[[-2.84908605]] [[ 0.05472858]]
[[-2.84907341]] [[ 0.05472923]]
[[-2.84903312]] [[ 0.05473132]]
[[-2.84923816]] [[ 0.05472071]]
[[-2.84887958]] [[ 0.05473926]]
[[-2.84887385]] [[ 0.05473956]]
[[-2.8484478]] [[ 0.05476161]]
[[-2.84846354]] [[ 0.0547608]]
[[-2.84844208]] [[ 0.05476191]]
[[-2.84840775]] [[ 0.05476368]]
[[-2.84825325]] [[ 0.05477168]]
[[-2.84827828]] [[ 0.05477038]]
[[-2.84840083]] [[ 0.05476404]]
[[-2.84848595]] [[ 0.05475963]]
[[-2.84844017]] [[ 0.054762]]
[[-2.84839821]] [[ 0.05476417]]
[[-2.84857345]] [[ 0.0547551]]
[[-2.84855461]] [[ 0.05475608]]
[[-2.84854579]] [[ 0.05475653]]
[[-2.84881878]] [[ 0.05474241]]
[[-2.84869528]] [[ 0.0547488]]
[[-2.84860921]] [[ 0.05475325]]
[[-2.84848952]]

[[-2.84883976]] [[ 0.05474132]]
[[-2.84875655]] [[ 0.05474563]]
[[-2.84875798]] [[ 0.05474555]]
[[-2.84849524]] [[ 0.05475915]]
[[-2.84842348]] [[ 0.05476286]]
[[-2.84864736]] [[ 0.05475128]]
[[-2.8486743]] [[ 0.05474988]]
[[-2.8486793]] [[ 0.05474963]]
[[-2.84892154]] [[ 0.05473709]]
[[-2.84894276]] [[ 0.05473599]]
[[-2.84883428]] [[ 0.05474161]]
[[-2.84878731]] [[ 0.05474404]]
[[-2.84889555]] [[ 0.05473843]]
[[-2.84910464]] [[ 0.05472762]]
[[-2.84917736]] [[ 0.05472386]]
[[-2.84915876]] [[ 0.05472482]]
[[-2.84917545]] [[ 0.05472396]]
[[-2.84922934]] [[ 0.05472117]]
[[-2.84897327]] [[ 0.05473442]]
[[-2.8488338]] [[ 0.05474163]]
[[-2.84862781]] [[ 0.05475229]]
[[-2.84864163]] [[ 0.05475157]]
[[-2.84857368]] [[ 0.05475509]]
[[-2.84854913]] [[ 0.05475636]]
[[-2.84844422]] [[ 0.05476179]]
[[-2.84846759]] [[ 0.05476058]]
[[-2.84839058]] [[ 0.05476457]]
[[-2.84851122]] [[ 0.05475832]]
[[-2.8485148]] [[ 0.05475814]]
[[-2.84843111]] [[ 0.05476247]]
[[-2.84853387]] [[ 0.05475715]]
[[-2.8485589

loss 0.00326899 new_prob:  [[ 0.945337    0.945337    0.945337   ...,  0.945337    0.945337    0.945337  ]
 [ 0.9453311   0.9453311   0.9453311  ...,  0.9453311   0.9453311
   0.9453311 ]
 [ 0.94533449  0.94533449  0.94533449 ...,  0.94533449  0.94533449
   0.94533449]
 ..., 
 [ 0.94529629  0.94529629  0.94529629 ...,  0.94529629  0.94529629
   0.94529629]
 [ 0.94529396  0.94529396  0.94529396 ...,  0.94529396  0.94529396
   0.94529396]
 [ 0.9452951   0.9452951   0.9452951  ...,  0.9452951   0.9452951
   0.9452951 ]]
loss 0.00326834 new_prob:  [[ 0.94534767  0.94534767  0.94534767 ...,  0.94534767  0.94534767
   0.94534767]
 [ 0.94534177  0.94534177  0.94534177 ...,  0.94534177  0.94534177
   0.94534177]
 [ 0.94534522  0.94534522  0.94534522 ...,  0.94534522  0.94534522
   0.94534522]
 ..., 
 [ 0.94530684  0.94530684  0.94530684 ...,  0.94530684  0.94530684
   0.94530684]
 [ 0.94530451  0.94530451  0.94530451 ...,  0.94530451  0.94530451
   0.94530451]
 [ 0.94530559  0.94530559  0.9453

[[-2.84986424]] [[ 0.05468833]]
[[-2.84977794]] [[ 0.0546928]]
[[-2.84981632]] [[ 0.05469081]]
[[-2.84976029]] [[ 0.05469371]]
[[-2.84998679]] [[ 0.054682]]
[[-2.8497138]] [[ 0.05469611]]
[[-2.8497901]] [[ 0.05469217]]
[[-2.84975982]] [[ 0.05469373]]
[[-2.8497653]] [[ 0.05469345]]
[[-2.84969425]] [[ 0.05469712]]
[[-2.84984255]] [[ 0.05468946]]
[[-2.84977341]] [[ 0.05469303]]
[[-2.85005212]] [[ 0.05467862]]
[[-2.85000396]] [[ 0.05468111]]
[[-2.8499074]] [[ 0.05468611]]
[[-2.85021186]] [[ 0.05467037]]
[[-2.85029793]] [[ 0.05466592]]
[[-2.85022068]] [[ 0.05466991]]
[[-2.85002899]] [[ 0.05467982]]
[[-2.85018706]] [[ 0.05467165]]
[[-2.85026646]] [[ 0.05466754]]
[[-2.8501513]] [[ 0.0546735]]
[[-2.84997129]] [[ 0.0546828]]
[[-2.84995675]] [[ 0.05468355]]
[[-2.85005474]] [[ 0.05467848]]
[[-2.85014296]] [[ 0.05467393]]
[[-2.85029006]] [[ 0.05466633]]
[[-2.85015035]] [[ 0.05467355]]
[[-2.85012841]] [[ 0.05467468]]
[[-2.84997964]] [[ 0.05468237]]
[[-2.85011935]] [[ 0.05467515]]
[[-2.85019469]] [[

[[-2.85027027]] [[ 0.05466735]]
[[-2.85022235]] [[ 0.05466983]]
[[-2.85009885]] [[ 0.0546762]]
[[-2.85016823]] [[ 0.05467262]]
[[-2.85005546]] [[ 0.05467845]]
[[-2.84993076]] [[ 0.0546849]]
[[-2.85001493]] [[ 0.05468054]]
[[-2.85003924]] [[ 0.05467929]]
[[-2.85004997]] [[ 0.05467873]]
[[-2.8504343]] [[ 0.05465887]]
[[-2.85044789]] [[ 0.05465817]]
[[-2.85037279]] [[ 0.05466205]]
[[-2.85030437]] [[ 0.05466558]]
[[-2.85035539]] [[ 0.05466295]]
[[-2.85007739]] [[ 0.05467732]]
[[-2.85009003]] [[ 0.05467666]]
[[-2.84995604]] [[ 0.05468359]]
[[-2.85006046]] [[ 0.05467819]]
[[-2.85012722]] [[ 0.05467474]]
[[-2.85017347]] [[ 0.05467235]]
[[-2.8501091]] [[ 0.05467568]]
[[-2.8499825]] [[ 0.05468222]]
[[-2.85002589]] [[ 0.05467998]]
[[-2.85002804]] [[ 0.05467987]]
[[-2.85008478]] [[ 0.05467694]]
[[-2.85000062]] [[ 0.05468129]]
[[-2.85015488]] [[ 0.05467331]]
[[-2.85031509]] [[ 0.05466503]]
[[-2.85032201]] [[ 0.05466467]]
[[-2.8501904]] [[ 0.05467148]]
[[-2.85025024]] [[ 0.05466838]]
[[-2.85032535]

[[-2.85060906]] [[ 0.05464984]]
[[-2.85059857]] [[ 0.05465038]]
[[-2.85056925]] [[ 0.0546519]]
[[-2.85039377]] [[ 0.05466097]]
[[-2.85029769]] [[ 0.05466593]]
[[-2.85046053]] [[ 0.05465752]]
[[-2.850564]] [[ 0.05465217]]
[[-2.85058737]] [[ 0.05465096]]
[[-2.85041714]] [[ 0.05465976]]
[[-2.8502593]] [[ 0.05466792]]
[[-2.85014224]] [[ 0.05467397]]
[[-2.85026622]] [[ 0.05466756]]
[[-2.85038781]] [[ 0.05466128]]
[[-2.85022688]] [[ 0.05466959]]
[[-2.85024881]] [[ 0.05466846]]
[[-2.85026026]] [[ 0.05466786]]
[[-2.85013533]] [[ 0.05467432]]
[[-2.85017705]] [[ 0.05467217]]
[[-2.85033894]] [[ 0.0546638]]
[[-2.85046577]] [[ 0.05465725]]
[[-2.85038042]] [[ 0.05466166]]
[[-2.85047245]] [[ 0.0546569]]
[[-2.85034347]] [[ 0.05466357]]
[[-2.85011554]] [[ 0.05467534]]
[[-2.85013819]] [[ 0.05467417]]
[[-2.85018873]] [[ 0.05467156]]
[[-2.85025454]] [[ 0.05466816]]
[[-2.85028005]] [[ 0.05466684]]
[[-2.85001993]] [[ 0.05468029]]
[[-2.84997416]] [[ 0.05468265]]
[[-2.84994149]] [[ 0.05468434]]
[[-2.84992576]

loss -0.0370126 new_prob:  [[ 0.94549686  0.94549686  0.94549686 ...,  0.94549686  0.94549686
   0.94549686]
 [ 0.94549066  0.94549066  0.94549066 ...,  0.94549066  0.94549066
   0.94549066]
 [ 0.94549471  0.94549471  0.94549471 ...,  0.94549471  0.94549471
   0.94549471]
 ..., 
 [ 0.94545627  0.94545627  0.94545627 ...,  0.94545627  0.94545627
   0.94545627]
 [ 0.9454785   0.9454785   0.9454785  ...,  0.9454785   0.9454785
   0.9454785 ]
 [ 0.94548017  0.94548017  0.94548017 ...,  0.94548017  0.94548017
   0.94548017]]
loss -0.0370978 new_prob:  [[ 0.94537365  0.94537365  0.94537365 ...,  0.94537365  0.94537365
   0.94537365]
 [ 0.94536752  0.94536752  0.94536752 ...,  0.94536752  0.94536752
   0.94536752]
 [ 0.94537163  0.94537163  0.94537163 ...,  0.94537163  0.94537163
   0.94537163]
 ..., 
 [ 0.94533509  0.94533509  0.94533509 ...,  0.94533509  0.94533509
   0.94533509]
 [ 0.94535661  0.94535661  0.94535661 ...,  0.94535661  0.94535661
   0.94535661]
 [ 0.94535822  0.94535822  0.9

[[-2.84742713]] [[ 0.05481446]]
[[-2.84730339]] [[ 0.05482088]]
[[-2.84743834]] [[ 0.05481388]]
[[-2.8475399]] [[ 0.05480862]]
[[-2.847332]] [[ 0.05481939]]
[[-2.84733963]] [[ 0.054819]]
[[-2.84724498]] [[ 0.0548239]]
[[-2.84725571]] [[ 0.05482335]]
[[-2.84750462]] [[ 0.05481045]]
[[-2.84739304]] [[ 0.05481623]]
[[-2.84732485]] [[ 0.05481976]]
[[-2.84740281]] [[ 0.05481572]]
[[-2.84724689]] [[ 0.0548238]]
[[-2.84724522]] [[ 0.05482389]]
[[-2.8474884]] [[ 0.05481129]]
[[-2.8473289]] [[ 0.05481955]]
[[-2.84746647]] [[ 0.05481243]]
[[-2.84733081]] [[ 0.05481945]]
[[-2.84725189]] [[ 0.05482354]]
[[-2.84719658]] [[ 0.05482641]]
[[-2.84720182]] [[ 0.05482614]]
[[-2.84711027]] [[ 0.05483088]]
[[-2.8471005]] [[ 0.05483139]]
[[-2.84734273]] [[ 0.05481884]]
[[-2.8470757]] [[ 0.05483267]]
[[-2.84708595]] [[ 0.05483214]]
[[-2.84705257]] [[ 0.05483387]]
[[-2.84707189]] [[ 0.05483287]]
[[-2.84727788]] [[ 0.0548222]]
[[-2.84734988]] [[ 0.05481847]]
[[-2.84736681]] [[ 0.05481759]]
-1
Game 1 over; aliv

[[-2.84700751]] [[ 0.05483621]]
[[-2.84695959]] [[ 0.05483869]]
[[-2.84718084]] [[ 0.05482723]]
[[-2.84709764]] [[ 0.05483154]]
[[-2.84723091]] [[ 0.05482463]]
[[-2.84739375]] [[ 0.05481619]]
[[-2.84753919]] [[ 0.05480866]]
[[-2.84735084]] [[ 0.05481842]]
[[-2.84731245]] [[ 0.05482041]]
[[-2.8473537]] [[ 0.05481827]]
[[-2.84747386]] [[ 0.05481204]]
[[-2.84743571]] [[ 0.05481402]]
[[-2.84740448]] [[ 0.05481564]]
[[-2.84746885]] [[ 0.0548123]]
[[-2.84753227]] [[ 0.05480902]]
[[-2.84719086]] [[ 0.05482671]]
[[-2.84722137]] [[ 0.05482512]]
[[-2.84698105]] [[ 0.05483758]]
[[-2.8469758]] [[ 0.05483785]]
[[-2.84705353]] [[ 0.05483382]]
[[-2.847054]] [[ 0.0548338]]
[[-2.84684229]] [[ 0.05484477]]
[[-2.84705782]] [[ 0.05483361]]
[[-2.84705877]] [[ 0.05483355]]
[[-2.84708142]] [[ 0.05483238]]
[[-2.84702229]] [[ 0.05483545]]
[[-2.84695888]] [[ 0.05483873]]
[[-2.84701991]] [[ 0.05483557]]
[[-2.846946]] [[ 0.0548394]]
[[-2.84701252]] [[ 0.05483595]]
[[-2.84689426]] [[ 0.05484208]]
[[-2.84674501]] [

In [11]:
actions = np.load(os.path.join(DATA_DIR, "actions.npy"))

In [16]:
training_images = np.load(os.path.join(DATA_DIR, "images.npy"))
actions = np.load(os.path.join(DATA_DIR, "actions.npy"))
rewards = np.load(os.path.join(DATA_DIR, "adjusted_rewards.npy"))
last_jumps = np.load(os.path.join(DATA_DIR, "last_jumps.npy"))
X_data = add_jumps_to_training(training_images = training_images, last_jumps = last_jumps)

Parsing data...


In [18]:
X_data[0].shape

(40, 11361)

In [22]:
actions[2].shape

(41, 1)