Skip to content

Commit

Permalink
Test refactor
Browse files Browse the repository at this point in the history
This removes some redundant functions and simplifies
the mocking as some of it was outdated.
  • Loading branch information
oyvindeide committed Sep 26, 2023
1 parent 36da8dc commit c1ac285
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 203 deletions.
5 changes: 0 additions & 5 deletions tests/unit_tests/ensemble_evaluator/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import os
import stat
from dataclasses import dataclass
from pathlib import Path
from unittest.mock import Mock

Expand Down Expand Up @@ -106,10 +105,6 @@ def _make_ensemble_builder(monkeypatch, tmpdir, num_reals, num_jobs, job_sleep=0
)
)

@dataclass
class RunArg:
iens: int

for iens in range(0, num_reals):
run_path = Path(tmpdir / f"real_{iens}")
os.mkdir(run_path)
Expand Down
43 changes: 43 additions & 0 deletions tests/unit_tests/job_queue/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import stat
from pathlib import Path
from unittest.mock import MagicMock

import pytest

import ert
from ert.load_status import LoadStatus


@pytest.fixture
def mock_fm_ok(monkeypatch):
fm_ok = MagicMock(return_value=(LoadStatus.LOAD_SUCCESSFUL, ""))
monkeypatch.setattr(ert.job_queue.job_queue_node, "forward_model_ok", fm_ok)
yield fm_ok


@pytest.fixture
def simple_script(tmp_path):
SIMPLE_SCRIPT = """#!/bin/sh
echo "finished successfully" > STATUS
"""
fout = Path(tmp_path / "job_script")
fout.write_text(SIMPLE_SCRIPT, encoding="utf-8")
fout.chmod(stat.S_IRWXU | stat.S_IRWXO | stat.S_IRWXG)
yield str(fout)


@pytest.fixture
def failing_script(tmp_path):
"""
This script is susceptible to race conditions. Python works
better than sh."""
FAILING_SCRIPT = """#!/usr/bin/env python
import sys
with open("one_byte_pr_invocation", "a") as f:
f.write(".")
sys.exit(1)
"""
fout = Path(tmp_path / "failing_script")
fout.write_text(FAILING_SCRIPT, encoding="utf-8")
fout.chmod(stat.S_IRWXU | stat.S_IRWXO | stat.S_IRWXG)
yield str(fout)
85 changes: 28 additions & 57 deletions tests/unit_tests/job_queue/test_job_queue.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import json
import stat
import time
from dataclasses import dataclass
from pathlib import Path
from threading import BoundedSemaphore
from typing import Any, Callable, Dict, Optional
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import pytest

import ert.callbacks
from ert.config import QueueSystem
from ert.job_queue import Driver, JobQueue, JobQueueNode, JobStatus
from ert.load_status import LoadStatus


def wait_for(
Expand All @@ -28,71 +27,46 @@ def wait_for(
)


def dummy_exit_callback(*args):
print(args)


DUMMY_CONFIG: Dict[str, Any] = {
"job_script": "job_script.py",
"num_cpu": 1,
"job_name": "dummy_job_{}",
"run_path": "dummy_path_{}",
"ok_callback": lambda _, _b: (LoadStatus.LOAD_SUCCESSFUL, ""),
"exit_callback": dummy_exit_callback,
}

SIMPLE_SCRIPT = """#!/usr/bin/env python
print('hello')
"""

NEVER_ENDING_SCRIPT = """#!/usr/bin/env python
@pytest.fixture
def never_ending_script(tmp_path):
NEVER_ENDING_SCRIPT = """#!/usr/bin/env python
import time
while True:
time.sleep(0.5)
"""

FAILING_SCRIPT = """#!/usr/bin/env python
import sys
sys.exit(1)
"""


@dataclass
class RunArg:
iens: int
"""
fout = Path(tmp_path / "never_ending_job_script")
fout.write_text(NEVER_ENDING_SCRIPT, encoding="utf-8")
fout.chmod(stat.S_IRWXU | stat.S_IRWXO | stat.S_IRWXG)
yield str(fout)


def create_local_queue(
monkeypatch,
executable_script: str,
max_submit: int = 1,
max_runtime: Optional[int] = None,
callback_timeout: Optional["Callable[[int], None]"] = None,
):
monkeypatch.setattr(
ert.job_queue.job_queue_node, "forward_model_ok", DUMMY_CONFIG["ok_callback"]
)
monkeypatch.setattr(
JobQueueNode, "run_exit_callback", DUMMY_CONFIG["exit_callback"]
)

driver = Driver(driver_type=QueueSystem.LOCAL)
job_queue = JobQueue(driver, max_submit=max_submit)

scriptpath = Path(DUMMY_CONFIG["job_script"])
scriptpath.write_text(executable_script, encoding="utf-8")
scriptpath.chmod(stat.S_IRWXU | stat.S_IRWXO | stat.S_IRWXG)

for iens in range(10):
Path(DUMMY_CONFIG["run_path"].format(iens)).mkdir(exist_ok=False)
job = JobQueueNode(
job_script=DUMMY_CONFIG["job_script"],
job_script=executable_script,
job_name=DUMMY_CONFIG["job_name"].format(iens),
run_path=DUMMY_CONFIG["run_path"].format(iens),
num_cpu=DUMMY_CONFIG["num_cpu"],
status_file=job_queue.status_file,
exit_file=job_queue.exit_file,
run_arg=RunArg(iens),
run_arg=MagicMock(),
max_runtime=max_runtime,
callback_timeout=callback_timeout,
)
Expand All @@ -109,9 +83,9 @@ def start_all(job_queue, sema_pool):
job = job_queue.fetch_next_waiting()


def test_kill_jobs(tmpdir, monkeypatch):
def test_kill_jobs(tmpdir, monkeypatch, never_ending_script):
monkeypatch.chdir(tmpdir)
job_queue = create_local_queue(monkeypatch, NEVER_ENDING_SCRIPT)
job_queue = create_local_queue(never_ending_script)

assert job_queue.queue_size == 10
assert job_queue.is_active()
Expand Down Expand Up @@ -140,9 +114,9 @@ def test_kill_jobs(tmpdir, monkeypatch):
job.wait_for()


def test_add_jobs(tmpdir, monkeypatch):
def test_add_jobs(tmpdir, monkeypatch, simple_script):
monkeypatch.chdir(tmpdir)
job_queue = create_local_queue(monkeypatch, SIMPLE_SCRIPT)
job_queue = create_local_queue(simple_script)

assert job_queue.queue_size == 10
assert job_queue.is_active()
Expand All @@ -160,9 +134,9 @@ def test_add_jobs(tmpdir, monkeypatch):
job.wait_for()


def test_failing_jobs(tmpdir, monkeypatch):
def test_failing_jobs(tmpdir, monkeypatch, failing_script):
monkeypatch.chdir(tmpdir)
job_queue = create_local_queue(monkeypatch, FAILING_SCRIPT, max_submit=1)
job_queue = create_local_queue(failing_script, max_submit=1)

assert job_queue.queue_size == 10
assert job_queue.is_active()
Expand All @@ -186,20 +160,17 @@ def test_failing_jobs(tmpdir, monkeypatch):
assert job_queue.snapshot()[iens] == str(JobStatus.FAILED)


def test_timeout_jobs(tmpdir, monkeypatch):
def test_timeout_jobs(tmpdir, monkeypatch, never_ending_script):
monkeypatch.chdir(tmpdir)
job_numbers = set()

def callback(iens):
nonlocal job_numbers
job_numbers.add(iens)
mock_callback = MagicMock()

job_queue = create_local_queue(
monkeypatch,
NEVER_ENDING_SCRIPT,
never_ending_script,
max_submit=1,
max_runtime=5,
callback_timeout=callback,
callback_timeout=mock_callback,
)

assert job_queue.queue_size == 10
Expand All @@ -222,15 +193,15 @@ def callback(iens):
iens = job_queue._differ.qindex_to_iens(q_index)
assert job_queue.snapshot()[iens] == str(JobStatus.IS_KILLED)

assert job_numbers == set(range(10))
assert len(mock_callback.mock_calls) == 20

for job in job_queue.job_list:
job.wait_for()


def test_add_dispatch_info(tmpdir, monkeypatch):
def test_add_dispatch_info(tmpdir, monkeypatch, simple_script):
monkeypatch.chdir(tmpdir)
job_queue = create_local_queue(monkeypatch, SIMPLE_SCRIPT)
job_queue = create_local_queue(simple_script)
ens_id = "some_id"
cert = "My very nice cert"
token = "my_super_secret_token"
Expand Down Expand Up @@ -259,9 +230,9 @@ def test_add_dispatch_info(tmpdir, monkeypatch):
assert (runpath / cert_file).read_text(encoding="utf-8") == cert


def test_add_dispatch_info_cert_none(tmpdir, monkeypatch):
def test_add_dispatch_info_cert_none(tmpdir, monkeypatch, simple_script):
monkeypatch.chdir(tmpdir)
job_queue = create_local_queue(monkeypatch, SIMPLE_SCRIPT)
job_queue = create_local_queue(simple_script)
ens_id = "some_id"
dispatch_url = "wss://example.org"
cert = None
Expand Down
Loading

0 comments on commit c1ac285

Please sign in to comment.