Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

complex circuit simulation #667

Closed
joamatab opened this issue Sep 5, 2022 · 15 comments
Closed

complex circuit simulation #667

joamatab opened this issue Sep 5, 2022 · 15 comments
Labels
bug Something isn't working

Comments

@joamatab
Copy link
Contributor

joamatab commented Sep 5, 2022

"""FIXME.


"""
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
import sax
import gdsfactory as gf


def straight(wl=1.5, length=10.0, neff=2.4) -> sax.SDict:
    return sax.reciprocal({("o1", "o2"): jnp.exp(2j * jnp.pi * neff * length / wl)})


def mmi1x2() -> sax.SDict:
    """Returns an ideal 1x2 splitter."""
    return sax.reciprocal(
        {
            ("o1", "o2"): 0.5**0.5,
            ("o1", "o3"): 0.5**0.5,
        }
    )


def mmi2x2(*, coupling: float = 0.5) -> sax.SDict:
    """Returns an ideal 2x2 splitter.

    Args:
        coupling: power coupling coefficient.
    """
    kappa = coupling**0.5
    tau = (1 - coupling) ** 0.5
    return sax.reciprocal(
        {
            ("o1", "o4"): tau,
            ("o1", "o3"): 1j * kappa,
            ("o2", "o4"): 1j * kappa,
            ("o2", "o3"): tau,
        }
    )


def bend_euler(wl: float = 1.5, length: float = 20.0, loss: float = 50e-3) -> sax.SDict:
    """Returns bend Sparameters."""
    amplitude = jnp.asarray(10 ** (-loss * length / 20), dtype=complex)
    return {k: amplitude * v for k, v in straight(wl=wl, length=length).items()}


def phase_shifter(
    wl: float = 1.55,
    neff: float = 2.34,
    voltage: float = 0,
    length: float = 10,
    loss: float = 0.0,
) -> sax.SDict:
    """Returns simple phase shifter model.

    Args:
        wl: wavelength in um.
        neff: effective index.
        voltage: voltage per PI phase shift.
        length: in um.
        loss: in dB.
    """
    deltaphi = voltage * jnp.pi
    phase = 2 * jnp.pi * neff * length / wl + deltaphi
    amplitude = jnp.asarray(10 ** (-loss * length / 20), dtype=complex)
    transmission = amplitude * jnp.exp(1j * phase)
    return sax.reciprocal(
        {
            ("o1", "o2"): transmission,
        }
    )


models = {
    "bend_euler": bend_euler,
    "mmi1x2": mmi1x2,
    "mmi2x2": mmi2x2,
    "straight": straight,
    "taper": straight,
    "straight_heater_metal_undercut": phase_shifter,
    "compass": sax.models.passthru(10),
    "via": sax.models.passthru(10),
}


if __name__ == "__main__":
    c = gf.components.switch_tree(bend_s=None)
    c.show(show_ports=True)
    n = netlist = c.get_netlist_recursive(
        exclude_port_types=("electrical", "placement")
    )
    # netlist.pop(list(n.keys())[0])
    mzi_circuit, _ = sax.circuit(netlist=netlist, models=models)
    S = mzi_circuit(wl=1.55)
    wl = np.linspace(1.5, 1.6, 256)
    S = mzi_circuit(wl=wl)

    plt.figure(figsize=(14, 4))
    plt.title("MZI")

    plt.plot(1e3 * wl, jnp.abs(S["o1_0_0", "o2_1_9"]) ** 2)
    plt.plot(1e3 * wl, jnp.abs(S["o1_0_0", "o3_1_6"]) ** 2)
    plt.plot(1e3 * wl, jnp.abs(S["o1_0_0", "o2_1_5"]) ** 2)
    plt.plot(1e3 * wl, jnp.abs(S["o1_0_0", "o3_1_10"]) ** 2)
    plt.xlabel("λ [nm]")
    plt.ylabel("T")
    plt.grid(True)
    plt.show()
@joamatab joamatab added the bug Something isn't working label Sep 5, 2022
@flaport
Copy link
Collaborator

flaport commented Sep 6, 2022

I am not completely sure what you mean @joamatab.

But if your intention is to ignore some lower level components, this can be done by just adding a model for a subcircuit higher in the dependency tree. See for example the SAX Circuit documentation

On the other hand, missing components can be stubbed by just adding a passthru model for them. See the SAX models here.

@joamatab joamatab changed the title ValueError: Missing models. The following models are still missing to build the circuit complex circuit simulation Sep 8, 2022
@joamatab
Copy link
Contributor Author

joamatab commented Sep 8, 2022

I added with passthrough,

now I don't see any light coming out

image

@flaport

@joamatab
Copy link
Contributor Author

Because we use lossless couplers the sum of all powers should be T=100%

@jan-david-fischbach
Copy link
Collaborator

The euler bends impart some loss to the system

def bend_euler(wl: float = 1.5, length: float = 20.0, loss: float = 50e-3)

Eliminating that source of loss does however only yield a factor of 10 more power out:
4*6.25% = 25% of the input power.
When isolating the MZI I also only get 25% in each output, indicating that every MZI losses half the power somehow. I will continue the investigation.

@jan-david-fischbach
Copy link
Collaborator

So the problem seems to be with gf.components.straight_heater_metal. If I assign it a model by:

models = {
    "bend_euler": bend_euler,
    "mmi1x2": mmi1x2,
    "mmi2x2": mmi2x2,
    "straight": straight,
    "taper": straight,
    "straight_heater_metal_u_cfc14dc8": phase_shifter,
    #"compass": sax.models.passthru(10),
    #"via": sax.models.passthru(10),

the expected behavior is reached. As has been suggested by @flaport this leverages cutting the dependency tree by providing a model at a higher hierarchy.

But if your intention is to ignore some lower level components, this can be done by just adding a model for a subcircuit higher in the dependency tree. See for example the SAX Circuit documentation

However, I find the fact that the autogenerated hash has to be included in the model assignment rather uncomfortable (and probably less reliable/reproducible?).

@jan-david-fischbach
Copy link
Collaborator

jan-david-fischbach commented Nov 6, 2022

This is the dependecy tree before providing the model for straight_heater_metal_u_... (sorry for the overlapping labels)
dep_tree2
and after:
dep_tree1

This is the response I get from the full system: Note how it adds up to 100% for all wavelengths. (I introduced a length mismatch between the MZI arms for better visualization)
MZI_splitter_out

@jan-david-fischbach
Copy link
Collaborator

The changed code for reference

"""FIXME.


"""
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
import sax
import gdsfactory as gf
from sax.circuit import create_dag, draw_dag, _validate_net


def straight(wl=1.5, length=10.0, neff=2.4) -> sax.SDict:
    return sax.reciprocal({("o1", "o2"): jnp.exp(2j * jnp.pi * neff * length / wl)})


def mmi1x2() -> sax.SDict:
    """Returns an ideal 1x2 splitter."""
    return sax.reciprocal(
        {
            ("o1", "o2"): 0.5**0.5,
            ("o1", "o3"): 0.5**0.5,
        }
    )

def mmi2x2(*, coupling: float = 0.5) -> sax.SDict:
    """Returns an ideal 2x2 splitter.

    Args:
        coupling: power coupling coefficient.
    """
    kappa = coupling**0.5
    tau = (1 - coupling) ** 0.5
    return sax.reciprocal(
        {
            ("o1", "o4"): tau,
            ("o1", "o3"): 1j * kappa,
            ("o2", "o4"): 1j * kappa,
            ("o2", "o3"): tau,
        }
    )


def bend_euler(wl: float = 1.5, length: float = 20.0, loss: float = 50e-3) -> sax.SDict:
    """Returns bend Sparameters."""
    loss = 0
    amplitude = jnp.asarray(10 ** (-loss * length / 20), dtype=complex)
    return {k: amplitude * v for k, v in straight(wl=wl, length=length).items()}


def phase_shifter(
    wl: float = 1.55,
    neff: float = 2.34,
    voltage: float = 0,
    length: float = 10,
    loss: float = 0.0,
) -> sax.SDict:
    """Returns simple phase shifter model.

    Args:
        wl: wavelength in um.
        neff: effective index.
        voltage: voltage per PI phase shift.
        length: in um.
        loss: in dB.
    """
    deltaphi = voltage * jnp.pi
    phase = 2 * jnp.pi * neff * length / wl + deltaphi
    amplitude = jnp.asarray(10 ** (-loss * length / 20), dtype=complex)
    transmission = amplitude * jnp.exp(1j * phase)
    return sax.reciprocal(
        {
            ("o1", "o2"): transmission,
        }
    )


models = {
    "bend_euler": bend_euler,
    "mmi1x2": mmi1x2,
    "mmi2x2": mmi2x2,
    "straight": straight,
    "taper": straight,
    "straight_heater_metal_u_5d466884": phase_shifter,
    "compass": sax.models.passthru(10),
    "via": sax.models.passthru(10),
}


if __name__ == "__main__":
    print("starting test")
    c = gf.components.switch_tree(bend_s=None, noutputs=4)

    # c = gf.components.mzi1x2_2x2(combiner=gf.components.mmi2x2,
    #     delta_length=0,
    #     straight_x_top=gf.components.straight_heater_metal,
    #     length_x=None)

    # c = gf.components.mzi1x2_2x2(combiner=gf.components.mmi2x2, straight_x_top=gf.components.straight_heater_metal)

    # #c = gf.components.straight_heater_metal()
    

    c.show(show_ports=True)
    n = netlist = c.get_netlist_recursive(
        exclude_port_types=("electrical", "placement")
    )
    # netlist.pop(list(n.keys())[0])
    # c.plot_netlist()
    # print(netlist)
    mzi_circuit, _ = sax.circuit(netlist=netlist, models=models)

    dag = create_dag(netlist=_validate_net(netlist), models=models)
    draw_dag(dag)

    

    wl = np.linspace(1.5, 1.6, 256)
    S = mzi_circuit(wl=wl)

    plt.figure(figsize=(14, 4))
    plt.title("MZI")
    print(S.keys())

    # plt.plot(1e3 * wl, jnp.abs(S["o1", "o1"]) ** 2)
    # plt.plot(1e3 * wl, jnp.abs(S["o1", "o2"]) ** 2)
    # plt.plot(1e3 * wl, jnp.abs(S["o1", "o3"]) ** 2)

    plt.plot(1e3 * wl, jnp.abs(S["o1_0_0", "o2_1_9"]) ** 2)
    plt.plot(1e3 * wl, jnp.abs(S["o1_0_0", "o3_1_6"]) ** 2, "--")
    plt.plot(1e3 * wl, jnp.abs(S["o1_0_0", "o2_1_5"]) ** 2)
    plt.plot(1e3 * wl, jnp.abs(S["o1_0_0", "o3_1_10"]) ** 2, "--")

    #plt.plot(1e3 * wl, jnp.abs(S["o1_0_0", "o3_2_10"]) ** 2)

    plt.xlabel("λ [nm]")
    plt.ylabel("T")
    plt.grid(True)
    plt.show()

@joamatab
Copy link
Contributor Author

joamatab commented Nov 6, 2022

Thank you Jan-David,

how could we make easier to see the layout ports
image

and the netlist?
image

@jan-david-fischbach
Copy link
Collaborator

Definitely a good question. The lower graph is the dependency graph and not the netlist, correct? I think the way the netlist is displayed is quite good with plot_netlist. Except the position of the nodes should maybe be placed in the center of the respective component. Right now it is placed at the first port I think (not sure).

@jan-david-fischbach
Copy link
Collaborator

About the original error: I suspect it is caused by sax handling components starting with a - incorrectly. Which leads to one side of the mzi to be disconnected (heater implemented as sequence including -), resulting in the incorrect behavior as observed. Further investigation to come.

@flaport
Copy link
Collaborator

flaport commented Nov 6, 2022

  1. I think the hashes in the netlist come from autonaming. PR Improve reference naming #624 was merged to include reference aliases. I would recommend to add a reference alias to substitute the autogenerated name for the alias of your choice.

  2. In general, names in a netlist should be valid python identifiers (i.e. strings that can be used as python variable names). SAX wil try to do some variable name cleanup if this is not the case, but in general the behavior is pretty much undefined.

@jan-david-fischbach
Copy link
Collaborator

Concluding that we should try to avoid - from gdsfactory side...

@jan-david-fischbach
Copy link
Collaborator

@flaport Would it be thinkable to give a warning in cases the naming scheme is violated?

@flaport
Copy link
Collaborator

flaport commented Nov 8, 2022

I've open a SAX issue above. I'll try to implement something like this asap.

@joamatab
Copy link
Contributor Author

I think this issue has been solved, let us know if there is something that needs to be done on the gdsfactory side

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants