# Finite Element Neural Networks

**Table of Contents**

- [Pure and Statically Composed](#Pure-and-Statically-Composed)
- [Combinators, Composition & Categories](#Combinators-Composition--Categories)
- [Automatic Differentiation](#Automatic-Differentiation)
- [Compiling the Model](#Compiling-the-Model)

## Pure and Statically Composed

This notebook uses [Anabel](http://www.claudioperez.xyz/Projects/Anabel) with Jax

- **Statically composed**

Anabel allows us to annotate a computational graph that represents our finite element model through a YAML file that strongly resembles typical FEM input scripts.

```yaml
el:
  model:
    x0: # model node coordinates
      n1: [     0.,    0.]
      n2: [  4000.,    0.]
      n3: [  8000.,    0.]
      n4: [ 12000.,    0.]
      n5: [  4000., 4000.]
      n6: [  8000., 4000.]

    bn: # model boundary conditions
      n1: [1, 1]
      n4: [0, 1]
     
    el:   # graph nodes
      truss: {'jit': true}

    mesh: # graph edges
      e1: [truss,  [n1, n2]]
      e2: [truss,  [n2, n3]]
      e3: [truss,  [n3, n4]]
      e4: [truss,  [n1, n5]]
      e5: [truss,  [n5, n6]]
      e6: [truss,  [n4, n6]]
      e7: [truss,  [n2, n5]]
      e8: [truss,  [n3, n6]]
      e9: [truss,  [n3, n5]]
```

This is stored in a file which we will name `graph.yaml`. Now if we import `anabel`, we can load this data using the function `anabel.io.load`, which just lends some nice syntactic sugar for loading serialized data.

In [16]:
import anabel as ana

graph = ana.io.load('graph.yaml')

The variable `graph` just contains a simple Python dictionary with lists, floats, and strings. The function `anabel.io.load` currently supports input from JSON and YAML files, but any data serialization format can easily be used here.

## Combinators, Composition & Categories

In [8]:
from elle import truss2d

In [9]:
def truss_force(u, xyz, E, A):
    DX = x0[1,0] - x0[0,0]
    DY = x0[1,1] - x0[0,1]
    L = jnp.linalg.norm([DX,DY])
    return jnp.array([])

In [11]:
elements = {
    'model' :  ana.models.basic, 
    'truss' :  truss2d.force}

In [12]:
model = ana.compose(elements, graph, node='el')

In [12]:
import inspect
inspect.signature(model)

<Signature (dx, *, x0={'n1': [0.0, 0.0], 'n2': [4000.0, 0.0], 'n3': [8000.0, 0.0], 'n4': [12000.0, 0.0], 'n5': [4000.0, 4000.0], 'n6': [8000.0, 4000.0]}, params)>

## Automatic Differentiation

We will now automatically generate a function that produces the stiffness matrix of the model using automatic differentiation. Later on this technique will be used to obtain gradients of far more interesting parameters, but for the time being, this example will serve as an excellent introduction to the concept as most finite element analysts are very familiar with the idea that a stiffness matrix is simply the gradient or Hessian of some function.

We begin by importing Google's open source [Jax] library. It's interface is designed to emmulate the numpy API, so working with it should be very familiar to anyone with some Python or Matlab experience.

In [13]:
import jax.numpy as jnp

In [13]:
nvars = 11
nf = 9
nr = 3
u = jnp.zeros(nf+nr,dtype='float32')

Kf = ana.autodiff.stiffness_matrix(model, nf)

def f(params):
    kwds = {
        "params":{
          "e1": {"A" : params[ 2], "E": params[0]},
          "e2": {"A" : params[ 3], "E": params[0]},
          "e3": {"A" : params[ 4], "E": params[0]},
          "e4": {"A" : params[ 5], "E": params[0]},
          "e5": {"A" : params[ 6], "E": params[0]},
          "e6": {"A" : params[ 7], "E": params[0]},
          "e7": {"A" : params[ 8], "E": params[0]},
          "e8": {"A" : params[ 9], "E": params[0]},
          "e9": {"A" : params[10], "E": params[0]}}}

    kf = Kf(u,**kwds)
    load_vector = jnp.array([0., params[1], 0., params[1], 0., 0., 0., 0., 0.], dtype='float32')[:,None] 
    U = jnp.linalg.solve(kf, load_vector)
    return U[[1,3], [0,0]]

## Compiling the Model

In [14]:
from jaxlib import xla_client
import jax

In [15]:
f = jax.jit(f)
param_init = jnp.zeros(nvars,dtype='float32')

f_xla = jax.xla_computation(f)

fxla = xla_client.XlaComputation(f_xla(param_init).as_serialized_hlo_module_proto())

with open('output.pb','wb') as f:
    f.write( fxla.as_serialized_hlo_module_proto ( ) )