In [1]:
import collections
import random
import datetime

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

import gym_super_mario_bros

from nes_py.wrappers import JoypadSpace
from gym_super_mario_bros import actions

%load_ext tensorboard

In [2]:
state_shape = (56, 56, 1)
learning_rate = 0.1
discount_factor = 0.99999
epsilon = 0.8
eps_decay = 0.999
update_target_network_interval = 5
action_set = gym_super_mario_bros.actions.RIGHT_ONLY

In [3]:
logdir = "logs/scalars/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir, histogram_freq=1, write_grads=True)
saving_callback = tf.keras.callbacks.ModelCheckpoint('/tmp/mario0', period=500)
lr_reducer = tf.keras.callbacks.ReduceLROnPlateau(monitor='loss', factor=np.sqrt(0.1),
                                                  cooldown=50, patience=300,
                                                  min_lr=0.000025)
file_writer = tf.summary.create_file_writer(logdir + "/metrics")
file_writer.set_as_default()



In [4]:
def create_model():
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Conv2D(32, (8, 8), strides=(4, 4), activation='relu',
                                     input_shape=state_shape))
    model.add(tf.keras.layers.Conv2D(64, (4, 4), strides=(2, 2), activation='relu',
                                     kernel_initializer='he_normal'))
    model.add(tf.keras.layers.Conv2D(64, (3, 3), strides=(1, 1), activation='relu',
                                     kernel_initializer='he_normal'))
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(512, activation='relu', use_bias=False))
    model.add(tf.keras.layers.Dense(5, use_bias=False))
    model.compile(optimizer = tf.keras.optimizers.Adam(lr = learning_rate),
                  loss = 'mse', metrics=['mse'])
    return model
        

In [5]:
model = create_model()
model.summary()
target_model = create_model()
target_model.set_weights(model.get_weights())

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 13, 13, 32)        2080      
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 5, 5, 64)          32832     
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 3, 3, 64)          36928     
_________________________________________________________________
flatten (Flatten)            (None, 576)               0         
_________________________________________________________________
dense (Dense)                (None, 512)               294912    
_________________________________________________________________
dense_1 (Dense)              (None, 5)                 2560      
Total params: 369,312
Trainable params: 369,312
Non-trainable params: 0
__________________________________________________

In [6]:
def update_target_network(
    episode, update_target_network_interval, main_network, target_network):
    if ((episode+1) % update_target_network_interval) == 0:
        target_network.set_weights(main_network.get_weights())
    return target_network

In [7]:
env = gym_super_mario_bros.make('SuperMarioBros-v0')
env = JoypadSpace(env, action_set)

In [8]:
def select_action(epsilon, state):
    if (np.random.random() <= epsilon):
        return np.random.choice(len(action_set))
    else:
        return np.argmax(model.predict(state))

In [9]:
def greyscale(state):
    return tf.image.rgb_to_grayscale([state])[0]

def resize(state):
    return tf.compat.v1.image.resize_images([state], (state_shape[0], state_shape[1]))[0]

def downsample(state):
    state = resize(state)
    state = greyscale(state)
    state = (state - 128) / 128  # state is in the [0,255] range.
    return tf.cast(tf.reshape(state, (1,) + state_shape), tf.dtypes.bfloat16)

In [10]:
def compute_bellman_target(discount_factor, reward, model, state_next, done):
    if done:
        return reward
    return (reward + discount_factor * np.max(model.predict(state_next)))

In [11]:
initial_epoch = 0

def sample_from_replay_buffer_and_train_model(replay_buffer, batch_size, 
                                              model, target_model, discount_factor):
    global initial_epoch
    if(len(replay_buffer) >= batch_size):
        batch = random.sample(replay_buffer, batch_size)
        
        states = [item[0] for item in batch]
        actions = [item[1] for item in batch]
        rewards = [item[2]/15.0 for item in batch]  # rewards are in the [-15,15] range.
        states_next = [item[3] for item in batch]  
        dones = [item[4] for item in batch]
        
        stacked_states = np.empty(shape=(0,) + state_shape)
        for state in states:
            stacked_states = tf.concat((stacked_states, state), axis=0)

        # TODO:  I changed this from target network.  Which way is right?
        target_q_values = model.predict(stacked_states)
        print ('target_q_values before = {}'.format(target_q_values))
        for i in range(len(states)):
            target_q_values[i, actions[i]] = compute_bellman_target(
                discount_factor, rewards[i], target_model, states_next[i], dones[i])
        print ('target_q_values after = {}'.format(target_q_values))
        
        def summarize_q_values(epoch, logs): 
            if epoch % 10 != 0: return
            tf.summary.scalar('actions', data=tf.reduce_mean(actions), step=epoch)
            tf.summary.scalar('rewards', data=tf.reduce_mean(rewards), step=epoch)
            bt = [target_q_values[i, actions[i]] for i in range(len(states))]
            tf.summary.scalar('bellman_target', data=tf.reduce_mean(bt), step=epoch)
            delta = [rewards[i] - bt[i] for i in range(len(states))]
            tf.summary.scalar('delta-reward-bt', data=tf.reduce_mean(delta), step=epoch)
            tf.summary.scalar('target_q_values', data=tf.reduce_mean(target_q_values), step=epoch)
            tf.summary.scalar('max_rewards', data=tf.reduce_max(rewards), step=epoch)
            tf.summary.scalar('max_q', data=tf.reduce_max(target_q_values), step=epoch)            
            tf.summary.scalar('min_q', data=tf.reduce_min(target_q_values), step=epoch)            
            zeros = [np.count_nonzero(t==0) for t in target_q_values]
            tf.summary.scalar('zeros_q', data=tf.reduce_mean(zeros), step=epoch)                        
        summarize = tf.keras.callbacks.LambdaCallback(on_epoch_begin=summarize_q_values)

        model.fit(stacked_states, target_q_values,
                  epochs=initial_epoch + 1, initial_epoch=initial_epoch,
                  verbose=False, 
                  callbacks=[tensorboard_callback, saving_callback, lr_reducer,
                             summarize])
        initial_epoch += 1
    return model

In [12]:
batch_size = 16
replay_buffer_size = 2000
replay_buffer = collections.deque(maxlen = replay_buffer_size)

In [13]:
%tensorboard --logdir logs/scalars

Reusing TensorBoard on port 6006 (pid 1064), started 19:59:36 ago. (Use '!kill 1064' to kill it.)

In [14]:
for episode in range(100):
    state = downsample(env.reset())
    done = False
    
    while not done:
        action = select_action(epsilon, state)
        state_next, reward, done, info = env.step(action)
        state_next = downsample(state_next)
        replay_buffer.append((state, action, reward, state_next, done))
        model = sample_from_replay_buffer_and_train_model(
            replay_buffer, batch_size, model, target_model, discount_factor)
        state = state_next 
        
    target_network = update_target_network(
            episode, update_target_network_interval, model, target_model)
    epsilon = max(epsilon * eps_decay, 0.01)

target_q_values before = [[-0.00935046 -0.02743847  0.04431012 -0.01514357 -0.0198843 ]
 [-0.00775429 -0.02800521  0.04617035 -0.01531694 -0.0202534 ]
 [-0.00935046 -0.02743847  0.04431012 -0.01514357 -0.0198843 ]
 [-0.00839841 -0.02651665  0.04540986 -0.01507182 -0.01930286]
 [-0.01126588 -0.0268988   0.0444677  -0.01563934 -0.01898715]
 [-0.01019889 -0.02828924  0.0458653  -0.01433276 -0.01940974]
 [-0.01160986 -0.0273679   0.04570829 -0.01254732 -0.0204813 ]
 [-0.01032206 -0.02668443  0.04489542 -0.0149122  -0.01750922]
 [-0.00924347 -0.02904942  0.04580518 -0.01551813 -0.01932049]
 [-0.00935046 -0.02743847  0.04431012 -0.01514357 -0.0198843 ]
 [-0.00867484 -0.02997219  0.04480771 -0.01728663 -0.02016675]
 [-0.01132084 -0.02701094  0.04461735 -0.01171903 -0.02028387]
 [-0.00829883 -0.02955672  0.04340792 -0.01590681 -0.02166599]
 [-0.01052438 -0.02820709  0.04463421 -0.02148392 -0.01517634]
 [-0.00805387 -0.02886324  0.04602066 -0.01433504 -0.01872844]
 [-0.00956083 -0.02805928  0.0

target_q_values after = [[-4.03503876e+01 -4.47492485e+01 -4.55176468e+01  1.10074148e-01
  -4.29698944e+01]
 [-4.89532013e+01 -5.42434464e+01 -5.52529755e+01 -6.72124634e+01
   4.43096757e-02]
 [-5.19744682e+01 -5.75904579e+01 -5.87007980e+01 -7.12244949e+01
   1.81419283e-01]
 [-5.45868607e+01 -6.04639053e+01  1.12076074e-01 -7.46209717e+01
  -5.80926132e+01]
 [-5.25021629e+01 -5.81780548e+01 -5.92994499e+01  4.64029647e-02
  -5.58818283e+01]
 [-5.48463440e+01 -6.07534866e+01 -6.19462051e+01 -7.49782410e+01
   1.11192152e-01]
 [-3.93083038e+01 -4.35911179e+01 -4.43390160e+01  1.11473925e-01
  -4.18277359e+01]
 [-4.12107506e+01  4.58047278e-02 -4.64988556e+01 -5.69155579e+01
  -4.38778839e+01]
 [-4.89532013e+01 -5.42434464e+01  4.58648689e-02 -6.72124634e+01
  -5.20961990e+01]
 [ 4.62440662e-02 -4.56119461e+01 -4.63841553e+01 -5.68683357e+01
  -4.37946396e+01]
 [-3.90104866e+01 -4.32558708e+01  4.57078293e-02 -5.40267906e+01
  -4.15335426e+01]
 [-5.55850296e+01  1.13353118e-01 -6.2780

target_q_values after = [[-11.623469    -9.92231      0.04616991 -15.408962    -6.9295044 ]
 [-12.031958   -10.280169    -7.043397     0.17702784  -7.29231   ]
 [-11.916305   -10.179019     0.11063781 -15.589899    -7.1931877 ]
 [-12.139118   -10.404651    -7.183495     0.04640296  -7.4166694 ]
 [-12.079446   -10.326459    -7.0993648  -15.698741     0.04430968]
 [-12.079446   -10.326459    -7.0993648  -15.698741     0.04430968]
 [-12.079446   -10.326459     0.04586487 -15.698741    -7.34064   ]
 [-11.799086     0.11156163  -6.8723726  -15.519901    -7.1096315 ]
 [-12.467468   -10.6415205   -7.422065   -15.980199     0.04463373]
 [-12.253259   -10.436002    -7.206254     0.11147393  -7.431981  ]
 [-11.816044   -10.126965    -6.909742   -15.508413     0.11119215]
 [-11.655005    -9.965691    -6.7462735    0.1131473   -6.979913  ]
 [-12.048324     0.11124872  -7.091549   -15.667288    -7.3340206 ]
 [-12.685892   -10.842113    -7.6080675    0.11007415  -7.833167  ]
 [-12.245821     0.04580

target_q_values after = [[0.04624407 2.2041826  1.7976813  3.3839197  1.5419132 ]
 [1.8333374  2.4036543  1.9738767  0.1815889  1.7130678 ]
 [1.5324476  2.0888722  0.11063781 3.144921   1.4338312 ]
 [1.5034306  2.0773582  1.6400514  3.193657   0.17852665]
 [1.1770921  1.6925313  1.267345   2.1956985  0.04463373]
 [1.6962383  0.11335312 1.8332955  3.4400756  1.5865229 ]
 [1.6100541  0.11264787 1.748846   3.3233302  1.5082351 ]
 [1.5426776  2.096815   1.6812139  3.16455    0.18186446]
 [1.8223253  2.3923526  1.9620215  3.6845086  0.04602021]
 [1.972575   0.11156163 2.1104195  3.9561865  1.8448329 ]
 [1.6452322  2.1995895  1.7901419  0.11007415 1.5344502 ]
 [1.4935365  0.11124872 1.6242454  3.0173142  1.3970287 ]
 [1.5426776  2.096815   0.04586487 3.16455    1.4424067 ]
 [1.5286422  2.083117   1.6655821  3.1315875  0.11119215]
 [1.3591617  0.04580473 1.4758301  2.7435048  1.2887106 ]
 [2.058092   2.6391578  0.04616991 4.097841   1.9235617 ]]
target_q_values before = [[0.13069576 0.5555632

target_q_values after = [[0.         0.         0.04570783 0.         0.        ]
 [0.         0.         0.         0.         0.04430968]
 [0.         0.         0.11207607 0.         0.        ]
 [0.         0.         0.         0.         0.04602021]
 [0.         0.         0.         0.         0.18141928]
 [0.         0.         0.         0.         0.11048228]
 [0.         0.         0.18007122 0.         0.        ]
 [0.         0.         0.         0.11007415 0.        ]
 [0.04624407 0.         0.         0.         0.        ]
 [0.         0.         0.04586487 0.         0.        ]
 [0.18311423 0.         0.         0.         0.        ]
 [0.         0.11264787 0.         0.         0.        ]
 [0.         0.         0.         0.         0.18186446]
 [0.         0.         0.         0.11128356 0.        ]
 [0.1780235  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.17852665]]
target_q_values before = [[0. 0. 0. 0. 0.]
 [0.

target_q_values before = [[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]
target_q_values after = [[0.         0.11124872 0.         0.         0.        ]
 [0.11351626 0.         0.         0.         0.        ]
 [0.         0.         0.         0.11128356 0.        ]
 [0.         0.         0.         0.04640296 0.        ]
 [0.         0.         0.11661106 0.         0.        ]
 [0.         0.11335312 0.         0.         0.        ]
 [0.1780235  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.18186446]
 [0.         0.04580473 0.         0.         0.        ]
 [0.         0.         0.         0.         0.18141928]
 [0.         0.         0.18215801 0.         0.        ]
 [0.         0.         

target_q_values after = [[0.         0.         0.         0.         0.17852665]
 [0.         0.11264787 0.         0.         0.        ]
 [0.         0.         0.         0.17702784 0.        ]
 [0.         0.         0.         0.         0.18141928]
 [0.         0.         0.04586487 0.         0.        ]
 [0.         0.         0.         0.04640296 0.        ]
 [0.         0.11335312 0.         0.         0.        ]
 [0.14638093 0.         0.         0.         0.        ]
 [0.10882609 0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.04430968]
 [0.18311423 0.         0.         0.         0.        ]
 [0.         0.         0.04570783 0.         0.        ]
 [0.         0.         0.         0.         0.04463373]
 [0.11480085 0.         0.         0.         0.        ]
 [0.         0.         0.18007122 0.         0.        ]
 [0.         0.         0.11207607 0.         0.        ]]
target_q_values before = [[0. 0. 0. 0. 0.]
 [0.

target_q_values before = [[ 0.         0.         0.         0.         0.       ]
 [ 0.         0.         0.         0.         0.       ]
 [ 0.         0.         0.         0.         0.       ]
 [ 0.         0.         0.         0.         0.       ]
 [ 0.         0.         0.         0.         0.       ]
 [ 0.         0.         0.         0.         0.       ]
 [ 0.         0.         0.         0.         0.       ]
 [ 0.         0.         0.         0.         0.       ]
 [ 0.         0.         0.         0.         0.       ]
 [ 0.         0.         0.         0.         0.       ]
 [ 0.         0.         0.         0.         0.       ]
 [ 0.         0.         0.         0.         0.       ]
 [ 0.         0.         0.         0.         0.       ]
 [ 0.         0.         0.         0.         0.       ]
 [-2.1016154 -2.2252588 -2.5290742 -2.0563543 -2.1460028]
 [ 0.         0.         0.         0.         0.       ]]
target_q_values after = [[ 0.11480085  0.     

target_q_values after = [[0.         0.         0.         0.04640296 0.        ]
 [0.10882609 0.         0.         0.         0.        ]
 [0.18311423 0.         0.         0.         0.        ]
 [0.         0.11264787 0.         0.         0.        ]
 [0.21729733 0.         0.         0.         0.        ]
 [0.         0.         0.11661106 0.         0.        ]
 [0.         0.04580473 0.         0.         0.        ]
 [0.         0.         0.18716332 0.         0.        ]
 [0.         0.         0.         0.         0.18186446]
 [0.         0.         0.         0.         0.04430968]
 [0.         0.         0.         0.15468536 0.        ]
 [0.18178523 0.         0.         0.         0.        ]
 [0.11480085 0.         0.         0.         0.        ]
 [0.         0.         0.11207607 0.         0.        ]
 [0.         0.         0.         0.15976681 0.        ]
 [0.         0.         0.         0.         0.04430968]]
target_q_values before = [[0. 0. 0. 0. 0.]
 [0.

target_q_values after = [[0.         0.         0.         0.11128356 0.        ]
 [0.         0.         0.18716332 0.         0.        ]
 [0.11351626 0.         0.         0.         0.        ]
 [0.11480085 0.         0.         0.         0.        ]
 [0.         0.         0.         0.14161351 0.        ]
 [0.10882609 0.         0.         0.         0.        ]
 [0.         0.         0.         0.11007415 0.        ]
 [0.         0.         0.15866414 0.         0.        ]
 [0.         0.         0.         0.         0.18141928]
 [0.         0.         0.         0.1131473  0.        ]
 [0.         0.         0.         0.         0.17852665]
 [0.         0.         0.         0.22494504 0.        ]
 [0.         0.         0.         0.         0.11048228]
 [0.         0.11264787 0.         0.         0.        ]
 [0.         0.         0.         0.21205632 0.        ]
 [0.         0.         0.16049384 0.         0.        ]]
target_q_values before = [[0. 0. 0. 0. 0.]
 [0.

target_q_values after = [[0.         0.         0.14233696 0.         0.        ]
 [0.15574124 0.         0.         0.         0.        ]
 [0.11351626 0.         0.         0.         0.        ]
 [0.         0.         0.         0.17702784 0.        ]
 [0.         0.         0.         0.         0.04430968]
 [0.         0.         0.         0.22180687 0.        ]
 [0.16564736 0.         0.         0.         0.        ]
 [0.         0.         0.22155444 0.         0.        ]
 [0.         0.11264787 0.         0.         0.        ]
 [0.14606132 0.         0.         0.         0.        ]
 [0.         0.         0.         0.21205632 0.        ]
 [0.         0.11156163 0.         0.         0.        ]
 [0.         0.         0.04616991 0.         0.        ]
 [0.14638093 0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.11048228]
 [0.         0.         0.         0.         0.18329218]]
target_q_values before = [[0. 0. 0. 0. 0.]
 [0.

target_q_values after = [[0.         0.         0.11063781 0.         0.        ]
 [0.         0.         0.         0.04640296 0.        ]
 [0.         0.04580473 0.         0.         0.        ]
 [0.2099396  0.         0.         0.         0.        ]
 [0.         0.         0.22155444 0.         0.        ]
 [0.         0.         0.         0.         0.04463373]
 [0.04624407 0.         0.         0.         0.        ]
 [0.         0.         0.         0.22180687 0.        ]
 [0.         0.         0.         0.17702784 0.        ]
 [0.         0.14526322 0.         0.         0.        ]
 [0.         0.         0.         0.11147393 0.        ]
 [0.         0.         0.         0.15976681 0.        ]
 [0.         0.         0.14233696 0.         0.        ]
 [0.14931117 0.         0.         0.         0.        ]
 [0.14781536 0.         0.         0.         0.        ]
 [0.         0.11124872 0.         0.         0.        ]]
target_q_values before = [[0. 0. 0. 0. 0.]
 [0.

target_q_values after = [[0.         0.         0.04586487 0.         0.        ]
 [0.         0.16036437 0.         0.         0.        ]
 [0.         0.         0.04616991 0.         0.        ]
 [0.         0.         0.11063781 0.         0.        ]
 [0.14606132 0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.23012371]
 [0.         0.         0.         0.11007415 0.        ]
 [0.         0.         0.         0.         0.17852665]
 [0.11135925 0.         0.         0.         0.        ]
 [0.04624407 0.         0.         0.         0.        ]
 [0.         0.         0.11207607 0.         0.        ]
 [0.2141151  0.         0.         0.         0.        ]
 [0.         0.22169057 0.         0.         0.        ]
 [0.21729733 0.         0.         0.         0.        ]
 [0.         0.         0.         0.15976681 0.        ]
 [0.11351626 0.         0.         0.         0.        ]]
target_q_values before = [[0. 0. 0. 0. 0.]
 [0.

target_q_values after = [[0.         0.         0.         0.11113393 0.        ]
 [0.22577535 0.         0.         0.         0.        ]
 [0.         0.         0.         0.11147393 0.        ]
 [0.         0.11124872 0.         0.         0.        ]
 [0.14638093 0.         0.         0.         0.        ]
 [0.18178523 0.         0.         0.         0.        ]
 [0.         0.         0.         0.1815889  0.        ]
 [0.16564736 0.         0.         0.         0.        ]
 [0.         0.         0.         0.21906103 0.        ]
 [0.         0.         0.         0.         0.22653785]
 [0.         0.         0.         0.         0.04430968]
 [0.         0.16036437 0.         0.         0.        ]
 [0.         0.11335312 0.         0.         0.        ]
 [0.         0.         0.         0.         0.23012371]
 [0.         0.         0.         0.1131473  0.        ]
 [0.21584493 0.         0.         0.         0.        ]]
target_q_values before = [[0. 0. 0. 0. 0.]
 [0.

target_q_values after = [[0.         0.         0.         0.         0.18186446]
 [0.         0.         0.2195433  0.         0.        ]
 [0.         0.         0.         0.         0.20957875]
 [0.         0.         0.         0.         0.18141928]
 [0.         0.         0.         0.         0.21707672]
 [0.         0.         0.         0.22180687 0.        ]
 [0.22577535 0.         0.         0.         0.        ]
 [0.         0.         0.         0.11007415 0.        ]
 [0.         0.         0.         0.21473628 0.        ]
 [0.         0.04580473 0.         0.         0.        ]
 [0.         0.         0.         0.2266287  0.        ]
 [0.         0.         0.         0.21906103 0.        ]
 [0.         0.         0.         0.         0.21615306]
 [0.         0.         0.         0.         0.21240209]
 [0.         0.         0.         0.         0.04430968]
 [0.21455501 0.         0.         0.         0.        ]]
target_q_values before = [[0. 0. 0. 0. 0.]
 [0.

target_q_values after = [[0.         0.         0.         0.         0.04430968]
 [0.         0.         0.         0.         0.11048228]
 [0.         0.         0.         0.21205632 0.        ]
 [0.22190472 0.         0.         0.         0.        ]
 [0.16564736 0.         0.         0.         0.        ]
 [0.20349473 0.         0.         0.         0.        ]
 [0.         0.         0.         0.11113393 0.        ]
 [0.         0.         0.04570783 0.         0.        ]
 [0.14638093 0.         0.         0.         0.        ]
 [0.         0.         0.         0.15976681 0.        ]
 [0.         0.         0.         0.         0.04463373]
 [0.         0.11124872 0.         0.         0.        ]
 [0.         0.         0.15866414 0.         0.        ]
 [0.         0.21338029 0.         0.         0.        ]
 [0.18311423 0.         0.         0.         0.        ]
 [0.13913111 0.         0.         0.         0.        ]]
target_q_values before = [[0. 0. 0. 0. 0.]
 [0.

target_q_values after = [[0.14330009 0.         0.         0.         0.        ]
 [0.         0.         0.         0.11113393 0.        ]
 [0.22374925 0.         0.         0.         0.        ]
 [0.         0.         0.         0.14161351 0.        ]
 [0.2141151  0.         0.         0.         0.        ]
 [0.         0.         0.         0.2266287  0.        ]
 [0.         0.         0.         0.         0.15749815]
 [0.18311423 0.         0.         0.         0.        ]
 [0.         0.         0.         0.1131473  0.        ]
 [0.         0.         0.04616991 0.         0.        ]
 [0.         0.         0.18007122 0.         0.        ]
 [0.         0.         0.         0.17702784 0.        ]
 [0.         0.         0.16049384 0.         0.        ]
 [0.         0.         0.21682732 0.         0.        ]
 [0.         0.         0.21574102 0.         0.        ]
 [0.         0.         0.         0.         0.04430968]]
target_q_values before = [[0. 0. 0. 0. 0.]
 [0.

target_q_values before = [[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]
target_q_values after = [[0.21522032 0.         0.         0.         0.        ]
 [0.11480085 0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.20957875]
 [0.         0.         0.2339305  0.         0.        ]
 [0.22634089 0.         0.         0.         0.        ]
 [0.         0.         0.         0.21205632 0.        ]
 [0.13913111 0.         0.         0.         0.        ]
 [0.         0.14526322 0.         0.         0.        ]
 [0.         0.         0.18716332 0.         0.        ]
 [0.20705044 0.         0.         0.         0.        ]
 [0.14330009 0.         0.         0.         0.        ]
 [0.20112617 0.         

target_q_values after = [[0.         0.11124872 0.         0.         0.        ]
 [0.14606132 0.         0.         0.         0.        ]
 [0.         0.         0.         0.11128356 0.        ]
 [0.20705044 0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.04430968]
 [0.         0.         0.04616991 0.         0.        ]
 [0.22171001 0.         0.         0.         0.        ]
 [0.21584493 0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.20490782]
 [0.22037333 0.         0.         0.         0.        ]
 [0.         0.04476112 0.         0.         0.        ]
 [0.         0.         0.18215801 0.         0.        ]
 [0.         0.         0.         0.         0.21240209]
 [0.         0.         0.         0.         0.20957875]
 [0.         0.23188038 0.         0.         0.        ]
 [0.         0.         0.         0.22494504 0.        ]]


KeyboardInterrupt: 