# Meent Tutorial 2
## PyTorch: Gradient and Optimization

In [1]:
import torch

import meent
from meent.on_torch.optimizer.loss import LossDeflector
from meent.on_torch.optimizer.optimizer import OptimizerTorch

In [2]:
backend = 2  # Torch

# common
pol = 0  # 0: TE, 1: TM

n_I = 1  # n_incidence
n_II = 1  # n_transmission

theta = 0 * torch.pi / 180
phi = 0 * torch.pi / 180

wavelength = 900

thickness = torch.tensor([500., 1000.])
period = torch.tensor([1000.])

fourier_order = [10]

type_complex = torch.complex128
device = torch.device('cpu')

grating_type = 0

In [3]:
ucell_1d_m = torch.tensor([
    [[0, 0, 0, 1, 1, 1, 1, 0, 0, 0, ]],
    [[1, 1, 1, 1, 0, 1, 1, 1, 1, 1, ]],
    ]) * 4 + 1.  # refractive index

## 2.1 Gradient
Gradient can be calculated with the help of `torch.autograd` function.
Read this for further information: [A GENTLE INTRODUCTION TO TORCH.AUTOGRAD](https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html)

Gradient can be utilized to solve optimization problems. Here some examples show couple of ways to get gradient or optimized values with or without predefined functions of meent.

### 2.1.1 case 1
manually get gradient

In [4]:
mee = meent.call_mee(backend=backend, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, fourier_order=fourier_order, wavelength=wavelength, period=period, ucell=ucell_1d_m, thickness=thickness, type_complex=type_complex, device=device, fft_type=0, improve_dft=True)

mee.ucell.requires_grad = True
mee.thickness.requires_grad = True

de_ri, de_ti = mee.conv_solve()
loss = de_ti[de_ti.shape[0] // 2 + 1]

loss.backward()
print('ucell gradient:')
print(mee.ucell.grad)
print('thickness gradient:')
print(mee.thickness.grad)

ucell gradient:
tensor([[[-0.0512, -0.0253, -0.0073,  0.0787, -0.0184,  0.0945,  0.0878,
          -0.0012, -0.0364, -0.0478]],

        [[-0.1795, -0.0860, -0.2223, -0.1938,  0.0899,  0.0558, -0.0456,
          -0.1359, -0.2983,  0.1288]]])
thickness gradient:
tensor([ 0.0022, -0.0067])


### 2.1.2 case 2
using predefined 'grad' function in meent

In [5]:
mee = meent.call_mee(backend=backend, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, fourier_order=fourier_order, wavelength=wavelength, period=period, ucell=ucell_1d_m, thickness=thickness, type_complex=type_complex, device=device, fft_type=0, improve_dft=True)

pois = ['ucell', 'thickness']  # Parameter Of Interests

forward = mee.conv_solve

# can use custom loss function or predefined loss function in meent.

# loss_fn = LossDeflector(x_order=1)  # predefined in meent
loss_fn = lambda x: x[1][x[1].shape[0] // 2 + 1]  # custom

grad = mee.grad(pois, forward, loss_fn)
print('ucell gradient:')
print(grad['ucell'])
print('thickness gradient:')
print(grad['thickness'])


ucell gradient:
tensor([[[-0.0512, -0.0253, -0.0073,  0.0787, -0.0184,  0.0945,  0.0878,
          -0.0012, -0.0364, -0.0478]],

        [[-0.1795, -0.0860, -0.2223, -0.1938,  0.0899,  0.0558, -0.0456,
          -0.1359, -0.2983,  0.1288]]])
thickness gradient:
tensor([ 0.0022, -0.0067])


## 2.2 Optimization

### 2.2.1 case 1

In [6]:
mee = meent.call_mee(backend=backend, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, fourier_order=fourier_order, wavelength=wavelength, period=period, ucell=ucell_1d_m, thickness=thickness, type_complex=type_complex, device=device, fft_type=0, improve_dft=True)

mee.ucell.requires_grad = True
mee.thickness.requires_grad = True
opt = torch.optim.SGD([mee.ucell, mee.thickness], lr=1E-2, momentum=0.9)

for _ in range(100):

    de_ri, de_ti = mee.conv_solve()

    center = de_ti.shape[0] // 2
    loss = de_ti[center + 1]

    print(loss)
    loss.backward()
    opt.step()
    opt.zero_grad()

tensor(0.0291, dtype=torch.float64, grad_fn=<SelectBackward0>)
tensor(0.0267, dtype=torch.float64, grad_fn=<SelectBackward0>)
tensor(0.0235, dtype=torch.float64, grad_fn=<SelectBackward0>)
tensor(0.0185, dtype=torch.float64, grad_fn=<SelectBackward0>)
tensor(0.0140, dtype=torch.float64, grad_fn=<SelectBackward0>)
tensor(0.0092, dtype=torch.float64, grad_fn=<SelectBackward0>)
tensor(0.0057, dtype=torch.float64, grad_fn=<SelectBackward0>)
tensor(0.0035, dtype=torch.float64, grad_fn=<SelectBackward0>)
tensor(0.0018, dtype=torch.float64, grad_fn=<SelectBackward0>)
tensor(0.0008, dtype=torch.float64, grad_fn=<SelectBackward0>)
tensor(0.0012, dtype=torch.float64, grad_fn=<SelectBackward0>)
tensor(0.0023, dtype=torch.float64, grad_fn=<SelectBackward0>)
tensor(0.0030, dtype=torch.float64, grad_fn=<SelectBackward0>)
tensor(0.0033, dtype=torch.float64, grad_fn=<SelectBackward0>)
tensor(0.0029, dtype=torch.float64, grad_fn=<SelectBackward0>)
tensor(0.0023, dtype=torch.float64, grad_fn=<SelectBack

### 2.2.2 case 2

In [7]:
mee = meent.call_mee(backend=backend, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, fourier_order=fourier_order, wavelength=wavelength, period=period, ucell=ucell_1d_m, thickness=thickness, type_complex=type_complex, device=device, fft_type=0, improve_dft=True)


def forward_fn():

    de_ri, de_ti = mee.conv_solve()

    center = de_ti.shape[0] // 2
    loss = de_ti[center + 1]
    return loss

pois = ['ucell', 'thickness']
forward = forward_fn
loss_fn = lambda x: x
opt_torch = torch.optim.SGD
opt_options = {'lr': 1E-2,
               'momentum': 0.9,
               }

mee.fit(pois, forward, loss_fn, opt_torch, opt_options)

step 0, loss: 0.029141591152553024
step 1, loss: 0.026655024513172282
step 2, loss: 0.02348784876570143
step 3, loss: 0.018518856342830476
step 4, loss: 0.014045038431714268
step 5, loss: 0.009240928804785936
step 6, loss: 0.005675249611651967
step 7, loss: 0.0035335352418931344
step 8, loss: 0.0018367942977096337
step 9, loss: 0.0007733766260211415
step 10, loss: 0.001228100591650995
step 11, loss: 0.0023393996402068327
step 12, loss: 0.0029773505075635716
step 13, loss: 0.0033075310517868732
step 14, loss: 0.0029168526201331017
step 15, loss: 0.002292950201690544
step 16, loss: 0.0016667459159062903
step 17, loss: 0.0010469976100209674
step 18, loss: 0.0005248982259839158
step 19, loss: 0.00021465831290599008
step 20, loss: 0.00010435964182362401
step 21, loss: 5.935564914121781e-05
step 22, loss: 4.123373181406379e-05
step 23, loss: 8.568562465758101e-05
step 24, loss: 0.00017595533399906548
step 25, loss: 0.00026939594800027183
step 26, loss: 0.00033817415116938675
step 27, loss: 0

In [8]:
mee = meent.call_mee(backend=backend, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, fourier_order=fourier_order, wavelength=wavelength, period=period, ucell=ucell_1d_m, thickness=thickness, type_complex=type_complex, device=device, fft_type=0, improve_dft=True)

pois = ['ucell', 'thickness']

forward = mee.conv_solve
loss_fn = LossDeflector(1, 0)

opt_torch = torch.optim.SGD
opt_options = {'lr': 1E-2,
               'momentum': 0.9,
               }

mee.fit(pois, forward, loss_fn, opt_torch, opt_options)

step 0, loss: 0.029141591152553024
step 1, loss: 0.026655024513172282
step 2, loss: 0.02348784876570143
step 3, loss: 0.018518856342830476
step 4, loss: 0.014045038431714268
step 5, loss: 0.009240928804785936
step 6, loss: 0.005675249611651967
step 7, loss: 0.0035335352418931344
step 8, loss: 0.0018367942977096337
step 9, loss: 0.0007733766260211415
step 10, loss: 0.001228100591650995
step 11, loss: 0.0023393996402068327
step 12, loss: 0.0029773505075635716
step 13, loss: 0.0033075310517868732
step 14, loss: 0.0029168526201331017
step 15, loss: 0.002292950201690544
step 16, loss: 0.0016667459159062903
step 17, loss: 0.0010469976100209674
step 18, loss: 0.0005248982259839158
step 19, loss: 0.00021465831290599008
step 20, loss: 0.00010435964182362401
step 21, loss: 5.935564914121781e-05
step 22, loss: 4.123373181406379e-05
step 23, loss: 8.568562465758101e-05
step 24, loss: 0.00017595533399906548
step 25, loss: 0.00026939594800027183
step 26, loss: 0.00033817415116938675
step 27, loss: 0