Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix: add task_id to JaxSimulationData #1674

Merged
merged 1 commit into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Bug in plotting and computing tilted plane intersections of transformed 0 thickness geometries.
- `Simulation.to_gdspy()` and `Simulation.to_gdstk()` now place polygons in GDS layer `(0, 0)` when no `gds_layer_dtype_map` is provided instead of erroring.
- `task_id` now properly stored in `JaxSimulationData`.

## [2.7.0rc1] - 2024-04-22

Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks
14 changes: 11 additions & 3 deletions tidy3d/plugins/adjoint/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def run(
callback_url=callback_url,
verbose=verbose,
)
# TODO: add task_id
yaugenst-flex marked this conversation as resolved.
Show resolved Hide resolved
return JaxSimulationData.from_sim_data(sim_data, jax_info)


Expand Down Expand Up @@ -151,7 +152,9 @@ def run_fwd(
)

res = RunResidual(fwd_task_id=task_id)
jax_sim_data_orig = JaxSimulationData.from_sim_data(sim_data_orig, jax_info_orig)
jax_sim_data_orig = JaxSimulationData.from_sim_data(
sim_data_orig, jax_info_orig, task_id=task_id
)

return jax_sim_data_orig, (res,)

Expand Down Expand Up @@ -410,6 +413,7 @@ def run_async(
task_name = str(_task_name_orig(i))
sim_data_tidy3d = batch_data_tidy3d[task_name]
jax_info = jax_infos[str(task_name)]
# TODO: add task_id
jax_sim_data = JaxSimulationData.from_sim_data(sim_data_tidy3d, jax_info=jax_info)
jax_batch_data.append(jax_sim_data)

Expand Down Expand Up @@ -450,8 +454,10 @@ def run_async_fwd(
batch_data_orig = [sim_data for _, sim_data in batch_data_orig.items()]

jax_batch_data_orig = []
for sim_data_orig, jax_info_orig in zip(batch_data_orig, jax_infos_orig):
jax_sim_data = JaxSimulationData.from_sim_data(sim_data_orig, jax_info_orig)
for sim_data_orig, jax_info_orig, task_id in zip(batch_data_orig, jax_infos_orig, fwd_task_ids):
jax_sim_data = JaxSimulationData.from_sim_data(
sim_data_orig, jax_info_orig, task_id=task_id
)
jax_batch_data_orig.append(jax_sim_data)

residual = RunResidualBatch(fwd_task_ids=fwd_task_ids)
Expand Down Expand Up @@ -626,6 +632,7 @@ def run_local(
)

# convert back to jax type and return
# TODO: add task_id
return JaxSimulationData.from_sim_data(sim_data_tidy3d, jax_info=jax_info)


Expand Down Expand Up @@ -779,6 +786,7 @@ def run_async_local(
task_name = _task_name_orig_local(i, task_name_suffix)
sim_data_tidy3d = batch_data_tidy3d[task_name]
jax_info = jax_infos[str(task_name)]
# TODO: add task_id
jax_sim_data = JaxSimulationData.from_sim_data(sim_data_tidy3d, jax_info=jax_info)
jax_batch_data.append(jax_sim_data)

Expand Down
Loading