Skip to content

Commit

Permalink
gRPC based microservice approach for distributed agents (#331)
Browse files Browse the repository at this point in the history
* Working draft.

* Changed relative import path.

* Added grpc requirements.

* Changed subprocess to multiprocess call.

* Upon termination, remote agent executes `Stop` rpc to stop remote worker process.

* Enabled client to terminate master process running in localhost.

* Black format.

* Edited cli.

* Changed handling of action response timeout.

* Extracted AgentServicer into separate file.

* Verifying errors.

* Added print statements.

* Fixed creation of multiple local zoo masters.

* Use grpc future and remove timeout, for act().

* Client informs master to stop worker process.

* Make format.

* Workers run using subprocess instead of multiprocessing.

* Separate master proto and worker proto.

* Make format.

* Use grpc StatusCode and context for error messages

* Clean up.

* Removed old authentication mechanism.

* Downgraded grpcio version.

* Follow our naming convention.

* Readable `for` loop variable.

* Changed master/worker to manager/worker.

* Improved comment on zoo_addrs.
  • Loading branch information
Adaickalavan committed Dec 31, 2020
1 parent f2af260 commit d006ace
Show file tree
Hide file tree
Showing 21 changed files with 1,320 additions and 484 deletions.
15 changes: 8 additions & 7 deletions cli/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,16 @@ def build():
)


@zoo_cli.command(name="worker", help="Start the agent worker")
@click.argument("auth_key", type=str, default=None)
@zoo_cli.command(
name="manager",
help="Start the manager process which instantiates workers. Workers execute remote agents.",
)
@click.argument("port", default=7432, type=int)
def worker(auth_key, port):
from smarts.zoo.worker import listen
def manager(port):
from smarts.zoo import manager as zoo_manager

auth_key = auth_key if auth_key else ""
listen(port, auth_key)
zoo_manager.serve(port)


zoo_cli.add_command(build_policy)
zoo_cli.add_command(worker)
zoo_cli.add_command(manager)
2 changes: 1 addition & 1 deletion envision/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

import smarts.core.models
from smarts.core.utils.file import path2hash
from .web import dist as web_dist
from envision.web import dist as web_dist

logging.basicConfig(level=logging.WARNING)

Expand Down
21 changes: 2 additions & 19 deletions examples/single_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,7 @@ def act(self, obs: Observation):
)


def main(
scenarios,
sim_name,
headless,
num_episodes,
seed,
auth_key=None,
max_episode_steps=None,
):
def main(scenarios, sim_name, headless, num_episodes, seed, max_episode_steps=None):
agent_spec = AgentSpec(
interface=AgentInterface.from_type(
AgentType.Laner, max_episode_steps=max_episode_steps
Expand All @@ -61,8 +53,7 @@ def main(
timestep_sec=0.1,
sumo_headless=True,
seed=seed,
# zoo_workers=[("143.110.210.157", 7432)], # Distribute social agents across these workers
auth_key=auth_key,
# zoo_addrs=[("10.193.241.236", 7432)], # Sample server address (ip, port), to distribute social agents in remote server.
# envision_record_data_replay_path="./data_replay",
)

Expand All @@ -83,20 +74,12 @@ def main(

if __name__ == "__main__":
parser = default_argument_parser("single-agent-example")
parser.add_argument(
"--auth_key",
type=str,
default=None,
help="Authentication key for connection to run agent",
)
args = parser.parse_args()
auth_key = args.auth_key if args.auth_key else ""

main(
scenarios=args.scenarios,
sim_name=args.sim_name,
headless=args.headless,
num_episodes=args.episodes,
seed=args.seed,
auth_key=auth_key,
)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ google==3.0.0
google-auth==1.23.0
google-auth-oauthlib==0.4.2
google-pasta==0.2.0
grpcio==1.32.0
grpcio==1.30.0
gym==0.17.3
h5py==2.10.0
hyperlink==20.0.1
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@
"matplotlib",
"scikit-image",
# The following are for /smarts/zoo
"twisted",
"grpcio==1.30.0",
"PyYAML",
"twisted",
],
extras_require={
"train": [
Expand All @@ -75,6 +76,7 @@
],
"dev": [
"black==19.10b0",
"grpcio-tools==1.30.0",
"sphinx",
"sphinx-rtd-theme",
"sphinxcontrib-apidoc",
Expand Down
32 changes: 13 additions & 19 deletions smarts/core/agent_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,20 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
import cloudpickle
import logging
from typing import Set

from envision.types import format_actor_id

from smarts.core.bubble_manager import BubbleManager
from smarts.core.data_model import SocialAgent
from smarts.core.mission_planner import MissionPlanner
from smarts.core.remote_agent_buffer import RemoteAgentBuffer
from smarts.core.sensors import Sensors
from smarts.core.utils.id import SocialAgentId
from smarts.core.vehicle import VehicleState
from smarts.zoo.registry import make as make_social_agent

from .mission_planner import MissionPlanner
from .remote_agent_buffer import RemoteAgentBuffer
from .sensors import Sensors
from .vehicle import VehicleState


class AgentManager:
"""Tracks agent states and implements methods for managing agent life cycle.
Expand All @@ -41,11 +40,9 @@ class AgentManager:
time.
"""

def __init__(self, interfaces, zoo_workers=None, auth_key=None):
def __init__(self, interfaces, zoo_addrs=None):
self._log = logging.getLogger(self.__class__.__name__)
self._remote_agent_buffer = RemoteAgentBuffer(
zoo_worker_addrs=zoo_workers, auth_key=auth_key
)
self._remote_agent_buffer = RemoteAgentBuffer(zoo_manager_addrs=zoo_addrs)

self._ego_agent_ids = set()
self._social_agent_ids = set()
Expand Down Expand Up @@ -229,17 +226,18 @@ def fetch_agent_actions(self, sim, ego_agent_actions):
try:
social_agent_actions = {
agent_id: (
self._remote_social_agents_action[agent_id].result()
cloudpickle.loads(
self._remote_social_agents_action[agent_id].result().action
)
if self._remote_social_agents_action.get(agent_id, None)
else None
)
for agent_id, remote_agent in self._remote_social_agents.items()
}
except Exception as e:
self._log.error(
"RemoteAgent: Resolving the remote agent's action (a Future object) generated exception."
"Resolving the remote agent's action (a Future object) generated exception."
)
self._log.exception(e)
raise e

agents_without_actions = [
Expand Down Expand Up @@ -300,9 +298,7 @@ def send_observations_to_social_agents(self, observations):
self._remote_social_agents_action = {}
for agent_id, remote_agent in self._remote_social_agents.items():
obs = observations[agent_id]
self._remote_social_agents_action[agent_id] = remote_agent.act(
obs, timeout=5
)
self._remote_social_agents_action[agent_id] = remote_agent.act(obs)

def setup_agents(self, sim):
self.init_ego_agents(sim)
Expand Down Expand Up @@ -471,9 +467,7 @@ def reset_agents(self, observations):
self._remote_social_agents_action = {}
for agent_id, remote_agent in self._remote_social_agents.items():
obs = observations[agent_id]
self._remote_social_agents_action[agent_id] = remote_agent.act(
obs, timeout=5
)
self._remote_social_agents_action[agent_id] = remote_agent.act(obs)

# Observations contain those for social agents; filter them out
return self._filter_for_active_ego(observations)
Expand Down
96 changes: 49 additions & 47 deletions smarts/core/remote_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,73 +18,75 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
import cloudpickle
import grpc
import logging
import time

from concurrent import futures
from multiprocessing.connection import Client

from .agent import AgentSpec
from smarts.core.agent import AgentSpec
from smarts.zoo import manager_pb2
from smarts.zoo import manager_pb2_grpc
from smarts.zoo import worker_pb2
from smarts.zoo import worker_pb2_grpc


class RemoteAgentException(Exception):
pass


class RemoteAgent:
def __init__(self, address, socket_family, auth_key, connection_retries=100):
def __init__(self, manager_address, worker_address):
self._log = logging.getLogger(self.__class__.__name__)
auth_key_conn = str.encode(auth_key) if auth_key else None

self._conn = None
self._tp_exec = futures.ThreadPoolExecutor()
# Track the last action future.
self._act_future = None

for i in range(connection_retries):
# Waiting on agent to open it's socket.
try:
self._conn = Client(
address, family=socket_family, authkey=auth_key_conn
)
break
except Exception:
self._log.debug(
f"RemoteAgent retrying connection to agent in: attempt {i}"
)
time.sleep(0.1)

if self._conn is None:
raise RemoteAgentException("Failed to connect to remote agent")

def __del__(self):
self.terminate()
self._manager_channel = grpc.insecure_channel(
f"{manager_address[0]}:{manager_address[1]}"
)
self._worker_address = worker_address
self._worker_channel = grpc.insecure_channel(
f"{worker_address[0]}:{worker_address[1]}"
)
try:
# Wait until the grpc server is ready or timeout after 30 seconds.
grpc.channel_ready_future(self._manager_channel).result(timeout=30)
grpc.channel_ready_future(self._worker_channel).result(timeout=30)
except grpc.FutureTimeoutError as e:
raise RemoteAgentException(
"Timeout while connecting to remote worker process."
) from e
self._manager_stub = manager_pb2_grpc.ManagerStub(self._manager_channel)
self._worker_stub = worker_pb2_grpc.WorkerStub(self._worker_channel)

def _act(self, obs, timeout):
# Send observation
self._conn.send({"type": "obs", "payload": obs})
# Receive action
if self._conn.poll(timeout):
try:
return self._conn.recv()
except ConnectionResetError as e:
self.terminate()
raise e
else:
return None
def act(self, obs):
# Run task asynchronously and return a Future.
self._act_future = self._worker_stub.act.future(
worker_pb2.Observation(payload=cloudpickle.dumps(obs))
)

def act(self, obs, timeout=None):
# Run task asynchronously and return a Future
return self._tp_exec.submit(self._act, obs, timeout)
return self._act_future

def start(self, agent_spec: AgentSpec):
# Send the AgentSpec to the agent runner
self._conn.send(
# We use cloudpickle only for the agent_spec to allow for serialization of lambdas
{"type": "agent_spec", "payload": cloudpickle.dumps(agent_spec)}
# Send the AgentSpec to the agent runner.
# Cloudpickle used only for the agent_spec to allow for serialization of lambdas.
self._worker_stub.build(
worker_pb2.Specification(payload=cloudpickle.dumps(agent_spec))
)

def terminate(self):
if self._conn:
self._conn.close()
# If the last action future returned is incomplete, cancel it first.
if (self._act_future is not None) and (not self._act_future.done()):
self._act_future.cancel()

# Close worker channel
self._worker_channel.close()

# Stop the remote worker process
response = self._manager_stub.stop_worker(
manager_pb2.Port(num=self._worker_address[1])
)

# Shutdown thread pool executor
self._tp_exec.shutdown()
# Close manager channel
self._manager_channel.close()
Loading

0 comments on commit d006ace

Please sign in to comment.