Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix warning capture #1136

Merged
merged 3 commits into from
Sep 6, 2023
Merged

Fix warning capture #1136

merged 3 commits into from
Sep 6, 2023

Conversation

dbochkov-flexcompute
Copy link
Contributor

@dbochkov-flexcompute dbochkov-flexcompute commented Aug 31, 2023

While trying to incorporate warnings capture into GUI, MC team noticed that when we have validation errors, setting td.log.set_capture(True) changes returned validation errors to something not very informative. This happens because during initialization of models we use try-finally in Tidy3dBaseModel.__init__() https://github.com/flexcompute/tidy3d/pull/1136/files#diff-d331aef8c3290adc3733d4ce375c5a4b139ad41bd5634113ca404cc827e3d949L76-R82 to try to finish warning capture no matter what. But our warning capture parsing assumes the models are well-built. So, when there are validation errors the instances are incomplete and parsing produces its own errors that are not related to simulation validation.

A seemingly straightforward way to fix this is to just remove try-finally construction, so that we don't attempt to built warning capture tree if there are any errors. Initially, it wouldn't work as intended, but after I added missing discriminator=TYPE_TAG_STR everything is fine. I believe this is because when there is no discriminator pydantic tries to initialize with every model possible and those that fail, exit before reaching log.end_capture(self) in Tidy3dBaseModel.__init__(). Previously finally: would ensure we still execure that.

Additional improvements:

  • added log.start_capture()/log.end_capture(self) into Simulation.validate_pre_upload() to capture those warnings too
  • added an optional argument custom_loc: List = None to log.warning() so that additional information can be provided about warning location. This is useful for Simulation validators, which otherwise would just point at the simulation itself.
  • added an optional argument capture: bool = True to log.warning() to avoid capturing not very informative and somewhat arbitrary frequency passed to 'Medium.eps_model()' is outside of 'Medium.frequency_range' = ...,
  • expanded warning capture test significantly to produce and capture pretty much every possible warning. Also added checks that td.log.set_capture(True) doesn't screw anything if there are validation errors.

Note that all these changes doesn't affect python client warning behavior. This only matters for output of td.log.captured_warnings() when td.log.set_capture(True)

@dbochkov-flexcompute
Copy link
Contributor Author

Having issues with having this PR pass github tests. Everything seems fine on linux and macos, but failing on windows https://github.com/flexcompute/tidy3d/actions/runs/6086188923. There two issues:

  1. monitor storage size for (atleast) FieldTimeMonitor seems to be different on linux vs windows. In my test simulation setup on linux I get {'flux': 44.0, 'mode': 864.0, 'time': 19296371088.0, 'n2f_monitor': 1440.0} for sim.monitor_data_sizes but in windows test runs it is {'flux': 44.0, 'mode': 864.0, 'time': 2116501904.0, 'n2f_monitor': 1440.0} https://github.com/flexcompute/tidy3d/actions/runs/6086188923/job/16511966069#step:5:446
  2. the second type of issues is not being able to get hash of jax-related objects. weird that it shows up only in windows tests
          # Check the class type and its superclasses for a matching encoder
          for base in obj.__class__.__mro__[:-1]:
              try:
                  encoder = ENCODERS_BY_TYPE[base]
              except KeyError:
                  continue
              return encoder(obj)
          else:  # We have exited the for loop without finding a suitable encoder
  >           raise TypeError(f"Object of type '{obj.__class__.__name__}' is not JSON serializable")
  E           jax._src.traceback_util.UnfilteredStackTrace: TypeError: Object of type 'JVPTracer' is not JSON serializable

@momchil-flex @tylerflex let me know if you have any suggestions

@tylerflex
Copy link
Collaborator

Hm, not sure about the monitor size thing, but in the meantime it could be possible to just change the test to relax the monitor sizes a bit until we figure it out.

Regarding the jax part. So the JVPTracer is what jax uses to store the gradient information. It can't be serialized to json unfortuantely. Does this come up when trying to serialize a warning or error due to the changes in this PR? I'm having trouble pinpointing where this is happening in the code. I also have no idea why this would occur in windows and not the other OS :/

@tylerflex
Copy link
Collaborator

for some reason I couldn't see the error but now I do, so yea I guess

tidy3d\log.py:167: in end_capture
[1056](https://github.com/flexcompute/tidy3d/actions/runs/6086188923/job/16511966069?pr=1136#step:5:1057)
      model_fields = model.get_submodels_by_hash()

I wonder if we need to instead wrap whatever block this is in a try except TypeError and then find an alternative way to handle unhashable objects? Or we could try a different way to get the model fields, I think model.__fields__ should work?

@dbochkov-flexcompute
Copy link
Contributor Author

Actually, after looking more carefully, .get_submodels_by_hash() shouldn't be even triggered in those test cases. It should only be invoked if log.set_capture(True) is called. I guess what happened is that because of the failed test log.set_capture(False) was not called, so any test afterwards continued warning capture. Making the very first failed test passed (and successfully calling log.set_capture(False)), removed fails in those adjoint tests as well.

We still need to figure out why calculated monitor size is different on different OS's, and make warning capture work with adjoint related objects. The latter would be needed when GUI start integrating the adjoint feature.

For now, this PR passes tests and ready to be reviewed

@dbochkov-flexcompute dbochkov-flexcompute marked this pull request as ready for review September 5, 2023 16:15
Copy link
Collaborator

@tylerflex tylerflex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @dbochkov-flexcompute . it looks good to me, I dont 100% follow all of the details so feel free to get another review if you are unsure, but given the tests are passing and the issue seems cleared up, I approve. The improvements are also quite nice , thanks 👍

@@ -89,49 +93,166 @@ def test_logging_warning_capture():
monitor_flux = td.FluxMonitor(
center=(0, 0, 0),
size=(8, 8, 8),
freqs=freqs,
freqs=list(freqs),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did this not work with freqs passed as a numpy array? I would have thought ArrayLike would handle this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, I was just playing around with adding frequencies outside of source range (something like freqs=list(freqs) + [1]), and this remained like that accidentally

@momchil-flex momchil-flex merged commit 10ef9bf into pre/2.4 Sep 6, 2023
14 checks passed
@momchil-flex momchil-flex deleted the daniil/fix-warning-capture branch September 6, 2023 21:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants