In [1]:
import random, os, numpy as np, tensorflow as tf
from config import *

In [2]:
def add_jumps_to_training(training_images, last_jumps):
    print("Parsing data...")
    iter_counter = 0
    X_data = []
    # (game frames, height, width)
    for i, game in enumerate(training_images):
        X_data.append([])
        for j, image in enumerate(game):
            X_data[iter_counter].append(np.append(image.ravel(), last_jumps[i][j]))
        X_data[iter_counter] = np.array(X_data[iter_counter], dtype=np.float32)
        iter_counter += 1

    return np.array(X_data)

In [3]:
class FlappyGraph:
    def __init__(self, img_size):
        L1 = 200
        L2 = 50
        output_dim = 1
        tf.reset_default_graph()
        self.inputs = tf.placeholder(tf.float32, [None, img_size], name='inputs')
        self.actions = tf.placeholder(tf.float32, [None, 2], name='actions')

        self.rewards = tf.placeholder(tf.float32, [None], name='rewards')

        # single layer neural network
        W1 = tf.Variable(tf.truncated_normal([img_size, L1], stddev=0.001, dtype=tf.float32))
        b1 = tf.Variable(tf.zeros(L1))

        W2 = tf.Variable(tf.truncated_normal([L1, L2], stddev=0.001, dtype=tf.float32))
        b2 = tf.Variable(tf.ones(L2))
        
        W3 = tf.Variable(tf.truncated_normal([L2, output_dim], stddev=0.001, dtype=tf.float32))
        b3 = tf.Variable(tf.ones(output_dim))

        y1 = tf.nn.relu(tf.matmul(self.inputs, W1) + b1)
        y2 = tf.nn.relu(tf.matmul(y1, W2) + b2)
        
        self.y_logits = tf.matmul(y2, W3) + b3
        self.sigmoid = tf.sigmoid(self.y_logits)
        
        if self.actions[1] == 1:
            self.loss = tf.reduce_mean(self.rewards * tf.log(self.sigmoid[0]))
        else:
            self.loss = tf.reduce_mean(self.rewards * tf.log(1 - self.sigmoid[0]))
        # policy gradient loss function
    
        self.lr = tf.placeholder(tf.float32)
        self.train_step = tf.train.RMSPropOptimizer(self.lr).minimize(self.loss)

In [4]:
flappy_graph = FlappyGraph(int((CANVAS_WIDTH * IMG_SCALE_FACTOR) * round(CANVAS_HEIGHT * IMG_SCALE_FACTOR)) + 1)
init = tf.global_variables_initializer()
global sess
sess = tf.Session()
sess.run(init)

In [5]:
global saver
saver = tf.train.Saver()
def save_model():
    if not os.path.exists(MODEL_DIR):
        os.makedirs(MODEL_DIR)
    saver.save(sess, os.path.join(MODEL_DIR, "trained_flappy"))

In [6]:
def train_iteration():
    print("Loading data...")
    training_images = np.load(os.path.join(DATA_DIR, "images.npy"))
    actions = np.load(os.path.join(DATA_DIR, "actions.npy"))
    rewards = np.load(os.path.join(DATA_DIR, "adjusted_rewards.npy"))
    last_jumps = np.load(os.path.join(DATA_DIR, "last_jumps.npy"))
    X_data = add_jumps_to_training(training_images = training_images, last_jumps = last_jumps)

    for i in range(5):
        for j in range(5):
            _, train_loss = sess.run([flappy_graph.train_step, flappy_graph.loss], 
                        feed_dict={
                            flappy_graph.inputs: X_data[i], 
                            flappy_graph.actions: actions[i], 
                            flappy_graph.rewards: rewards[i], 
                            flappy_graph.lr: 1e-4}
                        )
            print("loss", train_loss)

In [7]:
import run_agent

In [8]:
for i in range(10):
    run_agent.run()
    train_iteration()
    save_model()

INFO:tensorflow:Restoring parameters from ./models/trained_flappy
[[-14.96641159]] [[  3.16351503e-07]]
[[-14.96641445]] [[  3.16350594e-07]]
[[-14.96642017]] [[  3.16348775e-07]]
[[-14.95883942]] [[  3.18756065e-07]]
[[-14.87686539]] [[  3.45986621e-07]]
[[-14.85658455]] [[  3.53075137e-07]]
[[-14.78149796]] [[  3.80607048e-07]]
[[-14.77673054]] [[  3.82425895e-07]]
[[-14.80684757]] [[  3.71080091e-07]]
[[-14.75785637]] [[  3.89712397e-07]]
[[-14.76136494]] [[  3.88347473e-07]]
[[-14.70836735]] [[  4.09484102e-07]]
[[-14.72859955]] [[  4.01282591e-07]]
[[-14.75746918]] [[  3.89863345e-07]]
[[-14.69422913]] [[  4.15314588e-07]]
[[-14.69995594]] [[  4.12942967e-07]]
[[-14.64801979]] [[  4.34956320e-07]]
[[-14.65798473]] [[  4.30643524e-07]]
[[-14.68261719]] [[  4.20165321e-07]]
[[-14.64347839]] [[  4.36936148e-07]]
[[-14.62577343]] [[  4.44740948e-07]]
[[-14.55106926]] [[  4.79237372e-07]]
[[-14.55886078]] [[  4.75517908e-07]]
[[-14.59005737]] [[  4.60912389e-07]]
[[-14.53526878]] [[  4

loss 0.897366
loss 0.893487
loss 0.889274
loss 0.88464
INFO:tensorflow:Restoring parameters from ./models/trained_flappy
[[ 0.89482641]] [[ 0.70988518]]
[[ 0.89482641]] [[ 0.70988518]]
[[ 0.89482629]] [[ 0.70988518]]
[[ 0.89494008]] [[ 0.7099086]]
[[ 0.89499307]] [[ 0.70991945]]
[[ 0.89526135]] [[ 0.70997477]]
[[ 0.89493108]] [[ 0.70990676]]
[[ 0.8949092]] [[ 0.70990223]]
[[ 0.89502609]] [[ 0.70992625]]
[[ 0.89537829]] [[ 0.70999885]]
[[ 0.89534861]] [[ 0.70999271]]
[[ 0.89544404]] [[ 0.71001238]]
[[ 0.8952651]] [[ 0.70997554]]
[[ 0.89513081]] [[ 0.70994788]]
[[ 0.8950187]] [[ 0.70992476]]
[[ 0.89486521]] [[ 0.70989317]]
[[ 0.89508963]] [[ 0.70993936]]
[[ 0.89503437]] [[ 0.70992804]]
[[ 0.89526606]] [[ 0.70997572]]
[[ 0.89526725]] [[ 0.70997596]]
[[ 0.89540613]] [[ 0.71000457]]
[[ 0.8956126]] [[ 0.71004707]]
[[ 0.89564341]] [[ 0.71005344]]
[[ 0.89571381]] [[ 0.71006793]]
[[ 0.89573479]] [[ 0.71007222]]
[[ 0.89563131]] [[ 0.71005088]]
[[ 0.89589882]] [[ 0.71010596]]
[[ 0.89588153]] [[ 0

[[-2.82490849]] [[ 0.05599292]]
[[-2.82934642]] [[ 0.0557588]]
[[-2.83512712]] [[ 0.05545522]]
-1
Game 1 over; alive frames: 32
[[-2.82560682]] [[ 0.05595602]]
[[-2.96717739]] [[ 0.04893091]]
[[-2.96463943]] [[ 0.04904916]]
[[-2.94268608]] [[ 0.05008333]]
[[-2.93660545]] [[ 0.05037341]]
[[-2.92198634]] [[ 0.05107734]]
[[-2.92591882]] [[ 0.05088707]]
[[-2.92926526]] [[ 0.05072569]]
[[-2.91394401]] [[ 0.05146855]]
[[-2.91411018]] [[ 0.05146044]]
[[-2.90159035]] [[ 0.052075]]
[[-2.90288496]] [[ 0.05201114]]
[[-2.90939546]] [[ 0.05169106]]
[[-2.90780592]] [[ 0.05176903]]
[[-2.90427256]] [[ 0.05194276]]
[[-2.88547277]] [[ 0.05287639]]
[[-2.88627863]] [[ 0.05283605]]
[[-2.89110661]] [[ 0.05259495]]
[[-2.88295579]] [[ 0.05300258]]
[[-2.88172197]] [[ 0.05306454]]
[[-2.86950302]] [[ 0.05368189]]
[[-2.87155247]] [[ 0.05357787]]
[[-2.87657404]] [[ 0.05332381]]
[[-2.86675239]] [[ 0.0538218]]
[[-2.86812234]] [[ 0.05375207]]
[[-2.84887409]] [[ 0.05473955]]
[[-2.8472867]] [[ 0.05482174]]
[[-2.8530590

[[-2.88166714]] [[ 0.0530673]]
[[-2.87496781]] [[ 0.05340496]]
[[-2.87608671]] [[ 0.05334842]]
[[-2.86732054]] [[ 0.05379287]]
[[-2.86651206]] [[ 0.05383404]]
[[-2.87505412]] [[ 0.05340059]]
-1
Game 2 over; alive frames: 264
[[-2.86893916]] [[ 0.05371054]]
[[-2.96718264]] [[ 0.04893067]]
[[-2.96464396]] [[ 0.04904895]]
[[-2.94269109]] [[ 0.05008309]]
[[-2.93661141]] [[ 0.05037312]]
[[-2.92199206]] [[ 0.05107706]]
[[-2.92592502]] [[ 0.05088678]]
[[-2.92927098]] [[ 0.05072542]]
[[-2.9139483]] [[ 0.05146834]]
[[-2.91411543]] [[ 0.05146018]]
[[-2.90157223]] [[ 0.0520759]]
[[-2.9028852]] [[ 0.05201112]]
[[-2.90943742]] [[ 0.05168901]]
[[-2.90232968]] [[ 0.05203852]]
[[-2.90582609]] [[ 0.05186631]]
[[-2.88539696]] [[ 0.05288019]]
[[-2.88435483]] [[ 0.0529324]]
[[-2.89292979]] [[ 0.05250417]]
[[-2.88856053]] [[ 0.05272196]]
[[-2.8887651]] [[ 0.05271175]]
[[-2.87288761]] [[ 0.05351022]]
[[-2.8736093]] [[ 0.05347367]]
[[-2.87923503]] [[ 0.05318965]]
[[-2.86966133]] [[ 0.05367385]]
[[-2.87106514

[[-2.87649965]] [[ 0.05332757]]
[[-2.88274503]] [[ 0.05301316]]
[[-2.8712616]] [[ 0.05359263]]
[[-2.86798429]] [[ 0.05375909]]
[[-2.85469604]] [[ 0.05443908]]
[[-2.85859537]] [[ 0.05423871]]
[[-2.86449409]] [[ 0.05393692]]
[[-2.85264635]] [[ 0.05454469]]
[[-2.85191393]] [[ 0.05458247]]
[[-2.83653116]] [[ 0.05538173]]
[[-2.836869]] [[ 0.05536406]]
[[-2.83953977]] [[ 0.05522455]]
[[-2.82460833]] [[ 0.05600879]]
[[-2.82291627]] [[ 0.05609832]]
[[-2.80893493]] [[ 0.05684325]]
[[-2.80920339]] [[ 0.05682887]]
[[-2.81728029]] [[ 0.05639749]]
[[-2.80712271]] [[ 0.05694049]]
[[-2.8011651]] [[ 0.05726125]]
[[-2.7850101]] [[ 0.0581396]]
[[-2.78765583]] [[ 0.05799489]]
[[-2.79715991]] [[ 0.05747784]]
[[-2.78455567]] [[ 0.05816449]]
[[-2.77939868]] [[ 0.05844764]]
[[-2.76582885]] [[ 0.05919889]]
[[-2.76979184]] [[ 0.05897857]]
[[-2.77589607]] [[ 0.05864069]]
[[-2.75930142]] [[ 0.05956348]]
[[-2.75989032]] [[ 0.0595305]]
[[-2.74657583]] [[ 0.06028032]]
-1
Game 3 over; alive frames: 287
[[-2.74805737

[[-2.72821951]] [[ 0.06132858]]
[[-2.73142576]] [[ 0.06114427]]
[[-2.73281193]] [[ 0.06106474]]
[[-2.72389674]] [[ 0.0615779]]
[[-2.72987461]] [[ 0.06123338]]
[[-2.72468877]] [[ 0.06153215]]
[[-2.71768737]] [[ 0.0619377]]
[[-2.72269082]] [[ 0.06164762]]
[[-2.73485684]] [[ 0.0609476]]
[[-2.72803807]] [[ 0.06133903]]
[[-2.72747564]] [[ 0.06137141]]
[[-2.72451448]] [[ 0.06154222]]
[[-2.72836828]] [[ 0.06132002]]
[[-2.72816157]] [[ 0.06133191]]
[[-2.72525358]] [[ 0.06149954]]
[[-2.72377586]] [[ 0.06158489]]
[[-2.72250724]] [[ 0.06165825]]
[[-2.72104359]] [[ 0.06174298]]
[[-2.72756243]] [[ 0.06136641]]
[[-2.72795033]] [[ 0.06134408]]
[[-2.72045517]] [[ 0.06177708]]
[[-2.72109103]] [[ 0.06174023]]
[[-2.71978903]] [[ 0.0618157]]
[[-2.73486042]] [[ 0.06094739]]
[[-2.73333621]] [[ 0.06103469]]
[[-2.72768831]] [[ 0.06135917]]
[[-2.72436833]] [[ 0.06155066]]
[[-2.72053647]] [[ 0.06177237]]
[[-2.72312641]] [[ 0.06162244]]
[[-2.71889997]] [[ 0.06186729]]
[[-2.71775246]] [[ 0.06193392]]
[[-2.7147810

[[ 0.6117692]] [[ 0.64834428]]
[[ 0.6125505]] [[ 0.64852238]]
[[ 0.61272383]] [[ 0.64856184]]
[[ 0.61223423]] [[ 0.64845026]]
[[ 0.61296999]] [[ 0.64861804]]
[[ 0.6130631]] [[ 0.6486392]]
[[ 0.6142689]] [[ 0.64891398]]
[[ 0.6146605]] [[ 0.64900321]]
[[ 0.61367679]] [[ 0.64877903]]
[[ 0.61519086]] [[ 0.64912403]]
[[ 0.61530399]] [[ 0.64914978]]
[[ 0.61559314]] [[ 0.64921558]]
[[ 0.61572886]] [[ 0.64924651]]
[[ 0.61532521]] [[ 0.6491546]]
[[ 0.61598909]] [[ 0.64930576]]
[[ 0.61560631]] [[ 0.64921862]]
[[ 0.61659753]] [[ 0.64944434]]
[[ 0.61646831]] [[ 0.6494149]]
[[ 0.61673725]] [[ 0.64947617]]
[[ 0.61729777]] [[ 0.64960372]]
[[ 0.61737281]] [[ 0.64962083]]
-1
Game 5 over; alive frames: 40
5 games finished. Exiting...
Processing training data...
201 images saved.
Calculating adjusted rewards..
Saving data...
Completed data parsing!
Loading data...
Parsing data...
loss -0.184307
loss -0.190652
loss -0.195615
loss -0.199554
loss -0.202729
loss 0.760657
loss 0.725954
loss 0.664627
loss 0.56

[[-4.87645912]] [[ 0.00756628]]
[[-4.87935972]] [[ 0.00754453]]
[[-4.88625765]] [[ 0.00749305]]
[[-4.87093067]] [[ 0.0076079]]
[[-4.86407328]] [[ 0.00765985]]
[[-4.84277534]] [[ 0.00782345]]
-1
Game 5 over; alive frames: 41
5 games finished. Exiting...
Processing training data...
221 images saved.
Calculating adjusted rewards..
Saving data...
Completed data parsing!
Loading data...
Parsing data...
loss -0.00107042
loss -0.00110733
loss -0.00114869
loss -0.00119524
loss -0.00124796
loss 0.00468607
loss 0.00389084
loss 0.00328385
loss 0.00280802
loss 0.002427
loss 0.00211648
loss 0.00185962
loss 0.00164444
loss 0.0014623
loss 0.00130664
loss 0.000849952
loss 0.000787999
loss 0.000731041
loss 0.000678628
loss 0.000630338
loss 0.0008872
loss 0.000794613
loss 0.000714633
loss 0.000645092
loss 0.000584164
INFO:tensorflow:Restoring parameters from ./models/trained_flappy
[[-7.19929123]] [[ 0.00074656]]
[[-7.19928932]] [[ 0.00074656]]
[[-7.19928932]] [[ 0.00074656]]
[[-7.19512844]] [[ 0.000749

[[-8.82509232]] [[ 0.00014698]]
[[-8.84628582]] [[ 0.00014389]]
[[-8.81586838]] [[ 0.00014834]]
[[-8.82019138]] [[ 0.0001477]]
[[-8.78724098]] [[ 0.00015265]]
[[-8.79853249]] [[ 0.00015093]]
[[-8.81645679]] [[ 0.00014825]]
[[-8.78193951]] [[ 0.00015346]]
[[-8.78668308]] [[ 0.00015273]]
[[-8.74623299]] [[ 0.00015903]]
[[-8.74627399]] [[ 0.00015903]]
[[-8.76452923]] [[ 0.00015615]]
[[-8.74408531]] [[ 0.00015938]]
[[-8.739501]] [[ 0.00016011]]
[[-8.70053768]] [[ 0.00016647]]
[[-8.70197487]] [[ 0.00016623]]
[[-8.71443558]] [[ 0.00016417]]
[[-8.67963314]] [[ 0.00016998]]
[[-8.68045044]] [[ 0.00016985]]
[[-8.64006424]] [[ 0.00017684]]
[[-8.64040661]] [[ 0.00017678]]
[[-8.65579796]] [[ 0.00017408]]
[[-8.62078476]] [[ 0.00018029]]
[[-8.62246037]] [[ 0.00017998]]
[[-8.5918293]] [[ 0.00018558]]
[[-8.59229469]] [[ 0.0001855]]
[[-8.61445522]] [[ 0.00018143]]
-1
Game 1 over; alive frames: 32
[[-8.59047318]] [[ 0.00018583]]
[[-8.94333935]] [[ 0.00013059]]
[[-8.93385696]] [[ 0.00013183]]
[[-8.8896122

-1
Game 1 over; alive frames: 32
[[-10.11398506]] [[  4.05074134e-05]]
[[-10.51557064]] [[  2.71102763e-05]]
[[-10.50472164]] [[  2.74059894e-05]]
[[-10.44922352]] [[  2.89699201e-05]]
[[-10.43974018]] [[  2.92459499e-05]]
[[-10.38675117]] [[  3.08374110e-05]]
[[-10.38780212]] [[  3.08050185e-05]]
[[-10.41012764]] [[  3.01249238e-05]]
[[-10.36459255]] [[  3.15283287e-05]]
[[-10.37072659]] [[  3.13355304e-05]]
[[-10.33483887]] [[  3.24804787e-05]]
[[-10.34582901]] [[  3.21254774e-05]]
[[-10.36480331]] [[  3.15216857e-05]]
[[-10.33430767]] [[  3.24977336e-05]]
[[-10.33269596]] [[  3.25501533e-05]]
[[-10.29057217]] [[  3.39505314e-05]]
[[-10.28634644]] [[  3.40942934e-05]]
[[-10.30581951]] [[  3.34368160e-05]]
[[-10.27716255]] [[  3.44088439e-05]]
[[-10.28018284]] [[  3.43050779e-05]]
[[-10.23392487]] [[  3.59291807e-05]]
[[-10.24330044]] [[  3.55939083e-05]]
[[-10.25308609]] [[  3.52473107e-05]]
[[-10.21031952]] [[  3.67873581e-05]]
[[-10.20961571]] [[  3.68132569e-05]]
[[-10.16257]] [[ 

[[-11.55996513]] [[  9.54040479e-06]]
[[-11.55847168]] [[  9.55466294e-06]]
[[-11.58915138]] [[  9.26598295e-06]]
-1
Game 1 over; alive frames: 32
[[-11.5693779]] [[  9.45102511e-06]]
[[-12.01907349]] [[  6.02809496e-06]]
[[-12.00691986]] [[  6.10180496e-06]]
[[-11.94966125]] [[  6.46137914e-06]]
[[-11.93293381]] [[  6.57037026e-06]]
[[-11.87636566]] [[  6.95275457e-06]]
[[-11.87948608]] [[  6.93109314e-06]]
[[-11.90374565]] [[  6.76497211e-06]]
[[-11.85958481]] [[  7.07041136e-06]]
[[-11.85783958]] [[  7.08276184e-06]]
[[-11.81537628]] [[  7.38999370e-06]]
[[-11.82792473]] [[  7.29784097e-06]]
[[-11.84989357]] [[  7.13926511e-06]]
[[-11.8190403]] [[  7.36296624e-06]]
[[-11.81237793]] [[  7.41218491e-06]]
[[-11.76337242]] [[  7.78446702e-06]]
[[-11.75896835]] [[  7.81882591e-06]]
[[-11.78026962]] [[  7.65403729e-06]]
[[-11.75341225]] [[  7.86238888e-06]]
[[-11.75141525]] [[  7.87810495e-06]]
[[-11.69700241]] [[  8.31864872e-06]]
[[-11.70750427]] [[  8.23174560e-06]]
[[-11.71862602]] [[

[[-12.91021633]] [[  2.47265325e-06]]
[[-12.90799141]] [[  2.47816092e-06]]
[[-12.94308662]] [[  2.39269775e-06]]
-1
Game 1 over; alive frames: 32
[[-12.9189291]] [[  2.45120327e-06]]
[[-13.42792606]] [[  1.47341518e-06]]
[[-13.4145422]] [[  1.49326775e-06]]
[[-13.34727573]] [[  1.59716978e-06]]
[[-13.33354568]] [[  1.61925027e-06]]
[[-13.27101994]] [[  1.72372722e-06]]
[[-13.27278233]] [[  1.72069190e-06]]
[[-13.29962444]] [[  1.67511939e-06]]
[[-13.24356461]] [[  1.77170830e-06]]
[[-13.24813366]] [[  1.76363164e-06]]
[[-13.20179844]] [[  1.84727253e-06]]
[[-13.21625423]] [[  1.82076087e-06]]
[[-13.24072647]] [[  1.77674372e-06]]
[[-13.20343113]] [[  1.84425903e-06]]
[[-13.20024586]] [[  1.85014278e-06]]
[[-13.1467104]] [[  1.95189023e-06]]
[[-13.14222717]] [[  1.96066048e-06]]
[[-13.16532898]] [[  1.91588492e-06]]
[[-13.12987995]] [[  1.98501925e-06]]
[[-13.1312561]] [[  1.98228940e-06]]
[[-13.0725975]] [[  2.10204553e-06]]
[[-13.08547211]] [[  2.07515609e-06]]
[[-13.09776878]] [[  2

KeyboardInterrupt: 