# Meent Tutorial 2
## Gradient and Optimization with [JAX](https://jax.readthedocs.io) and [Optax](https://optax.readthedocs.io/)

In [1]:
import jax
import optax

import jax.numpy as jnp

import meent
from meent.on_jax.optimizer.loss import LossDeflector

In [2]:
backend = 1  # JAX

# common
pol = 0  # 0: TE, 1: TM

n_I = 1  # n_incidence
n_II = 1  # n_transmission

theta = 0 * jnp.pi / 180  # angle of incidence
phi = 0 * jnp.pi / 180  # angle of rotation

wavelength = 900

thickness = [500., 1000.]  # thickness of each layer, from top to bottom.
period = [1000.]  # length of the unit cell. Here it's 1D.

fourier_order = [10]

type_complex = jnp.complex128
jax.config.update('jax_enable_x64', True)

grating_type = 0  # grating type: 0 for 1D grating without rotation (phi == 0)

In [3]:
ucell_1d_m = jnp.array([
    [[0, 0, 0, 1, 1, 1, 1, 0, 0, 0, ]],
    [[1, 1, 1, 1, 0, 1, 1, 1, 1, 1, ]],
    ]) * 4. + 1.  # refractive index

## 2.1 Gradient

Gradient can be calculated with the help of `jax.value_and_grad` function.
Read this for further information: [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax-101/04-advanced-autodiff.html)

Optax is used for optimization. Like PyTorch, Optax also provides various loss functions and optimizers so users easily can utilize well-established implementations. Refer this tutorial: [Learn Optax](https://optax.readthedocs.io/en/latest/optax-101.html)


Gradient can be utilized to solve optimization problems. Here are examples that show couple of ways to get gradient or optimized values with or without predefined functions of meent.

### 2.1.1 Examples

In [4]:
mee = meent.call_mee(backend=backend, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, fourier_order=fourier_order, wavelength=wavelength, period=period, ucell=ucell_1d_m, thickness=thickness, type_complex=type_complex, fft_type=0, improve_dft=True)

pois = ['ucell', 'thickness']
forward = mee.conv_solve
loss_fn = LossDeflector(x_order=1, y_order=0)

# case 1: Gradient
grad = mee.grad(pois, forward, loss_fn)

print('ucell gradient:')
print(grad['ucell'])
print('thickness gradient:')
print(grad['thickness'])

ucell gradient:
[[[-0.05115948 -0.02534053 -0.00729983  0.07873275 -0.01842706
    0.09449833  0.08780079 -0.001232   -0.03641673 -0.04781187]]

 [[-0.1795402  -0.08599972 -0.2222932  -0.19380002  0.08989283
    0.05578499 -0.04559217 -0.13589897 -0.29833958  0.12877706]]]
thickness gradient:
[Array(0.00222085, dtype=float64, weak_type=True), Array(-0.00671622, dtype=float64, weak_type=True)]


## 2.2 Optimization

### 2.2.1 Examples

Example 1

In [5]:
mee = meent.call_mee(backend=backend, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, fourier_order=fourier_order, wavelength=wavelength, period=period, ucell=ucell_1d_m, thickness=thickness, type_complex=type_complex, fft_type=0, improve_dft=True)

pois = ['ucell', 'thickness']
forward = mee.conv_solve
loss_fn = LossDeflector(x_order=1, y_order=0)

# case 2: SGD
optimizer = optax.sgd(learning_rate=1e-2, momentum=0.9)
res = mee.fit(pois, forward, loss_fn, optimizer, iteration=3)

print('ucell final:')
print(res['ucell'])
print('thickness final:')
print(res['thickness'])

100%|██████████| 3/3 [00:04<00:00,  1.59s/it]

ucell final:
[[[1.00286486 1.00145571 1.00050162 4.9966673  5.00175321 4.99580683
   4.99617408 1.00015106 1.00214675 1.00275149]]

 [[5.00542326 4.99990074 5.00824614 5.00650358 0.99324857 4.99253641
   4.99834413 5.00367486 5.01333385 4.98859416]]]
thickness final:
[Array(499.9998925, dtype=float64), Array(1000.00039494, dtype=float64)]



