<a href="https://colab.research.google.com/github/profteachkids/chetools/blob/main/FlattenWrap.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from collections import namedtuple
import numpy as np
from scipy.special import expit,logit
from functools import partial

In [1]:
class DotDict(dict):
    __getattr__= dict.__getitem__
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

In [3]:
class Unk():
    pass

class Range(Unk):
    def __init__(self,value,lo,hi):
        self.lo=lo
        self.hi=hi
        self.diff = hi-lo
        self.x=value
        self.shape=(1,)
        self.size = 1

    def flatten(self):
        return np.ravel(logit((self.x-self.lo)/self.diff))

    def unflatten(self,xx):
        return expit(xx.reshape(self.shape))*self.diff +self.lo


class RangeArray(Range):
    def __init__(self,value,lo,hi):
        self.lo=np.atleast_1d(lo)
        self.hi=np.atleast_1d(hi)
        self.diff = self.hi-self.lo
        value = np.atleast_1d(value)
        self.x = value
        self.shape = value.shape
        self.size = value.size

class Comp(Unk):
    def __init__(self,value):
        self.x=np.asarray(value).reshape(-1)
        self.x=self.x/np.sum(self.x)
        self.size=self.x.size - 1

    def __repr__(self):
        return f'{self.x}'

    def flatten(self):
        return np.log(self.x[:-1]) + np.log(1.+ (1. - self.x[-1])/self.x[-1])


    def unflatten(self,xx):
        xm1 = np.exp(xx)/(1+np.sum(np.exp(xx)))
        return np.concatenate((xm1, np.atleast_1d(1.-np.sum(xm1))))

class CompArray(Unk):
    def __init__(self,value):
        self.x=value
        self.x=self.x/np.sum(self.x,axis=1).reshape(-1,1)
        self.nrows,self.ncols = self.x.shape
        self.size=self.x.size - self.nrows

    def __repr__(self):
        return f'{self.x}'

    def flatten(self):
        return np.ravel(np.log(self.x[:,:-1]) + np.log(1.+ (1. - self.x[:,-1])/self.x[:,-1]).reshape(-1,1))


    def unflatten(self,xx):
        xx=xx.reshape(self.nrows,self.ncols-1)
        xm1 = np.exp(xx)/(1+np.sum(np.exp(xx),axis=1).reshape(-1,1))
        return np.c_[xm1, 1.-np.sum(xm1,axis=1)]



In [15]:
Unk_Tuple = namedtuple('Unk_Tuple', ['keys', 'start', 'end', 'unk'])

def dtox(d):
    d2=DotDict()
    size=0
    unks={}

    for k,v in d.items():
        idv = id(v)
        if idv in unks:
            unks[idv].keys.append(k)
            continue
        if isinstance(v,Unk):
            unks[idv]=Unk_Tuple([k], size, size+v.size, v)
            size+=v.size
        else:
            d2[k]=v
    x = np.zeros(size)
    for k,v in unks.items():
        x[v.start:v.end]= v.unk.flatten()
            
    def xtod(x,d2):
        for k,v in unks.items():
            unflattened = v.unk.unflatten(x[v.start:v.end])
            for key in v.keys:
                d2[key]=unflattened
        return d2

    def xtodunk(x):
        dunk={}
        for k,v in unks.items():
            unflattened = v.unk.unflatten(x[v.start:v.end])
            for key in v.keys:
                dunk[key]=unflattened
        return dunk

    def wrap(f):
        def wrapped(x,d2):
            d2=xtod(x,d2)
            res=f(d2)
            return res[0] if type(res) in (tuple,list) else res

        return partial(wrapped,d2=d2)

    d2=xtod(x,d2)
    return wrap, np.asarray(x), d2, xtod, xtodunk

In [21]:
d = DotDict()
d.range1 = Range(10., 0., 50.)
d.range_array = RangeArray([5., 12., 24.], 0., [10., 20., 30.])
d.comp1 = Comp([0.1, 0.2, 0.3, 0.4])
d.comp_array = CompArray(np.tile(np.array([0.1, 0.2, 0.7]), (4,1)))
d.const = np.array([3., 5.])
d.repeat = d.range_array

In [22]:
wrap, x, d2, xtod, xtodunk = dtox(d)

In [23]:
d2

{'comp1': array([0.1, 0.2, 0.3, 0.4]), 'comp_array': array([[0.1, 0.2, 0.7],
        [0.1, 0.2, 0.7],
        [0.1, 0.2, 0.7],
        [0.1, 0.2, 0.7]]), 'const': array([3., 5.]), 'range1': array([10.]), 'range_array': array([ 5., 12., 24.]), 'repeat': array([ 5., 12., 24.])}