<a href="https://colab.research.google.com/github/profteachkids/chetools/blob/main/tools/toNamedTuple.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [13]:
# !wget -N -q https://raw.githubusercontent.com/profteachkids/chetools/main/tools/dotmap.ipynb
# !pip install importnb
# from importnb import Notebook
# with Notebook(): 
#     from dotmap import DotMap
from collections import namedtuple
import jax
import numpy as np
import jax.numpy as jnp
import jax.tree_util as tu
from jax.flatten_util import ravel_pytree
from functools import partial
from copy import deepcopy

In [56]:
class Unk():
    pass

class Comp(Unk):
    def __init__(self,x):
        self.x=jnp.asarray(x).reshape(-1)
        if self.x.size<2:
            raise ValueError('At least 2 components required')

    def __repr__(self):
        return f'{self.x}'

    def nan(self):
        self.x=jnp.full_like(self.x, jnp.nan)



    @staticmethod
    def flatten(c):
        return jnp.log(c.x[:-1]) + jnp.log(1.+ (1. - c.x[-1])/c.x[-1]), None


    @staticmethod
    def unflatten(aux, q):
        q=jnp.squeeze(jnp.asarray(q)) #q may be a tuple that can't be squeezed
        xm1 = jnp.exp(q)/(1+jnp.sum(jnp.exp(q)))
        return jnp.concatenate((jnp.atleast_1d(xm1), jnp.atleast_1d(1.-jnp.sum(xm1))))


jax.tree_util.register_pytree_node(Comp, Comp.flatten, Comp.unflatten)

class Range(Unk):
    def __init__(self,x, lo, hi):
        self.x=x
        self.lo = lo
        self.diff = hi-lo
        if self.diff <= 0. or self.x<lo or self.x>hi:
            raise ValueError('Hi > x > Lo is required')

    def __repr__(self):
        return f'{self.x}, lo={self.lo}, diff={self.diff}'

    def nan(self):
        self.x=jnp.nan

    @staticmethod
    def flatten(v):
        p = (v.x-v.lo)/v.diff
        return (jnp.log(p)-jnp.log(1.-p),), (v.lo,v.diff)

    @staticmethod
    def unflatten(aux, f):
        return jax.nn.sigmoid(f[0])*aux[1]+aux[0]

jax.tree_util.register_pytree_node(Range, Range.flatten, Range.unflatten)

class RangeArray(Unk):
    def __init__(self,x, lo, hi):
        self.x=x
        self.lo = lo
        self.diff = hi-lo
        if jnp.any(self.diff <= 0.) or jnp.any(self.x<lo) or jnp.any(self.x>hi):
            raise ValueError('Hi > x > Lo is required')

    def __repr__(self):
        return f'{self.x}, lo={self.lo}, diff={self.diff}'

    def nan(self):
        self.x=jnp.full_like(self.x, jnp.nan)

    @staticmethod
    def flatten(v):
        p = (v.x-v.lo)/v.diff
        return (jnp.log(p)-jnp.log(1.-p),), (v.lo,v.diff)

    @staticmethod
    def unflatten(aux, f):
        f=jnp.squeeze(jnp.asarray(f))
        return jax.nn.sigmoid(f)*aux[1]+aux[0]

jax.tree_util.register_pytree_node(RangeArray, RangeArray.flatten, RangeArray.unflatten)

In [121]:
def toNamedTuple_recursive(d,e=None, NA=False):
    if e is None:
        if isinstance(d,list):
            e=[]
        elif isinstance(d,dict):
            e={}

    if isinstance(d,dict):
        for k,v in d.items():
            if isinstance(v,dict):
                e[k]={}
                e[k]=toNamedTuple_recursive(v,e[k],NA)
            elif isinstance(v,list):
                e[k]=[]
                e[k]=toNamedTuple_recursive(v,e[k],NA)
            else:
                if isinstance(v,Unk) and NA==True:
                    e[k]=deepcopy(v)
                    e[k].nan()
                else:
                    e[k]=v
        NT=namedtuple('_',e.keys())
        res=NT(**e)
    
    elif isinstance(d,list):
        e=[None]*len(d)
        for i,v in enumerate(d):
            if isinstance(v,dict):
                e[i]={}
                e[i]=toNamedTuple_recursive(v,e[i],NA)
            elif isinstance(v,list):
                e[i]=[]
                e[i]=toNamedTuple_recursive(v,e[i],NA)
            else:
                if isinstance(v,Unk) and NA==True:
                    e[i]=deepcopy(v)
                    e[i].nan()
                else:
                    e[i]=v
        res = tuple(e)
    
    return res

In [122]:
d=dict(b=dict(m=Comp([1/3,1/3,1/3]),n=jnp.array([1,2,3])))
e=dict(c=Range(50,0,100),g=dict(z=2))
f=dict(x=3.,y=[d,d])

z=[dict(a=d, k=[d,[e,f]]), [d,e,f]]
dt=toNamedTuple_recursive(z,NA=False)
dt

(_(a=_(b=_(m=[0.33333334 0.33333334 0.33333334], n=DeviceArray([1, 2, 3], dtype=int32))), k=(_(b=_(m=[0.33333334 0.33333334 0.33333334], n=DeviceArray([1, 2, 3], dtype=int32))), (_(c=50, lo=0, diff=100, g=_(z=2)), _(x=3.0, y=(_(b=_(m=[0.33333334 0.33333334 0.33333334], n=DeviceArray([1, 2, 3], dtype=int32))), _(b=_(m=[0.33333334 0.33333334 0.33333334], n=DeviceArray([1, 2, 3], dtype=int32)))))))),
 (_(b=_(m=[0.33333334 0.33333334 0.33333334], n=DeviceArray([1, 2, 3], dtype=int32))),
  _(c=50, lo=0, diff=100, g=_(z=2)),
  _(x=3.0, y=(_(b=_(m=[0.33333334 0.33333334 0.33333334], n=DeviceArray([1, 2, 3], dtype=int32))), _(b=_(m=[0.33333334 0.33333334 0.33333334], n=DeviceArray([1, 2, 3], dtype=int32)))))))

In [123]:
val,unflat = ravel_pytree(dt)

In [124]:
val

DeviceArray([0., 0., 1., 2., 3., 0., 0., 1., 2., 3., 0., 2., 3., 0., 0.,
             1., 2., 3., 0., 0., 1., 2., 3., 0., 0., 1., 2., 3., 0., 2.,
             3., 0., 0., 1., 2., 3., 0., 0., 1., 2., 3.], dtype=float32)

In [125]:
zz=unflat(val)

In [126]:
zz

(_(a=_(b=_(m=DeviceArray([0.33333334, 0.33333334, 0.3333333 ], dtype=float32), n=DeviceArray([1, 2, 3], dtype=int32))), k=(_(b=_(m=DeviceArray([0.33333334, 0.33333334, 0.3333333 ], dtype=float32), n=DeviceArray([1, 2, 3], dtype=int32))), (_(c=DeviceArray(50., dtype=float32), g=_(z=DeviceArray(2, dtype=int32))), _(x=DeviceArray(3., dtype=float32), y=(_(b=_(m=DeviceArray([0.33333334, 0.33333334, 0.3333333 ], dtype=float32), n=DeviceArray([1, 2, 3], dtype=int32))), _(b=_(m=DeviceArray([0.33333334, 0.33333334, 0.3333333 ], dtype=float32), n=DeviceArray([1, 2, 3], dtype=int32)))))))),
 (_(b=_(m=DeviceArray([0.33333334, 0.33333334, 0.3333333 ], dtype=float32), n=DeviceArray([1, 2, 3], dtype=int32))),
  _(c=DeviceArray(50., dtype=float32), g=_(z=DeviceArray(2, dtype=int32))),
  _(x=DeviceArray(3., dtype=float32), y=(_(b=_(m=DeviceArray([0.33333334, 0.33333334, 0.3333333 ], dtype=float32), n=DeviceArray([1, 2, 3], dtype=int32))), _(b=_(m=DeviceArray([0.33333334, 0.33333334, 0.3333333 ], dtype=

In [101]:

def dictHandler(build,k):
    def dictSetter(v):
        if isinstance(v,dict):
            NT = namedtuple('_', v.keys())
            build[k]=NT(**v)
        elif isinstance(v,list):
            build[k]=tuple(v)
        return
    return dictSetter

def listHandler(l,i):
    def listSetter(v):
        if isinstance(v,list):
            l[i]=tuple(v)
        elif isinstance(v,dict):
            NT = namedtuple('_', v.keys())
            l[i]=NT(**v)
        else:
            l[i]=v
        return
    return listSetter

def toNamedTuple_stack(orig):
    d=deepcopy(orig)

    def listRoot(v):
        nonlocal d
        d=tuple(v)


    def dictRoot(v):
        nonlocal d
        NT = namedtuple('_', v.keys())
        d=NT(**v)

    if isinstance(d,list):
        stack=[(d, listRoot)]
    elif isinstance(d, dict):
        stack=[(d, dictRoot)]

    count=0
    while stack:
        nested=False
        source,setter=stack[-1]
        if isinstance(source,list):
            for i,v in enumerate(source):
                if type(v) in (dict,list):
                    stack.append((v, listHandler(source,i)))
                    nested=True

        elif isinstance(source,dict):
            for k,v in source.items():
                if type(v) in (dict,list):
                    stack.append((v, dictHandler(source,k)))
                    nested=True

        if not nested: 
            setter(source)
            stack.pop()
        count+=1
    return d