Skip to content

Commit

Permalink
adapt to AWS lambda environment
Browse files Browse the repository at this point in the history
  • Loading branch information
qingeng committed Aug 26, 2023
1 parent 901a813 commit a056231
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 308 deletions.
13 changes: 10 additions & 3 deletions tests/test_web/test_webapi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
# Tests webapi and things that depend on it
import json

import pytest
import responses
from _pytest import monkeypatch
from tidy3d.plugins.mode import ModeSolver

from tests.test_plugins.test_mode_solver import SIM_SIZE, WAVEGUIDE, SRC, PLANE

import tidy3d as td

Expand Down Expand Up @@ -303,11 +308,13 @@ def test_download_json(monkeypatch, mock_get_info, tmp_path):
sim = make_sim()

def mock_download(*args, **kwargs):
file_path = "simulation.hdf5"
sim.to_file(file_path)
compress_file_to_gzip(file_path, "simulation.hdf5.gz")
pass

def get_str(*args, **kwargs):
return sim.json().encode("utf-8")

monkeypatch.setattr("tidy3d.web.simulation_task.download_file", mock_download)
monkeypatch.setattr("tidy3d.web.simulation_task.read_simulation_from_hdf5", get_str)

fname_tmp = str(tmp_path / "web_test_tmp.json")
download_json(TASK_ID, fname_tmp)
Expand Down
2 changes: 1 addition & 1 deletion tidy3d/plugins/dispersion/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ...components.medium import PoleResidue
from ...constants import MICROMETER, HERTZ
from ...exceptions import WebError, Tidy3dError, SetupError
from ...web.httputils import get_headers
from ...web.http_management import get_headers
from ...web.environment import Env

from .fit import DispersionFitter
Expand Down
81 changes: 49 additions & 32 deletions tidy3d/plugins/mode/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ def get(
cls,
task_id: str,
solver_id: str,
to_file: str = "mode_solver.json",
sim_file: str = "simulation.json",
to_file: str = "mode_solver.hdf5",
sim_file: str = "simulation.hdf5",
verbose: bool = True,
progress_callback: Callable[[float], None] = None,
) -> ModeSolverTask:
Expand All @@ -225,9 +225,9 @@ def get(
Unique identifier of the task on server.
solver_id: str
Unique identifier of the mode solver in the task.
to_file: str = "mode_solver.json"
to_file: str = "mode_solver.hdf5"
File to store the mode solver downloaded from the task.
sim_file: str = "simulation.json"
sim_file: str = "simulation.hdf5"
File to store the simulation downloaded from the task.
verbose: bool = True
Whether to display progress bars.
Expand Down Expand Up @@ -339,18 +339,18 @@ def abort(self):

def get_modesolver(
self,
to_file: str = "mode_solver.json",
sim_file: str = "simulation.json",
to_file: str = "mode_solver.hdf5",
sim_file: str = "simulation.hdf5",
verbose: bool = True,
progress_callback: Callable[[float], None] = None,
) -> ModeSolver:
"""Get mode solver associated with this task from the server.
Parameters
----------
to_file: str = "mode_solver.json"
to_file: str = "mode_solver.hdf5"
File to store the mode solver downloaded from the task.
sim_file: str = "simulation.json"
sim_file: str = "simulation.hdf5"
File to store the simulation downloaded from the task.
verbose: bool = True
Whether to display progress bars.
Expand All @@ -369,31 +369,48 @@ def get_modesolver(
'sim_file' will be created.
"""
if self.file_type == "Gz":
to_gz = pathlib.Path(to_file).with_suffix(".hdf5.gz")
to_hdf5 = pathlib.Path(to_file).with_suffix(".hdf5")
download_file(
self.task_id,
MODESOLVER_GZ,
to_file=to_gz,
verbose=verbose,
progress_callback=progress_callback,
)
extract_gz_file(to_gz, to_hdf5)
to_file = str(to_hdf5)
mode_solver = ModeSolver.from_hdf5(to_hdf5)
hdf5_gz_file, hdf5_gz_file_path = tempfile.mkstemp()
os.close(hdf5_gz_file)
# keep hdf5_file_path
hdf5_file, hdf5_file_path = tempfile.mkstemp()
os.close(hdf5_file)
try:
download_file(
self.solver_id,
MODESOLVER_GZ,
to_file=hdf5_gz_file_path,
verbose=verbose,
progress_callback=progress_callback,
)
extract_gz_file(hdf5_gz_file_path, hdf5_file_path)
mode_solver = ModeSolver.from_hdf5(hdf5_file_path)
if to_file.endswith(".json"):
mode_solver.to_json(to_file)
if os.path.exists(hdf5_file_path):
os.remove(hdf5_file_path)
finally:
os.unlink(hdf5_gz_file_path)
os.unlink(hdf5_file_path)

elif self.file_type == "Hdf5":
to_hdf5 = pathlib.Path(to_file).with_suffix(".hdf5")
download_file(
self.solver_id,
MODESOLVER_HDF5,
to_file=to_hdf5,
verbose=verbose,
progress_callback=progress_callback,
)

to_file = str(to_hdf5)
mode_solver = ModeSolver.from_hdf5(to_hdf5)
hdf5_file, hdf5_file_path = tempfile.mkstemp()
os.close(hdf5_file)
try:
download_file(
self.solver_id,
MODESOLVER_HDF5,
to_file=hdf5_file_path,
verbose=verbose,
progress_callback=progress_callback,
)
mode_solver = ModeSolver.from_hdf5(hdf5_file_path)
if to_file.endswith(".json"):
mode_solver.to_json(to_file)
if os.path.exists(hdf5_file_path):
os.remove(hdf5_file_path)

finally:
os.unlink(hdf5_file_path)

else:
download_file(
Expand All @@ -410,14 +427,14 @@ def get_modesolver(
verbose=verbose,
progress_callback=progress_callback,
)

mode_solver_dict = ModeSolver.dict_from_json(to_file)
mode_solver_dict["simulation"] = Simulation.from_json(sim_file)

mode_solver = ModeSolver.parse_obj(mode_solver_dict)

# Overwrite downloaded file with valid contents
mode_solver.to_file(to_file)

return mode_solver

def get_result(
Expand Down
5 changes: 5 additions & 0 deletions tidy3d/web/cli/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
"""Constants for the CLI."""
import os
from os.path import expanduser

TIDY3D_DIR = f"{expanduser('~')}/.tidy3d"

if not os.access(TIDY3D_DIR, os.W_OK):
TIDY3D_DIR = "/tmp/.tidy3d"

CONFIG_FILE = TIDY3D_DIR + "/config"
CREDENTIAL_FILE = TIDY3D_DIR + "/auth.json"
21 changes: 18 additions & 3 deletions tidy3d/web/http_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import os
from functools import wraps
from os.path import expanduser
from enum import Enum
from typing import Dict

import requests
import toml
from tidy3d.web.cli.constants import CONFIG_FILE

from .environment import Env
from ..exceptions import WebError
Expand All @@ -28,8 +29,8 @@ def api_key() -> None:

if os.environ.get(SIMCLOUD_APIKEY):
return os.environ.get(SIMCLOUD_APIKEY)
if os.path.exists(f"{expanduser('~')}/.tidy3d/config"):
with open(f"{expanduser('~')}/.tidy3d/config", encoding="utf-8") as config_file:
if os.path.exists(CONFIG_FILE):
with open(CONFIG_FILE, encoding="utf-8") as config_file:
config = toml.loads(config_file.read())
return config.get("apikey", "")

Expand Down Expand Up @@ -65,6 +66,20 @@ def api_key_auth(request: requests.request) -> requests.request:
return request


def get_headers() -> Dict[str, str]:
"""get headers for http request.
Returns
-------
Dict[str, str]
dictionary with "Authorization" and "Application" keys.
"""
return {
"simcloud-api-key": api_key(),
"Application": "TIDY3D",
}


def http_interceptor(func):
"""Intercept the response and raise an exception if the status code is not 200."""

Expand Down
Loading

0 comments on commit a056231

Please sign in to comment.