# Meent Tutorial 2
Gradient and Optimization

In [4]:
import jax
import optax
import time

import jax.numpy as jnp
import matplotlib.pyplot as plt

import meent
from meent.on_jax.optimizer.loss import LossDeflector
from meent.on_jax.optimizer.optimizer import OptimizerJax

In [46]:
backend = 1  # JAX

# common
pol = 0  # 0: TE, 1: TM

n_I = 1  # n_incidence
n_II = 1  # n_transmission

theta = 20 * jnp.pi / 180
phi = 50 * jnp.pi / 180

wavelength = 900

thickness = [500.]
period = [1000., 300.]

fourier_order = [4, 2]

type_complex = jnp.complex128

grating_type = 2

In [47]:
ucell_1d_s = jnp.array([
    [
        [0, 1, 0, 1, 1.1, 0, 1, 0, 1, 1, ],
    ],
], dtype=jnp.float64) * 4. + 1.  # refractive index

ucell_2d_s = jnp.array([
    [
        [0, 1, 0, 1, 1, 0, 1, 0, 1, 1, ],
        [1, 1, 1, 1, 1, 1, 1, 0, 1, 1, ],
    ],
]) * 4 + 1.  # refractive index

## 2.1 Gradient

In [48]:
mee = meent.call_mee(backend=backend, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, fourier_order=fourier_order, wavelength=wavelength, period=period, ucell=ucell_2d_s, thickness=thickness, type_complex=type_complex, fft_type=0, improve_dft=True)

In [49]:
pois = ['ucell', 'thickness']

forward = mee.conv_solve
loss_fn = LossDeflector(x_order=0, y_order=0)

# case 1: Gradient
grad = mee.grad(pois, forward, loss_fn)
print(1, grad)

# case 2: SGD
optimizer = optax.sgd(learning_rate=1e-2)
mee.fit(pois, forward, loss_fn, optimizer)
print(3, mee.thickness*1E5)


1 (Array(0.49917828, dtype=float64), {'thickness': [Array(0.02179398, dtype=float64, weak_type=True)], 'ucell': Array([[[-0.04457337, -0.0672744 ,  0.25881959,  0.22487314,
          0.22402433,  0.24415874, -0.01060622,  0.18296631,
          0.09376066,  0.13521143],
        [-0.18894426, -0.16049066,  0.15925966,  0.94591111,
          0.98177941,  0.15896715, -0.07167921,  0.11117953,
          0.0299822 , -0.05946715]]], dtype=float64)})


UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float64[] wrapped in a JVPTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.Detail: Tracer from a higher level: Traced<ShapedArray(float64[], weak_type=True)>with<JVPTrace(level=2/1)> with
  primal = Traced<ShapedArray(float64[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
  tangent = Traced<ShapedArray(float64[], weak_type=True)>with<JaxprTrace(level=1/1)> with
    pval = (ShapedArray(float64[], weak_type=True), None)
    recipe = LambdaBinding() in trace JVPTrace(level=2/1)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

In [48]:
mee = meent.call_mee(backend=backend, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, fourier_order=fourier_order, wavelength=wavelength, period=period, ucell=ucell_2d_s, thickness=thickness, type_complex=type_complex, fft_type=0, improve_dft=True)

In [49]:
pois = ['ucell', 'thickness']

forward = mee.conv_solve
loss_fn = LossDeflector(x_order=0, y_order=0)

# case 1: Gradient
grad = mee.grad(pois, forward, loss_fn)
print(1, grad)

# case 2: SGD
optimizer = optax.sgd(learning_rate=1e-2)
mee.fit(pois, forward, loss_fn, optimizer)
print(3, mee.thickness*1E5)


1 (Array(0.49917828, dtype=float64), {'thickness': [Array(0.02179398, dtype=float64, weak_type=True)], 'ucell': Array([[[-0.04457337, -0.0672744 ,  0.25881959,  0.22487314,
          0.22402433,  0.24415874, -0.01060622,  0.18296631,
          0.09376066,  0.13521143],
        [-0.18894426, -0.16049066,  0.15925966,  0.94591111,
          0.98177941,  0.15896715, -0.07167921,  0.11117953,
          0.0299822 , -0.05946715]]], dtype=float64)})


UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float64[] wrapped in a JVPTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.Detail: Tracer from a higher level: Traced<ShapedArray(float64[], weak_type=True)>with<JVPTrace(level=2/1)> with
  primal = Traced<ShapedArray(float64[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
  tangent = Traced<ShapedArray(float64[], weak_type=True)>with<JaxprTrace(level=1/1)> with
    pval = (ShapedArray(float64[], weak_type=True), None)
    recipe = LambdaBinding() in trace JVPTrace(level=2/1)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

In [48]:
mee = meent.call_mee(backend=backend, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, fourier_order=fourier_order, wavelength=wavelength, period=period, ucell=ucell_2d_s, thickness=thickness, type_complex=type_complex, fft_type=0, improve_dft=True)

In [49]:
pois = ['ucell', 'thickness']

forward = mee.conv_solve
loss_fn = LossDeflector(x_order=0, y_order=0)

# case 1: Gradient
grad = mee.grad(pois, forward, loss_fn)
print(1, grad)

# case 2: SGD
optimizer = optax.sgd(learning_rate=1e-2)
mee.fit(pois, forward, loss_fn, optimizer)
print(3, mee.thickness*1E5)


1 (Array(0.49917828, dtype=float64), {'thickness': [Array(0.02179398, dtype=float64, weak_type=True)], 'ucell': Array([[[-0.04457337, -0.0672744 ,  0.25881959,  0.22487314,
          0.22402433,  0.24415874, -0.01060622,  0.18296631,
          0.09376066,  0.13521143],
        [-0.18894426, -0.16049066,  0.15925966,  0.94591111,
          0.98177941,  0.15896715, -0.07167921,  0.11117953,
          0.0299822 , -0.05946715]]], dtype=float64)})


UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float64[] wrapped in a JVPTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.Detail: Tracer from a higher level: Traced<ShapedArray(float64[], weak_type=True)>with<JVPTrace(level=2/1)> with
  primal = Traced<ShapedArray(float64[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
  tangent = Traced<ShapedArray(float64[], weak_type=True)>with<JaxprTrace(level=1/1)> with
    pval = (ShapedArray(float64[], weak_type=True), None)
    recipe = LambdaBinding() in trace JVPTrace(level=2/1)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError