In [None]:
from pathlib import Path

from gym import spaces
from stable_baselines.common.policies import CnnLnLstmPolicy, LstmPolicy
from stable_baselines.common.vec_env import SubprocVecEnv, DummyVecEnv
from stable_baselines import PPO2
from tesse.msgs import *
import time

from tesse_gym import get_network_config
from tesse_gym.tasks.goseek import GoSeekFullPerception, decode_observations

# Configuration

#### Set sim path

In [None]:
filename = Path("../../goseek-challenge/simulator/goseek-v0.1.4.x86_64")
assert filename.exists(), f"Must set a valid path!"

#### Set environment parameters


__Note__ To minimize training time during initial use, we've set `total_timestamps` and `n_environments` to 1e5 and 2 respectively. Setting `total_timestamps` to 3e6 and `n_environments` to 4 should produce an agent that approximates our baseline. 

In [None]:
n_environments = 6  # number of environments to train over
total_timesteps = 100000  # number of training timesteps
scene_id = [1, 2, 3, 4, 5, 5]  # list all available scenes
n_targets = 30  # number of targets spawned in each scene
target_found_reward = 3  # reward per found target
episode_length = 400



def make_unity_env(filename, num_env):
    """ Create a wrapped Unity environment. """

    def make_env(rank):

        def _thunk():
            env = GoSeekFullPerception(
                str(filename),
                network_config=get_network_config(worker_id=rank),
                n_targets=n_targets,
                episode_length=episode_length,
                scene_id=scene_id[rank],
                target_found_reward=target_found_reward,
            )

            return env

        return _thunk

    return SubprocVecEnv([make_env(i) for i in range(num_env)])

#### Launch environments.

In [None]:
env = make_unity_env(filename, n_environments)

# Define the Model 

The following network assumes an observation of consisting of RGB, segmentation, and depth images along with the agent's relative pose from start. Images are processed using the Stable Baseline default CNN. The resulting feature vector is concatenated with the pose vector and given to an LSTM.

In [None]:
import tensorflow as tf
from stable_baselines.common.policies import nature_cnn

#### Define network to consume images and pose

In [None]:
def decode_tensor_observations(observation, img_shape=(-1, 240, 320, 5)):
    """ Decode observation vector into images and poses.

    Args:
        observation (np.ndarray): Shape (N,) observation array of flattened
            images concatenated with a pose vector. Thus, N is equal to N*H*W*C + N*3.
        img_shape (Tuple[int, int, int, int]): Shapes of all images stacked in (N, H, W, C).
            Default value is (-1, 240, 320, 5).
    
    Returns:
        Tuple[tf.Tensor, tf.Tensor]: Tensors with the following information
            - Tensor of shape (N, `img_shape[1:]`) containing RGB,
                segmentation, and depth images stacked across the channel dimension.
            - Tensor of shape (N, 3) containing (x, y, heading) relative to starting point.
                (x, y) are in meters, heading is given in degrees in the range [-180, 180].
    """
    imgs = tf.reshape(observation[:, :-3], img_shape)
    pose = observation[:, -3:]

    return imgs, pose

In [None]:
def image_and_pose_network(observation, **kwargs):
    """ Network to process image and pose data.
    
    Use the stable baselines nature_cnn to process images. The resulting
    feature vector is then combined with the pose estimate and given to an
    LSTM (LSTM defined in PPO2 below).
    
    Args:
        raw_observations (tf.Tensor): 1D tensor containing image and 
            pose data.
        
    Returns:
        tf.Tensor: Feature vector. 
    """
    imgs, pose = decode_tensor_observations(observation)
    image_features = nature_cnn(imgs)
    return tf.concat((image_features, pose), axis=-1)

In [None]:
from stable_baselines.common.tf_layers import conv, linear, conv_to_fc, lstm
def image_and_pose_network2(observation, **kwargs):
    imgs, pose = decode_tensor_observations(observation)
    activ = tf.nn.relu
    layer_1 = activ(conv(imgs, 'c1', n_filters=64, filter_size=8, stride=4, init_scale=np.sqrt(2), **kwargs))
    layer_2 = activ(conv(layer_1, 'c2', n_filters=128, filter_size=4, stride=2, init_scale=np.sqrt(2), **kwargs))
    layer_3 = activ(conv(layer_2, 'c3', n_filters=128, filter_size=3, stride=1, init_scale=np.sqrt(2), **kwargs))
    layer_3 = conv_to_fc(layer_3)
    image_features =  activ(linear(layer_3, 'fc1', n_hidden=1024, init_scale=np.sqrt(2)))
    return tf.concat((image_features, pose), axis=-1)

#### Register custom network

Outputs of the network defined above will be fed into an LSTM defined below in PPO2.

In [None]:
policy_kwargs = {'cnn_extractor': image_and_pose_network}

In [None]:
class CustomCnnLnLstmPolicy(LstmPolicy):

    def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=1024, reuse=False, **_kwargs):
        super(CustomCnnLnLstmPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm, reuse,
                                              layer_norm=True, feature_extraction="cnn", **_kwargs)

In [None]:
model = PPO2(
    CustomCnnLnLstmPolicy,
    env,
    verbose=1,
    tensorboard_log="./tensorboard/",
    n_steps = 128,
    nminibatches=6,
    cliprange = 0.2,
    gamma=0.999,
    noptepochs = 3,
    learning_rate=0.0003,
#    full_tensorboard_log=True,
    policy_kwargs=policy_kwargs,
)

# Train the Model

#### Define logging directory and callback function to save checkpoints

In [None]:
log_dir = Path("results/goseek-ppo")
log_dir.mkdir(parents=True, exist_ok=True)

In [None]:
model = PPO2.load(str( f"results/ppo2-newhyper4600k.pkl"), env,verbose=1, tensorboard_log="./tensorboard/",)
#model.set_env(env)

In [None]:
from stable_baselines.common.callbacks import CallbackList, CheckpointCallback
checkpoint_callback = CheckpointCallback(save_freq=3000, save_path='./results/',
                                         name_prefix='ppo2-5-NR5')

callbackList = CallbackList([checkpoint_callback, TensorboardCallback()])

In [None]:
from stable_baselines.common.callbacks import CheckpointCallback
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path='./results/',
                                         name_prefix='ppo2-stalin-bigmem-newhyper8')
model.learn(total_timesteps=total_timesteps,  callback=checkpoint_callback, reset_num_timesteps=False)

In [None]:
model.save(str( f"results/ppo2-newhyper4700k.pkl"))

# Visualize Results

__Note__: Stable-Baselines requires that policy input dimensions be consistent across training and testing. Thus, the number of environments used for visualization must be a multiple of the number of environments used for training. The observation vector is then appropriately duplicated during inference. 

In [None]:
%matplotlib notebook
import matplotlib.pyplot as plt

#### Load model

In [None]:
MODEL_WEIGHTS_PATH = "./results/ppo2-newhyper.pkl"
assert MODEL_WEIGHTS_PATH, f"Must give a model weights path!"

#model = PPO2.load(str(MODEL_WEIGHTS_PATH))
n_train_envs = model.act_model.initial_state.shape[0]

#### Visualize all observed images

In [None]:
obs = env.reset()
rgb, segmentation, depth, pose = decode_observations(obs)
lstm_state = None

assert (
    n_train_envs % obs.shape[0] == 0
), f"The number of visualization environments must be a multiple of the training environments"

In [None]:
import sys
import numpy
np.set_printoptions(threshold=sys.maxsize)

print(segmentation[0])
fig, ax = plt.subplots(1, 3)
ax[0].imshow(rgb[0])
ax[1].imshow(segmentation[0])
ax[2].imshow(depth[0])

#### Run an episode and plot the first person agent view

In [None]:
done = False
fig, ax = plt.subplots(1, obs.shape[0])
ax = [ax] if obs.shape[0] == 1 else ax

for i in range(episode_length):
    actions, lstm_state = model.predict(
        np.concatenate((n_train_envs // obs.shape[0]) * [obs]),
        state=lstm_state,
        deterministic=False,
    )

    actions = actions[: obs.shape[0]]
    obs, reward, done, _ = env.step(actions)

    plt.cla()
    rgb, segmentation, depth, pose = decode_observations(obs)

    for i in range(obs.shape[0]):
        ax[i].imshow(rgb[i])
    fig.canvas.draw()

obs = env.reset()
rgb, segmentation, depth, pose = decode_observations(obs)
lstm_state = None