Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix mzi lattice example #253

Closed
wants to merge 3 commits into from
Closed

fix mzi lattice example #253

wants to merge 3 commits into from

Conversation

joamatab
Copy link
Contributor

trying to fix #251

Still not working, marking it as draft

@lucas-flexcompute

@joamatab joamatab marked this pull request as draft November 28, 2023 17:37
@lucas-flexcompute
Copy link

The plugin seems to be working correctly, but I get strange results when using the jnp.interp function inside a model. Here's a simple example: when the circuit model is calculated with the same wavelengths used in the original simulation, the result is correct:

def bend_model(cross_section: gf.typings.CrossSectionSpec = "xs_sc"):
    component = gf.components.bend_euler(cross_section=cross_section)
    s = gt.write_sparameters(
        component=component,
        filepath=PATH.sparameters_repo / "bend_filter.npz",
        layer_stack=layer_stack,
    )
    wavelengths = s.pop("wavelengths")

    @jax.jit
    def _model(wl=1.55):
        s11 = jnp.interp(wl, wavelengths, s["o1@0,o1@0"])
        s21 = jnp.interp(wl, wavelengths, s["o2@0,o1@0"])
        return {
            ("o1", "o1"): s11,
            ("o1", "o2"): s21,
            ("o2", "o1"): s21,
            ("o2", "o2"): s11,
        }

    return _model

c = gf.Component(name="bend")
ref = c.add_ref(gf.components.bend_euler(cross_section=cross_section))
c.add_ports(ref.ports)
x, _ = sax.circuit(
    c.get_netlist(), {"bend_euler": bend_model(cross_section=cross_section)}
)

wl = np.linspace(1.5, 1.6, 11)
s = x(wl=wl)
plt.plot(wl, jnp.abs(s[("o1", "o2")]) ** 2)
plt.ylabel("S21")
plt.xlabel("λ (µm)")

image

But when I try to use a more refined wavelength set wl = np.linspace(1.5, 1.6, 101), the interpolated points tend to zero as they get farther from the originally sampled points:
image

Any ideas to what might be causing this?

@joamatab
Copy link
Contributor Author

joamatab commented Nov 29, 2023

this worked well for me

i did not see any issue when increasing the number of points

import sax
import gdsfactory as gf 
from gplugins.common.config import PATH
import gplugins.tidy3d as gt
import numpy as np 
import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt 

from gdsfactory.pdk import get_active_pdk


pdk = get_active_pdk()
layer_stack = pdk.get_layer_stack()
cross_section = gf.cross_section.xs_sc

def bend_model(cross_section: gf.typings.CrossSectionSpec = "xs_sc"):
    component = gf.components.bend_euler(cross_section=cross_section)
    s = gt.write_sparameters(
        component=component,
        filepath=PATH.sparameters_repo / "bend_filter.npz",
        layer_stack=layer_stack,
    )
    wavelengths = s.pop("wavelengths")

    @jax.jit
    def _model(wl=1.55):
        s11 = jnp.interp(wl, wavelengths, s["o1@0,o1@0"])
        s21 = jnp.interp(wl, wavelengths, s["o2@0,o1@0"])
        return {
            ("o1", "o1"): s11,
            ("o1", "o2"): s21,
            ("o2", "o1"): s21,
            ("o2", "o2"): s11,
        }

    return _model

c = gf.Component(name="bend")
ref = c.add_ref(gf.components.bend_euler(cross_section=cross_section))
c.add_ports(ref.ports)
x, _ = sax.circuit(
    c.get_netlist(), {"bend_euler": bend_model(cross_section=cross_section)}
)

# wl = np.linspace(1.5, 1.6, 11) # works well
wl = np.linspace(1.5, 1.6, 101) # works well as well
s = x(wl=wl)
plt.plot(wl, jnp.abs(s[("o1", "o2")]) ** 2)
plt.ylabel("S21")
plt.xlabel("λ (µm)")
plt.show()

image

@flaport

@flaport
Copy link
Collaborator

flaport commented Nov 29, 2023

Hi, @lucas-flexcompute ,

The issue is that you're interpolating an S-parameter, which is complex valued. jnp.interp only works on real valued arrays (and it will take the real value of the array if complex).

Interpolating S-parameters is kinda difficult, as the phase wraps around, luckily in this case it seems the wavelength step is small enough to be able to unwrap the phase. I adapted the code you provided to correctly interpolate the phase:

def bend_model(cross_section: gf.typings.CrossSectionSpec = "xs_sc"):
    component = gf.components.bend_euler(cross_section=cross_section)
    s = gt.write_sparameters(
        component=component,
        filepath=PATH.sparameters_repo / "bend_filter.npz",
        layer_stack=layer_stack,
    )
    wavelengths = s.pop("wavelengths")
    amp = {k: jnp.abs(v) for k, v in s.items()}
    phi = {k: jnp.unwrap(jnp.angle(v)) for k, v in s.items()}

    @jax.jit
    def _model(wl=1.55):
        amp11 = jnp.interp(wl, wavelengths, amp["o1@0,o1@0"])
        amp21 = jnp.interp(wl, wavelengths, amp["o2@0,o1@0"])
        phi11 = jnp.interp(wl, wavelengths, phi["o1@0,o1@0"])
        phi21 = jnp.interp(wl, wavelengths, phi["o2@0,o1@0"])
        s11 = amp11 * jnp.exp(1j*phi11)
        s21 = amp21 * jnp.exp(1j*phi21)
        return {
            ("o1", "o1"): s11,
            ("o1", "o2"): s21,
            ("o2", "o1"): s21,
            ("o2", "o2"): s11,
        }

    return _model

c = gf.Component(name="bend")
ref = c.add_ref(gf.components.bend_euler(cross_section=cross_section))
c.add_ports(ref.ports)
x, _ = sax.circuit(
    c.get_netlist(), {"bend_euler": bend_model(cross_section=cross_section)}
)

wl = np.linspace(1.5, 1.6, 101)
s = x(wl=wl)
plt.plot(wl, 10*jnp.log10(jnp.abs(s[("o1", "o2")]) ** 2))
plt.ylabel("S21 (amp, dB)")
plt.xlabel("λ (µm)")
plt.twinx()
plt.plot(wl, jnp.unwrap(jnp.angle(s[("o1", "o2")])), color="C1")
plt.ylabel("S21 (phi, rad)")

image

PS: In cases where using small wavelength steps to be able to unwrap the phase is infeasible / expensive you can look into sax.grouped_interp to unwrap a phase for wavelengths containing a 'small' step and a 'big' step.

@lucas-flexcompute
Copy link

Ok, so that's the problem with the notebook, then, thanks @flaport! I guess the safest solution here would be to avoid interpolation altogether, but that means fixing the wavelength from the start. What do you think @joamatab?

@flaport
Copy link
Collaborator

flaport commented Nov 29, 2023

@lucas-flexcompute , I think interpolating is safe as long as you use enough frequency points, so I would recommend that. Adding more wavelengths is relatively cheap for FDTD anyway.

@lucas-flexcompute
Copy link

Using the complex interpolation by magnitude/phase works:

image

@lucas-flexcompute lucas-flexcompute mentioned this pull request Nov 29, 2023
@joamatab
Copy link
Contributor Author

Awesome, now it looks even better than before!

@joamatab
Copy link
Contributor Author

Thank you Floris and Lucas, we made another PR for this

#255

@joamatab joamatab closed this Nov 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Issue on page /notebooks/workflow_3_cascaded_mzi.html
3 participants