Skip to content

Commit

Permalink
supports upload task with gz
Browse files Browse the repository at this point in the history
1
  • Loading branch information
qingeng committed Jul 26, 2023
1 parent eae94bf commit 52a0e89
Show file tree
Hide file tree
Showing 9 changed files with 164 additions and 83 deletions.
37 changes: 7 additions & 30 deletions tests/test_plugins/test_mode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tidy3d.plugins.mode.solver import compute_modes
from tidy3d import ScalarFieldDataArray
from tidy3d.web.environment import Env
from tidy3d.version import __version__


WAVEGUIDE = td.Structure(geometry=td.Box(size=(100, 0.5, 0.5)), medium=td.Medium(permittivity=4.0))
Expand Down Expand Up @@ -44,7 +45,7 @@ def mock_download(task_id, remote_path, to_file, *args, **kwargs):
sources=[SRC],
)
mode_spec = td.ModeSpec(
num_modes=1,
num_modes=3,
target_neff=2.0,
filter_pol="tm",
precision="double",
Expand Down Expand Up @@ -80,7 +81,9 @@ def mock_download(task_id, remote_path, to_file, *args, **kwargs):
"projectId": PROJECT_ID,
"taskName": TASK_NAME,
"modeSolverName": MODESOLVER_NAME,
"fileType": "Json",
"fileType": "Gz",
"source": "Python",
"protocolVersion": __version__,
}
)
],
Expand All @@ -91,33 +94,7 @@ def mock_download(task_id, remote_path, to_file, *args, **kwargs):
"status": "draft",
"createdAt": "2023-05-19T16:47:57.190Z",
"charge": 0,
"fileType": "Json",
}
},
status=200,
)

responses.add(
responses.POST,
f"{Env.current.web_api_endpoint}/tidy3d/modesolver/py",
match=[
responses.matchers.json_params_matcher(
{
"projectId": PROJECT_ID,
"taskName": TASK_NAME,
"modeSolverName": MODESOLVER_NAME,
"fileType": "Hdf5",
}
)
],
json={
"data": {
"refId": TASK_ID,
"id": SOLVER_ID,
"status": "draft",
"createdAt": "2023-05-19T16:47:57.190Z",
"charge": 0,
"fileType": "Hdf5",
"fileType": "Gz",
}
},
status=200,
Expand Down Expand Up @@ -149,7 +126,7 @@ def mock_download(task_id, remote_path, to_file, *args, **kwargs):
"status": "queued",
"createdAt": "2023-05-19T16:47:57.190Z",
"charge": 0,
"fileType": "Json",
"fileType": "Gz",
}
},
status=200,
Expand Down
2 changes: 2 additions & 0 deletions tests/test_web/test_tidy3d_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def test_create(set_api_key):
{
"taskName": "test task",
"callbackUrl": None,
"fileType": "Gz",
"simulationType": "tidy3d",
"parentTasks": None,
}
Expand Down Expand Up @@ -186,6 +187,7 @@ def test_submit(set_api_key):
{
"taskName": task_name,
"callbackUrl": None,
"fileType": "Gz",
"simulationType": "tidy3d",
"parentTasks": None,
}
Expand Down
4 changes: 2 additions & 2 deletions tests/test_web/test_webapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import os

import tidy3d as td
import tidy3d.web as web

from responses import matchers

Expand All @@ -16,7 +15,6 @@
from tidy3d.web.webapi import download_log, estimate_cost, get_info, get_run_info, get_tasks
from tidy3d.web.webapi import load, load_simulation, start, upload, monitor, real_cost
from tidy3d.web.container import Job, Batch
from tidy3d.web.task import TaskInfo
from tidy3d.web.asynchronous import run_async

from tidy3d.__main__ import main
Expand Down Expand Up @@ -80,6 +78,7 @@ def mock_upload(monkeypatch, set_api_key):
"callbackUrl": None,
"simulationType": "tidy3d",
"parentTasks": None,
"fileType": "Gz",
}
)
],
Expand All @@ -97,6 +96,7 @@ def mock_download(*args, **kwargs):
pass

monkeypatch.setattr("tidy3d.web.simulation_task.upload_string", mock_download)
monkeypatch.setattr("tidy3d.web.simulation_task.upload_file", mock_download)


@pytest.fixture
Expand Down
100 changes: 68 additions & 32 deletions tidy3d/plugins/mode/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,20 @@
from ...components.data.monitor_data import ModeSolverData
from ...exceptions import WebError
from ...log import log
from ...web.file_util import compress_file_to_gzip, extract_gz_file
from ...web.http_management import http
from ...web.s3utils import download_file, upload_file, upload_string
from ...web.simulation_task import Folder, SIMULATION_JSON
from ...web.simulation_task import Folder, SIMULATION_JSON, SIM_FILE_HDF5_GZ
from ...web.types import ResourceLifecycle, Submittable

from .mode_solver import ModeSolver, MODE_MONITOR_NAME
from ...version import __version__

MODESOLVER_API = "tidy3d/modesolver/py"
MODESOLVER_JSON = "mode_solver.json"
MODESOLVER_HDF5 = "mode_solver.hdf5"
MODESOLVER_GZ = "mode_solver.hdf5.gz"

MODESOLVER_LOG = "output/result.log"
MODESOLVER_RESULT = "output/result.hdf5"

Expand Down Expand Up @@ -94,7 +98,7 @@ def run(
status = task.get_info().status

if status == "error":
raise WebError("Error runnig mode solver.")
raise WebError("Error running mode solver.")

log.log(log_level, f"Mode solver status: {status}")
if verbose:
Expand Down Expand Up @@ -154,6 +158,7 @@ class ModeSolverTask(ResourceLifecycle, Submittable, extra=pydantic.Extra.allow)
)

# pylint: disable=arguments-differ
# pylint: disable=protected-access
@classmethod
def create(
cls,
Expand Down Expand Up @@ -195,8 +200,10 @@ def create(
{
"projectId": folder.folder_id,
"taskName": task_name,
"protocolVersion": __version__,
"modeSolverName": mode_solver_name,
"fileType": "Hdf5" if len(mode_solver.simulation.custom_datasets) > 0 else "Json",
"fileType": "Gz",
"source": "Python",
},
)
log.info(
Expand Down Expand Up @@ -259,6 +266,7 @@ def get_info(self) -> ModeSolverTask:
resp = http.get(f"{MODESOLVER_API}/{self.task_id}/{self.solver_id}")
return ModeSolverTask(**resp, mode_solver=self.mode_solver)

# pylint: disable=protected-access
def upload(
self, verbose: bool = True, progress_callback: Callable[[float], None] = None
) -> None:
Expand All @@ -273,46 +281,60 @@ def upload(
"""
mode_solver = self.mode_solver.copy()

# Upload simulation as json for GUI display
sim = mode_solver.simulation

file, file_name = tempfile.mkstemp()
gz_file, gz_file_name = tempfile.mkstemp()
os.close(file)
os.close(gz_file)

# Upload simulation as json for download_json
upload_string(
self.task_id,
mode_solver.simulation._json_string, # pylint: disable=protected-access
sim._json_string, # pylint: disable=protected-access
SIMULATION_JSON,
verbose=verbose,
progress_callback=progress_callback,
)

if self.file_type == "Hdf5":
# Upload a single HDF5 file with the full data
file, file_name = tempfile.mkstemp()
os.close(file)
mode_solver.to_hdf5(file_name)

try:
upload_file(
self.solver_id,
file_name,
MODESOLVER_HDF5,
verbose=verbose,
progress_callback=progress_callback,
)
finally:
os.unlink(file_name)
else:
# Send only mode solver, without simulation
mode_solver_spec = mode_solver.dict()
sim.to_hdf5(file_name)
try:
# Upload simulation.hdf5.gz for GUI display
# compress .hdf5 to .hdf5.gz
compress_file_to_gzip(file_name, gz_file_name)

# Upload mode solver without simulation: 'construct' skips all validation
mode_solver_spec["simulation"] = None
mode_solver = ModeSolver.construct(**mode_solver_spec)
upload_string(
upload_file(
self.task_id,
gz_file_name,
SIM_FILE_HDF5_GZ,
verbose=verbose,
progress_callback=progress_callback,
)
finally:
os.unlink(file_name)
os.unlink(gz_file_name)

# Upload a single HDF5 file with the full data
file, file_name = tempfile.mkstemp()
gz_file, gz_file_name = tempfile.mkstemp()
os.close(file)
os.close(gz_file)
mode_solver.to_hdf5(file_name)

try:
# compress .hdf5 to .hdf5.gz
compress_file_to_gzip(file_name, gz_file_name)

upload_file(
self.solver_id,
mode_solver._json_string, # pylint: disable=protected-access
MODESOLVER_JSON,
gz_file_name,
MODESOLVER_GZ,
verbose=verbose,
progress_callback=progress_callback,
extra_arguments={"type": "ms"},
)
finally:
os.unlink(file_name)
os.unlink(gz_file_name)

# pylint: disable=arguments-differ
def submit(self):
Expand Down Expand Up @@ -362,7 +384,21 @@ def get_modesolver(
stored in the same path as 'to_file', but with '.hdf5' extension, and neither 'to_file' or
'sim_file' will be created.
"""
if self.file_type == "Hdf5":
if self.file_type == "Gz":
to_gz = pathlib.Path(to_file).with_suffix(".hdf5.gz")
to_hdf5 = pathlib.Path(to_file).with_suffix(".hdf5")
download_file(
self.task_id,
self.mode_solver_path + MODESOLVER_GZ,
to_file=to_gz,
verbose=verbose,
progress_callback=progress_callback,
)
extract_gz_file(to_gz, to_hdf5)
to_file = str(to_hdf5)
mode_solver = ModeSolver.from_hdf5(to_hdf5)

elif self.file_type == "Hdf5":
to_hdf5 = pathlib.Path(to_file).with_suffix(".hdf5")
download_file(
self.task_id,
Expand Down
2 changes: 1 addition & 1 deletion tidy3d/web/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from .cli.migrate import migrate
from .webapi import run, upload, get_info, start, monitor, delete, download, load, estimate_cost
from .webapi import get_tasks, delete_old, download_json, download_log, load_simulation, real_cost
from .webapi import get_tasks, delete_old, download_log, download_json, load_simulation, real_cost
from .container import Job, Batch, BatchData
from .cli import tidy3d_cli
from .cli.app import configure_fn as configure
Expand Down
29 changes: 29 additions & 0 deletions tidy3d/web/file_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""compress and extract file"""

import gzip


def compress_file_to_gzip(input_file, output_gz_file):
"""
Compresses a file using gzip.
Args:
input_file (str): The path of the input file.
output_gz_file (str): The path of the output gzip file.
"""
with open(input_file, "rb") as file_in:
with gzip.open(output_gz_file, "wb") as file_out:
file_out.writelines(file_in)


def extract_gz_file(input_gz_file, output_file):
"""
Extract the GZ file
Args:
input_gz_file (str): The path of the gzip input file.
output_file (str): The path of the output file.
"""
with gzip.open(input_gz_file, "rb") as f_in:
with open(output_file, "wb") as f_out:
f_out.write(f_in.read())
3 changes: 3 additions & 0 deletions tidy3d/web/s3utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,9 @@ def _upload(_callback: Callable) -> None:
Key=token.get_s3_key(),
Callback=_callback,
Config=_s3_config,
ExtraArgs={"ContentEncoding": "gzip"}
if token.get_s3_key().endswith(".gz")
else None,
)

if progress_callback is not None:
Expand Down
Loading

0 comments on commit 52a0e89

Please sign in to comment.