Skip to content

Commit

Permalink
Merge pull request #53 from mlcommons/fixes
Browse files Browse the repository at this point in the history
Fixes
  • Loading branch information
xzfc committed Jan 12, 2021
2 parents deb79aa + 18dc5ec commit 3c604b5
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 18 deletions.
24 changes: 19 additions & 5 deletions ptd_client_server/lib/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,22 +154,36 @@ def main() -> None:

for mode in ["ranging", "testing"]:
logging.info(f"Running workload in {mode} mode")
out = f"{args.output}/{mode}"
out = os.path.join(args.output, mode)

os.mkdir(out)

env = os.environ.copy()
env["ranging"] = "1" if mode == "ranging" else "0"
env["out"] = out
newEnv = {
"ranging": "1" if mode == "ranging" else "0",
"out": out,
}
env = dict(os.environ.copy(), **newEnv)

common.ntp_sync(args.ntp)
command(serv, f"session,{session},start,{mode}", check=True)

logging.info("Running runWorkload")
logging.info(f"Running the workload {args.run_workload!r}")
logging.info(
"Environment variables: "
+ " ".join((f"{n}={v!r}" for n, v in newEnv.items()))
)
subprocess.run(args.run_workload, shell=True, check=True, env=env)

command(serv, f"session,{session},stop,{mode}", check=True)

if len(os.listdir(out)) == 0:
logging.fatal(f"The directory {out!r} is empty")
logging.fatal(
"Please make sure that the provided workload command writes its "
"output into the directory specified by environment variable $out"
)
exit(1)

if args.send_logs:
logging.info("Packing logs into zip and uploading to the server")
create_zip(f"{out}.zip", out)
Expand Down
1 change: 0 additions & 1 deletion ptd_client_server/lib/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import string
import subprocess
import sys
import threading


DEFAULT_PORT = 4950
Expand Down
29 changes: 17 additions & 12 deletions ptd_client_server/lib/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from ipaddress import ip_address
from typing import Optional, Dict, Tuple
import argparse
import time
import base64
import configparser
import datetime
Expand Down Expand Up @@ -181,9 +180,16 @@ def __init__(self, command: str, port: int) -> None:
self._init_Amps: Optional[str] = None
self._init_Volts: Optional[str] = None

def start(self) -> bool:
def start(self) -> None:
try:
self._start()
except Exception:
logging.exception("Could not start PTDaemon")
exit(1)

def _start(self) -> None:
if self._process is not None:
return True
return
if sys.platform == "win32":
# shell=False:
# On Windows, we don't need a shell to run a command from a single
Expand All @@ -205,6 +211,8 @@ def start(self) -> bool:
s = None
while s is None and retries > 0:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if self._process.poll() is not None:
raise RuntimeError("PTDaemon unexpectedly terminated")
try:
s.connect(("127.0.0.1", self._port))
except ConnectionRefusedError:
Expand All @@ -214,22 +222,19 @@ def start(self) -> bool:
s = None
retries -= 1
if s is None:
logging.error("Could not connect to PTD")
self.terminate()
return False
raise RuntimeError("Could not connect to PTDaemon")
self._socket = s
self._proto = common.Proto(s)

if self.cmd("Hello") != "Hello, PTDaemon here!":
logging.error("This is not PTDaemon")
return False
raise RuntimeError("This is not PTDaemon")

self.cmd("Identify") # reply traced in logs

logging.info("Connected to PTDaemon")

self._get_initial_range()
return True

def stop(self) -> None:
self.cmd("Stop")
Expand Down Expand Up @@ -267,6 +272,8 @@ def cmd(self, cmd: str) -> Optional[str]:
logging.info(f"Sending to ptd: {cmd!r}")
self._proto.send(cmd)
reply = self._proto.recv()
if reply is None:
exit_with_error_msg("Got no reply from PTDaemon")
logging.info(f"Reply from ptd: {reply!r}")
return reply

Expand Down Expand Up @@ -443,8 +450,7 @@ def start(self, mode: Mode) -> bool:
return True

if mode == Mode.RANGING and self._state == SessionState.INITIAL:
if not self._server._ptd.start():
return False
self._server._ptd.start()

common.ntp_sync(self._server._config.ntp_server)

Expand All @@ -461,8 +467,7 @@ def start(self, mode: Mode) -> bool:
return True

if mode == Mode.TESTING and self._state == SessionState.RANGING_DONE:
if not self._server._ptd.start():
return False
self._server._ptd.start()

common.ntp_sync(self._server._config.ntp_server)

Expand Down

0 comments on commit 3c604b5

Please sign in to comment.