Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding wavelength-dependent parameter optimization example, adding models #1

Closed
wants to merge 0 commits into from

Conversation

simbilod
Copy link
Contributor

@simbilod simbilod commented Feb 2, 2021

I added an example using S-parameters for thin-film propagation. I then put the models in a .py file, and changed the folder structure to look like PhotonTorch (models folder, containing different python files for different model categories).

More interestingly, the example shows how to define wavelength-dependent parameters and optimize them all independently using the current interface. This could be built-in the source code to make it more wieldy. In the meantine, maybe this can be useful to someone.

@flaport
Copy link
Owner

flaport commented Feb 2, 2021

Thanks Simon!

The optimization of the thin films looks really interesting! I'll happily accept your contribution!

Before I merge this in, however, would you be able to remove the code output from the jupyter cells from the git-history? I'd like to keep the repo as light as possible... Probably the easiest way to do this is to reset to my latest commit, remove the code output from both notebooks and force push your changes to your master branch (be aware that this changes your git history, maybe keep a backup branch of your current master branch...)

As a final comment, here is a little hint to make the jitting faster for the final loss function you're using in the thin film notebook. At least on my laptop with only a CPU, the jitting seems to be orders of magnitudes faster:

    def inner_loop(transmitted, i):
        params = sax.copy_params(fabry_perot_tunable["default_params"])
        params = sax.set_global_params(params, wl=wls[i])
        params = sax.set_global_params(params, t_amp=ts[i])
        params = sax.set_global_params(params, t_ang=ts[N+i])
        params["gap"]["ni"] = 1.
        params["gap"]["di"] = 1000.
        # Perform computation
        transmission_i = fabry_perot_tunable["in","out"](params)
        transmitted = jax.ops.index_update(transmitted, jax.ops.index[i], jnp.abs(transmission_i)**2)
        return transmitted, i

    transmitted, _ = jax.lax.scan(inner_loop, transmitted, jnp.arange(N, dtype=jnp.int32))

@simbilod simbilod closed this Feb 3, 2021
simbilod added a commit to simbilod/sax that referenced this pull request Feb 3, 2021
…orporated lax loop function. Also cosmetic changes.
flaport added a commit that referenced this pull request Feb 3, 2021
Updated #1 with clean history: scrubbed notebook outputs, incorporated lax looping…
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants