In [20]:
import jax
import jax.numpy as jnp
import numpy as np

from jax.experimental import loops
from jax.ops import index_update, index_add

In [65]:
def pava(X, Y):
    # maybe add shape checking here...
    idx = jnp.argsort(X)
    X_s = X[idx]
    Y_s = Y[idx]
    
    X_unique, Y_unique = make_unique(X_s, Y_s)
    return X_unique, Y_unique
    

def make_unique(X, Y, eps=1e-8):
    # Get rid of duplicate independent variables and associated observations
    # by averaging.
    # 
    # This is a direct copy of the implementation from sklearn here:
    # https://github.com/scikit-learn/scikit-learn/blob/7e1e6d09bcc2eaeba98f7e737aac2ac782f0e5f1/sklearn/_isotonic.pyx#L66
    # but modified to assume that all weights are equal 1
    n_unique = len(np.unique(X))
    n_samples = X.shape[0]
    
    
    with loops.Scope() as s:
        
        s.X_out = jnp.empty(n_unique)
        s.Y_out = jnp.empty_like(s.X_out)

        s.w_cur = 0
        s.x_cur = X[0]
        s.y_cur = 0
        s.i = 0
        s.x = 0
        
        def true_fun(args):
            i, w_cur, j, X_out, Y_out, x, x_cur, y_cur = args
            X_out = index_update(X_out, i, x_cur)
            Y_out = index_update(Y_out, i, y_cur / w_cur)
            x_cur = x
            y_cur = Y[j]
            i += 1
            w_cur = 1
            return i, w_cur, X_out, Y_out, x_cur, y_cur
            
        def false_fun(args):
            i, w_cur, j, X_out, Y_out, x, x_cur, y_cur = args
            y_cur += Y[j]
            w_cur += 1
            return i, w_cur, X_out, Y_out, x_cur, y_cur
    
        for j in s.range(n_samples):
            s.x = X[j]
            
            s.i, s.w_cur, s.X_out, s.Y_out, s.x_cur, s.y_cur = jax.lax.cond(
                s.x - s.x_cur >= eps,
                true_fun,
                false_fun,
                (s.i, s.w_cur, j, s.X_out, s.Y_out, s.x, s.x_cur, s.y_cur),
            )
        
        s.X_out = index_update(s.X_out, s.i, s.x_cur)
        s.Y_out = index_update(s.Y_out, s.i, s.y_cur / s.w_cur)
        
        return s.X_out[:s.i+1], s.Y_out[:s.i+1]
    

    

In [66]:
x = jnp.array([25,25,25,30,30,30])
y = jnp.array([10,12,1,23, 19, 7])

pava(x, y)

(DeviceArray([25., 30.], dtype=float32),
 DeviceArray([ 7.6666665, 16.333334 ], dtype=float32))

In [62]:
def f1(args):
    a,b,c = args
    return c

def f2(args):
    a,b,c = args
    return a

a = 1
b = 3
c = 1
jax.lax.cond(
    a > b,
    f1,
    f2,
    (a,b,c)
)

DeviceArray(1, dtype=int32)