In [1]:

import example
import jax.numpy as jnp
from jaxlib import xla_client

xla_client.register_custom_call_target(b"py_add_xla", example.py_add(), platform="gpu")

In [2]:

a = jnp.arange(128, dtype=jnp.float32)
b = jnp.arange(128, dtype=jnp.float32)

In [10]:

from jax import core

py_add_p = core.Primitive("py_add")

def py_add_prim(a, b):
    return py_add_p.bind(a, b)

@py_add_p.def_impl
def py_add_impl(a, b):
    return "hello world"

from jax._src import abstract_arrays


@py_add_p.def_abstract_eval
def py_add_abeval(a, b):
    assert a.shape == b.shape
    assert a.ndim == 1
    assert a.shape[0] == 128
    return abstract_arrays.ShapedArray(a.shape, a.dtype)

In [11]:

from jax.interpreters import xla
import numpy as np

xops = xla_client.ops

c = xla_client.XlaBuilder("comp_builder")

def py_add_translation(xla_builder, a, b):
    shape = xla_client.Shape.array_shape(np.dtype("float32"), (128,), (0,))
    opaque = b"This is opaque"
    print(f"Type of a is {type(a)}")
    print(f"Type of shape is {type(shape)}")
    return xops.CustomCallWithLayout(
        xla_builder,
        b"py_add_xla",
        operands=(a, b),
        shape_with_layout=shape,
        operand_shapes_with_layout=(shape, shape),
        opaque=opaque,
        # has_side_effect=False,
        # schedule=0,
        # api_version=1,
    )

xla.backend_specific_translations["gpu"][py_add_p] = py_add_translation

"""
    builder: XlaBuilder,
    call_target_name: bytes,
    operands: Sequence[XlaOp],
    shape_with_layout: Shape,
    operand_shapes_with_layout: Sequence[Shape],
    opaque: bytes = ...,
    has_side_effect: bool = ...,
    schedule: CustomCallSchedule = ...,
    api_version: CustomCallApiVersion = ...,
"""

shape = xla_client.Shape.array_shape(np.dtype("float32"), (128,), (0,))
shape

f32[128]{0}

In [12]:

import jax

c = jax.jit(py_add_prim)(a,b)

Type of a is <class 'jaxlib.xla_extension.XlaOp'>
Type of shape is <class 'jaxlib.xla_extension.Shape'>
Now use custom XLA!
This is opaque


In [13]:

c

DeviceArray([  0.,   2.,   4.,   6.,   8.,  10.,  12.,  14.,  16.,  18.,
              20.,  22.,  24.,  26.,  28.,  30.,  32.,  34.,  36.,  38.,
              40.,  42.,  44.,  46.,  48.,  50.,  52.,  54.,  56.,  58.,
              60.,  62.,  64.,  66.,  68.,  70.,  72.,  74.,  76.,  78.,
              80.,  82.,  84.,  86.,  88.,  90.,  92.,  94.,  96.,  98.,
             100., 102., 104., 106., 108., 110., 112., 114., 116., 118.,
             120., 122., 124., 126., 128., 130., 132., 134., 136., 138.,
             140., 142., 144., 146., 148., 150., 152., 154., 156., 158.,
             160., 162., 164., 166., 168., 170., 172., 174., 176., 178.,
             180., 182., 184., 186., 188., 190., 192., 194., 196., 198.,
             200., 202., 204., 206., 208., 210., 212., 214., 216., 218.,
             220., 222., 224., 226., 228., 230., 232., 234., 236., 238.,
             240., 242., 244., 246., 248., 250., 252., 254.],            dtype=float32)

In [14]:

d = py_add_prim(a,b)
d

'hello world'

In [None]:

d