<a href="https://colab.research.google.com/github/profteachkids/chetools/blob/main/tools/Flatten_Wrap_JAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from jax.scipy.special import logit, expit
from collections import namedtuple

In [2]:
class DotDict(dict):
    __getattr__= dict.__getitem__
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

Unk = namedtuple('Unk', 'x flatten unflatten size_flat size_unflat, shape start end')

In [138]:
def range(val, lo, hi):

    diff = hi - lo

    def flatten(x):
        return jnp.ravel(logit((x-lo)/diff))

    def unflatten(x):
        return jnp.squeeze(expit(x)*diff + lo)

    return Unk(x=val, flatten=flatten, unflatten=unflatten, size_flat=1, size_unflat=1, shape=(), start=None, end=None )

def range_array(val, lo, hi):

    lo=jnp.atleast_1d(lo)
    hi=jnp.atleast_1d(hi)
    diff = hi-lo
    val = jnp.atleast_1d(val)
    shape = val.shape

    def flatten(x):
        return jnp.ravel(logit((x-lo)/diff))

    def unflatten(x):
        return expit(x.reshape(shape))*diff +lo

    return Unk(x=val, flatten=flatten, unflatten=unflatten, size_flat=val.size, size_unflat=val.size, shape=shape, start=None, end=None )

def comp(val):

    val = jnp.asarray(val)
    sum = jnp.sum(val)

    def flatten(x):
        x = x/sum
        return jnp.log(x[:-1]) + jnp.log(1.+ (1. - x[-1])/x[-1])


    def unflatten(x):
        xm1 = jnp.exp(x)/(1+jnp.sum(jnp.exp(x)))
        return sum*jnp.concatenate((xm1, jnp.atleast_1d(1.-jnp.sum(xm1))))

    return Unk(x=val, flatten=flatten, unflatten=unflatten, size_flat=val.size-1, size_unflat=val.size, shape=(val.size), start=None, end=None )

def comp_array(val):

    val = jnp.asarray(val)
    sum = jnp.sum(val,axis=-1)
    rows,cols = val.shape

    def flatten(x):
        return jnp.ravel(jnp.log(x[:,:-1]) + jnp.log(1.+ (1. - x[:,-1])/x[:,-1]).reshape(-1,1))


    def unflatten(x):
        x=x.reshape(rows,-1)
        xm1 = jnp.exp(x)/(1+jnp.sum(jnp.exp(x),axis=-1)).reshape(-1,cols-1)
        return jnp.c_[xm1, 1.-jnp.sum(xm1,axis=-1)]


    return Unk(x=val, flatten=flatten, unflatten=unflatten, size_flat=val.size-rows, size_unflat=val.size, shape=(val.shape), start=None, end=None )

In [139]:
def dtox(d):
    d2={}
    d2_named_tuple = namedtuple('d2', d.keys())
    pos=0
    unks={}
    consts = {}

    for k,v in d.items():
        if isinstance(v,Unk):
            unks[k]=v._replace(start=pos, end=pos+v.size_flat )
            pos+=v.size_flat
        else:
            d2[k]=v
    

    x = np.zeros(pos)
    for k,unk in unks.items():
        x[unk.start:unk.end]=unk.flatten(unk.x)

    def xtod(x):
        for k,unk in unks.items():
            d2[k]=unk.unflatten(x[unk.start:unk.end])
        return d2_named_tuple(**d2)

    return np.asarray(x), xtod(x), xtod

In [140]:
d = DotDict()
d.a = range(5, 0., 8.)
d.b = 6.5
d.c = range_array([1,2,5], 0., 10.)
d.d = comp([0.1,0.4,0.5])
d.e = comp_array([[0.2,0.3,0.5],[0.1,0.4,0.5]])

In [141]:
x, d2, xtod = dtox(d)

In [142]:
jxtod=jax.jit(xtod)

In [143]:
jxtod(x)

d2(a=DeviceArray(5., dtype=float32), b=DeviceArray(6.5, dtype=float32, weak_type=True), c=DeviceArray([1., 2., 5.], dtype=float32), d=DeviceArray([0.09999999, 0.4       , 0.5       ], dtype=float32), e=DeviceArray([[0.19999999, 0.3       , 0.5       ],
             [0.09999999, 0.4       , 0.5       ]], dtype=float32))