In [None]:
import fdtdx
import jax
import jax.numpy as jnp
from matplotlib import pyplot as plt
import pytreeclass as tc
from IPython.display import Video
%matplotlib inline
import numpy as np

In [None]:

def build_and_run_single_lambda(lam_m):
    
    N_CORE  = jnp.sqrt(12.0)
    DX      = 20e-9

    GRID_X_LENGTH = 6.5e-6
    GRID_Y_LENGTH = 20e-6
    GRID_Z_LENGTH = 1.5e-6  

    BUS_LENGTH = 2.0e-6
    BUS_WIDTH  = 0.2e-6

    NUM_OUTPUT_WG = 12
    OUTPUT_WG_LENGTH_X      = 2.0e-6
    OUTPUT_WG_WIDTH         = 0.2e-6
    OUTPUT_WG_GAP           = 0.64e-6
    OUTPUT_WG_CENTER_SPACING = OUTPUT_WG_WIDTH + OUTPUT_WG_GAP

    DUMMY_LENGTH_Z = 0.5e-6    

    # =============================================================================
    # Colors 
    # =============================================================================
    COLOR_WG        = fdtdx.colors.LIGHT_BLUE
    COLOR_TAPERRECT = fdtdx.colors.TAN
    COLOR_OPTREG    = fdtdx.colors.MAGENTA
    COLOR_SOURCE    = fdtdx.colors.ORANGE
    COLOR_DET       = fdtdx.colors.LIGHT_GREEN

    # =============================================================================
    # Materials
    # =============================================================================
    material_config = {
        "mat_bg":   fdtdx.Material(permittivity=fdtdx.constants.relative_permittivity_air),
        "mat_core": fdtdx.Material(permittivity=float(N_CORE**2)),
    }

    # =============================================================================
    # Simulation config + volume + PML
    # =============================================================================
    config = fdtdx.SimulationConfig(
        time=100e-15,
        resolution=float(DX),
        dtype=jnp.float32,
        courant_factor=0.99,
    )

    volume = fdtdx.SimulationVolume(
        partial_real_shape=(float(GRID_X_LENGTH), float(GRID_Y_LENGTH), float(GRID_Z_LENGTH)),
        material=material_config["mat_bg"],
    )

    object_list = [volume]
    placement_constraints = []

    bound_cfg = fdtdx.BoundaryConfig.from_uniform_bound(thickness=10, boundary_type="pml")
    boundaries, b_constraints = fdtdx.boundary_objects_from_config(bound_cfg, volume)
    object_list.extend(boundaries.values())
    placement_constraints.extend(b_constraints)

    # =============================================================================
    # Bus waveguide (left-aligned to volume)
    # =============================================================================
    bus = fdtdx.UniformMaterialObject(
        name="bus_waveguide",
        material=material_config["mat_core"],
        partial_real_shape=(float(BUS_LENGTH), float(BUS_WIDTH), float(DUMMY_LENGTH_Z)),
        color=COLOR_WG,
    )

    placement_constraints.append(
        bus.place_relative_to(
            boundaries['min_x'],
            axes=(0, 1, 2),
            own_positions=(-1, 0, 0),
            other_positions=(-1, 0, 0),
        )
    )
    object_list.append(bus)

    # =============================================================================
    # Wedge Area
    # =============================================================================
    L  = 0.5e-6     # taper length (x)
    w1 = 0.2e-6     # width at left face
    w2 = 10.0e-6    # width at right face
    Lz = float(DUMMY_LENGTH_Z)

    verts_centered = np.array(
        [
            (-L/2, -w1/2),
            (-L/2,  w1/2),
            ( L/2,  w2/2),
            ( L/2, -w2/2),
        ],
        dtype=np.float32,
    )

    taper_Lx = L 
    taper_Ly = max(w1, w2) 

    taper = fdtdx.ExtrudedPolygon(
        name="taper",
        materials=material_config,
        material_name="mat_core",
        axis=2,
        vertices=jnp.array(verts_centered, dtype=jnp.float32),
        partial_real_shape=(float(taper_Lx), float(taper_Ly), float(Lz)),
        color=COLOR_TAPERRECT,
    )
    object_list.append(taper)

    placement_constraints.append(
        taper.place_relative_to(
            bus,
            axes=(0, 1, 2),
            own_positions=(-1, 0, 0),
            other_positions=(+1, 0, 0),
        )
    )

    rectangle_length_x = 0.5e-6
    rectangle_width_y  = 10.0e-6

    rectangle = fdtdx.UniformMaterialObject(
        name="rectangle",
        material=material_config["mat_core"],
        partial_real_shape=(float(rectangle_length_x), float(rectangle_width_y), float(DUMMY_LENGTH_Z)),
        color=COLOR_TAPERRECT,
    )
    object_list.append(rectangle)

    placement_constraints.append(
        rectangle.place_relative_to(
            taper,
            axes=(0, 1, 2),
            own_positions=(-1, 0, 0),
            other_positions=(+1, 0, 0),
        )
    )

    # =============================================================================
    # Optimization region (Test UniformMaterialObject now)
    # =============================================================================

    optimization_region = fdtdx.UniformMaterialObject(
    name="optimization_region",
    material=material_config["mat_core"],
    partial_real_shape=(1.0e-6, 10.0e-6, float(DUMMY_LENGTH_Z)),
    color=COLOR_OPTREG,
    )
    object_list.append(optimization_region)

    placement_constraints.append(
        optimization_region.place_relative_to(
            rectangle,
            axes=(0, 1, 2),
            own_positions=(-1, 0, 0),
            other_positions=(+1, 0, 0),
        )
    )


    # =============================================================================
    # Output waveguides (bus and outputs share the same color)
    # =============================================================================
    output_wgs = []
    for i in range(NUM_OUTPUT_WG):
        out = fdtdx.UniformMaterialObject(
            name=f"out_wg_{i:02d}",
            material=material_config["mat_core"],
            partial_real_shape=(
                float(OUTPUT_WG_LENGTH_X),
                float(OUTPUT_WG_WIDTH),
                float(DUMMY_LENGTH_Z),
            ),
            color=COLOR_WG,
        )
        output_wgs.append(out)
        object_list.append(out)

    mid = (NUM_OUTPUT_WG - 1) / 2.0
    for i, out in enumerate(output_wgs):
        y_c = (i - mid) * OUTPUT_WG_CENTER_SPACING
        placement_constraints.append(
            out.place_relative_to(
                optimization_region,
                axes=(0, 1, 2),
                own_positions=(-1, 0, 0),
                other_positions=(+1, 0, 0),
                margins=(0.0, float(y_c), 0.0),
            )
        )

    # =============================================================================
    # Source
    # =============================================================================
    
    source = fdtdx.ModePlaneSource(
        name="input_source",
        wave_character=fdtdx.WaveCharacter(wavelength=float(lam_m)),
        direction="+",
        partial_grid_shape=(1, 20, 50),
        mode_index=0,
        color=COLOR_SOURCE,
    )
    object_list.append(source)

    placement_constraints.append(
        source.place_relative_to(
            volume,
            axes=(0, 1, 2),
            other_positions=(-1, 0, 0),
            own_positions=(-1, 0, 0),
            margins=(0.3e-6, 0.0, 0.0),
        )
    )

    # ==================================================================================
    # Mode Overlap Detectors (single broadband run) on the end of the  output waveguides
    # ==================================================================================
    
    period =  fdtdx.constants.wavelength_to_period(lam_m)

    output_detectors = [] 

    for i, out in enumerate(output_wgs):
        det = fdtdx.ModeOverlapDetector(
        name=f"det_out_{i:02d}",
        wave_characters=[fdtdx.WaveCharacter(wavelength=lam_m)],
        direction="+",
        filter_pol=None,   
        reduce_volume=False,   
        dtype=jnp.complex64,
        partial_grid_shape=(1, 10, 25),  # plane normal to propagation axis
        switch=fdtdx.OnOffSwitch(period=period, start_time=0.99*config.time, on_for_periods=40)
    )

        object_list.append(det)
        output_detectors.append(det)


        placement_constraints += [
        det.place_relative_to(out, axes=(0,), other_positions=(+1,), own_positions=(+1,)),
        det.place_at_center(out, axes=(1, 2)),
        det.same_size(out, axes=(1, 2)),
        ]



    # =============================================================================
    # Place objects, then apply params
    # =============================================================================
    key = jax.random.PRNGKey(7)
    key_place, key_mask = jax.random.split(key, 2)
    objects, arrays, params, config, _ = fdtdx.place_objects(
        object_list=object_list,
        config=config,
        constraints=placement_constraints,
        key=key_place,
    )


    arrays, objects, _ = fdtdx.apply_params(
        arrays=arrays,
        objects=objects,
        params=params,
        key=key,
    )


    # =============================================================================
    # Run simulation
    # =============================================================================
    final_state = fdtdx.run_fdtd(
        arrays=arrays,
        objects=objects,
        config=config,
        key=key,
    )
    _, arrays_out = final_state

    # -----------------------------
    # Compute T for this single lambda
    # -----------------------------
    ds = arrays_out.detector_states

    T = []

    for i in range(NUM_OUTPUT_WG):
        det = next(d for d in objects.forward_detectors if d.name == f"det_out_{i:02d}")
        T.append(det.compute_overlap(ds[det.name]))
    T = jnp.stack(T, axis=0)


    return {
    "lambda_m": float(lam_m),
    "Transmission": T,
    "arrays_out": arrays_out,
    "objects": objects,
}


# =============================
# Run two wavelengths and store
# ============================
results = []


for lam_m in np.linspace(1.54e-6, 1.56e-6, 10):
    out = build_and_run_single_lambda(
        lam_m,
    )
    results.append(out)
