In [29]:
from jax import jit, grad, vmap, random
from jax.lax import cond, fori_loop
import jax.numpy as jnp
# import jax.tree_util
from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node, \
    tree_flatten_with_path, tree_structure, tree_map
from utils_fs_v2 import nonlinear_DNN, linear_DNN, DNN
from functools import partial
from anytree import NodeMixin
from math import pi

# Section 1

#### This is the main code. All the other sections are proofs of concept that are used to develop this main code

#### Defining useful parameters

In [2]:
# N_sf = 200 # single fidelity hidden width
# N_nl = 80 # multifidelity nonlinear network hidden width
N_sf = 4 # single fidelity hidden width
N_nl = 4 # multifidelity nonlinear network hidden width
sfnet_shape = [1, N_sf, N_sf, 2] # shape of the single fidelity network
nonlin_mfnet_shape = [3, N_nl, N_nl, 2] # shape of the multifidelity nonlinear network
lin_mfnet_shape = [2,  4, 2] # shape of the multifidelity linear network
init_nl, apply_nl, weight_nl = nonlinear_DNN(nonlin_mfnet_shape)
init_l, apply_l = linear_DNN(lin_mfnet_shape)
init_sf, apply_sf = DNN(sfnet_shape)

#### Defining the function for applying the multifidelity network

In [3]:
def apply_mf(params,vertices,pt,u_lf):
    u_nl = apply_nl(params[1], jnp.hstack([pt, u_lf]))
    u_l = apply_l(params[0], u_lf)
    u_local = u_l + u_nl
    w = weight(vertices,pt)
    # u_local = w*(u_l + u_nl)
    return [u_local, w]

def weight(vertices,pt):
    """
    NOTE: In Alexander's paper, this is w hat, not w. Amanda does not ensure that the weight
            functions constitute a partition of unity globally. I am not sure if this is important?
    """
    mu = (vertices[0] + vertices[1])/2
    sigma = (vertices[1] - vertices[0])/2
    w = 1 + jnp.cos(pi*(pt-mu)/sigma)
    w = w**2
    return w

#### Defining the classes `SFDomain` and `MFDomain`

In [4]:
class RootUtilities:
    def __init__(self):
        self.levels = [] # list of what classes are on which levels of the domain tree

class SFDomain(NodeMixin,RootUtilities):
    def __init__(self,vertices,params_prev=[]):
        super(SFDomain,self).__init__()
        if len(params_prev) > 0:
            params = params_prev
        else:
            params_sf = init_sf(random.PRNGKey(1))
            params = (params_sf)
        self.params = params
        self.vertices = vertices
        self.tree_level_organizer(self) # add root to the level organizer
        
    def register_children(self):
        registration = []
        for child in self.children:
            registration.append(child.__repr__())
        return registration
    
    def tree_level_organizer(self,new_node): # TESTED
        """
        Recognizes which level "new_node" is on and places "new_node" in the appropriate level in 
        "self.levels".
        =================================================================================================
        INPUT:
        new_node:   A class that is being added to the tree
        """
        depth = new_node.depth # find the depth of the new node
        if depth >= len(self.levels): # if "new_node" is the first on a new level, add that level to "self.levels"
            self.levels.append([new_node])
        else: # if "new_node" is a member of an existing level, add it to that level of "self.levels"
            self.levels[depth].append(new_node)


    def __repr__(self):
        return "SFDomain(vertices={}, params={},children={})".format(self.vertices, "sf_params", self.register_children())

class MFDomain(NodeMixin):
    def __init__(self,vertices,params_prev=[],parent=None):
        if len(params_prev) > 0:
            params = params_prev
        else:
            params_nl = init_nl(random.PRNGKey(13))
            params_l = init_l(random.PRNGKey(12345))
            params = (params_l, params_nl)
        self.params = params
        self.vertices = vertices
        self.parent = parent
        self.root.tree_level_organizer(self) # add current node to the level organizer 

    def register_children(self):
        registration = []
        for child in self.children:
            registration.append(child.__repr__())
        return registration

    def __repr__(self):
        return "MFDomain(vertices={}, params={},children={})".format(self.vertices, "mf_params", self.register_children())

#### Registering `SFDomain` so that it can be jitted. I believe that it works

In [5]:
def return_leaves(v):
  leaves = [v.vertices,v.params]
  # leaves = (v.name,)
  for child in v.children:
    leaves.append(return_leaves(child))
    # leaves = leaves + (return_leaves(child),)
  return leaves

def special_flatten(v):
  leaves = [v.vertices,v.params]
  # leaves = (v.name,)
  for child in v.children:
    leaves.append(return_leaves(child))
    # leaves = leaves + (return_leaves(child),)
  aux_data = None
  return (leaves, aux_data)

def grow_tree(parent,tree):
  new_parent = MFDomain(tree[0],tree[1],parent=parent)
  for branch in tree[2:]:
    grow_tree(new_parent,branch)

def special_unflatten(aux_data, tree):
  root = SFDomain(tree[0],tree[1])
  for branch in tree[2:]:
    grow_tree(root,branch)
  return root

# Global registration
register_pytree_node(
    SFDomain,
    special_flatten,    # tell JAX what are the children nodes
    special_unflatten   # tell JAX how to pack back into a RegisteredSpecial
)

In [6]:
A = SFDomain([[0.0],[1.0]])
B = MFDomain([[0.0],[0.6]],parent = A)
C = MFDomain([[0.4],[1.0]],parent = A)
# D = MFDomain([[0.0],[0.3]],parent = B)
# E = MFDomain([[0.2],[0.5]],parent = B)
# F = MFDomain([[0.4],[0.7]],parent = C)
# G = MFDomain([[0.6],[1.0]],parent = C)

In [9]:
def is_interior(vertices,pt):
    """
    Returns true if the point "pt" is on the interior of the hyperrectangle defined by vertices
    """
    return ((vertices[0] < pt) & (pt < vertices[1])).all()

def apply_lvl(is_in,params,vertices,pt,u_pred):
    return cond(is_in,apply_mf,lambda a,b,c,d: [0., 0.], (params,vertices,pt,u_pred))

def multifidelity_network_body(idx,val):
    root, pt, u_pred = val
    level = root.levels[idx]
    lvl_vertices = jnp.array([mfdomain.vertices for mfdomain in level])
    lvl_params = jnp.array([mfdomain.params for mfdomain in level])
    lvl_is_in = vmap(is_interior,(0,None))(lvl_vertices,pt)
    mfouts_and_weights = vmap(apply_lvl, (0,0,0,None,None))(lvl_is_in,lvl_params,lvl_vertices,pt,u_pred)
    mfouts = mfouts_and_weights[:,0]
    weights = mfouts_and_weights[:,1]
    total_weight = weights.sum()
    weights = weights/total_weight
    mfouts = mfouts*weights
    u_pred = mfouts.sum()
    return (root,pt,u_pred)

# @jit
def multifidelity_network(root,pt,max_level):
    u_pred = apply_sf(root.params,pt) # get the single fidelity approximation
    # print(root.vertices)
    # print(root.params[0])
    # print(root.params[1])
    # print(root.params[2])

    # temp = jnp.array([[root.params]])
    # temp = jnp.array([root])
    # print(root.children[0].vertices)
    fori_loop(1,max_level+1,multifidelity_network_body,(root,pt,u_pred))
    return u_pred


@jit
def multifidelity_network(root,pt,max_level):
    u_pred = apply_sf(root.params,pt) # get the single fidelity approximation
    for idx in range(1,max_level+1):
        level = root.levels[idx]
        lvl_vertices = jnp.array([mfdomain.vertices for mfdomain in level])
        lvl_params = jnp.array([mfdomain.params for mfdomain in level])
        lvl_is_in = vmap(is_interior,(0,None))(lvl_vertices,pt)
        mfouts_and_weights = vmap(apply_lvl, (0,0,0,None,None))(lvl_is_in,lvl_params,lvl_vertices,pt,u_pred)
        mfouts = mfouts_and_weights[:,0]
        weights = mfouts_and_weights[:,1]
        total_weight = weights.sum()
        weights = weights/total_weight
        mfouts = mfouts*weights
        u_pred = mfouts.sum()
    return u_pred

In [10]:
pts = jnp.linspace(0,1,21)
level_sizes = jnp.array([len(A.levels[i]) for i in range(len(A.levels))])
vmap(multifidelity_network,(None,0,None))(A,pts,1)

TracerIntegerConversionError: The __index__() method was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=2/0)>
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError

In [32]:
jnp.array(range(1,3))

Array([1, 2], dtype=int32)

In [46]:
func = lambda x,y: [0.,0.]
print(func(1,2))

[0.0, 0.0]


# Section 2

#### The below example shows how to register a tree so that one can utilize the tree inside of a @jit function

We define a class called `Special` and a class called `SpecialLeaf`. The `Special` class serves as the root and the `SpecialLeaf` class will be used to create nodes on this tree. Below we create a function that will output the flattened and unflattened trees.

In [38]:
def show_example(structured):
  flat, tree = tree_flatten(structured)
  unflattened = tree_unflatten(tree, flat)
  print(f"{structured=}\n  {flat=}\n  {tree=}\n  {unflattened=}")

In [39]:
class Special(NodeMixin):
  def __init__(self, name):
    self.name = name
  
  def register_children(self):
    registration = []
    # registration = ()
    for child in self.children:
      registration.append(child.__repr__())
      # registration = registration + (child.__repr__(),)
    return registration

  def __repr__(self):
    return "Special(name={}, children={})".format(self.name, self.register_children())

class SpecialLeaf(NodeMixin):
  def __init__(self, name, parent):
    self.name = name
    self.parent = parent

  def register_children(self):
    registration = []
    # registration = ()
    for child in self.children:
      registration.append(child.__repr__())
      # registration = registration + (child.__repr__(),)
    return registration

  def __repr__(self):
    return "SpecialLeaf(name={}, children={})".format(self.name,self.register_children())

In [40]:
A = Special(1)
B = SpecialLeaf(2,parent=A)
C = SpecialLeaf(3,parent=A)
D = SpecialLeaf(4,parent=B)

#### The below examples will not work because we have not registered the tree

In [41]:
show_example(A)

structured=Special(name=1, children=["SpecialLeaf(name=2, children=['SpecialLeaf(name=4, children=[])'])", 'SpecialLeaf(name=3, children=[])'])
  flat=[Special(name=1, children=["SpecialLeaf(name=2, children=['SpecialLeaf(name=4, children=[])'])", 'SpecialLeaf(name=3, children=[])'])]
  tree=PyTreeDef(*)
  unflattened=Special(name=1, children=["SpecialLeaf(name=2, children=['SpecialLeaf(name=4, children=[])'])", 'SpecialLeaf(name=3, children=[])'])


Note that "tree" does not show any of the structure of A. The "flat" attribute is also not flat.

#### Now we register the code to see what happens

In [42]:
class RegisteredSpecial(Special):
  def __repr__(self):
    return "Special(name={}, children={})".format(self.name,self.register_children())
  
def return_leaves(v):
  leaves = [v.name]
  # leaves = (v.name,)
  for child in v.children:
    leaves.append(return_leaves(child))
    # leaves = leaves + (return_leaves(child),)
  return leaves

def special_flatten(v):
  leaves = [v.name]
  # leaves = (v.name,)
  for child in v.children:
    leaves.append(return_leaves(child))
    # leaves = leaves + (return_leaves(child),)
  aux_data = None
  return (leaves, aux_data)

def grow_tree(parent,tree):
  new_parent = SpecialLeaf(tree[0],parent=parent)
  for branch in tree[1:]:
    grow_tree(new_parent,branch)

def special_unflatten(aux_data, tree):
  root = RegisteredSpecial(tree[0])
  for branch in tree[1:]:
    grow_tree(root,branch)
  return root

# Global registration
register_pytree_node(
    RegisteredSpecial,
    special_flatten,    # tell JAX what are the children nodes
    special_unflatten   # tell JAX how to pack back into a RegisteredSpecial
)


In [43]:
a = RegisteredSpecial(1)
b = SpecialLeaf(2,parent=a)
c = SpecialLeaf(3,parent=a)
d = SpecialLeaf(4,parent=b)

In [45]:
show_example(a)

structured=Special(name=1, children=["SpecialLeaf(name=2, children=['SpecialLeaf(name=4, children=[])'])", 'SpecialLeaf(name=3, children=[])'])
  flat=[1, 2, 4, 3]
  tree=PyTreeDef(CustomNode(RegisteredSpecial[None], [*, [*, [*]], [*]]))
  unflattened=Special(name=1, children=["SpecialLeaf(name=2, children=['SpecialLeaf(name=4, children=[])'])", 'SpecialLeaf(name=3, children=[])'])


This is what the output of "show_example should look like.

#### Now, we make sure that a jitted function actually detects whether the data inside the structure has changed.

In [187]:
@jit
def print_node(root,temp):
    return root.name

In [190]:
temp = jnp.linspace(0,1,4)
vmap(print_node,(None,0))(a,temp)

Array([1, 1, 1, 1], dtype=int32, weak_type=True)

In [191]:
a.name = 2

In [192]:
temp = jnp.linspace(0,1,4)
vmap(print_node,(None,0))(a,temp)

Array([2, 2, 2, 2], dtype=int32, weak_type=True)

#### Yay! The output changed meaning the jitted function `print_node` detected the fact that the tree changed

# Section 3

#### In this section I am analyzing how to retrieve key paths in the pytree. Make sure to run Section 2 before this section.

In [52]:
from jax.tree_util import treedef_children

In [53]:
a = RegisteredSpecial(1)
b = SpecialLeaf(2,parent=a)
c = SpecialLeaf(3,parent=a)
d = SpecialLeaf(4,parent=b)

In [64]:
treedef = tree_structure(a)
treedef

PyTreeDef(CustomNode(RegisteredSpecial[None], [*, [*, [*]], [*]]))

In [65]:
treedef_children(treedef)

[PyTreeDef(*), PyTreeDef([*, [*]]), PyTreeDef([*])]

In [43]:
import collections
import jax
ATuple = collections.namedtuple("ATuple", ('name'))

tree = [1, {'k1': 2, 'k2': (3, 4)}, ATuple('foo')]
flattened, _ = jax.tree_util.tree_flatten_with_path(tree)
for key_path, value in flattened:
    print(f'Value of tree{jax.tree_util.keystr(key_path)}: {value}')

Value of tree[0]: 1
Value of tree[1]['k1']: 2
Value of tree[1]['k2'][0]: 3
Value of tree[1]['k2'][1]: 4
Value of tree[2].name: foo


# Section 4

#### In this section I am seeing whether it is possible to iterate through an array of structs in jax

In [2]:
def show_example(structured):
  flat, tree = tree_flatten(structured)
  unflattened = tree_unflatten(tree, flat)
  print(f"{structured=}\n  {flat=}\n  {tree=}\n  {unflattened=}")

In [3]:
class Test(object):
  def __init__(self, x, y):
    self.x = x
    self.y = y

  def __repr__(self):
    return "Test(x={}, y={})".format(self.x, self.y)

show_example(Test(1., 2.))

structured=Test(x=1.0, y=2.0)
  flat=[Test(x=1.0, y=2.0)]
  tree=PyTreeDef(*)
  unflattened=Test(x=1.0, y=2.0)


In [4]:
from jax.tree_util import register_pytree_node

class RegisteredTest(Test):
  def __repr__(self):
    return "RegisteredTest(x={}, y={})".format(self.x, self.y)

def test_flatten(v):
  children = (v.x, v.y)
  aux_data = None
  return (children, aux_data)

def test_unflatten(aux_data, children):
  return RegisteredTest(*children)

# Global registration
register_pytree_node(
    RegisteredTest,
    test_flatten,    # tell JAX what are the children nodes
    test_unflatten   # tell JAX how to pack back into a RegisteredSpecial
)

show_example(RegisteredTest(1., 2.))

structured=RegisteredTest(x=1.0, y=2.0)
  flat=[1.0, 2.0]
  tree=PyTreeDef(CustomNode(RegisteredTest[None], [*, *]))
  unflattened=RegisteredTest(x=1.0, y=2.0)


In [5]:
T1 = RegisteredTest(1., 1.)
T2 = RegisteredTest(2.,2.)
T3 = RegisteredTest(3.,3.)

In [11]:
class TestWrapper:
    def __init__(self,T1,T2,T3):
        self.T_list = [T1,T2,T3]
    
    def __repr__(self):
      return "TestWrapper(T_list={})".format(self.T_list)
       

def tw_flatten(v):
  temp = v.T_list
  children = (temp[0].x, temp[0].y,temp[1].x, temp[1].y,temp[2].x, temp[2].y)
  aux_data = None
  return (children, aux_data)

def tw_unflatten(aux_data, children):
  Te1 = RegisteredTest(children[0],children[1])
  Te2 = RegisteredTest(children[2],children[3])
  Te3 = RegisteredTest(children[4],children[5])

  return TestWrapper(Te1,Te2,Te3)

# Global registration
register_pytree_node(
    TestWrapper,
    tw_flatten,    # tell JAX what are the children nodes
    tw_unflatten   # tell JAX how to pack back into a RegisteredSpecial
)

In [12]:
TW = TestWrapper(T1,T2,T3)

In [22]:
show_example(TW)

structured=TestWrapper(T_list=[RegisteredTest(x=1.0, y=1.0), RegisteredTest(x=2.0, y=2.0), RegisteredTest(x=3.0, y=3.0)])
  flat=[1.0, 1.0, 2.0, 2.0, 3.0, 3.0]
  tree=PyTreeDef(CustomNode(TestWrapper[None], [*, *, *, *, *, *]))
  unflattened=TestWrapper(T_list=[RegisteredTest(x=1.0, y=1.0), RegisteredTest(x=2.0, y=2.0), RegisteredTest(x=3.0, y=3.0)])


In [30]:
def get_x(T):
    return T.x
@jit
def iter_through(TeW):
    # v = np.zeros(3)
    # fori_loop(0.,3.,g,0.)
    idk = jnp.arange(3)
    v = tree_map(lambda x,y: x + y, TeW, TeW)
    # for i in np.arange(3):
    #     v[i] = TeW.T_list[i].x
    #     # TeW.T_list[i].x
    #     v.at[i].set(i)
    # v.at[i].set(TeW.T_list[i].x)
    return v

In [31]:
iter_through(TW)

TestWrapper(T_list=[RegisteredTest(x=2.0, y=2.0), RegisteredTest(x=4.0, y=4.0), RegisteredTest(x=6.0, y=6.0)])