In [None]:
import sax
import meow as mw
import jax.numpy as jnp

In [None]:
def fresnel_mirror_ij(ni=1.0, nj=1.0, theta_0=0, pol="s"):
    """Model a (fresnel) interface between two refractive indices

    Args:
        ni: refractive index of the initial medium
        nj: refractive index of the final
        theta: angle of incidence measured from normal in vacuum
        pol: "s" or "p" polarization
    """

    # print(f"{ni=}; {nj=}")
    theta_i = jnp.arcsin(jnp.sin(theta_0) / ni)
    theta_j = jnp.arcsin(jnp.sin(theta_0) / nj)  # need to investigate
    cos_i = jnp.cos(theta_i)
    cos_j = jnp.cos(theta_j)

    if pol == "s":
        r_fresnel_ij = (ni * cos_i - nj * cos_j) / (ni * cos_i + nj * cos_j)
        # i->i reflection
        t_fresnel_ij = 2 * ni * cos_i / (ni * cos_i + nj * cos_j)  # i->j transmission
        t_fresnel_ji = 2 * nj * cos_j / (ni * cos_i + nj * cos_j)

    elif pol == "p":
        r_fresnel_ij = (nj * cos_i - ni * cos_j) / (nj * cos_i + ni * cos_j)
        # i->i reflection
        t_fresnel_ij = 2 * ni * cos_i / (nj * cos_i + ni * cos_j)  # i->j transmission
        t_fresnel_ji = 2 * nj * cos_j / (nj * cos_i + ni * cos_j)

    else:
        raise ValueError(f"polarization should be either 's' or 'p'")

    r_fresnel_ji = -r_fresnel_ij  # j -> i reflection

    sdict = {
        ("left", "left"): r_fresnel_ij,
        ("left", "right"): t_fresnel_ij,
        ("right", "left"): t_fresnel_ji,
        ("right", "right"): r_fresnel_ji,
    }
    return sdict

In [None]:
s = fresnel_mirror_ij(1, 2)

In [None]:
s

In [None]:
sax.sdense(s)

In [None]:
tester = {
    ("1", "1"): 11,
    ("1", "2"): 21,
    ("2", "1"): 12,
    ("2", "2"): 22,
}

In [None]:
S, pm = sax.sdense(tester)
jnp.abs(S)

In [None]:
S_scoo = sax.scoo(tester)

In [None]:
S_scoo

In [None]:
sax.scoo(sax.sdense(S_scoo))