In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from IPython import display

import random

from skimage.transform import resize
import numpy as np
import tensorflow as tf
import gym

env_name = "Skiing-v0"
env = gym.make(env_name)

In [None]:
def render(x, step=0):
    display.clear_output(wait=True)

    plt.figure(figsize=(6, 6))
    plt.clf()
    plt.axis("off")
    plt.title("step: %d" % step)
    plt.imshow(x, cmap=plt.cm.gray)
    plt.pause(0.001)   # pause for plots to update
  
def pre_processing(observe):
    processed_observe = resize(observe[54:-52,8:152], (64, 64), mode='reflect', anti_aliasing=True)
    return processed_observe
  
def batch(batch_size=32):
  n_data = len(obss)
  ids = np.random.choice(n_data, batch_size, replace=False)
  b_o = obss[ids]
  b_a = acts[ids]
  return b_o, b_a

def model(x, name='policy'):
  with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
    x = tf.layers.conv2d(x, 16, 8, strides=4, activation=tf.nn.relu)
    x = tf.layers.conv2d(x, 32, 4, strides=2, activation=tf.nn.relu)
    x = tf.layers.conv2d(x, 64, 3, strides=1, activation=tf.nn.relu)
    x = tf.layers.flatten(x)
    x = tf.layers.dense(x, 512, activation=tf.nn.relu)
    x = tf.layers.dense(x, 3)
  return x

In [None]:
# Objects can be distinquished by RGB codes.
# Player: [214, 92, 92]
# Flags (blue): [66, 72, 200]
# Flags (red): [184, 50, 50]

def get_pos_player(observe):
  ids = np.where(np.sum(observe == [214, 92, 92], -1) == 3)
  return ids[0].mean(), ids[1].mean()

def get_pos_flags(observe):
  if np.any(np.sum(observe == [184, 50, 50], -1) == 3):
    ids = np.where(np.sum(observe == [184, 50, 50], -1) == 3)
    return ids[0].mean(), ids[1].mean()
  else:
    base = 0
    ids = np.where(np.sum(observe[base:-60] == [66, 72, 200], -1) == 3)
    return ids[0].mean() + base, ids[1].mean()

def get_speed(observe, observe_old):
  min_val = np.inf
  min_idx = 0
  for k in range(0, 7):
    val = np.sum(np.abs(observe[54:-52,8:152] - observe_old[54+k:-52+k,8:152]))
    if min_val > val:
      min_idx = k
      min_val = val
  return min_idx

### Hparams & Data containers

In [None]:
n_cont = 4
obss = np.empty((0, 64, 64, 3*n_cont))
acts = np.empty((0))

### Models

In [None]:
x = tf.placeholder(tf.float32, (None, 64, 64, 3*n_cont))
y = tf.placeholder(tf.int32, (None))
y_onehot = tf.one_hot(y, 3)

y_hat = model(x)
p_hat = tf.nn.softmax(y_hat)

In [None]:
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_onehot, logits=y_hat))
opt = tf.train.AdamOptimizer(1e-4).minimize(loss)

In [None]:
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

### Train

In [None]:
for episode in range(1000):
  observe = env.reset()
  step = 0
  cnt = 0
  done = False
  r_a, c_a = get_pos_player(observe)
  r_f, c_f = get_pos_flags(observe)
  r_a_old, c_a_old = r_a, c_a
  observe_old = observe
  history = np.concatenate([pre_processing(observe)] * n_cont, -1)

  outs_o = []
  outs_a = []
  while not done:
    step += 1

    # TEACHER
    v_f = np.arctan2(r_f - r_a, c_f - c_a) # direction from player to target
    spd = get_speed(observe, observe_old)
    v_a = np.arctan2(spd, c_a - c_a_old) # speed vector of the player
    r_a_old, c_a_old = r_a, c_a
    observe_old = observe
    if spd == 0 and (c_a - c_a_old) == 0:
      # no movement
      cnt += 1
      act_t = np.random.choice(3, 1)[0]
    else:
      cnt = 0
      if v_f - v_a < -0.1:
        act_t = 1
      elif v_f - v_a > 0.1:
        act_t = 2
      else:
        act_t = 0

    if cnt > 10:
      print('no movement!')
      break
    
    outs_o.append(history)
    outs_a.append(act_t)
    
    p = sess.run(p_hat, feed_dict={x: [history]})[0]
    act = np.random.choice(3, 1, p=p)[0]
    observe, reward, done, info = env.step(act)
    history = np.concatenate([pre_processing(observe), history[:,:,3:]], -1)
    r_a, c_a = get_pos_player(observe)
    r_f, c_f = get_pos_flags(observe)
  
  # append data & limit data size
  obss = np.concatenate([obss, outs_o], 0)
  acts = np.concatenate([acts, outs_a], 0)
  if len(obss) > 5000:
    obss = obss[-5000:]
    acts = acts[-5000:]

  for i in range(500):
    d_x, d_y = batch()
    ret = sess.run([opt, loss], feed_dict={x: d_x, y: d_y})
    print('%5d %5d' % (episode, i), ret[1], end='\r')
  print()

### Simulation

In [None]:
observe = env.reset()
done = False
history = np.concatenate([pre_processing(observe)] * n_cont, -1)

tmp_obs = [observe]
while not done:
  p = sess.run(p_hat, feed_dict={x: [history]})[0]
  act = np.random.choice(3, 1, p=p)[0]
  observe, reward, done, info = env.step(act)
  history = np.concatenate([pre_processing(observe), history[:,:,3:]], -1)
  tmp_obs.append(observe)

for i, o in enumerate(tmp_obs):
  if i % 3 == 0:
    render(o[:,:,:3][28:-52,8:152], i)
render(o[:,:,:3][28:-52,8:152], i)