In [1]:
"""
DPH Inference Library with JAX
--------------------------------
Supports:
- Decorator-based definition of DPH models
- User-defined parameter decoding from flat vector z
- Central difference autodiff (SVGD/VI compatible)
- Integration with JAX `custom_call` backend (C++)
"""

import os
os.environ['JAX_ENABLE_X64'] = 'True'

import struct

import jax
import jax.numpy as jnp
from functools import wraps
from jaxlib.hlo_helpers import custom_call
# from jax.interpreters import mlir
from jax.interpreters.mlir import ir
import jax.extend as jex
import ctypes
import jax.core

import jax.interpreters.mlir as mlir
from jax.interpreters import ad
#from jax import ad


# # Load C++ shared library with dph_pmf_param
# ctypes.CDLL("./libdph_param.so")

# Import the custom JAX extension module
# import ptdalgorithms._core
import ptdalgorithms

# ----------------------
# Decorator for DPH model
# ----------------------
def dph_model(decode_fn):
    """Decorator for DPH models with user-defined decode function."""
    def decorator(f):
        @wraps(f)
        def wrapper(_theta, t_obs):
            theta = decode_fn(_theta)
            return f(theta, t_obs)
        return wrapper
    return decorator

# ----------------------
# Decorator for registering a named DPH kernel
# ----------------------
def register_ptd_kernel(name: str):
    def decorator(func):
        prim = jex.core.Primitive(name=name.encode() if isinstance(name, str) else name)

        @prim.def_abstract_eval
        def abstract(theta_aval, times_aval):
            return jax.core.ShapedArray(times_aval.shape, jnp.float64)

        def lowering(ctx, graph_ptr, theta, times):
            avals = ctx.avals_in
            theta_layout = list(reversed(range(len(avals[0].shape))))
            times_layout = list(reversed(range(len(avals[1].shape))))
            graph_ptr_layout = list(reversed(range(len(avals[2].shape))))

            out_type = ir.RankedTensorType.get(avals[1].shape, mlir.dtype_to_ir_type(jnp.dtype(jnp.float64)))

            # encode the pointer as an opaque string
            opaque = struct.pack("Q", graph_ptr)  # unsigned 64-bit

            call_op = custom_call(
                call_target_name=name.encode(),
                result_types=[out_type],
                operands=[theta, times],
                operand_layouts=[theta_layout, times_layout],
                result_layouts=[times_layout],
                opaque=opaque,
            )
            
            return call_op.results
        
        mlir.register_lowering(prim, lowering, platform="cpu")

#        @prim.def_jvp
        def jvp(primals, tangents):

            _graph, theta, times = primals
            _, dtheta, dtimes = tangents

            f = prim.bind
            eps = 1e-5
            f0 = f(_graph, theta, times)

            if dtheta is not None:
                def compute_partial_grad(i):
                    theta_plus = theta.at[i].add(eps)
                    theta_minus = theta.at[i].add(-eps)
                    return dtheta[i] * (f(_graph, theta_plus, times) - f(_graph, theta_minus, times)) / (2 * eps)
                
                grad_theta = jnp.sum(jax.vmap(compute_partial_grad)(jnp.arange(theta.shape[0])), axis=0)
            else:
                grad_theta = jnp.zeros_like(times)

            return f0, grad_theta

        ad.primitive_jvps[prim] = jvp

        def wrapper(theta, times):
            return prim.bind(theta, times)

        return wrapper
    return decorator

# ----------------------
# Register DPH kernel
# ----------------------
dph_pmf = register_ptd_kernel("jax_graph_method_pmf")(lambda theta, times: None)
# dph_pdf = register_dph_kernel("dph_pdf_param")(lambda theta, times: None)
# # and so on...

# ----------------------
# Example decode function
# ----------------------
# def decode_z(z):
#     m = 4
#     alpha = jax.nn.softmax(z[:m])
#     T = jnp.reshape(jax.nn.softmax(z[m:m+m*m].reshape((m, m)), axis=1) * 0.9, (m, m))
#     t = 1.0 - jnp.sum(T, axis=1)
#     return alpha, T, t
def decode_z(theta):
    return theta

# ----------------------
# Example model using decorator
# ----------------------
@dph_model(decode_z)
def dph_negloglik(theta, t_obs):
    pmfs = dph_pmf(theta, t_obs)
    return -jnp.sum(jnp.log(pmfs + 1e-8))

# # ----------------------
# # Usage Example
# # ----------------------
# if __name__ == "__main__":

# Test with a function that actually depends on theta
def test_dph_pmf(theta, t_obs):
    # Create a simple parameterized function that depends on theta
    # This mimics a simple discrete distribution
    probs = jax.nn.softmax(theta[0] * jnp.ones_like(t_obs) + theta[1] * t_obs)
    return probs

# ptdalgorithms.Graph.register_custom_call_target(
#     "ptdalgorithms_jax_pmf",
#     test_dph_pmf,
#     platform="cpu",
#     platform_version=None,
#     has_side_effects=False,
#     num_outputs=1,
#     num_inputs=2,
#     input_dtypes=[jnp.float64, jnp.int64],
#     output_dtypes=[jnp.float64],
#     input_layouts=[None, None],
#     output_layouts=[None],
#     opaque=None,
#     name="ptdalgorithms_jax_pmf",
#     doc="Custom call target for DPH PMF computation",
# )

# # Override the dph_pmf for testing
# dph_pmf = test_dph_pmf

m = 2
z = jnp.zeros((m))
z = z.at[0].set(0.5)  # Example parameter setting
z = z.at[1].set(0.1)  # Example parameter setting

t_obs = jnp.array([3, 4, 5], dtype=jnp.int64)

print("Testing with parameterized mock dph_pmf function...")
print(f"Parameters: {z}")

# Test the pmf function directly
pmfs = test_dph_pmf(z, t_obs)
print(f"PMF values: {pmfs}")

loss = jax.jit(dph_negloglik)(z, t_obs)
print(f"Loss: {loss}")

grads = jax.grad(dph_negloglik)(z, t_obs)
print(f"Gradients: {grads}")

# Test with different parameters to verify gradients change
z2 = jnp.array([0.5, 0.5])
loss2 = jax.jit(dph_negloglik)(z2, t_obs)
grads2 = jax.grad(dph_negloglik)(z2, t_obs)
print(f"\nWith different parameters {z2}:")
print(f"Loss: {loss2}")
print(f"Gradients: {grads2}")




Testing with parameterized mock dph_pmf function...
Parameters: [0.5 0.1]
PMF values: [0.30060961 0.33222499 0.3671654 ]


TypeError: register_ptd_kernel.<locals>.decorator.<locals>.lowering() missing 1 required positional argument: 'times'