Skip to content

Commit

Permalink
bugfix: add task_id to JaxSimulationData
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerflex committed May 3, 2024
1 parent a2f7e83 commit b41c7aa
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Bug in plotting and computing tilted plane intersections of transformed 0 thickness geometries.
- `Simulation.to_gdspy()` and `Simulation.to_gdstk()` now place polygons in GDS layer `(0, 0)` when no `gds_layer_dtype_map` is provided instead of erroring.
- `task_id` now properly stored in `JaxSimulationData`.

## [2.7.0rc1] - 2024-04-22

Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks
14 changes: 11 additions & 3 deletions tidy3d/plugins/adjoint/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def run(
callback_url=callback_url,
verbose=verbose,
)
# TODO: add task_id
return JaxSimulationData.from_sim_data(sim_data, jax_info)


Expand Down Expand Up @@ -151,7 +152,9 @@ def run_fwd(
)

res = RunResidual(fwd_task_id=task_id)
jax_sim_data_orig = JaxSimulationData.from_sim_data(sim_data_orig, jax_info_orig)
jax_sim_data_orig = JaxSimulationData.from_sim_data(
sim_data_orig, jax_info_orig, task_id=task_id
)

return jax_sim_data_orig, (res,)

Expand Down Expand Up @@ -410,6 +413,7 @@ def run_async(
task_name = str(_task_name_orig(i))
sim_data_tidy3d = batch_data_tidy3d[task_name]
jax_info = jax_infos[str(task_name)]
# TODO: add task_id
jax_sim_data = JaxSimulationData.from_sim_data(sim_data_tidy3d, jax_info=jax_info)
jax_batch_data.append(jax_sim_data)

Expand Down Expand Up @@ -450,8 +454,10 @@ def run_async_fwd(
batch_data_orig = [sim_data for _, sim_data in batch_data_orig.items()]

jax_batch_data_orig = []
for sim_data_orig, jax_info_orig in zip(batch_data_orig, jax_infos_orig):
jax_sim_data = JaxSimulationData.from_sim_data(sim_data_orig, jax_info_orig)
for sim_data_orig, jax_info_orig, task_id in zip(batch_data_orig, jax_infos_orig, fwd_task_ids):
jax_sim_data = JaxSimulationData.from_sim_data(
sim_data_orig, jax_info_orig, task_id=task_id
)
jax_batch_data_orig.append(jax_sim_data)

residual = RunResidualBatch(fwd_task_ids=fwd_task_ids)
Expand Down Expand Up @@ -626,6 +632,7 @@ def run_local(
)

# convert back to jax type and return
# TODO: add task_id
return JaxSimulationData.from_sim_data(sim_data_tidy3d, jax_info=jax_info)


Expand Down Expand Up @@ -779,6 +786,7 @@ def run_async_local(
task_name = _task_name_orig_local(i, task_name_suffix)
sim_data_tidy3d = batch_data_tidy3d[task_name]
jax_info = jax_infos[str(task_name)]
# TODO: add task_id
jax_sim_data = JaxSimulationData.from_sim_data(sim_data_tidy3d, jax_info=jax_info)
jax_batch_data.append(jax_sim_data)

Expand Down

0 comments on commit b41c7aa

Please sign in to comment.