<a href="https://colab.research.google.com/github/profteachkids/chetools/blob/main/FlattenWrap2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from collections import namedtuple
import numpy as np
from scipy.special import expit,logit
from functools import partial

In [2]:
class DotDict(dict):
    __getattr__= dict.__getitem__
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

In [18]:
class Unk():
    pass

class Range(Unk):
    def __init__(self,value,lo,hi):
        self.lo=lo
        self.hi=hi
        self.diff = hi-lo
        self.x=value
        self.shape=(1,)
        self.size = 1
        self.unflatten_size = 1

    def flatten(self):
        return np.ravel(logit((self.x-self.lo)/self.diff))

    def unflatten(self,xx):
        return expit(xx.reshape(self.shape))*self.diff +self.lo


class RangeArray(Unk):
    def __init__(self,value,lo,hi):
        self.lo=np.atleast_1d(lo)
        self.hi=np.atleast_1d(hi)
        self.diff = self.hi-self.lo
        value = np.atleast_1d(value)
        self.x = value
        self.shape = value.shape
        self.size = value.size
        self.unflatten_size = value.size

    def flatten(self):
        return np.ravel(logit((self.x-self.lo)/self.diff))

    def unflatten(self,xx):
        return expit(xx.reshape(self.shape))*self.diff +self.lo

class Comp(Unk):
    def __init__(self,value):
        self.x=np.asarray(value).reshape(-1)
        self.x=self.x/np.sum(self.x)
        self.size=self.x.size - 1
        self.unflatten_size = self.x.size

    def __repr__(self):
        return f'{self.x}'

    def flatten(self):
        return np.log(self.x[:-1]) + np.log(1.+ (1. - self.x[-1])/self.x[-1])


    def unflatten(self,xx):
        xm1 = np.exp(xx)/(1+np.sum(np.exp(xx)))
        return np.concatenate((xm1, np.atleast_1d(1.-np.sum(xm1))))

class CompArray(Unk):
    def __init__(self,value):
        self.x=value
        self.x=self.x/np.sum(self.x,axis=1).reshape(-1,1)
        self.nrows,self.ncols = self.x.shape
        self.size=self.x.size - self.nrows
        self.unflatten_size = self.x.size

    def __repr__(self):
        return f'{self.x}'

    def flatten(self):
        return np.ravel(np.log(self.x[:,:-1]) + np.log(1.+ (1. - self.x[:,-1])/self.x[:,-1]).reshape(-1,1))


    def unflatten(self,xx):
        xx=xx.reshape(self.nrows,self.ncols-1)
        xm1 = np.exp(xx)/(1+np.sum(np.exp(xx),axis=1).reshape(-1,1))
        return np.c_[xm1, 1.-np.sum(xm1,axis=1)]


    


In [36]:
Unk_Tuple_Dict = namedtuple('Unk_Tuple_Dict', ['keys', 'start', 'end', 'unk'])
# unk stored in x[start:end] to be unflattened into a list of keys of a dictionary

Unk_Tuple_Dict_Array = namedtuple('Unk_Tuple_Dict_Array', ['key', 'array_start', 'array_end', 'unk_id'])
# unk to be unflatten into an array stored in one key of a dictionary from position array_start to array_end


def dtox(d):
    d2=DotDict()
    size=0
    unks_dict={}
    unks_dict_arr=[]

    for k,v in d.items():
        idv = id(v)
        if idv in unks_dict:
            unks_dict[idv].keys.append(k)        
        elif isinstance(v,Unk):
            unks_dict[idv]=Unk_Tuple_Dict([k], size, size+v.size, v)
            size+=v.size
        elif isinstance(v,list):
            arr = [] # store expanded list items
            arr_pos = 0 #keep track of position
            for vi in v:
                idvi = id(vi)
                if idvi in unks_dict:
                    unks_dict_arr.append(Unk_Tuple_Dict_Array(k, arr_pos, arr_pos+vi.unflatten_size, idvi))
                    arr_pos+=vi.unflatten_size

                if isinstance(vi,Range):
                    arr.append(vi.x)
                    arr_pos+=1
                elif type(vi) in (RangeArray, Comp):
                    arr.extend(vi.x)
                    arr_pos+=vi.unflatten_size
                elif type(vi) in (list,tuple):
                    arr.extend(vi)
                    arr_pos+=len(vi)
                elif isinstance(vi,np.ndarray):
                    arr.extend(vi)
                    arr_pos+=vi.size
                else:
                    arr.append(vi)
                    arr_pos+=1

            d2[k]=np.array(arr)

        else:
            d2[k]=v  # constant
    x = np.zeros(size)
    for k,v in unks_dict.items():
        x[v.start:v.end]= v.unk.flatten()
            
    def xtod(x,d2):
        for k,v in unks_dict.items():
            unflattened = v.unk.unflatten(x[v.start:v.end])
            for key in v.keys:
                d2[key]=unflattened

        for uda in unks_dict_arr:
            v = unks_dict[uda.unk_id]
            d2[uda.key][uda.array_start:uda.array_end]=v.unk.unflatten(x[v.start:v.end])
                

        return d2

    def xtodunk(x):
        dunk={}
        for k,v in unks_dict.items():
            unflattened = v.unk.unflatten(x[v.start:v.end])
            for key in v.keys:
                dunk[key]=unflattened
        return dunk

    def wrap(f):
        def wrapped(x,d2):
            d2=xtod(x,d2)
            res=f(d2)
            return res[0] if type(res) in (tuple,list) else res

        return partial(wrapped,d2=d2)

    d2=xtod(x,d2)
    return wrap, np.asarray(x), d2, xtod, xtodunk

In [37]:
d = DotDict()
d.range1 = Range(10., 0., 50.)
d.range_array = RangeArray([5., 12., 24.], 0., [10., 20., 30.])
d.comp1 = Comp([0.1, 0.2, 0.3, 0.4])
d.comp_array = CompArray(np.tile(np.array([0.1, 0.2, 0.7]), (4,1)))
d.const = np.array([3., 5.])
d.repeat = d.range_array
d.arr1 = [d.range1, d.range_array, d.comp1, [1,2,3], (4,5), 6, np.array([7,8,9])]

In [38]:
wrap, x, d2, xtod, xtodunk = dtox(d)

{'arr1': array([10. ,  5. ,  5. , 12. , 24. ,  0.2,  0.3,  0.4,  0.1,  0.2,  0.3,
         0.4,  5. ,  6. ,  7. ,  8. ,  9. ]),
 'comp1': array([0.1, 0.2, 0.3, 0.4]),
 'comp_array': array([[0.1, 0.2, 0.7],
        [0.1, 0.2, 0.7],
        [0.1, 0.2, 0.7],
        [0.1, 0.2, 0.7]]),
 'const': array([3., 5.]),
 'range1': array([10.]),
 'range_array': array([ 5., 12., 24.]),
 'repeat': array([ 5., 12., 24.])}

In [40]:
xx

NameError: ignored

In [41]:
x

array([-1.38629436,  0.        ,  0.40546511,  1.38629436, -1.38629436,
       -0.69314718, -0.28768207, -1.94591015, -1.25276297, -1.94591015,
       -1.25276297, -1.94591015, -1.25276297, -1.94591015, -1.25276297])