In [1]:
import jax
import jax.numpy as jnp
jax.config.update('jax_enable_x64', True)


In [2]:
a = jnp.array([
    [[0, 0, 0, 1, 1, 1, 1, 0, 0, 0, ]],
    [[1, 1, 1, 1, 0, 1, 1, 1, 1, 1, ]],
]) * 4 + 1 + 0j  # refractive index


In [4]:
%timeit aa=a.conj()

1.93 µs ± 129 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [3]:
%timeit aa = a.real + a.imag * -1j

9.54 µs ± 24.1 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [4]:
%timeit aa=a.conj()

1.91 µs ± 93.4 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [5]:
%timeit aa = a.real + a.imag * -1j

9.61 µs ± 17 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [2]:
import jax
import optax
import jax.numpy as jnp

from tqdm import tqdm

In [8]:
def _grad(params, forward, loss_fn):

    def forward_pass(params, forward, loss):
        result = forward(**params)
        loss_value = loss(result)
        return loss_value

    loss_value, grads = jax.value_and_grad(forward_pass)(params, forward, loss_fn)
    return loss_value, grads

def grad(self, pois, forward, loss_fn):
    params = {poi: (getattr(self, poi)) for poi in pois}
    _, grads = self._grad(params, forward, loss_fn)
    [setattr(self, poi, params[poi]) for poi in pois]

    return grads

def fit(self, pois, forward, loss_fn, optimizer, iteration=1):
    params = {poi: (getattr(self, poi)) for poi in pois}
    opt_state = optimizer.init(params)

    @jax.jit
    def step(params, opt_state):
        loss_value, grads = self._grad(params, forward, loss_fn)
        grads = {k: v.conj() for k, v in grads.items()}
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss_value

    for _ in tqdm(range(iteration)):
        params, opt_state, loss_value = step(params, opt_state)

    [setattr(self, poi, params[poi]) for poi in pois]

    return params

In [3]:

@jax.grad
def grad_loss(ucell):
    loss = ucell.conj()[0,0]
    return loss.real

ucell = jnp.array([
    [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, ],
    [1, 1, 1, 1, 0, 1, 1, 1, 1, 1, ],
]) * 4 + 1 + 3j  # refractive index

grad_ad = grad_loss(ucell)


In [4]:
grad_ad

Array([[1.-0.j, 0.-0.j, 0.-0.j, 0.-0.j, 0.-0.j, 0.-0.j, 0.-0.j, 0.-0.j,
        0.-0.j, 0.-0.j],
       [0.-0.j, 0.-0.j, 0.-0.j, 0.-0.j, 0.-0.j, 0.-0.j, 0.-0.j, 0.-0.j,
        0.-0.j, 0.-0.j]], dtype=complex128)

In [30]:
@jax.jit
def ff(arr):
    res = arr.conj()
    return res

In [31]:
def grad(ucell, forward):

    def forward_pass(ucell, forward):
        res = forward(ucell)
        res = res.real[0,0]
        return res

    loss_value, grads = jax.value_and_grad(forward_pass)(ucell, forward)
    return loss_value, grads


In [38]:
ucell = jnp.array([
    [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, ],
    [1, 1, 1, 1, 0, 1, 1, 1, 1, 1, ],
]) * 4 + 1 + 3j  # refractive index


In [39]:
grad(ucell, ff)

(Array(1., dtype=float64, weak_type=True),
 Array([[1.-0.j, 0.-0.j, 0.-0.j, 0.-0.j, 0.-0.j, 0.-0.j, 0.-0.j, 0.-0.j,
         0.-0.j, 0.-0.j],
        [0.-0.j, 0.-0.j, 0.-0.j, 0.-0.j, 0.-0.j, 0.-0.j, 0.-0.j, 0.-0.j,
         0.-0.j, 0.-0.j]], dtype=complex128))