In [54]:
import jax
import jax.numpy as jnp
from jax.core import Literal, JaxprEqn, Var as JaxprVar, ClosedJaxpr
from dataclasses import dataclass
from typing import Any, Dict, List, Union

# === 1) Define a minimal Futhark AST ===


@dataclass
class FutharkType:
    base: str  # e.g. "f32"
    dims: List[int]  # e.g. [5] for f32[5]


@dataclass
class Var:
    name: str
    type: FutharkType


class Expr:
    pass


@dataclass
class LiteralExpr(Expr):
    value: Any
    type: FutharkType


@dataclass
class UnaryOp(Expr):
    op: str  # "sin", "exp", ...
    x: Expr


@dataclass
class BinaryOp(Expr):
    op: str  # "+", "*", etc.
    x: Expr
    y: Expr


@dataclass
class Let:
    var: Var
    expr: Expr


@dataclass
class Function:
    name: str
    params: List[Var]
    body: List[Let]
    result: Var


# === 2) Helpers: map JAX dtypes & avals to Futhark types ===


def jax_dtype_to_futhark(dtype) -> str:
    if dtype == jnp.float32:
        return "f32"
    if dtype == jnp.float64:
        return "f64"
    if dtype == jnp.int32:
        return "i32"
    if dtype == jnp.int64:
        return "i64"
    raise NotImplementedError(f"dtype {dtype}")


def aval_to_futhark_type(aval) -> FutharkType:
    base = jax_dtype_to_futhark(aval.dtype)
    dims = list(aval.shape)
    return FutharkType(base, dims)


# === 3) The main translator: ClosedJaxpr → Futhark AST ===


def jaxpr_to_futhark_ast(cj: ClosedJaxpr, name: str = "f") -> Function:
    jaxpr = cj.jaxpr
    env: Dict[JaxprVar, Union[Var, LiteralExpr]] = {}
    lets: List[Let] = []

    # 3.1) parameters
    params: List[Var] = []
    print(jaxpr.invars)
    for v in jaxpr.invars:
        vt = aval_to_futhark_type(v.aval)
        fv = Var(str(v), vt)
        env[v] = fv
        params.append(fv)

    # 3.2) walk equations
    for eqn in jaxpr.eqns:
        # resolve inputs
        inputs: List[Expr] = []
        for iv in eqn.invars:
            if isinstance(iv, Literal):
                lit_t = aval_to_futhark_type(iv.aval)
                inputs.append(LiteralExpr(iv.val, lit_t))
            else:
                inputs.append(env[iv])

        # pick the right AST node
        prim = eqn.primitive.name
        if prim == "add":
            expr = BinaryOp("+", inputs[0], inputs[1])
        elif prim == "mul":
            expr = BinaryOp("*", inputs[0], inputs[1])
        elif prim == "sin":
            expr = UnaryOp("sin", inputs[0])
        # … you can add more primitives here …
        else:
            raise NotImplementedError(f"primitive {prim}")

        # bind the output var
        outv = eqn.outvars[0]
        vt = aval_to_futhark_type(outv.aval)
        fv = Var(str(outv), vt)
        env[outv] = fv
        lets.append(Let(fv, expr))

    # 3.3) final result
    result = env[jaxpr.outvars[0]]
    return Function(name, params, lets, result)


# === 4) (Optional) Pretty-printer for the AST ===


def print_futhark_type(t: FutharkType) -> str:
    dims = "".join(f"[{d}]" for d in t.dims)
    return f"{t.base}{dims}"


def print_expr(e: Expr) -> str:
    if isinstance(e, LiteralExpr):
        # format floats with .0
        v = e.value
        if isinstance(v, float) and "." not in repr(v):
            v = f"{v:.1f}"
        return repr(v)
    if isinstance(e, Var):
        return e.name
    if isinstance(e, UnaryOp):
        return f"{e.op}({print_expr(e.x)})"
    if isinstance(e, BinaryOp):
        return f"({print_expr(e.x)} {e.op} {print_expr(e.y)})"
    raise NotImplementedError


def print_function(fn: Function) -> str:
    ps = ", ".join(f"{p.name}: {print_futhark_type(p.type)}" for p in fn.params)
    body = "\n  ".join(f"let {l.var.name} = {print_expr(l.expr)}" for l in fn.body)
    return f"let {fn.name} ({ps}) =\n" f"  {body}\n" f"  in {fn.result.name}"


# === 5) Example usage ===


def f(x):
    return jnp.sin(x) + x


x0 = jnp.ones((5,), jnp.float32)
cj = jax.make_jaxpr(f)(x0)
fn = jaxpr_to_futhark_ast(cj, name="my_f")
print(print_function(fn))


[Var(id=4750816960):float32[5]]
let my_f (Var(id=4750816960):float32[5]: f32[5]) =
  let Var(id=4749379072):float32[5] = sin(Var(id=4750816960):float32[5])
  let Var(id=4749308736):float32[5] = (Var(id=4749379072):float32[5] + Var(id=4750816960):float32[5])
  in Var(id=4749308736):float32[5]


In [57]:
import tempfile, subprocess

# 1) Generate Futhark code
src = print_function(fn)

# 2) Write to a temp file
with open("temp.fut", "w") as f:
    f.write(src)
    print("temp.fut")


temp.fut


In [None]:
# If you installed futhark locally:
subprocess.run(["futhark", "c", "-o", "prog", fut_path], check=True)

# Or, if using Docker:
# subprocess.run([
#   "docker","run","--rm",
#   "-v", f"{os.getcwd()}:/src","-w","/src",
#   "futhark/futhark","c","-o","prog", f"/src/{os.path.basename(fut_path)}"
# ], check=True)
