In [1]:
import tensorflow as tf

from anyrl.algos import DQN
from anyrl.envs import BatchedGymEnv
from anyrl.envs.wrappers import BatchedFrameStack
from anyrl.models import rainbow_models
from anyrl.rollouts import BasicPlayer, PrioritizedReplayBuffer, NStepPlayer
from anyrl.spaces import gym_space_vectorizer

from sonic_util import AllowBacktracking, make_env
import numpy as np
import csv
import ray
import time

In [2]:
@ray.remote(num_cpus=4,num_gpus=1)
class MultiAgent():
    """docstring for MultiAgent"""
    def __init__(self, num_agent=4):
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True 
        sess = tf.Session(config=config)

        agents = [DistralAgent(sess,i) for i in range(num_agent)]

        sess.run(tf.global_variables_initializer())

        [agent.update_target() for agent in agents]

    def train(self,distill_policy_weights):

        distill_grads_list = [agent.train(distill_policy_weights) for agent in agents]

        return distill_grads_list

In [3]:
@ray.remote(num_cpus=1)
class SonicEnv():

    def __init__(self,env_index):
        train_file = csv.reader(open('./sonic-train.csv','r'),delimiter=',')
        self.games = []
        for i,row in enumerate(train_file):
            if i == 0:
                continue
            self.games.append(row)

        self.env = AllowBacktracking(make_env(game=self.games[env_index][0],state=self.games[env_index][1]))

    def step(self,action):
        return self.env.step(action)

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)

    def action_space(self):
        return self.env.action_space.n

    def observation_space(self):
        return self.env.observation_space

In [4]:

class DistralAgent():

    def __init__(self,sess,env_index):
        # step 1: init env
        self.env_index = env_index
        self.env = SonicEnv.remote(env_index)
        action_space = ray.get(self.env.action_space.remote())
        observation_space = ray.get(self.env.observation_space.remote())
  
        self.sess = sess
        with tf.Graph().as_default():
            self.dqn = DQN(*rainbow_models(self.sess,
                                          action_space,
                                          gym_space_vectorizer(observation_space),
                                          min_val=-200,
                                          max_val=200))
        self.player = NStepPlayer(BasicPlayer(self.env, self.dqn.online_net), 3)

        self.replay_buffer = PrioritizedReplayBuffer(500000, 0.5, 0.4, epsilon=0.1)
        #self.sess.run(self.dqn.update_target)
        self.steps_taken = 0
        self.train_interval=1
        self.target_interval=8192
        self.batch_size=32
        self.min_buffer_size=200
        self.handle_ep=lambda steps, rew: None
        self.next_target_update = self.target_interval
        self.next_train_step = self.train_interval

    def update_target(self):
        self.sess.run(self.dqn.update_target)


    def init_env(self,env_index):
        train_file = csv.reader(open('./sonic-train.csv','r'),delimiter=',')
        self.games = []
        for i,row in enumerate(train_file):
            if i == 0:
                continue
            self.games.append(row)

        env = AllowBacktracking(make_env(game=self.games[env_index][0],state=self.games[env_index][1]))

        env = BatchedFrameStack(BatchedGymEnv([[env]]), num_images=4, concat=False)
        return env

    def train(self,distill_policy_weights):

        self.dqn.set_distill_policy_weights(distill_policy_weights)

        transitions = self.player.play()
        distill_grads = 0
        for trans in transitions:
                if trans['is_last']:
                    self.handle_ep(trans['episode_step'] + 1, trans['total_reward'])
                self.replay_buffer.add_sample(trans)
                self.steps_taken += 1
                if self.replay_buffer.size >= self.min_buffer_size and self.steps_taken >= self.next_train_step:
                    self.next_train_step = self.steps_taken + self.train_interval
                    batch = self.replay_buffer.sample(self.batch_size)

                    grad_names = []
                    for grad in self.dqn.distill_grads:
                        if grad[0] != None:
                            grad_names.append(grad[0])

                    _,losses,distill_grads = self.sess.run((self.dqn.optim,self.dqn.losses,grad_names),
                                         feed_dict=self.dqn.feed_dict(batch))
                    self.replay_buffer.update_weights(batch, losses)

                if self.steps_taken >= self.next_target_update:
                    self.next_target_update = self.steps_taken + self.target_interval
                    self.sess.run(self.dqn.update_target)

        return distill_grads

In [None]:
ray.init()

Process STDOUT and STDERR is being redirected to /tmp/raylogs/.
Waiting for redis server at 127.0.0.1:24205 to respond...


In [5]:
agents = [MultiAgent.remote() for i in range(4)]

Exception: Actors cannot be created before ray.init() has been called.