In [23]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax

class Model(eqx.Module):
    a: jax.Array  # Using JAX arrays to handle complex numbers
    
    def __init__(self):
        self.a = jnp.array(1.0 + 1.0j)
        
    def __call__(self):
        return jnp.imag(self.a) + jnp.real(self.a)

model = Model()

def loss_fn(model):
    return model()

print("Initial model:", model)

# Custom transformation to conjugate gradients
def conjugate_grads_transform():
    def init_fn(params):
        # Returns an empty state
        return None

    def update_fn(updates, state, params=None):
        # Conjugate the gradients if they are complex
        updates = jax.tree_util.tree_map(
            lambda g: jnp.conj(g) if jnp.iscomplexobj(g) else g, updates
        )
        return updates, state

    return optax.GradientTransformation(init_fn, update_fn)

# Combine conjugate transformation with Adam optimizer
opt = optax.chain(
    conjugate_grads_transform(),  # Conjugate gradients
    optax.adam(0.1)             # Adam optimizer
)
opt_state = opt.init(model)

# Training loop
for i in range(10):
    grad = eqx.filter_grad(loss_fn)(model)
    updates, opt_state = opt.update(grad, opt_state)
    model = eqx.apply_updates(model, updates)
    print(f"Step {i + 1}, model.a: {model.a}")

Initial model: Model(a=weak_c64[])
Step 1, model.a: (0.9292898178100586+0.9292898178100586j)
Step 2, model.a: (0.8585798740386963+0.8585798740386963j)
Step 3, model.a: (0.7878694534301758+0.7878694534301758j)
Step 4, model.a: (0.7171594500541687+0.7171594500541687j)
Step 5, model.a: (0.6464494466781616+0.6464494466781616j)
Step 6, model.a: (0.5757391452789307+0.5757391452789307j)
Step 7, model.a: (0.505029022693634+0.505029022693634j)
Step 8, model.a: (0.434318870306015+0.434318870306015j)
Step 9, model.a: (0.36360877752304077+0.36360877752304077j)
Step 10, model.a: (0.29289859533309937+0.29289859533309937j)


In [17]:
class Model(eqx.Module):
    a: jax.Array  # Using JAX arrays to handle complex numbers
    b: jax.Array
    
    def __init__(self):
        self.a = jnp.array(1.0)
        self.b = jnp.array(1.0)
        
    def __call__(self):
        return self.a + self.b

model = Model()

def loss_fn(model):
    return model()

print("Initial model:", model)

# Custom transformation to conjugate gradients
def conjugate_grads_transform():
    def init_fn(params):
        # Returns an empty state
        return None

    def update_fn(updates, state, params=None):
        # Conjugate the gradients if they are complex
        updates = jax.tree_util.tree_map(
            lambda g: jnp.conj(g) if jnp.iscomplexobj(g) else g, updates
        )
        return updates, state

    return optax.GradientTransformation(init_fn, update_fn)

# Combine conjugate transformation with Adam optimizer
opt = optax.adam(0.1)
opt_state = opt.init(model)

# Training loop
for i in range(10):
    grad = eqx.filter_grad(loss_fn)(model)
    updates, opt_state = opt.update(grad, opt_state)
    model = eqx.apply_updates(model, updates)
    print(f"Step {i + 1}, model.a: {model.a}, model.b: {model.b}")

Initial model: Model(a=weak_f32[], b=weak_f32[])
Step 1, model.a: 0.9000006914138794, model.b: 0.9000006914138794
Step 2, model.a: 0.8000016808509827, model.b: 0.8000016808509827
Step 3, model.a: 0.7000020742416382, model.b: 0.7000020742416382
Step 4, model.a: 0.6000030636787415, model.b: 0.6000030636787415
Step 5, model.a: 0.5000039935112, model.b: 0.5000039935112
Step 6, model.a: 0.400004506111145, model.b: 0.400004506111145
Step 7, model.a: 0.3000052869319916, model.b: 0.3000052869319916
Step 8, model.a: 0.20000602304935455, model.b: 0.20000602304935455
Step 9, model.a: 0.1000068262219429, model.b: 0.1000068262219429
Step 10, model.a: 7.495284080505371e-06, model.b: 7.495284080505371e-06


In [25]:
from jax import random

In [27]:
in_channels = 100
out_channels = 150

scale = 1.0 / (in_channels * out_channels)
v = random.uniform(random.key(0),(in_channels, 40, 50),minval=-scale, maxval=+scale)
bypass = random.uniform(random.key(0),(out_channels, in_channels),minval=-scale, maxval=+scale)

In [28]:
result1 = jnp.tensordot(bypass, v, axes=(1, 0))
result2 = jnp.einsum("oi,ixy->oxy", bypass, v)

In [36]:
%timeit jax.block_until_ready(jnp.einsum("oi,ixy->oxy", bypass, v))

1.16 ms ± 41.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
