# Meent Tutorial 2
## PyTorch: Gradient and Optimization

In [None]:
import torch

import meent
from meent.on_torch.optimizer.loss import LossDeflector
from meent.on_torch.optimizer.optimizer import OptimizerTorch

In [None]:
backend = 2  # Torch

# common
pol = 0  # 0: TE, 1: TM

n_I = 1  # n_incidence
n_II = 1  # n_transmission

theta = 0 * torch.pi / 180
phi = 0 * torch.pi / 180

wavelength = 900

thickness = torch.tensor([500., 1000.])
period = torch.tensor([1000.])

fourier_order = [10]

type_complex = torch.complex128
device = torch.device('cpu')

grating_type = 0

In [None]:
ucell_1d_m = torch.tensor([
    [[0, 0, 0, 1, 1, 1, 1, 0, 0, 0, ]],
    [[1, 1, 1, 1, 0, 1, 1, 1, 1, 1, ]],
    ]) * 4 + 1.  # refractive index

## 2.1 Gradient
Gradient can be calculated with the help of `torch.autograd` function.
Read this for further information: [A GENTLE INTRODUCTION TO TORCH.AUTOGRAD](https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html)

Gradient can be utilized to solve optimization problems. Here are examples that show couple of ways to get gradient or optimized values with or without predefined functions of meent.

### 2.1.1 Examples

Example 1: manually get gradient


In [None]:
mee = meent.call_mee(backend=backend, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, fourier_order=fourier_order, wavelength=wavelength, period=period, ucell=ucell_1d_m, thickness=thickness, type_complex=type_complex, device=device, fft_type=0, improve_dft=True)

mee.ucell.requires_grad = True
mee.thickness.requires_grad = True

de_ri, de_ti = mee.conv_solve()
loss = de_ti[de_ti.shape[0] // 2 + 1]

loss.backward()
print('ucell gradient:')
print(mee.ucell.grad)
print('thickness gradient:')
print(mee.thickness.grad)

Example 2: using predefined 'grad' function in meent

In [None]:
mee = meent.call_mee(backend=backend, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, fourier_order=fourier_order, wavelength=wavelength, period=period, ucell=ucell_1d_m, thickness=thickness, type_complex=type_complex, device=device, fft_type=0, improve_dft=True)

pois = ['ucell', 'thickness']  # Parameter Of Interests

forward = mee.conv_solve

# can use custom loss function or predefined loss function in meent.

# loss_fn = LossDeflector(x_order=1)  # predefined in meent
loss_fn = lambda x: x[1][x[1].shape[0] // 2 + 1]  # custom

grad = mee.grad(pois, forward, loss_fn)
print('ucell gradient:')
print(grad['ucell'])
print('thickness gradient:')
print(grad['thickness'])


## 2.2 Optimization

### 2.2.1 Examples

Example 1

In [None]:
mee = meent.call_mee(backend=backend, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, fourier_order=fourier_order, wavelength=wavelength, period=period, ucell=ucell_1d_m, thickness=thickness, type_complex=type_complex, device=device, fft_type=0, improve_dft=True)

mee.ucell.requires_grad = True
mee.thickness.requires_grad = True
opt = torch.optim.SGD([mee.ucell, mee.thickness], lr=1E-2, momentum=0.9)

for _ in range(3):

    de_ri, de_ti = mee.conv_solve()

    center = de_ti.shape[0] // 2
    loss = de_ti[center + 1]

    print(loss)
    loss.backward()
    opt.step()
    opt.zero_grad()

Example 2

In [None]:
mee = meent.call_mee(backend=backend, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, fourier_order=fourier_order, wavelength=wavelength, period=period, ucell=ucell_1d_m, thickness=thickness, type_complex=type_complex, device=device, fft_type=0, improve_dft=True)


def forward_fn():

    de_ri, de_ti = mee.conv_solve()

    center = de_ti.shape[0] // 2
    loss = de_ti[center + 1]
    return loss

pois = ['ucell', 'thickness']
forward = forward_fn
loss_fn = lambda x: x
opt_torch = torch.optim.SGD
opt_options = {'lr': 1E-2,
               'momentum': 0.9,
               }

mee.fit(pois, forward, loss_fn, opt_torch, opt_options, iteration=3)

Example 3

In [None]:
mee = meent.call_mee(backend=backend, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, fourier_order=fourier_order, wavelength=wavelength, period=period, ucell=ucell_1d_m, thickness=thickness, type_complex=type_complex, device=device, fft_type=0, improve_dft=True)

pois = ['ucell', 'thickness']

forward = mee.conv_solve
loss_fn = LossDeflector(1, 0)

opt_torch = torch.optim.SGD
opt_options = {'lr': 1E-2,
               'momentum': 0.9,
               }

mee.fit(pois, forward, loss_fn, opt_torch, opt_options, iteration=3)