In [1]:
import numpy as np
import matplotlib.pyplot as plt
import sys
from copy import copy
import sys; sys.path.append('../../../tidy3d')

import tmm # need to pip install tmm for this
import tidy3d as td
from tidy3d import web

In [2]:
def grad_tmm(freqs, theta, bck_eps, slab_eps, slab_ds):
    """Use numerical derivative to compute the gradient of transmission through slabs with respect
    to each slab permittivity and upper and lower boundary. Uses tmm to get the analytical
    transmission results.
    """
    delta = 1e-4
    grad_eps = np.zeros((len(slab_eps), len(freqs)))  # gradient of each slab's permittivity
    grad_bot = np.zeros((len(slab_eps), len(freqs)))  # gradient of shifting each slab's bottom boundary
    grad_top = np.zeros((len(slab_eps), len(freqs)))  # gradient of shifting each slab's top boundary
    T = np.zeros((len(freqs),))
    freq0 = freqs[len(freqs) // 2]

    def slab_ds_boundary(d_list, islab, dl, boundary):
        """Compute a list of slab thicknesses if the ``islab``-th slab's ``"top"`` or ``"bottom"``
        ``boundary`` is perturbed by ``dl``."""
        slab_ds_pert = copy(d_list)
        if boundary == "bottom":
            slab_ds_pert[islab] -= dl
            slab_ds_pert[islab - 1] += dl
        if boundary == "top":
            slab_ds_pert[islab] += dl
            slab_ds_pert[islab + 1] -= dl

        return slab_ds_pert

    eps_list = [bck_eps] + slab_eps + [bck_eps]
    n_list = np.sqrt(eps_list)
    d_list = [np.inf] + slab_ds + [np.inf]
    for ifreq, freq in enumerate(freqs):
        wavelength = td.C_0 / freq
        for islab in range(len(slab_eps)):
            n_list_p = n_list.copy()
            n_list_p[islab + 1] = np.sqrt(slab_eps[islab] + delta / 2)
            n_list_m = n_list.copy()
            n_list_m[islab + 1] = np.sqrt(slab_eps[islab] - delta / 2)
            d_list_p_top = slab_ds_boundary(d_list, islab + 1, delta / 2, "top")
            d_list_m_top = slab_ds_boundary(d_list, islab + 1, -delta / 2, "top")
            d_list_p_bot = slab_ds_boundary(d_list, islab + 1, delta / 2, "bottom")
            d_list_m_bot = slab_ds_boundary(d_list, islab + 1, -delta / 2, "bottom")
            
            # using a constant-k rather than constant-angle formulation, so angle is
            # frequency-dependent
            th = np.arcsin(np.sin(theta) * freq0 / freq)

            T[ifreq] = tmm.coh_tmm("p", n_list, d_list, th, wavelength)["T"]
            t_deps_p = tmm.coh_tmm("p", n_list_p, d_list, th, wavelength)["T"]
            t_deps_m = tmm.coh_tmm("p", n_list_m, d_list, th, wavelength)["T"]
            t_dd_p_top = tmm.coh_tmm("p", n_list, d_list_p_top, th, wavelength)["T"]
            t_dd_m_top = tmm.coh_tmm("p", n_list, d_list_m_top, th, wavelength)["T"]
            t_dd_p_bot = tmm.coh_tmm("p", n_list, d_list_p_bot, th, wavelength)["T"]
            t_dd_m_bot = tmm.coh_tmm("p", n_list, d_list_m_bot, th, wavelength)["T"]

            grad_eps[islab, ifreq] = (t_deps_p - t_deps_m) / delta
            grad_top[islab, ifreq] = (t_dd_p_top - t_dd_m_top) / delta
            grad_bot[islab, ifreq] = (t_dd_p_bot - t_dd_m_bot) / delta

    # Normalize gradients to sum up to 1 at every frequency
    grad_eps = grad_eps / np.linalg.norm(grad_eps, axis=0)
    grad_top = grad_top / np.linalg.norm(grad_top, axis=0)
    grad_bot = grad_bot / np.linalg.norm(grad_bot, axis=0)
        
    tmm_data = {"grad_eps": grad_eps, "grad_top": grad_top, "grad_bot": grad_bot, "T": T}
    return tmm_data

In [3]:
def run(
    freq0=2e14,
    num_freqs=1,
    dl=.0125,
    slab_eps=[2**2, 1.8**2, 1.5**2, 1.9**2],
    slab_ds=[0.5, 0.25, 0.5, 0.25],
    bck_eps=1.4**2,
    angle_theta = 0,
):

    # frequency setup
    wavelength = td.C_0 / freq0
    fwidth = freq0 / 10.0
    ind_freq0 = num_freqs // 2
    freqs = np.linspace(freq0 - fwidth, freq0 + fwidth, num_freqs)

    # geometry setup
    bck_medium = td.Medium(permittivity=bck_eps)

    space_above = 2
    space_below = 2

    length_x = 0.5
    center_x = 0.0
    length_z = space_below + np.sum(slab_ds) + space_above
    sim_size = (length_x, 0, length_z)

    # make structures
    slabs = []
    z_start = -np.sum(slab_ds) / 2
    for (d, eps) in zip(slab_ds, slab_eps):
        slab = td.Structure(
            geometry=td.Box(center=[0, 0, z_start + d / 2], size=[td.inf, td.inf, d]),
            medium=td.Medium(permittivity=eps),
        )
        slabs.append(slab)
        z_start += d

    # source setup
    gaussian = td.GaussianPulse(freq0=freq0, fwidth=fwidth)
    src_z = -length_z / 2 + 3 * space_below / 4

    source = td.PlaneWave(
        center=(center_x, 0, src_z),
        size=(td.inf, td.inf, 0),
        source_time=gaussian,
        direction="+",
        angle_theta=angle_theta,
        angle_phi=0,
        pol_angle=0,
    )

    # boundaries
    boundary_x = td.Boundary.bloch_from_source(
        source=source, domain_size=sim_size[0], axis=0, medium=bck_medium
    )
    boundary_spec = td.BoundarySpec(x=boundary_x, y=td.Boundary.periodic(), z=td.Boundary.pml(num_layers=40))

    # monitors
    mnt_z = length_z / 2 - wavelength
    monitor_1 = td.DiffractionMonitor(
        center=[0.0, 0.0, mnt_z],
        size=[td.inf, td.inf, 0],
        freqs=freqs,
        name="diffraction",
        normal_dir="+",
    )

    monitor_2 = td.FieldMonitor(
        center=[0.0, 0.0, mnt_z],
        size=[td.inf, td.inf, 0],
        freqs=freqs,
        name="field",
    )

    # monitors to record the fields and permittivity needed for the gradient computation
    # they need to span 
    monitor_g1 = td.FieldMonitor(
        center=[0.0, 0.0, 0.0],
        size=[td.inf, td.inf, np.sum(slab_ds)],
        freqs=freqs,
        name="field_grad",
    )
    monitor_g2 = td.PermittivityMonitor(
        center=[0.0, 0.0, 0.0],
        size=[td.inf, td.inf, np.sum(slab_ds)],
        freqs=freqs,
        name="eps_grad",
    )

    # make simulation
    sim = td.Simulation(
        size=sim_size,
        grid_spec=td.GridSpec.uniform(dl=dl),
        structures=slabs,
        sources=[source],
        monitors=[monitor_1, monitor_2, monitor_g1, monitor_g2],
        run_time=50 / fwidth,
        boundary_spec=boundary_spec,
        medium=bck_medium,
        shutoff=1e-8,
    )
    
    # fig, ax = plt.subplots(1, 3)
    # sim.plot(y=0, ax=ax[0])

    # run forward simulation
    sim_data = web.run(sim, task_name="multilayer_forward")
    
    # output amplitudes
    amps = sim_data["diffraction"].amps.sel(polarization="p").values.ravel()

    # setup and run adjoint simulation
    # source needs to be centered at the simulation center_x, not the monitor one
    source_adj = td.PlaneWave(
        center=(center_x, 0, mnt_z),
        size=(td.inf, td.inf, 0),
        source_time=gaussian,
        direction="-",
        angle_theta=angle_theta,
        angle_phi=0,
        pol_angle=0,
    )
    
    

    # adjoint boundaries (bloch vector flips sign) because of source direction = "-"
    boundary_x = td.Boundary.bloch_from_source(
        source=source_adj, domain_size=sim_size[0], axis=0, medium=bck_medium
    )
    boundary_spec = td.BoundarySpec(x=boundary_x, y=td.Boundary.periodic(), z=td.Boundary.pml(num_layers=40))
    sim_adj = td.Simulation(
        size=sim_size,
        grid_spec=td.GridSpec.uniform(dl=dl),
        structures=slabs,
        sources=[source_adj],
        monitors=[monitor_g1, monitor_g2],
        run_time=50 / fwidth,
        boundary_spec=boundary_spec,
        medium=bck_medium,
        shutoff=1e-8,
    )
    # sim_adj.plot(y=0, ax=ax[1])
    # plt.show()
    import pdb; pdb.set_trace()
    sim_data_adj = web.run(sim_adj, task_name="multilayer_adjoint")

    # compute gradient w.r.t. slab permittivity and the top of the slab boundaries
    Exf, Eyf = sim_data["field_grad"].Ex, sim_data["field_grad"].Ey
    Dzf = sim_data["field_grad"].Ez * sim_data["eps_grad"].eps_zz
    Exa, Eya = sim_data_adj["field_grad"].Ex, sim_data_adj["field_grad"].Ey
    Dza = sim_data_adj["field_grad"].Ez * sim_data_adj["eps_grad"].eps_zz
    eps_list = [bck_eps] + slab_eps + [bck_eps]
    xs = np.linspace(-length_x / 2, length_x / 2, 100)
    zs = sim_data["field_grad"].Ex.z

    grad_top_adj = []
    grad_eps_adj = []

    for islab, slab in enumerate(slabs):
        """epsilon gradient"""
        zinds = np.where((zs > slab.geometry.bounds[0][2]) * (zs < slab.geometry.bounds[1][2]))[0]
        e_fwd = np.stack(
            (
                sim_data["field_grad"].Ex.isel(x=slice(1, -1), z=zinds),
                sim_data["field_grad"].Ey.isel(x=slice(1, -1), z=zinds),
                sim_data["field_grad"].Ez.isel(x=slice(1, -1), z=zinds),
            ),
            axis=0,
        )
        e_adj = np.stack(
            (
                sim_data_adj["field_grad"].Ex.isel(x=slice(1, -1), z=zinds),
                sim_data_adj["field_grad"].Ey.isel(x=slice(1, -1), z=zinds),
                sim_data_adj["field_grad"].Ez.isel(x=slice(1, -1), z=zinds),
            ),
            axis=0,
        )
        grad = np.sum(e_fwd * e_adj, axis=(0, 1, 2, 3)) # shape (nfreqs)
        # The 1j * np.conj(amps) due to dL/dE_fwd comes in here
        grad_eps_adj.append(grad * 1j * np.conj(amps))

        """top boundary gradient"""
        z_top = slab.geometry.bounds[1][2]
        d_eps = eps_list[islab + 1] - eps_list[islab + 2]
        d_eps_inv = (1 / eps_list[islab + 1] - 1 / eps_list[islab + 2])

        ex_fwd = Exf.interp(x=xs, z=z_top)
        ex_adj = Exa.interp(x=xs, z=z_top) 
        ey_fwd = Eyf.interp(x=xs, z=z_top) 
        ey_adj = Eya.interp(x=xs, z=z_top)
        dz_fwd = Dzf.interp(x=xs, z=z_top) 
        dz_adj = Dza.interp(x=xs, z=z_top) 

        integrand = d_eps * (ex_fwd * ex_adj + ey_fwd * ey_adj)
    #     print(float(np.real(integrand.isel(x=0)*1j*np.conj(amps))))
    #     print(float(np.real((d_eps_inv * (dz_fwd * dz_adj)).isel(x=0)*1j*np.conj(amps))))
        integrand -= d_eps_inv * (dz_fwd * dz_adj)

        grad = integrand.sum(dim=["x", "y"]) # shape (nfreqs)
        # The 1j * np.conj(amps) due to dL/dE_fwd comes in here
        grad_top_adj.append(grad * 1j * np.conj(amps))

    g_eps_adj = np.real(grad_eps_adj)
    grad_eps_adj = g_eps_adj / np.linalg.norm(g_eps_adj, axis=0)
    # print(grad_eps_adj)

    g_top_adj = np.real(grad_top_adj)
    grad_top_adj = g_top_adj / np.linalg.norm(g_top_adj, axis=0)
    # print(grad_top_adj)

    # compute numerical gradient
    tmm_data = grad_tmm(freqs, angle_theta, bck_eps, slab_eps, slab_ds)
    grad_eps = tmm_data["grad_eps"]
    grad_top = tmm_data["grad_top"]

    # print(grad_eps)

    # for p in 2 * np.pi * np.linspace(-1, 1, 2001):
    #     g_phased = np.real(np.array(grad_eps_adj) * np.exp(1j * p))
    #     g_phased /= np.linalg.norm(g_phased)
    #     diff = np.linalg.norm(g_phased - grad_eps)
    #     if diff < 0.1:
    #         print(p / 2 / np.pi, g_phased, diff)

#     print("Difference in epsilon gradient: ", np.linalg.norm(grad_eps - grad_eps_adj, axis=0))
#     print("Difference in top boundary gradient: ", np.linalg.norm(grad_top - grad_top_adj, axis=0))
    
    grad_data = {"eps": grad_eps_adj, "top": grad_top_adj}
    
    return sim_data, sim_data_adj, grad_data, tmm_data



In [4]:
def plot_grad_eps(sim_data, grad_eps_adj, grad_eps):
    freqs = sim_data["diffraction"].amps.f
    amps = sim_data["diffraction"].amps.sel(polarization="p").values.ravel()
    fig, ax = plt.subplots(1, 3, figsize=(12, 3.5), constrained_layout=True)
    prop_cycle = plt.rcParams['axes.prop_cycle']
    colors = prop_cycle.by_key()['color']
    for islab in range(0, 4):
        ax[0].plot(freqs, grad_eps_adj[islab], color=colors[islab + 1])
    for islab in range(0, 4):
        ax[0].plot(freqs, grad_eps[islab], color=colors[islab + 1], linestyle="dashed")
    ax[0].set_xlabel("frequency")
    ax[0].set_ylabel("Gradient (adjoint/numerical)")
    ax[0].legend(["slab 1", "slab 2", "slab 3", "slab 4"])

    for islab in range(0, 4):
        ax[1].plot(freqs, np.abs((grad_eps - grad_eps_adj)[islab]), color=colors[islab+1])
    ax[1].set_xlabel("frequency")
    ax[1].set_ylabel("Gradient error")
    ax[1].legend(["slab 1", "slab 2", "slab 3", "slab 4"])

    ax[2].plot(freqs, np.abs(amps), color="k")
    ax[2].plot(freqs, np.real(amps))
    ax[2].plot(freqs, np.imag(amps))
    ax[2].set_xlabel("frequency")
    ax[2].set_ylabel("Transmission")
    ax[2].legend(["abs", "re", "im"])
    
def plot_grad_top(sim_data, grad_top_adj, grad_top):
    freqs = sim_data["diffraction"].amps.f
    amps = sim_data["diffraction"].amps.sel(polarization="p").values.ravel()
    fig, ax = plt.subplots(1, 3, figsize=(12, 3.5), constrained_layout=True)
    prop_cycle = plt.rcParams['axes.prop_cycle']
    colors = prop_cycle.by_key()['color']
    for islab in range(0, 4):
        ax[0].plot(freqs, grad_top_adj[islab], color=colors[islab + 1])
    for islab in range(0, 4):
        ax[0].plot(freqs, grad_top[islab], color=colors[islab + 1], linestyle="dashed")
    ax[0].set_xlabel("frequency")
    ax[0].set_ylabel("Gradient (adjoint/numerical)")
    ax[0].legend(["slab 1", "slab 2", "slab 3", "slab 4"])

    for islab in range(0, 4):
        ax[1].plot(freqs, np.abs((grad_top - grad_top_adj)[islab]), color=colors[islab+1])
    ax[1].set_xlabel("frequency")
    ax[1].set_ylabel("Gradient error")
    ax[1].legend(["slab 1", "slab 2", "slab 3", "slab 4"])

    ax[2].plot(freqs, np.abs(amps), color="k")
    ax[2].plot(freqs, np.real(amps))
    ax[2].plot(freqs, np.imag(amps))
    ax[2].set_xlabel("frequency")
    ax[2].set_ylabel("Transmission")
    ax[2].legend(["abs", "re", "im"])

In [None]:
# Normal incidence
sim_data, sim_data_adj, grad_data, tmm_data = run()
plot_grad_eps(sim_data, grad_data["eps"], tmm_data["grad_eps"])
plot_grad_top(sim_data, grad_data["top"], tmm_data["grad_top"])

> [0;32m/var/folders/jx/9y0mtn3s3zzb6mzgmsw6s6gr0000gn/T/ipykernel_67992/2134217066.py[0m(146)[0;36mrun[0;34m()[0m
[0;32m    144 [0;31m    [0;31m# plt.show()[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    145 [0;31m    [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 146 [0;31m    [0msim_data_adj[0m [0;34m=[0m [0mweb[0m[0;34m.[0m[0mrun[0m[0;34m([0m[0msim_adj[0m[0;34m,[0m [0mtask_name[0m[0;34m=[0m[0;34m"multilayer_adjoint"[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    147 [0;31m[0;34m[0m[0m
[0m[0;32m    148 [0;31m    [0;31m# compute gradient w.r.t. slab permittivity and the top of the slab boundaries[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  amps


array([0.93398277-0.29135976j])


ipdb>  amps


array([0.93398277-0.29135976j])


In [None]:
tmm_data = grad_tmm(
    freqs=[2e14],
    slab_eps=[2**2, 1.8**2, 1.5**2, 1.9**2],
    slab_ds=[0.5, 0.25, 0.5, 0.25],
    bck_eps=1.4**2,
    theta = 0
)

print(tmm_data['eps_grad'])

In [None]:
# # 30 degree angle
# sim_data, sim_data_adj, grad_data, tmm_data = run(angle_theta=np.pi/6)
# plot_grad_eps(sim_data, grad_data["eps"], tmm_data["grad_eps"])
# plot_grad_top(sim_data, grad_data["top"], tmm_data["grad_top"])

In [None]:
# # plot forward and adjoint simulation
# fig, ax = plt.subplots(1, 2, figsize=(6, 8))
# sim_data.simulation.plot(y=0, ax=ax[0])
# sim_data_adj.simulation.plot(y=0, ax=ax[1])