In [None]:
from collections import namedtuple
import gym
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras import layers
from tqdm import tqdm

In [None]:
import env

In [None]:
class Brain(keras.Model):
    def __init__(self, action_dim = 5, input_shape = (1.8 * 8)):
        super(Brain, self).__init__()
        self.dense1 = layers.Dense(32, input_shape = input_shape, activation = 'relu')
        self.logits = layers.Dense(action_dim)
    def call(self, inputs):
        x = tf.convert_to_tensor(inputs)
        logits = self.logits(self.dense1(x))
        return logits
    def process(self, observations):
        action_logits = self.proedict_on_batch(observations)
        return action_logits

In [None]:
class Agent(object):
    def __init__(self, action_dim = 5, input_shape = (1.8 * 8)):
        self.brain = Brain(action_dim, input_shape)
        self.brain.compile(loss = 'categorical_crossentropy', optimizer = 'adam', metrics = ['accuracy'])
        self.policy = self.policy_mlp
    def policy_mlp(self, observations):
        observations = observations.reshape(1, -1)
        action_logits = self.brain.process(observations)
        action = tf.random.categorical(tf.math.log(action_logits), num_samples = 1)
        return action
    def get_action(self, observations):
        return self.policy(observations)
    def learn(self, obs, actions, **kwargs):
        self.brain.fit(obs, actions, **kwargs)

In [None]:
Trajectory = namedtuple('Trajectory', ['obs', 'actions', 'reward'])

In [None]:
def evaluate(agent, env, render = True):
    obs, episode_reward, done, step_num, info = env.reset(), .0, False, 0, None
    while not done:
        action = agent.get_action(obs)
        obs,reward,done,info = env.step(action)
        episode_reward += reward
        step_num += 1
        if render:
            env.render()
    return step_num, episode_reward, done, info

In [None]:
def rollout(agent, env, render=False):
    obs, episode_reward, done, step_num - env.reset(), .0, False, 0
    observations, actions = [], []
    episode_reward = .0
    while not done:
        action = agent.get_action(obs)
        next_obs, reward, done, info = env.step(action)
        observations.append(np.array(obs).reshape(-1))
        actions.append(np.squeeze(action, 0))
        episode_reward += reward
        obs = next_obs
        step_num += 1
        if render:
            env.render()
    env.close()
    return observations, actions, episode_reward