Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
fix(helper): fix gpuutil exception
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Oct 15, 2019
1 parent 59539d7 commit 38147fe
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 11 deletions.
3 changes: 3 additions & 0 deletions docs/chapter/enviromentvars.md
Expand Up @@ -38,3 +38,6 @@ Default is not set. A random port will be used.

(*depreciated*) Paths of the third party components. See examples in GNES hub for latest usage.

## `GNES_IPC_SOCK_TMP`

Temp directory for ipc sockets, not used on Windows.
2 changes: 1 addition & 1 deletion gnes/flow/__init__.py
Expand Up @@ -703,7 +703,7 @@ def _build_graph(self, copy_flow: bool) -> 'Flow':
op_flow._build_level = BuildLevel.GRAPH
return op_flow

def build(self, backend: Optional[str] = 'thread', copy_flow: bool = False, *args, **kwargs) -> 'Flow':
def build(self, backend: Optional[str] = 'process', copy_flow: bool = False, *args, **kwargs) -> 'Flow':
"""
Build the current flow and make it ready to use
Expand Down
45 changes: 35 additions & 10 deletions gnes/service/base.py
Expand Up @@ -15,10 +15,13 @@

import copy
import multiprocessing
import os
import random
import tempfile
import threading
import time
import types
import uuid
from contextlib import ExitStack
from enum import Enum
from typing import Tuple, List, Union, Type
Expand Down Expand Up @@ -113,8 +116,19 @@ class EventLoopEnd(Exception):
pass


def build_socket(ctx: 'zmq.Context', host: str, port: int, socket_type: 'SocketType', identity: 'str' = None) -> Tuple[
'zmq.Socket', str]:
def get_random_ipc() -> str:
try:
tmp = os.environ['GNES_IPC_SOCK_TMP']
if not os.path.exists(tmp):
raise ValueError('This directory for sockets ({}) does not seems to exist.'.format(tmp))
tmp = os.path.join(tmp, str(uuid.uuid1())[:8])
except KeyError:
tmp = tempfile.NamedTemporaryFile().name
return 'ipc://%s' % tmp


def build_socket(ctx: 'zmq.Context', host: str, port: int,
socket_type: 'SocketType', identity: 'str' = None, use_ipc: bool = False) -> Tuple['zmq.Socket', str]:
sock = {
SocketType.PULL_BIND: lambda: ctx.socket(zmq.PULL),
SocketType.PULL_CONNECT: lambda: ctx.socket(zmq.PULL),
Expand All @@ -129,11 +143,14 @@ def build_socket(ctx: 'zmq.Context', host: str, port: int, socket_type: 'SocketT
}[socket_type]()

if socket_type.is_bind:
host = BaseService.default_host
if port is None:
sock.bind_to_random_port('tcp://%s' % host)
if use_ipc:
sock.bind(host)
else:
sock.bind('tcp://%s:%d' % (host, port))
host = BaseService.default_host
if port is None:
sock.bind_to_random_port('tcp://%s' % host)
else:
sock.bind('tcp://%s:%d' % (host, port))
else:
if port is None:
sock.connect(host)
Expand Down Expand Up @@ -329,8 +346,12 @@ def __init__(self, args):
self.last_dump_time = time.perf_counter()
self._model = None
self.use_event_loop = True
self.ctrl_addr = 'tcp://%s:%d' % (self.default_host, self.args.port_ctrl)
self.logger.info('control address: %s' % self.ctrl_addr)
self.ctrl_with_ipc = (os.name != 'nt')
if self.ctrl_with_ipc:
self.ctrl_addr = get_random_ipc()
else:
self.ctrl_addr = 'tcp://%s:%d' % (self.default_host, self.args.port_ctrl)

self.send_recv_kwargs = dict(
check_version=self.args.check_version,
timeout=self.args.timeout,
Expand Down Expand Up @@ -401,7 +422,12 @@ def _run(self, ctx):
self.handler.service_context = self
# print('!!!! t_id: %d service_context: %r' % (threading.get_ident(), self.handler.service_context))
self.logger.info('bind sockets...')
ctrl_sock, ctrl_addr = build_socket(ctx, self.default_host, self.args.port_ctrl, SocketType.PAIR_BIND)
if self.ctrl_with_ipc:
ctrl_sock, ctrl_addr = build_socket(ctx, self.ctrl_addr, None, SocketType.PAIR_BIND,
use_ipc=self.ctrl_with_ipc)
else:
ctrl_sock, ctrl_addr = build_socket(ctx, self.default_host, self.args.port_ctrl, SocketType.PAIR_BIND)

self.logger.info('control over %s' % (colored(ctrl_addr, 'yellow')))

in_sock, _ = build_socket(ctx, self.args.host_in, self.args.port_in, self.args.socket_in,
Expand All @@ -412,7 +438,6 @@ def _run(self, ctx):
self.args.identity)
self.logger.info('output %s:%s' % (self.args.host_out, colored(self.args.port_out, 'yellow')))


self.logger.info(
'input %s:%s\t output %s:%s\t control over %s' % (
self.args.host_in, colored(self.args.port_in, 'yellow'),
Expand Down

0 comments on commit 38147fe

Please sign in to comment.