-
Notifications
You must be signed in to change notification settings - Fork 5
/
optimizer.py
157 lines (114 loc) · 4.82 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#import sys
#sys.path.append("/home/seolho/paper/meent")
import torch
import torch.nn as nn
from meent.on_torch.convolution_matrix import to_conv_mat
from meent.rcwa import call_solver
import numpy as np
import copy
fourier_order = 2
mode_key = 2
dtype = 0
device = 0
'''
class Optimizer:
def __init__(self, optimizer, solver, target_name, lr = 0.001):
self.solver = solver
self.target_name = target_name
self.target = getattr(self.solver, target_name)
if not isinstance(target, torch.Tensor):
self.target = torch.Tensor(self.target)
self.target.requires_grad = True
if isinstance(optimizer,str):
optimizer = getattr(torch.optim, optimizer)
self.optimizer = optimizer([self.target], lr = lr)
def optimize(self, iterations = 1000):
for iteration in range(iterations):
E_conv_all = to_conv_mat(self.solver.ucell, fourier_order)
o_E_conv_all = to_conv_mat(1 / self.solver.ucell, fourier_order)
de_ri, de_ti, _, _, _ = self.solver.solve(self.solver.wavelength, E_conv_all, o_E_conv_all)
self.optimizer.zero_grad()
loss = self.loss(de_ti)
loss.backward()
self.optimizer.step()
setattr(self.solver, self.target_name, self.target)
print(loss)
def loss(self, de_ti):
return -de_ti[3, 2]
'''
class Optimizer:
def __init__(self, optimizer, solver, target, lr = 0.001):
self.solver = solver
self.target = target
if not isinstance(target, torch.Tensor):
self.target = torch.Tensor(self.target)
self.target.requires_grad = True
if isinstance(optimizer,str):
optimizer = getattr(torch.optim, optimizer)
self.optimizer = optimizer([self.target], lr = lr)
def optimize(self, iterations = 1000):
for iteration in range(iterations):
E_conv_all = to_conv_mat(self.solver.ucell, fourier_order)
o_E_conv_all = to_conv_mat(1 / self.solver.ucell, fourier_order)
de_ri, de_ti, _, _, _ = self.solver.solve(self.solver.wavelength, E_conv_all, o_E_conv_all)
self.optimizer.zero_grad()
loss = self.loss(de_ti)
loss.backward()
self.optimizer.step()
print(loss)
def loss(self, de_ti):
return -de_ti[3, 2]
def load_setting(mode_key, dtype, device):
grating_type = 2
pol = 1 # 0: TE, 1: TM
n_I = 1 # n_incidence
n_II = 1 # n_transmission
theta = 0
phi = 0
psi = 0 if pol else 90
wavelength = 900
ucell_materials = [1, 3.48]
fourier_order = 2
thickness, period = [1120.], [1000, 1000]
ucell = np.array(
[[
[3., 1., 1., 1., 3.],
[3., 1., 1., 1., 3.],
[3., 1., 1., 1., 3.],
[3., 1., 1., 1., 3.],
[3., 1., 1., 1., 3.],
]]
)
if mode_key == 0:
device = 0
type_complex = np.complex128 if dtype == 0 else np.complex64
ucell = ucell.astype(type_complex)
elif mode_key == 1: # JAX
jax.config.update('jax_platform_name', 'cpu') if device == 0 else jax.config.update('jax_platform_name', 'gpu')
if dtype == 0:
jax.config.update("jax_enable_x64", True)
type_complex = jnp.complex128
ucell = ucell.astype(jnp.float64)
ucell = jnp.array(ucell, dtype=jnp.float64)
else:
type_complex = jnp.complex64
ucell = ucell.astype(jnp.float32)
ucell = jnp.array(ucell, dtype=jnp.float32)
else: # Torch
device = torch.device('cpu') if device == 0 else torch.device('cuda')
type_complex = torch.complex128 if dtype == 0 else torch.complex64
if dtype == 0:
ucell = torch.tensor(ucell, dtype=torch.float64, device=device)
else:
ucell = torch.tensor(ucell, dtype=torch.float32, device=device)
return grating_type, pol, n_I, n_II, theta, phi, psi, wavelength, thickness, ucell_materials, period, fourier_order,\
type_complex, device, ucell
grating_type, pol, n_I, n_II, theta, phi, psi, wavelength, thickness, ucell_materials, period, fourier_order, \
type_complex, device, ucell = load_setting(mode_key, dtype, device)
ucell.requires_grad = True
solver = call_solver(mode_key, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi,
psi=psi, fourier_order=fourier_order, wavelength=wavelength, period=period, ucell=ucell,
ucell_materials=ucell_materials, thickness=thickness, device=device,
type_complex=type_complex, )
optim = Optimizer('Adam', solver, solver.ucell, lr = 0.001)
optim.optimize()