# Stochastic Truss

In [1]:
import inspect

import jax.numpy as jnp
import numpy as onp

import anabel as ana
from elle import truss2d

## Combinators, Composition & Categories

In [None]:
def truss(u, xyz, E, A):
    L = jnp.linalg.norm(xyz[1,:]-xyz[0,:])
    return jnp.array([])

In [2]:
graph = ana.io.load('graph.yml')

In [3]:
elements = {
    'model' :  ana.models.basic, 
    'truss' :  truss2d.force}

In [4]:
model = ana.compose(elements, graph, node='el')
inspect.signature(model)



<Signature (dx, *, x0={'n1': [0.0, 0.0], 'n2': [4000.0, 0.0], 'n3': [8000.0, 0.0], 'n4': [12000.0, 0.0], 'n5': [4000.0, 4000.0], 'n6': [8000.0, 4000.0]}, params)>

## Automatic Differentiation

In [5]:
nvars = 11
nf = 9
nr = 3
u = onp.zeros(nf+nr,dtype='float32')

Kf = ana.autodiff.stiffness_matrix(model, nf)

def f(params):
    kwds = {
        "params":{
          "e1": {"A" : params[ 2], "E": params[0]},
          "e2": {"A" : params[ 3], "E": params[0]},
          "e3": {"A" : params[ 4], "E": params[0]},
          "e4": {"A" : params[ 5], "E": params[0]},
          "e5": {"A" : params[ 6], "E": params[0]},
          "e6": {"A" : params[ 7], "E": params[0]},
          "e7": {"A" : params[ 8], "E": params[0]},
          "e8": {"A" : params[ 9], "E": params[0]},
          "e9": {"A" : params[10], "E": params[0]}}}

    kf = Kf(u,**kwds)
    load_vector = jnp.array([0., params[1], 0., params[1], 0., 0., 0., 0., 0.], dtype='float32')[:,None] 
    U = jnp.linalg.solve(kf, load_vector)
    return U[[1,3], [0,0]]

## Compilation

In [6]:
from jaxlib import xla_client
import jax

In [7]:
f = jax.jit(f)
param_init = jnp.zeros(nvars,dtype='float32')

f_xla = jax.xla_computation(f)

fxla = xla_client.XlaComputation(f_xla(param_init).as_serialized_hlo_module_proto())

with open('output.pb','wb') as f:
    f.write( fxla.as_serialized_hlo_module_proto ( ) )