In [1]:
import numpy as np
import tensorflow as tf
from ludus.env import EnvController
from tensorflow.keras.layers import Conv2D, Dense, MaxPool2D, Flatten, Input, ZeroPadding2D
import cv2
import time
import copy
from nes_py.wrappers import BinarySpaceToDiscreteSpaceEnv
import gym_super_mario_bros
from gym_super_mario_bros.actions import COMPLEX_MOVEMENT

In [2]:
def make_env():
    env = gym_super_mario_bros.make('SuperMarioBros-v0')
    env = BinarySpaceToDiscreteSpaceEnv(env, COMPLEX_MOVEMENT)
    return env

In [3]:
env = make_env()
obs_shape = (42, 42)
obs_buffer_size = 3

with tf.variable_scope('policy'):
    state_ph = Input(list(obs_shape) + [obs_buffer_size])
    conv1 = Conv2D(32, 2, strides=(2, 2), activation='elu')(state_ph)
    mp1 = MaxPool2D(2)(conv1)
    conv2 = Conv2D(32, 2, activation='elu')(mp1)
    mp2 = MaxPool2D(2)(conv2)
    conv3 = Conv2D(32, 2, activation='elu')(mp2)
    mp3 = MaxPool2D(2)(conv3)
    flat = Flatten()(mp3)
    dense1 = Dense(128, activation='elu')(flat)
    act_oh = Dense(env.action_space.n, activation='softmax', use_bias=False)(dense1)
    act_out = tf.random.categorical(act_oh, 1)

  result = entry_point.load(False)


Instructions for updating:
Colocations handled automatically by placer.


In [4]:
lam = 0.1
beta = 0.2
enc_dim = 288 # Encoded Feature Dimension

# Placeholders
state_p_ph = tf.placeholder(dtype=tf.float32, shape=(None, *list(obs_shape), obs_buffer_size))
act_ph = tf.placeholder(dtype=tf.int32, shape=(None,))

# State Encoder
with tf.variable_scope('encoder'):
    enc_layers = [
        ZeroPadding2D(),
        Conv2D(32, 3, strides=(2, 2), activation='elu'),
        ZeroPadding2D(),
        Conv2D(32, 3, strides=(2, 2), activation='elu'),
        ZeroPadding2D(),
        Conv2D(32, 3, strides=(2, 2), activation='elu'),
        ZeroPadding2D(),
        Conv2D(32, 3, strides=(2, 2), activation='elu'),
        Flatten()
    ]
    
    enc_state = enc_layers[0](state_ph)
    for i in range(1, len(enc_layers)):
        enc_state = enc_layers[i](enc_state)
        
    enc_state_p = enc_layers[0](state_p_ph)
    for i in range(1, len(enc_layers)):
        enc_state_p = enc_layers[i](enc_state_p)
        
# Inverse Dynamics Model
with tf.variable_scope('inverse_model'):
    state_state_pair = tf.concat([enc_state, enc_state_p], axis=1)
    im_dense = Dense(256, activation='elu')(state_state_pair)
    act_pred = Dense(env.action_space.n, activation='softmax', use_bias=False)(im_dense)
    
# Foward Model
with tf.variable_scope('forward_model'):
    state_act_pair = tf.concat([enc_state, act_oh], axis=1)
    fm_dense = Dense(256, activation='elu')(state_act_pair)
    enc_state_p_pred = Dense(enc_dim, activation='elu')(fm_dense)
    # enc_state_p_pred = tf.clip_by_value(enc_state_p_pred, 0., 1.)
    
# Losses

# Inverse Dynamics Loss
act_actual_oh = tf.one_hot(act_ph, env.action_space.n)
loss_i = tf.reduce_mean(tf.keras.losses.categorical_crossentropy(act_actual_oh, act_pred))
loss_i = (1 - beta) * loss_i

# Forward Loss
state_p_diff = tf.square(enc_state_p_pred - enc_state_p)
loss_f = 0.5 * tf.reduce_mean(tf.reduce_sum(state_p_diff, axis=1))
loss_f = beta * loss_f

# Intrinsic Reward
ri = 0.5 * tf.reduce_mean(tf.reduce_sum(state_p_diff, axis=1))
loss_p = -lam * ri

icm_losses = [loss_i, loss_f, loss_p]

# Update Ops
optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
update_i = optimizer.minimize(loss_i, var_list=tf.trainable_variables(scope='encoder') + tf.trainable_variables(scope='inverse_model'))
update_f = optimizer.minimize(loss_f, var_list=tf.trainable_variables(scope='forward_model'))
update_p = optimizer.minimize(loss_p, var_list=tf.trainable_variables(scope='policy'))

icm_objective = [update_i, update_f, update_p]

Instructions for updating:
Use tf.cast instead.


In [7]:
def filter_obs(obs):
    obs = cv2.resize(obs, obs_shape, interpolation=cv2.INTER_LINEAR)
    obs = cv2.cvtColor(obs, cv2.COLOR_BGR2GRAY)
    return obs / 255

In [8]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [None]:
n_episodes = 1000000
steps_per_epoch = 2048
print_freq = 8
act_skip = 6
max_steps = int(4096 / act_skip)
render = False

all_rewards = []
ris, lis, lfs, lps = 0, 0, 0, 0
train_iters = 0
train_data = [] # Formatted as [obs_buffer_t, act_t, reward_t, obs_buffer_t+1]
best_runs = []
for episode in range(n_episodes):
    obs = env.reset()
    obs = filter_obs(obs)
    obs_buffer = np.rollaxis(np.array([obs] * obs_buffer_size), 0, 3)
    ep_reward = 0

    for step in range(max_steps):
        act = sess.run(act_out, feed_dict={state_ph: [obs_buffer]})[0][0]
        # act = np.random.randint(0, 7)
        
        step_reward = 0
        for i in range(act_skip):
            obs_p, r, d, _ = env.step(act)
            step_reward += r
            if d:
                break
        ep_reward += step_reward
        
        train_data.append([obs_buffer.copy(), act, step_reward])

        obs_p = filter_obs(obs_p)
        obs_buffer[:,:,:-1] = obs_buffer[:,:,1:]
        obs_buffer[:,:,-1] = obs_p

        train_data[-1].append(obs_buffer.copy())

        if render:
            env.render()
            time.sleep(0.02)

        if len(train_data) >= steps_per_epoch:
            np.random.shuffle(train_data)
            train_obs = np.array([x[0] for x in train_data])
            train_acts = np.array([x[1] for x in train_data])
            train_rewards = np.array([x[2] for x in train_data])
            train_obs_ps = np.array([x[3] for x in train_data])

            nri, li, lf, lp, _, _, _ = sess.run([ri] + icm_losses + icm_objective,
                                    feed_dict={
                                        state_ph: train_obs,
                                        act_ph: train_acts,
                                        state_p_ph: train_obs_ps
                                    })
            ris += nri
            lis += li
            lfs += lf
            lps += lp
            train_iters += 1
            
            train_data = []

        if d:
            break
            
    if ep_reward >= 2300:
        best_runs.append(copy.deepcopy(train_data[-step-1:]))
        print(f'Run with {ep_reward} reward')

    all_rewards.append(ep_reward)

    if (episode + 1) % print_freq == 0:
        print(f'R_e: {np.mean(all_rewards[-print_freq:])}, R_i: {ris/train_iters}')
        print(f'L_i: {lis/train_iters}, L_f: {lfs/train_iters}, L_p: {lps/train_iters}')
        print()
        
        ris, lis, lfs, lps = 0, 0, 0, 0
        train_iters = 0

R_e: 355.0, R_i: 6.833077907562256
L_i: 2.052494525909424, L_f: 1.366615653038025, L_p: -0.6833078265190125

R_e: 378.125, R_i: 2.3280125061670938
L_i: 1.998560945192973, L_f: 0.4656024972597758, L_p: -0.2328012486298879

R_e: 403.5, R_i: 2.0331903100013733
L_i: 1.9929001927375793, L_f: 0.4066380709409714, L_p: -0.2033190354704857

R_e: 367.75, R_i: 1.6912469466527302
L_i: 1.9970301787058513, L_f: 0.3382493903239568, L_p: -0.1691246951619784

R_e: 356.5, R_i: 0.6690746744473776
L_i: 1.9943735202153523, L_f: 0.13381493836641312, L_p: -0.06690746918320656

R_e: 383.625, R_i: 0.5299195498228073
L_i: 1.9906585216522217, L_f: 0.1059839129447937, L_p: -0.05299195647239685

R_e: 309.875, R_i: 0.8818743427594503
L_i: 1.9897712469100952, L_f: 0.17637487252553305, L_p: -0.08818743626276652

R_e: 293.0, R_i: 0.8507293462753296
L_i: 1.990841031074524, L_f: 0.17014586925506592, L_p: -0.08507293462753296

R_e: 238.625, R_i: 0.4895363748073578
L_i: 1.9875903725624084, L_f: 0.09790727496147156, L_p: -

R_e: 96.125, R_i: 0.19583994646867117
L_i: 1.9114255905151367, L_f: 0.03916799028714498, L_p: -0.01958399514357249

R_e: 54.25, R_i: 0.17849722504615784
L_i: 1.9275845885276794, L_f: 0.03569944482296705, L_p: -0.017849722411483526

R_e: 69.5, R_i: 0.18882253766059875
L_i: 1.9070003827412922, L_f: 0.03776450827717781, L_p: -0.018882254138588905

R_e: 90.5, R_i: 0.23161991933981577
L_i: 1.9179505904515584, L_f: 0.0463239848613739, L_p: -0.02316199243068695

R_e: 27.0, R_i: 0.30705372989177704
L_i: 1.9009760022163391, L_f: 0.061410749331116676, L_p: -0.030705374665558338

R_e: 68.375, R_i: 0.22247820099194845
L_i: 1.8889899253845215, L_f: 0.044495640943447747, L_p: -0.022247820471723873

R_e: 34.625, R_i: 0.2618107423186302
L_i: 1.9060755372047424, L_f: 0.05236214958131313, L_p: -0.026181074790656567

R_e: 48.25, R_i: 0.27340514461199444
L_i: 1.8945330381393433, L_f: 0.05468102917075157, L_p: -0.027340514585375786

R_e: 24.125, R_i: 0.2508036096890767
L_i: 1.8934104045232136, L_f: 0.05016

R_e: 85.875, R_i: 0.17695044974486032
L_i: 1.8362496693929036, L_f: 0.03539009019732475, L_p: -0.017695045098662376

R_e: 12.5, R_i: 0.16771257917086282
L_i: 1.8425149122873943, L_f: 0.03354251633087794, L_p: -0.01677125816543897

R_e: 21.625, R_i: 0.1705331727862358
L_i: 1.8248887658119202, L_f: 0.03410663641989231, L_p: -0.017053318209946156

R_e: 117.125, R_i: 0.18171665569146475
L_i: 1.8239476283391316, L_f: 0.036343333000938095, L_p: -0.018171666500469048

R_e: 65.375, R_i: 0.1774641474088033
L_i: 1.8355693022410076, L_f: 0.03549283059934775, L_p: -0.017746415299673874

R_e: 75.625, R_i: 0.15632963180541992
L_i: 1.819531261920929, L_f: 0.0312659265473485, L_p: -0.01563296327367425

R_e: 47.875, R_i: 0.17386223872502646
L_i: 1.820167899131775, L_f: 0.03477244824171066, L_p: -0.01738622412085533

R_e: 89.125, R_i: 0.1805775687098503
L_i: 1.8328160643577576, L_f: 0.03611551411449909, L_p: -0.018057757057249546

R_e: 58.5, R_i: 0.1937691867351532
L_i: 1.821010669072469, L_f: 0.0387538

R_e: 54.875, R_i: 0.2263376166423162
L_i: 1.8183639446894329, L_f: 0.045267523576815925, L_p: -0.022633761788407963

R_e: 10.625, R_i: 0.22037900984287262
L_i: 1.793946087360382, L_f: 0.044075801968574524, L_p: -0.022037900984287262

R_e: 125.0, R_i: 0.18627222875754038
L_i: 1.8165851831436157, L_f: 0.03725444649656614, L_p: -0.01862722324828307

R_e: 22.875, R_i: 0.22306901216506958
L_i: 1.8160701990127563, L_f: 0.04461380218466123, L_p: -0.022306901092330616

R_e: 89.625, R_i: 0.24887185543775558
L_i: 1.802974820137024, L_f: 0.04977437108755112, L_p: -0.02488718554377556

R_e: 75.25, R_i: 0.23834454516569772
L_i: 1.8015353679656982, L_f: 0.04766891027490298, L_p: -0.02383445513745149

R_e: 174.875, R_i: 0.2662657896677653
L_i: 1.8111714919408162, L_f: 0.053253158926963806, L_p: -0.026626579463481903

R_e: 16.5, R_i: 0.21749545633792877
L_i: 1.8103886842727661, L_f: 0.043499091640114784, L_p: -0.021749545820057392

R_e: 218.75, R_i: 0.25005261103312176
L_i: 1.8093153238296509, L_f: 0.

R_e: 107.75, R_i: 0.42166246473789215
L_i: 1.7914717197418213, L_f: 0.08433249220252037, L_p: -0.042166246101260185

R_e: 67.25, R_i: 0.4369794925053914
L_i: 1.807880123456319, L_f: 0.08739590148131053, L_p: -0.043697950740655266

R_e: 139.5, R_i: 0.4320560892422994
L_i: 1.7892718315124512, L_f: 0.08641122033198674, L_p: -0.04320561016599337

R_e: 125.5, R_i: 0.4988468289375305
L_i: 1.813711166381836, L_f: 0.09976936876773834, L_p: -0.04988468438386917

R_e: 6.875, R_i: 0.5181688666343689
L_i: 1.7842648824055989, L_f: 0.10363377630710602, L_p: -0.05181688815355301

R_e: 71.5, R_i: 0.48216254512468976
L_i: 1.7816830078760784, L_f: 0.09643250952164333, L_p: -0.04821625476082166

R_e: 157.125, R_i: 0.5158801227807999
L_i: 1.8093048334121704, L_f: 0.10317602753639221, L_p: -0.051588013768196106

R_e: 8.5, R_i: 0.45498016476631165
L_i: 1.7812662919362385, L_f: 0.09099603444337845, L_p: -0.045498017221689224

R_e: 62.0, R_i: 0.5406549374262491
L_i: 1.7892343600591023, L_f: 0.1081309889753659

R_e: 54.625, R_i: 0.9492529332637787
L_i: 1.7421146631240845, L_f: 0.1898505911231041, L_p: -0.09492529556155205

R_e: 72.75, R_i: 1.1307154496510823
L_i: 1.7519832452138264, L_f: 0.2261430968840917, L_p: -0.11307154844204585

R_e: 53.25, R_i: 1.1119122902552288
L_i: 1.766250769297282, L_f: 0.22238246103127798, L_p: -0.11119123051563899

R_e: -15.0, R_i: 0.9412918984889984
L_i: 1.7760155200958252, L_f: 0.18825838714838028, L_p: -0.09412919357419014

R_e: 44.875, R_i: 1.1075345277786255
L_i: 1.7694456179936726, L_f: 0.22150690853595734, L_p: -0.11075345426797867

R_e: 103.5, R_i: 1.0735809405644734
L_i: 1.7856628894805908, L_f: 0.21471619109312692, L_p: -0.10735809554656346

R_e: 100.125, R_i: 1.0517348051071167
L_i: 1.7697064876556396, L_f: 0.21034697443246841, L_p: -0.10517348721623421

R_e: 77.125, R_i: 1.2883622248967488
L_i: 1.7798113028208415, L_f: 0.257672443985939, L_p: -0.1288362219929695

R_e: 60.0, R_i: 1.0348154306411743
L_i: 1.7638638416926067, L_f: 0.2069630871216456, L_p:

R_e: 69.875, R_i: 1.291034460067749
L_i: 1.7325323224067688, L_f: 0.25820690393447876, L_p: -0.12910345196723938

R_e: 165.0, R_i: 1.711108406384786
L_i: 1.7751785516738892, L_f: 0.3422216872374217, L_p: -0.17111084361871085

R_e: 22.0, R_i: 1.4085260033607483
L_i: 1.7478429675102234, L_f: 0.28170520067214966, L_p: -0.14085260033607483

R_e: 28.0, R_i: 1.3634497721989949
L_i: 1.7396112283070881, L_f: 0.27268996834754944, L_p: -0.13634498417377472

R_e: 190.25, R_i: 1.4768288532892864
L_i: 1.7654232184092205, L_f: 0.29536577065785724, L_p: -0.14768288532892862

R_e: 63.75, R_i: 1.3803044557571411
L_i: 1.750073492527008, L_f: 0.27606089413166046, L_p: -0.13803044706583023

R_e: 67.25, R_i: 1.4319603443145752
L_i: 1.7389024496078491, L_f: 0.28639208277066547, L_p: -0.14319604138533273

R_e: 39.75, R_i: 1.2997668584187825
L_i: 1.7429654598236084, L_f: 0.259953369696935, L_p: -0.1299766848484675

R_e: 80.125, R_i: 1.6675457954406738
L_i: 1.7574055790901184, L_f: 0.3335091769695282, L_p: -0.

R_e: 41.125, R_i: 2.2845245997111
L_i: 1.7378268241882324, L_f: 0.45690492788950604, L_p: -0.22845246394475302

R_e: 14.5, R_i: 1.891304850578308
L_i: 1.7152981162071228, L_f: 0.3782609701156616, L_p: -0.1891304850578308

R_e: 24.875, R_i: 2.38569974899292
L_i: 1.761512041091919, L_f: 0.477139949798584, L_p: -0.238569974899292

R_e: 154.625, R_i: 2.043576995531718
L_i: 1.7468955119450886, L_f: 0.4087154070536296, L_p: -0.2043577035268148

R_e: -24.75, R_i: 1.7194057703018188
L_i: 1.7101767659187317, L_f: 0.34388116002082825, L_p: -0.17194058001041412

R_e: -28.875, R_i: 1.6990043719609578
L_i: 1.7029457489649455, L_f: 0.3398008743921916, L_p: -0.1699004371960958

R_e: 140.125, R_i: 2.4375224908192954
L_i: 1.742671807607015, L_f: 0.4875045120716095, L_p: -0.24375225603580475

R_e: 150.0, R_i: 1.9249382019042969
L_i: 1.7448996305465698, L_f: 0.38498765230178833, L_p: -0.19249382615089417

R_e: 34.875, R_i: 2.1611825625101724
L_i: 1.7199348608652751, L_f: 0.43223653237024945, L_p: -0.2161

R_e: 90.25, R_i: 2.7646025021870932
L_i: 1.729021430015564, L_f: 0.5529205103715261, L_p: -0.27646025518576306

R_e: -4.25, R_i: 2.5449063777923584
L_i: 1.7111451625823975, L_f: 0.5089812874794006, L_p: -0.2544906437397003

R_e: -38.5, R_i: 2.5325711568196616
L_i: 1.7305335601170857, L_f: 0.5065142313639323, L_p: -0.2532571156819661

R_e: 31.875, R_i: 2.8991196950276694
L_i: 1.7264331579208374, L_f: 0.5798239509264628, L_p: -0.2899119754632314

R_e: 82.875, R_i: 3.0198878049850464
L_i: 1.7212113738059998, L_f: 0.6039775907993317, L_p: -0.30198879539966583

R_e: 26.25, R_i: 2.733422597249349
L_i: 1.7211395899454753, L_f: 0.5466845333576202, L_p: -0.2733422666788101

R_e: 21.875, R_i: 2.790093104044596
L_i: 1.708811640739441, L_f: 0.5580186247825623, L_p: -0.27900931239128113

R_e: 59.75, R_i: 2.9547770023345947
L_i: 1.7567792534828186, L_f: 0.5909554064273834, L_p: -0.2954777032136917

R_e: 99.0, R_i: 2.778289477030436
L_i: 1.722423831621806, L_f: 0.5556579033533732, L_p: -0.27782895167

R_e: 170.875, R_i: 3.380659262339274
L_i: 1.7364594141642253, L_f: 0.6761318643887838, L_p: -0.3380659321943919

R_e: 47.25, R_i: 3.1297534306844077
L_i: 1.7243624130884807, L_f: 0.6259506940841675, L_p: -0.31297534704208374

R_e: 91.25, R_i: 3.028611898422241
L_i: 1.7238348126411438, L_f: 0.6057223975658417, L_p: -0.30286119878292084

R_e: -50.125, R_i: 2.8979559739430747
L_i: 1.7162675857543945, L_f: 0.5795912047227224, L_p: -0.2897956023613612

R_e: 80.0, R_i: 3.7842690149943032
L_i: 1.72083576520284, L_f: 0.7568537990252177, L_p: -0.3784268995126088

R_e: 45.375, R_i: 2.7692177295684814
L_i: 1.6983221769332886, L_f: 0.5538435578346252, L_p: -0.2769217789173126

R_e: -42.375, R_i: 2.7045787970225015
L_i: 1.6835299332936604, L_f: 0.5409157673517863, L_p: -0.27045788367589313

R_e: 45.125, R_i: 3.022427956263224
L_i: 1.7262410720189412, L_f: 0.6044856111208597, L_p: -0.3022428055604299

R_e: 28.875, R_i: 2.6675944328308105
L_i: 1.7183989882469177, L_f: 0.5335188955068588, L_p: -0.2667

In [None]:
run_num = 2
print(len(best_runs[run_num]))
print(sum([x[2] for x in best_runs[run_num]]))

In [None]:
obs = env.reset()
acts = [x[1] for x in best_runs[run_num]]
tr = 0
for i in range(len(acts)):
    for j in range(act_skip):
        _, r, d, _ = env.step(acts[i])
        tr += r
        
        env.render()
        time.sleep(0.02)
        if d:
            break
            
    
print(tr)

In [None]:
tr

In [None]:
best_runs[0][10]

In [None]:
print(np.mean(all_rewards))
print(np.mean(all_rewards[:int(len(all_rewards)/2)]))
print(np.mean(all_rewards[int(len(all_rewards)/2):]))
print(np.std(all_rewards))
print(np.std(all_rewards[:int(len(all_rewards)/2)]))
print(np.std(all_rewards[int(len(all_rewards)/2):]))

In [None]:
print(episode)

In [None]:
saver = tf.train.Saver()
saver.save(sess, '3670_episode.model')