# Julia

This example demonstrates how to use BlackJAX nested sampling with Julia implementations of likelihood and prior functions. Due to threading incompatibility between Julia and JAX's `pure_callback`, Julia runs in a separate process with JSON-RPC communication.

## Prerequisites

Install the required Python packages:
```bash
pip install git+https://github.com/handley-lab/blackjax
pip install numpy tqdm
```

Install Julia and required packages:
```julia
using Pkg
Pkg.add(["Distributions", "JSON", "Base64"])
```

## Setup Instructions

### 1. Create the Julia server

First, create a file `julia_server.jl` with the RPC server and your likelihood/prior functions:

```julia
# julia_server.jl
using LinearAlgebra, Distributions, JSON, Base64

function loglikelihood(theta)
    # theta is a matrix where each row is a parameter vector
    theta = convert(Matrix{Float64}, theta)
    dist = MvNormal(ones(5), 0.01 * I(5))
    return [logpdf(dist, theta[i, :]) for i in 1:size(theta, 1)]
end

function logprior(theta)
    # theta is a matrix where each row is a parameter vector
    theta = convert(Matrix{Float64}, theta)
    dist = MvNormal(zeros(5), I(5))
    return [logpdf(dist, theta[i, :]) for i in 1:size(theta, 1)]
end

# JSON-RPC server
function serve()
    while true
        try
            line = readline()
            request = JSON.parse(line)
            
            method = request["method"]
            params = request["params"]
            
            if method == "loglikelihood"
                # Decode base64 array
                theta = reinterpret(Float64, base64decode(params["data"]))
                theta = reshape(theta, params["shape"][2], params["shape"][1])'
                result = loglikelihood(theta)
            elseif method == "logprior"
                # Decode base64 array
                theta = reinterpret(Float64, base64decode(params["data"]))
                theta = reshape(theta, params["shape"][2], params["shape"][1])'
                result = logprior(theta)
            elseif method == "ping"
                result = "pong"
            else
                error("Unknown method: $method")
            end
            
            response = Dict("id" => request["id"], "result" => result)
            println(JSON.json(response))
            flush(stdout)
        catch e
            if isa(e, EOFError)
                break
            else
                response = Dict("id" => get(request, "id", nothing), "error" => string(e))
                println(JSON.json(response))
                flush(stdout)
            end
        end
    end
end

serve()
```

Save this as `julia_server.jl` in your working directory.

### 2. Run nested sampling with Julia functions

**Note:** Due to threading incompatibility, Julia runs in a separate process. This adds communication overhead, making it slower than direct integrations but still useful for leveraging Julia's scientific computing ecosystem.

In [None]:
import jax
import jax.numpy as jnp
import blackjax
from blackjax.ns.utils import finalise
import tqdm
import numpy as np
import subprocess
import json
import base64
import atexit

rng_key = jax.random.PRNGKey(0)

# Start Julia server
julia_process = subprocess.Popen(
    ['julia', 'julia_server.jl'],
    stdin=subprocess.PIPE,
    stdout=subprocess.PIPE,
    stderr=subprocess.PIPE,
    text=True,
    bufsize=1
)

# Register cleanup
atexit.register(lambda: julia_process.terminate())

# Counter for RPC request IDs
request_id = [0]

def julia_rpc(method, theta):
    """Send RPC request to Julia and get response"""
    theta_np = np.asarray(theta, dtype=np.float64)
    
    # Encode array as base64
    theta_bytes = theta_np.tobytes()
    theta_b64 = base64.b64encode(theta_bytes).decode('ascii')
    
    # Create request
    request_id[0] += 1
    request = {
        "id": request_id[0],
        "method": method,
        "params": {
            "data": theta_b64,
            "shape": list(theta_np.shape)
        }
    }
    
    # Send request
    julia_process.stdin.write(json.dumps(request) + '\n')
    julia_process.stdin.flush()
    
    # Get response
    response_line = julia_process.stdout.readline()
    response = json.loads(response_line)
    
    if "error" in response:
        raise RuntimeError(f"Julia error: {response['error']}")
    
    return np.array(response["result"], dtype=np.float64)

def wrap_fn(method, vmap_method='legacy_vectorized'):
    def numpy_wrapper(theta):
        return julia_rpc(method, theta)
    
    def jax_wrapper(x):
        out_shape = jax.ShapeDtypeStruct(x.shape[:-1], x.dtype)
        return jax.pure_callback(numpy_wrapper, out_shape, x, vmap_method=vmap_method)
    
    return jax_wrapper

loglikelihood_fn = wrap_fn("loglikelihood")
logprior_fn = wrap_fn("logprior")

algo = blackjax.nss(
    logprior_fn=logprior_fn,
    loglikelihood_fn=loglikelihood_fn,
    num_delete=50,
    num_inner_steps=20,
)

rng_key, sampling_key, initialization_key = jax.random.split(rng_key, 3)
live = algo.init(jax.random.normal(initialization_key, (1000, 5)))
step = jax.jit(algo.step)

dead_points = []

with tqdm.tqdm(desc="Dead points", unit=" dead points") as pbar:
    while (not live.logZ_live - live.logZ < -3):
        rng_key, subkey = jax.random.split(rng_key)
        live, dead = step(subkey, live)
        dead_points.append(dead)
        pbar.update(len(dead.particles))

ns_run = finalise(live, dead_points)

# Clean up Julia process
julia_process.terminate()