-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
filtering, projection, penalization, and optimization convenience wrappers for adjoint plugin #1012
Conversation
89d2dea
to
6c8d387
Compare
I think having a way to schedule the filtering/projection parameters is quite important for the applications we're pursuing. What's your thinking about that? |
I think so too, that's why I mentioned not documenting this optimizer for now and adding things to it over time. I think this could just be a stand in for more comprehensive optimizer later. We should also discuss whether it even makes sense to do this stuff ourselves or just import optimization packages in our examples. I just haven't found one I really like yet.. |
What is the benefit of having an objected oriented |
The main goal here is to provide a convenient way for people who are unfamiliar or intimidated by jax and these operations to just get something running. I assumed people more advanced could probably write their own operations from scratch, but if there are perhaps some higher level functions we can write that could be convenient that might be nice to include as well (such as the tanh projection). The reason I went with an OO approach to store the parameters is because in the future we might want to serialize these objects and upload them to our server along with some information about the objective function. It's a bit of a vague idea now but having them as part of our model makes life easier in the future if we ever need to save them. Leaving everything in functional form is somewhat cleaner but, it could cause some security issues for us if we would need to unpickle people's objective functions on our server. |
The Meep adjoint module is a useful reference for different design choices regarding how to interface simulators with automatic differentiation frameworks + optimizers. Meep has an Meep's |
Thanks for the insight @ianwilliamson. I'll definitely spend some more time studying the meep optimization and it's jax wrapper. It seems that the tidy3d adjoint plugin already works in a way that is very similar to meep's jax wrapper. We basically wrote a jax VJP for tidy3d's main simulation running function data = tidy3d.web.run(simulation) that does all the adjoint stuff behind the scenes. So you basically can "backprop" through this run function and plug it into any general objective function. def objective(params):
sim = preprocess(params)
data = run(sim)
return postprocess(data) - penalties(params)
grad = jax.value_and_grad(objective) The idea here was to make it as flexible as possible and let users (or us) define their own Jax-compatible optimization, constraints, or penalties around this function. Our docs notebooks showcase many examples of this. However, the code can get quite complicated and some users just want something simple that they can call instead of having to copy and paste all of the code in the notebooks. So this PR is just a way to capture some of that code into some utility functions that users can call. That being said, if there is some good Jax-compatible package with general tools and utilities for optimization and photonic inverse design, I'd rather just import that package in the notebook examples. The problem is that I haven't really found one that I really like. Some of the jax optimization libraries I've tried out do not provide the flexibility needed to tune parameters mid-run or grab optimization history information in a way that makes sense to me, so I'm still on the lookout for something else. In the meantime, this PR is meant to be a stop-gap to provide some convenient wrappers for our already flexible Jax-based adjoint API but isn't meant as the final solution. If you have any suggestions for optimizers or inv-des libraries I'd be curious to hear about them! And anyway thanks for taking the time to share your thoughts on it, it's useful to get another perspective on this from someone who I assume works with this stuff a lot. |
That's great!
An Optax optimization loop is pretty light weight and flexible: optimizer = optax.adam(learning_rate=1e-2)
params = ...
opt_state = optimizer.init(params)
num_steps = 1000
def loss_fn(params, step_num):
# schedule parameter changes based on `step_num`
loss_val = ...
aux_data = ...
return loss_val, aux_data
for step_num in range(num_steps):
(loss_value, aux_data), grads = jax.value_and_grad(loss, has_aux=True)(params, step_num)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
print(f'step {step_num}, loss: {loss_value}') The |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that having a straightforward interface is important to enable more users to access the optimizer, specially if doesn't forbid advanced users of implementing their own solutions.
Do you plan on expanding the available filters/optimizers/penalty classes? It just looks "sparse" as is. Other than that, I only think the docstrings could be improved a little, considering that they are meant for non-specialist users. For example, it is generally helpful to have references (e.g. Adam algorithm) and to include a more detailed description of the functions (e.g. equations for radius penalty) where it makes sense, similarly to what numpy and scipy do. I've often used the references in their docs to start digging on subjects that are new to me (I usually don't like to use functions blindly).
Also, the docs for RadiusPenalty
talk mention a taper, which I guess is a leftover from a previous implementation.
optimizer = optax.adam(learning_rate=1e-2)
params = ...
opt_state = optimizer.init(params)
num_steps = 1000
def loss_fn(params, step_num):
# schedule parameter changes based on `step_num`
loss_val = ...
aux_data = ...
return loss_val, aux_data
for step_num in range(num_steps):
(loss_value, aux_data), grads = jax.value_and_grad(loss, has_aux=True)(params, step_num)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
print(f'step {step_num}, loss: {loss_value}') this is smart. I've been passing the step number into the objective function sometimes but hadn't thought to just modify the objective directly as a function of step number. Also it's nice that optax let's you update and pass parameters each step. This seems promising. |
Thanks @lucas-flexcompute , I will work on fleshing out the docstrings a bit more and updating the penalty one to be more generic. Just so you know I was thinking of just keeping the Adam optimization undocumented and not using it in the notebooks, just to have it as a placeholder for later. |
@lucas-flexcompute do you think it would be better to use this object oriented design or just provide a set of functions (maybe in a single namespace)? # option 1, OO
radius_penalty = RadiusPenalty(min_radius=0.2, wrap=True).evaluate(polyslab.vertices)
# option 2, only function
radius_penalty = radius_penalty(polyslab.vertices, min_radius=0.2, wrap=True)
# option 3, static method of some "penalty" class
radius_penalty = Penalty.radius_of_curvature(polyslab.vertices, min_radius=0.2, wrap=True) |
Advantage of OO approach:
Advantage of functional
radius_penalty_fn = lambda x: radius_penalty(x, min_radius=0.2, wrap=True)
radius_penalty_fn(polyslab1.vertices)
radius_penalty_fn(polyslab2.vertices) which could be confusing for people. |
There may be value in making a library of topology optimization functions available to users, but if serializing the parameters of the functions is a use case, I would suggest adopting an approach which separates the data (the parameters) from the implementation (the functions). |
yea this was the idea behind the radius_penalty = RadiusPenalty(min_radius=0.2, wrap=True).evaluate(polyslab.vertices) Unless you had something different in mind |
Pure functions. You can call these from within your |
There are actually no pure functions in the Tidy3D frontend, or at least predominantly so, maybe some lurk somewhere. But for the most part everything is a method of some class. I don't think this is a detrimental design choice in any way apart from personal preference? But that has been @tylerflex's approach. |
I think maybe what Ian is getting at is to have these pure functions written somewhere and then just provide the object oriented wrappers for them if we need. class MyClass:
y: float
@staticmethod
def a_pure_function(x, y):
return x + y
def add(self, x):
return self.a_pure_function(x, self.y) ? |
Pure functions are, generally, the convention in the JAX ecosystem. They're also what I think of as being the most convenient for building optimization objective functions. The motivation for wanting to serialize and upload objective functions to your server is not clear to me. The JAX ecosystem offers well-tested optimization libraries, which is why I suggested Optax. I would not want to be in the business of building and maintaining those myself, especially if you want to use second order optimizers 😄 |
Fair points, I'll look into optax a bit tomorrow and see how it looks to implement our notebook examples using it. |
Yea I like optax, seems like it is super flexible. Slightly worried it might be a little confusing for some users who just want to do something super simple but I think the flexibility makes it worth it. for reference, I just modified our tutorial 3 to use optax and here's what the optimizer looks like (with our additions to history, printing progress, etc). I'm going to try the parameter scheduling from the grating coupler notebook next. import optax
# hyperparameters
num_steps = 18
learning_rate = 0.2
# initialize adam optimizer with starting parameters
params = np.array(eps_boxes)
optimizer = optax.adam(learning_rate=learning_rate)
opt_state = optimizer.init(params)
# store history
Js = []
perms = [params]
for i in range(num_steps):
# compute gradient and current objective funciton value
value, gradient = dJ_fn(params, step_num=i+1)
# outputs
print(f"step = {i + 1}")
print(f"\tJ = {value:.4e}")
print(f"\tgrad_norm = {np.linalg.norm(gradient):.4e}")
# compute and apply updates to the optimizer based on gradient (-1 sign to maximize obj_fn)
updates, opt_state = optimizer.update(-gradient, opt_state, params)
params = optax.apply_updates(params, updates)
# save history
Js.append(value)
perms.append(params) |
@ianwilliamson do you have a recommendation for a convenient way to checkpoint optax optimizer states? I checked out orbax but am not sure it's the best solution here |
You could pickle |
bd3c537
to
32d73f8
Compare
pickling opt_state and params worked pretty well, thanks @ianwilliamson Ok @lucas-flexcompute @momchil-flex I integrated optax into all adjoint notebooks, removed the optimizers from this PR and implemented the new checkpointing in the grating coupler notebook, which seems a bit cleaner. on the object oriented vs function approach to providing these filtering / projection / penalty functions:
|
@momchil-flex any remaining concerns on this PR or should I merge? Trying to clear out some PRs before they get stale. thanks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
32d73f8
to
0d00ffc
Compare
Corresponding notebook: flexcompute-readthedocs/tidy3d-docs#237
Note: The optimizer is super basic, I'm not sure I will even document it for now, just testing it internally. Eventually we will probably want to either find a good Jax-compatible optimizer (preferred) or put a lot of effort into making our own.
results from adding filter and binarized to notebook 3