In [11]:
from modeling.arghandling import flatten_args, unflatten_args
from modeling.arghandling import encode, decode
import numpy as np

In [12]:
x = [np.array(1), np.array([1,2]), np.array([[4,5],[6,7]])]

In [13]:
v = flatten_args(x)
unflatten_args(v, shapes=((1,), (2,), (2,2)), cleanup=True)

(1,
 array([1, 2]),
 array([[4, 5],
        [6, 7]]))

In [4]:
w = encode({'a': 1, 'b': np.array([[4,5],[6,7]]), 'c': 3}, 
        ('c', 'b', 'a'), flatten=True)

In [5]:
decode(w, order=['c', 'b', 'a'], shapes=((1,), (2,2), (1,)), 
       unflatten=True, cleanup=True)

{'c': 3,
 'b': array([[4, 5],
        [6, 7]]),
 'a': 1}

In [6]:
f = lambda x, y, z: (x**2+z+y+np.exp(-y), np.array([x,y,z]))

In [7]:
dict_in_dict_out = lambda d: decode(f(*encode(d, ('x','y','z'))), 
                                      ('a','b'), shapes=((1,),(3,)))
dict_in_flat_out = lambda d: flatten_args(f(*encode(d, ('x','y','z'))))

In [8]:
dict_in_dict_out({'x': 1, 'y': 2, 'z': 3})

{'a': 6.135335283236612, 'b': array([1, 2, 3])}

In [9]:
residuals = dict_in_flat_out({'x': 1, 'y': 2, 'z': 3})
residuals

array([6.13533528, 1.        , 2.        , 3.        ])

In [10]:
decode(residuals, ('a','b'), shapes=((1,),(3,)), unflatten=True, cleanup=True)

{'a': 6.135335283236612, 'b': array([1., 2., 3.])}

In [16]:
class Encoder():
    def __init__(self, order, parent=None, shapes=None):
        self.order = order
        self.parent = parent
        self.shapes = shapes

In [17]:
class EncodedFunction():
    def __init__(self, f, encoder=None, decoder=None):
        self.f = f
        self.encoder = encoder
        self.decoder = decoder

    def dict_out_only(self,*args):
        return decode(self.f(*args), 
                      self.decoder.order, shapes=self.decoder.shapes)

    def dict_in_only(self, d):
        return f(*encode(d, self.encoder.order))
    
    def dict_in_flat_out(self, d):
        return flatten_args(self.dict_in_only(d))
    
    def dict_in_dict_out(self, d):
        return decode(self.dict_in_only(d), 
                      self.decoder.order, shapes=self.decoder.shapes)
    
    

In [50]:
def reverse_encoding(encoder):
    mapping = dict()
    parent = encoder.parent
    while parent is not None:
        mapping = dict(zip(encoder.order, parent.order))
        parent = parent.parent
    return mapping

In [51]:
E1 = Encoder((1,2,3), Encoder(('x','y','z'), Encoder(('A','B','C'))))

In [52]:
reverse_encoding(E1)

{1: 'A', 2: 'B', 3: 'C'}

In [57]:
def compose(f, g=None, mapping=None, reverse=False):
    mapping = mapping if mapping is not None else dict()
    new_decoder = g.decoder if g is not None else f.decoder
    if reverse:
        mapping = reverse_encoding(f.encoder)
        mapping.update(reverse_encoding(new_decoder))
    func = f.f
    if g is not None:
        inter_order_out = tuple(mapping.get(key,key) for key in f.decoder.order)
        inter_order_in= tuple(mapping.get(key,key) for key in g.encoder.order)
        F = EncodedFunction(f, f.encoder, inter_order_in)
        G = EncodedFunction(g, inter_order_out, g.decoder)
        func = lambda *args: G.dict_in_only(F.dict_out_only(*args))   
    new_encoder_order = tuple(mapping.get(key,key) for key in f.encoder.order)
    new_encoder = Encoder(new_encoder_order, parent=f.encoder)
    new_decoder_order = tuple(mapping.get(key,key) for key in new_decoder.order)
    new_decoder = Encoder(new_decoder_order, parent=new_decoder)
    return EncodedFunction(func, new_encoder, new_decoder)

In [58]:
F = EncodedFunction(f, 
                    Encoder(('x','z','y')), 
                    Encoder(('a','b'), shapes=((1,),(3,))))

In [59]:
G = compose(F, mapping={'x':1, 'y':2, 'z':3, 'a':4})

In [60]:
G.dict_in_dict_out({1: 1, 2: 2, 3: 3})

{4: 6.049787068367864, 'b': array([1, 3, 2])}

In [61]:
H = compose(G, reverse=True)

In [66]:
H.dict_in_dict_out({'x': 1, 'y': 2, 'z': 3})

{'a': 6.049787068367864, 'b': array([1, 3, 2])}

In [21]:
F.dict_in_dict_out({'x': 1, 'y': 2, 'z': 3})

{'a': 6.135335283236612, 'b': array([1, 2, 3])}