# Meent Tutorial 2
## Gradient and Optimization with [JAX](https://jax.readthedocs.io) and [Optax](https://optax.readthedocs.io/)

In [2]:
import jax
import optax

import jax.numpy as jnp

import meent

In [3]:
backend = 1  # JAX

# common
pol = 0  # 0: TE, 1: TM

n_top = 1  # n_topncidence
n_bot = 1  # n_transmission

theta = 0 * jnp.pi / 180  # angle of incidence
phi = 0 * jnp.pi / 180  # angle of rotation

wavelength = 900

thickness = [500., 1000.]  # thickness of each layer, from top to bottom.
period = [1000.]  # length of the unit cell. Here it's 1D.

fto = [10]

type_complex = jnp.complex128

In [4]:
ucell_1d_m = jnp.array([
    [[0, 0, 0, 1, 1, 1, 1, 0, 0, 0, ]],
    [[1, 1, 1, 1, 0, 1, 1, 1, 1, 1, ]],
    ]) * 4. + 1.  # refractive index

## 2.1 Gradient

Gradient can be calculated with the help of `jax.value_and_grad` function.
Read this for further information: [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax-101/04-advanced-autodiff.html)

Optax is used for optimization. Like PyTorch, Optax also provides various loss functions and optimizers so users easily can utilize well-established implementations. Refer this tutorial: [Learn Optax](https://optax.readthedocs.io/en/latest/optax-101.html)


Gradient can be utilized to solve optimization problems. Here are examples that show couple of ways to get gradient or optimized values with or without predefined functions of meent.

### 2.1.1 Examples

In [5]:
class Loss:
    def __call__(self, meent_result, *args, **kwargs):
        res_psi, res_te, res_ti = meent_result.res, meent_result.res_te_inc, meent_result.res_tm_inc
        de_ti = res_psi.de_ti
        center = [a // 2 for a in de_ti.shape]
        res = de_ti[center[0], center[1]+1]

        return res


loss_fn = Loss()

In [6]:
mee = meent.call_mee(backend=backend, pol=pol, n_top=n_top, n_bot=n_bot, theta=theta, phi=phi, fto=fto, wavelength=wavelength, period=period, ucell=ucell_1d_m, thickness=thickness, type_complex=type_complex)

pois = ['ucell', 'thickness']
forward = mee.conv_solve

# case 1: Gradient
grad = mee.grad(pois, forward, loss_fn)

print('ucell gradient:')
print(grad['ucell'])
print('thickness gradient:')
print(grad['thickness'])

ucell gradient:
[[[-0.05114874 -0.02533636 -0.00729883  0.07873582 -0.01841166
    0.09447967  0.08779338 -0.0012304  -0.03640632 -0.04779842]]

 [[-0.17959986 -0.08614187 -0.22233491 -0.19389416  0.08978906
    0.05564021 -0.04575985 -0.13595162 -0.29835993  0.12867445]]]
thickness gradient:
[ 0.00222043 -0.00671415]


## 2.2 Optimization

### 2.2.1 Examples

Example 1

In [7]:
mee = meent.call_mee(backend=backend, pol=pol, n_top=n_top, n_bot=n_bot, theta=theta, phi=phi, fto=fto, wavelength=wavelength, period=period, ucell=ucell_1d_m, thickness=thickness, type_complex=type_complex)

pois = ['ucell', 'thickness']
forward = mee.conv_solve

# case 2: SGD
optimizer = optax.sgd(learning_rate=1e-2, momentum=0.9)
res = mee.fit(pois, forward, loss_fn, optimizer, iteration=3)

print('ucell final:')
print(res['ucell'])
print('thickness final:')
print(res['thickness'])

100%|██████████| 3/3 [00:06<00:00,  2.12s/it]

ucell final:
[[[1.00286423 1.00145549 1.00050169 4.99666797 5.00175318 4.99580863
   4.99617526 1.00015109 1.00214635 1.00275083]]

 [[5.0054235  4.99990456 5.00824621 5.0065062  0.99325253 4.99254125
   4.99835018 5.00367578 5.01333396 4.9885967 ]]]
thickness final:
[ 499.99989253 1000.00039487]



