Skip to content

Commit

Permalink
More JobStateUpdateClient
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisburr committed Nov 30, 2023
1 parent 3d76331 commit dfc0f3d
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 35 deletions.
Original file line number Diff line number Diff line change
@@ -1,50 +1,128 @@
import functools
from datetime import datetime, timezone


from DIRAC.Core.Security.DiracX import DiracXClient
from DIRAC.Core.Utilities.ReturnValues import convertToReturnValue
from DIRAC.Core.Utilities.TimeUtilities import fromString


def stripValueIfOK(func):
"""Decorator to remove S_OK["Value"] from the return value of a function if it is OK.
This is done as some update functions return the number of modified rows in
the database. This likely not actually useful so it isn't supported in
DiracX. Stripping the "Value" key of the dictionary means that we should
get a fairly straight forward error if the assumption is incorrect.
"""

@functools.wraps(func)
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
if result.get("OK"):
assert result.pop("Value") is None, "Value should be None if OK"
return result

return wrapper


class JobStateUpdateClient:
def sendHeartBeat(self, jobID: str | int, dynamicData: dict, staticData: dict):
raise NotImplementedError("TODO")

@stripValueIfOK
@convertToReturnValue
def setJobApplicationStatus(self, jobID: str | int, appStatus: str, source: str = "Unknown"):
raise NotImplementedError("TODO")
statusDict = {
"application_status": appStatus,
}
if source:
statusDict["Source"] = source
with DiracXClient() as api:
api.jobs.set_single_job_status(
jobID,
{datetime.now(tz=timezone.utc): statusDict},
)

@stripValueIfOK
@convertToReturnValue
def setJobAttribute(self, jobID: str | int, attribute: str, value: str):
with DiracXClient() as api:
api.jobs.set_single_job_properties(jobID, "need to [patch the client to have a nice summer body ?")
raise NotImplementedError("TODO")
if attribute == "Status":
api.jobs.set_single_job_status(
jobID,
{datetime.now(tz=timezone.utc): {"status": value}},
)
else:
api.jobs.set_single_job_properties(jobID, {attribute: value})

@stripValueIfOK
@convertToReturnValue
def setJobFlag(self, jobID: str | int, flag: str):
raise NotImplementedError("TODO")
with DiracXClient() as api:
api.jobs.set_single_job_properties(jobID, {flag: True})

@stripValueIfOK
@convertToReturnValue
def setJobParameter(self, jobID: str | int, name: str, value: str):
raise NotImplementedError("TODO")
print("HACK: This is a no-op until we decide what to do")

@stripValueIfOK
@convertToReturnValue
def setJobParameters(self, jobID: str | int, parameters: list):
raise NotImplementedError("TODO")
print("HACK: This is a no-op until we decide what to do")

@stripValueIfOK
@convertToReturnValue
def setJobSite(self, jobID: str | int, site: str):
raise NotImplementedError("TODO")
with DiracXClient() as api:
api.jobs.set_single_job_properties(jobID, {"Site": site})

@stripValueIfOK
@convertToReturnValue
def setJobStatus(
self,
jobID: str | int,
status: str = "",
minorStatus: str = "",
source: str = "Unknown",
datetime=None,
datetime_=None,
force=False,
):
raise NotImplementedError("TODO")
statusDict = {}
if status:
statusDict["Status"] = status
if minorStatus:
statusDict["MinorStatus"] = minorStatus
if source:
statusDict["Source"] = source
if datetime_ is None:
datetime_ = datetime.utcnow()
with DiracXClient() as api:
api.jobs.set_single_job_status(
jobID,
{fromString(datetime_).replace(tzinfo=timezone.utc): statusDict},
force=force,
)

@stripValueIfOK
@convertToReturnValue
def setJobStatusBulk(self, jobID: str | int, statusDict: dict, force=False):
raise NotImplementedError("TODO")
statusDict = {fromString(k).replace(tzinfo=timezone.utc): v for k, v in statusDict.items()}
with DiracXClient() as api:
api.jobs.set_job_status_bulk(
{jobID: statusDict},
force=force,
)

def setJobsParameter(self, jobsParameterDict: dict):
raise NotImplementedError("TODO")

@stripValueIfOK
@convertToReturnValue
def unsetJobFlag(self, jobID: str | int, flag: str):
raise NotImplementedError("TODO")
with DiracXClient() as api:
api.jobs.set_single_job_properties(jobID, {flag: False})

def updateJobFromStager(self, jobID: str | int, status: str):
raise NotImplementedError("TODO")
Original file line number Diff line number Diff line change
@@ -1,12 +1,45 @@
from datetime import datetime
from functools import partial
from textwrap import dedent

import pytest

import DIRAC

DIRAC.initialize()
from DIRAC.Core.Security.DiracX import DiracXClient
from DIRAC.WorkloadManagementSystem.Client.JobStateUpdateClient import JobStateUpdateClient
from ..utils import compare_results
from ..utils import compare_results2

test_jdl = """
Arguments = "Hello world from DiracX";
Executable = "echo";
JobGroup = jobGroup;
JobName = jobName;
JobType = User;
LogLevel = INFO;
MinNumberOfProcessors = 1000;
OutputSandbox =
{
std.err,
std.out
};
Priority = 1;
Sites = ANY;
StdError = std.err;
StdOutput = std.out;
"""


@pytest.fixture()
def example_jobids():
from DIRAC.Interfaces.API.Dirac import Dirac
from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise

d = Dirac()
job_id_1 = returnValueOrRaise(d.submitJob(test_jdl))
job_id_2 = returnValueOrRaise(d.submitJob(test_jdl))
return job_id_1, job_id_2


def test_sendHeartBeat(monkeypatch):
Expand All @@ -15,16 +48,22 @@ def test_sendHeartBeat(monkeypatch):
pytest.skip()


def test_setJobApplicationStatus(monkeypatch):
def test_setJobApplicationStatus(monkeypatch, example_jobids):
# JobStateUpdateClient().setJobApplicationStatus(jobID: str | int, appStatus: str, source: str = Unknown)
method = JobStateUpdateClient().setJobApplicationStatus
pytest.skip()
args = ["MyApplicationStatus"]
test_func1 = partial(method, example_jobids[0], *args)
test_func2 = partial(method, example_jobids[1], *args)
compare_results2(monkeypatch, test_func1, test_func2)


def test_setJobAttribute(monkeypatch):
@pytest.mark.parametrize("args", [["Status", "Killed"], ["JobGroup", "newJobGroup"]])
def test_setJobAttribute(monkeypatch, example_jobids, args):
# JobStateUpdateClient().setJobAttribute(jobID: str | int, attribute: str, value: str)
method = JobStateUpdateClient().setJobAttribute
pytest.skip()
test_func1 = partial(method, example_jobids[0], *args)
test_func2 = partial(method, example_jobids[1], *args)
compare_results2(monkeypatch, test_func1, test_func2)


def test_setJobFlag(monkeypatch):
Expand All @@ -45,22 +84,37 @@ def test_setJobParameters(monkeypatch):
pytest.skip()


def test_setJobSite(monkeypatch):
@pytest.mark.parametrize("jobid_type", [int, str])
def test_setJobSite(monkeypatch, example_jobids, jobid_type):
# JobStateUpdateClient().setJobSite(jobID: str | int, site: str)
method = JobStateUpdateClient().setJobSite
pytest.skip()
args = ["LCG.CERN.ch"]
test_func1 = partial(method, jobid_type(example_jobids[0]), *args)
test_func2 = partial(method, jobid_type(example_jobids[1]), *args)
compare_results2(monkeypatch, test_func1, test_func2)


def test_setJobStatus(monkeypatch):
def test_setJobStatus(monkeypatch, example_jobids):
# JobStateUpdateClient().setJobStatus(jobID: str | int, status: str = , minorStatus: str = , source: str = Unknown, datetime = None, force = False)
method = JobStateUpdateClient().setJobStatus
pytest.skip()
args = ["", "My Minor"]
test_func1 = partial(method, example_jobids[0], *args)
test_func2 = partial(method, example_jobids[1], *args)
compare_results2(monkeypatch, test_func1, test_func2)


def test_setJobStatusBulk(monkeypatch):
def test_setJobStatusBulk(monkeypatch, example_jobids):
# JobStateUpdateClient().setJobStatusBulk(jobID: str | int, statusDict: dict, force = False)
method = JobStateUpdateClient().setJobStatusBulk
pytest.skip()
args = [
{
datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f"): {"ApplicationStatus": "SomethingElse"},
datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f"): {"ApplicationStatus": "Something"},
}
]
test_func1 = partial(method, example_jobids[0], *args)
test_func2 = partial(method, example_jobids[1], *args)
compare_results2(monkeypatch, test_func1, test_func2)


def test_setJobsParameter(monkeypatch):
Expand Down
53 changes: 40 additions & 13 deletions tests/Integration/FutureClient/utils.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,47 @@
def compare_results(test_func):
import time


def compare_results(monkeypatch, test_func):
"""Compare the results from DIRAC and DiracX based services for a reentrant function."""
ClientClass = test_func.func.__self__
assert ClientClass.diracxClient, "FutureClient is not set up!"
compare_results2(monkeypatch, test_func, test_func)


def compare_results2(monkeypatch, test_func1, test_func2):
"""Compare the results from DIRAC and DiracX based services for two functions which should behave identically."""
# Get the result from the diracx-based handler
future_result = test_func()
start = time.monotonic()
with monkeypatch.context() as m:
m.setattr("DIRAC.Core.Tornado.Client.ClientSelector.useLegacyAdapter", lambda *_: True)
try:
future_result = test_func1()
except Exception as e:
future_result = e
else:
assert "rpcStub" not in future_result, "rpcStub should never be present when using DiracX!"
diracx_duration = time.monotonic() - start

# Get the result from the DIRAC-based handler
diracxClient = ClientClass.diracxClient
ClientClass.diracxClient = None
try:
old_result = test_func()
finally:
ClientClass.diracxClient = diracxClient
# We don't care about the rpcStub
start = time.monotonic()
with monkeypatch.context() as m:
m.setattr("DIRAC.Core.Tornado.Client.ClientSelector.useLegacyAdapter", lambda *_: False)
old_result = test_func2()
assert "rpcStub" in old_result, "rpcStub should always be present when using legacy DIRAC!"
legacy_duration = time.monotonic() - start

# We don't care about the rpcStub or Errno
old_result.pop("rpcStub")
old_result.pop("Errno", None)

if not old_result["OK"]:
assert not future_result["OK"], "FutureClient should have failed too!"
elif "Value" in future_result:
# Ensure the results match exactly
assert old_result == future_result
else:
# See the "stripValueIfOK" decorator for explanation
assert old_result["OK"] == future_result["OK"]
# assert isinstance(old_result["Value"], int)

# Ensure the results match
assert old_result == future_result
# if 3 * legacy_duration < diracx_duration:
# print(f"Legacy DIRAC took {legacy_duration:.3f}s, FutureClient took {diracx_duration:.3f}s")
# assert False, "FutureClient should be faster than legacy DIRAC!"

0 comments on commit dfc0f3d

Please sign in to comment.