Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

filtering, projection, penalization, and optimization convenience wrappers for adjoint plugin #1012

Merged
merged 1 commit into from
Jul 25, 2023

Conversation

tylerflex
Copy link
Collaborator

@tylerflex tylerflex commented Jul 17, 2023

Corresponding notebook: flexcompute-readthedocs/tidy3d-docs#237

Note: The optimizer is super basic, I'm not sure I will even document it for now, just testing it internally. Eventually we will probably want to either find a good Jax-compatible optimizer (preferred) or put a lot of effort into making our own.

results from adding filter and binarized to notebook 3
image
image

@tylerflex tylerflex marked this pull request as draft July 17, 2023 19:37
@tylerflex tylerflex force-pushed the tyler/adjoint/utils branch 5 times, most recently from 89d2dea to 6c8d387 Compare July 17, 2023 20:36
@tylerflex tylerflex marked this pull request as ready for review July 17, 2023 20:39
@momchil-flex
Copy link
Collaborator

I think having a way to schedule the filtering/projection parameters is quite important for the applications we're pursuing. What's your thinking about that?

@tylerflex
Copy link
Collaborator Author

I think so too, that's why I mentioned not documenting this optimizer for now and adding things to it over time. I think this could just be a stand in for more comprehensive optimizer later. We should also discuss whether it even makes sense to do this stuff ourselves or just import optimization packages in our examples. I just haven't found one I really like yet..

@ianwilliamson
Copy link

What is the benefit of having an objected oriented evaluate() API vs allowing users to compose JAX operations into their own objective function? Restricting the evaluate() API to a single jnp.ndarray input and output is less flexible than the alternative.

@tylerflex
Copy link
Collaborator Author

The main goal here is to provide a convenient way for people who are unfamiliar or intimidated by jax and these operations to just get something running. I assumed people more advanced could probably write their own operations from scratch, but if there are perhaps some higher level functions we can write that could be convenient that might be nice to include as well (such as the tanh projection).

The reason I went with an OO approach to store the parameters is because in the future we might want to serialize these objects and upload them to our server along with some information about the objective function. It's a bit of a vague idea now but having them as part of our model makes life easier in the future if we ever need to save them. Leaving everything in functional form is somewhat cleaner but, it could cause some security issues for us if we would need to unpickle people's objective functions on our server.

@ianwilliamson
Copy link

The Meep adjoint module is a useful reference for different design choices regarding how to interface simulators with automatic differentiation frameworks + optimizers.

Meep has an OptimizationProblem class (https://github.com/NanoComp/meep/blob/5fdec2d3b6d3e6eb151ee4b2193532b8080014e0/python/adjoint/optimization_problem.py#L12-L23) which predates its JAX wrapper (https://github.com/NanoComp/meep/blob/5fdec2d3b6d3e6eb151ee4b2193532b8080014e0/python/adjoint/wrapper.py#L3-L48).

Meep's OptimizationProblem works well for specific open source optimization libraries and design parameterizations, but is not very flexible. In contrast, the JAX wrapper provides a low level interface for embedding differentiable simulation calls into JAX graphs. The wrapper does not enforce any opinions on the optimization loop or the parameterization. When building an ecosystem around a simulator, there is a lot of value in having the latter. You can focus on making that entry point as efficient and as flexible as possible, without needing to support the myriad of hyper parameters and variations (e.g. scheduling) that are required to extract value from topology optimization.

@tylerflex
Copy link
Collaborator Author

Thanks for the insight @ianwilliamson. I'll definitely spend some more time studying the meep optimization and it's jax wrapper. It seems that the tidy3d adjoint plugin already works in a way that is very similar to meep's jax wrapper. We basically wrote a jax VJP for tidy3d's main simulation running function

data = tidy3d.web.run(simulation)

that does all the adjoint stuff behind the scenes. So you basically can "backprop" through this run function and plug it into any general objective function.

def objective(params):
    sim = preprocess(params)
    data = run(sim)
    return postprocess(data) - penalties(params)

grad = jax.value_and_grad(objective)

The idea here was to make it as flexible as possible and let users (or us) define their own Jax-compatible optimization, constraints, or penalties around this function. Our docs notebooks showcase many examples of this. However, the code can get quite complicated and some users just want something simple that they can call instead of having to copy and paste all of the code in the notebooks.

So this PR is just a way to capture some of that code into some utility functions that users can call. That being said, if there is some good Jax-compatible package with general tools and utilities for optimization and photonic inverse design, I'd rather just import that package in the notebook examples. The problem is that I haven't really found one that I really like. Some of the jax optimization libraries I've tried out do not provide the flexibility needed to tune parameters mid-run or grab optimization history information in a way that makes sense to me, so I'm still on the lookout for something else.

In the meantime, this PR is meant to be a stop-gap to provide some convenient wrappers for our already flexible Jax-based adjoint API but isn't meant as the final solution. If you have any suggestions for optimizers or inv-des libraries I'd be curious to hear about them! And anyway thanks for taking the time to share your thoughts on it, it's useful to get another perspective on this from someone who I assume works with this stuff a lot.

@ianwilliamson
Copy link

that does all the adjoint stuff behind the scenes. So you basically can "backprop" through this run function and plug it into any general objective function.

That's great!

Some of the jax optimization libraries I've tried out do not provide the flexibility needed to tune parameters mid-run or grab optimization history information in a way that makes sense to me, so I'm still on the lookout for something else.

An Optax optimization loop is pretty light weight and flexible:

optimizer = optax.adam(learning_rate=1e-2)
params = ...
opt_state = optimizer.init(params)
num_steps = 1000

def loss_fn(params, step_num):
  # schedule parameter changes based on `step_num`
  loss_val = ...
  aux_data = ...
  return loss_val, aux_data

for step_num in range(num_steps):
  (loss_value, aux_data), grads = jax.value_and_grad(loss, has_aux=True)(params, step_num)
  updates, opt_state = optimizer.update(grads, opt_state, params)
  params = optax.apply_updates(params, updates)
  print(f'step {step_num}, loss: {loss_value}')

The params are a PyTree instead of a single ndarray and you can append them (or the state) to an in-memory history list at each step. One can easily extend the sketch above with other Optax gradient transformations, such as clipping or decay, which is not ergonomic following the approach proposed in this PR.

Copy link
Collaborator

@lucas-flexcompute lucas-flexcompute left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that having a straightforward interface is important to enable more users to access the optimizer, specially if doesn't forbid advanced users of implementing their own solutions.

Do you plan on expanding the available filters/optimizers/penalty classes? It just looks "sparse" as is. Other than that, I only think the docstrings could be improved a little, considering that they are meant for non-specialist users. For example, it is generally helpful to have references (e.g. Adam algorithm) and to include a more detailed description of the functions (e.g. equations for radius penalty) where it makes sense, similarly to what numpy and scipy do. I've often used the references in their docs to start digging on subjects that are new to me (I usually don't like to use functions blindly).

Also, the docs for RadiusPenalty talk mention a taper, which I guess is a leftover from a previous implementation.

@tylerflex
Copy link
Collaborator Author

optimizer = optax.adam(learning_rate=1e-2)
params = ...
opt_state = optimizer.init(params)
num_steps = 1000

def loss_fn(params, step_num):
  # schedule parameter changes based on `step_num`
  loss_val = ...
  aux_data = ...
  return loss_val, aux_data

for step_num in range(num_steps):
  (loss_value, aux_data), grads = jax.value_and_grad(loss, has_aux=True)(params, step_num)
  updates, opt_state = optimizer.update(grads, opt_state, params)
  params = optax.apply_updates(params, updates)
  print(f'step {step_num}, loss: {loss_value}')

this is smart. I've been passing the step number into the objective function sometimes but hadn't thought to just modify the objective directly as a function of step number. Also it's nice that optax let's you update and pass parameters each step. This seems promising.

@tylerflex
Copy link
Collaborator Author

Thanks @lucas-flexcompute , I will work on fleshing out the docstrings a bit more and updating the penalty one to be more generic. Just so you know I was thinking of just keeping the Adam optimization undocumented and not using it in the notebooks, just to have it as a placeholder for later.

@tylerflex
Copy link
Collaborator Author

@lucas-flexcompute do you think it would be better to use this object oriented design or just provide a set of functions (maybe in a single namespace)?

# option 1, OO
radius_penalty = RadiusPenalty(min_radius=0.2, wrap=True).evaluate(polyslab.vertices)

# option 2, only function
radius_penalty = radius_penalty(polyslab.vertices, min_radius=0.2, wrap=True)

# option 3, static method of some "penalty" class
radius_penalty = Penalty.radius_of_curvature(polyslab.vertices, min_radius=0.2, wrap=True)

@tylerflex
Copy link
Collaborator Author

tylerflex commented Jul 18, 2023

Advantage of OO approach:

  • being able to store these variables easily if needed (eg adding a projector to your optimizer eventually).
  • bit easier to use pydantic's self documentation with pd.Field()

Advantage of functional

  • more flexible, if user needs to differentiate w.r.t. other parameters (eg beta value in projection) it is possible, yet maybe not recommended.
  • potentially simpler syntax unless one needs to reuse these projectors and filters in many places, in which case the functional form would require closure eg.
radius_penalty_fn = lambda x: radius_penalty(x, min_radius=0.2, wrap=True)

radius_penalty_fn(polyslab1.vertices)
radius_penalty_fn(polyslab2.vertices)

which could be confusing for people.

@ianwilliamson
Copy link

There may be value in making a library of topology optimization functions available to users, but if serializing the parameters of the functions is a use case, I would suggest adopting an approach which separates the data (the parameters) from the implementation (the functions).

@tylerflex
Copy link
Collaborator Author

tylerflex commented Jul 18, 2023

There may be value in making a library of topology optimization functions available to users, but if serializing the parameters of the functions is a use case, I would suggest adopting an approach which separates the data (the parameters) from the implementation (the functions).

yea this was the idea behind the .evaluate interface. basically it's just a function but the behavior is controlled by the data stored in the instance.

radius_penalty = RadiusPenalty(min_radius=0.2, wrap=True).evaluate(polyslab.vertices)

Unless you had something different in mind

@ianwilliamson
Copy link

Unless you had something different in mind

Pure functions. You can call these from within your evaluate() API if you still need that.

@momchil-flex
Copy link
Collaborator

There are actually no pure functions in the Tidy3D frontend, or at least predominantly so, maybe some lurk somewhere. But for the most part everything is a method of some class.

I don't think this is a detrimental design choice in any way apart from personal preference? But that has been @tylerflex's approach.

@tylerflex
Copy link
Collaborator Author

tylerflex commented Jul 18, 2023

I think maybe what Ian is getting at is to have these pure functions written somewhere and then just provide the object oriented wrappers for them if we need.

class MyClass:
    y: float

    @staticmethod
    def a_pure_function(x, y):
        return x + y

    def add(self, x):
        return self.a_pure_function(x, self.y)

?

@ianwilliamson
Copy link

Pure functions are, generally, the convention in the JAX ecosystem. They're also what I think of as being the most convenient for building optimization objective functions. The motivation for wanting to serialize and upload objective functions to your server is not clear to me. The JAX ecosystem offers well-tested optimization libraries, which is why I suggested Optax. I would not want to be in the business of building and maintaining those myself, especially if you want to use second order optimizers 😄

@tylerflex
Copy link
Collaborator Author

Fair points, I'll look into optax a bit tomorrow and see how it looks to implement our notebook examples using it.
As for the purity of the penalty functions, I feel it might be a matter of API preference. our tidy3d models are immutable and all of their methods should therefore be pure functions (no side effects).

@tylerflex
Copy link
Collaborator Author

Yea I like optax, seems like it is super flexible. Slightly worried it might be a little confusing for some users who just want to do something super simple but I think the flexibility makes it worth it.

for reference, I just modified our tutorial 3 to use optax and here's what the optimizer looks like (with our additions to history, printing progress, etc). I'm going to try the parameter scheduling from the grating coupler notebook next.

import optax

# hyperparameters
num_steps = 18
learning_rate = 0.2

# initialize adam optimizer with starting parameters
params = np.array(eps_boxes)
optimizer = optax.adam(learning_rate=learning_rate)
opt_state = optimizer.init(params)

# store history
Js = []
perms = [params]

for i in range(num_steps):

    # compute gradient and current objective funciton value
    value, gradient = dJ_fn(params, step_num=i+1)

    # outputs
    print(f"step = {i + 1}")
    print(f"\tJ = {value:.4e}")
    print(f"\tgrad_norm = {np.linalg.norm(gradient):.4e}")    

    # compute and apply updates to the optimizer based on gradient (-1 sign to maximize obj_fn)
    updates, opt_state = optimizer.update(-gradient, opt_state, params)
    params = optax.apply_updates(params, updates)

    # save history
    Js.append(value)
    perms.append(params)    

@tylerflex
Copy link
Collaborator Author

@ianwilliamson do you have a recommendation for a convenient way to checkpoint optax optimizer states? I checked out orbax but am not sure it's the best solution here

@ianwilliamson
Copy link

@ianwilliamson do you have a recommendation for a convenient way to checkpoint optax optimizer states? I checked out orbax but am not sure it's the best solution here

You could pickle params and opt_state to disk.

@tylerflex tylerflex added the 2.4 label Jul 21, 2023
@tylerflex tylerflex force-pushed the tyler/adjoint/utils branch 2 times, most recently from bd3c537 to 32d73f8 Compare July 21, 2023 15:46
@tylerflex
Copy link
Collaborator Author

pickling opt_state and params worked pretty well, thanks @ianwilliamson

Ok @lucas-flexcompute @momchil-flex I integrated optax into all adjoint notebooks, removed the optimizers from this PR and implemented the new checkpointing in the grating coupler notebook, which seems a bit cleaner.

on the object oriented vs function approach to providing these filtering / projection / penalty functions:

  1. The OO approach still are "pure functions" so I feel it still fits in the jax paradigm.
  2. I think the OO approach is more in line with how we've been doing things (eg Medium nk -> eps conversions) so that would be my vote. It also makes it easier to document and validate parameters. We could always provide only functions later if we choose to.

@tylerflex
Copy link
Collaborator Author

@momchil-flex any remaining concerns on this PR or should I merge? Trying to clear out some PRs before they get stale. thanks

Copy link
Collaborator

@momchil-flex momchil-flex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

tidy3d/plugins/adjoint/utils/filter.py Outdated Show resolved Hide resolved
tidy3d/plugins/adjoint/utils/filter.py Outdated Show resolved Hide resolved
tidy3d/plugins/adjoint/utils/penalty.py Show resolved Hide resolved
@tylerflex tylerflex merged commit 1f82e1a into pre/2.4 Jul 25, 2023
11 checks passed
@tylerflex tylerflex deleted the tyler/adjoint/utils branch July 25, 2023 21:15
This was referenced Jul 31, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants