In [1]:
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras import Model
from Poisson import equation
import numpy as np
from fenics import *
from fenics_adjoint import *
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 20})

In [2]:
# load data
dataset = equation.Dataset('poisson')
dataset.load()
Phi_train, theta_train = dataset.train
Phi_test, theta_test = dataset.test
Phi_val, theta_val = dataset.validate

In [3]:
class PDESolver(Layer):
    
    def __init__(self, Phi_shape, name='fenics', dynamic=True):
        super().__init__()
        self.Phi_shape = Phi_shape

    def call(self, inputs):
        return self.solver(inputs)
    
    def set_Phi_batch(self, Phi_batch):
        self.Phi_batch = Phi_batch
    
    @tf.custom_gradient
    def solver(self, theta_batch):
        """
        Solves constant force 1D Poisson equation
        and returns the solution as a numpy array
        with length soln_dim.
        """
        # convert eager tensor to numpy array
        # with shape (batch_size, num_params)
        Phi_batch = self.Phi_batch.numpy()
        theta_batch = theta_batch.numpy()
        batch_size = 25
        _, theta_size = theta_batch.shape
        soln_size = self.Phi_shape[0]

        # Create mesh and define function space
        mesh_res = soln_size - 1
        mesh = UnitIntervalMesh(mesh_res)
        V = FunctionSpace(mesh, 'P', 1)

        # Define variational problem
        u = TrialFunction(V)
        v = TestFunction(V)
        a = grad(u)[0] * grad(v)[0] * dx

        # compute batch solutions
        batch_solns = np.zeros((batch_size, soln_size))
        self.batch_grads = np.zeros((batch_size, theta_size))
        for idx, params in enumerate(theta_batch):
            c, b0, b1 = params
            c = Constant(c)
            L = c * v * dx

            # Define boundary condition
            u_D = Expression('x[0] == 0 ? b0: b1',
                            b0 = b0,
                            b1 = b1,
                            degree = 2)
            bd_vals = Function(V)
            bd_vals.assign(project(u_D, V))
            def boundary(x, on_boundary):
                return on_boundary
            bc = DirichletBC(V, bd_vals, boundary)

            # Compute solution
            u = Function(V)
            solve(a == L, u, bc)
            batch_solns[idx,:] = u.compute_vertex_values(mesh)
            
            # Compute L2 loss
            Phi = Function(V)
            dofs = V.dofmap().dofs(mesh, 0)
            Phi.vector().set_local(Phi_batch[idx][dofs])
            J = assemble(0.5 * inner(u - Phi, u - Phi) * dx)

            # Compute loss gradient
            control = [Control(c), Control(bd_vals)]
            dJdc, dJdb = compute_gradient(J, control)
            dJdc = dJdc.values().item()

            # Convert to gradients to scalars
            dJdb = dJdb.compute_vertex_values(mesh)
            self.batch_grads[idx,:] = [dJdc, dJdb[0], dJdb[-1]]

        def d_solver(dJdu):
            # tensorflow expects floats not doubles
            return np.float32(self.batch_grads)

        return batch_solns, d_solver

In [4]:
class SemanticAutoEncoder(Model):
    
    # model layers 
    def __init__(self, name='sae'):
        super().__init__()
        self.Phi_shape = (100,)
        self.dense = Dense(20, 'linear', input_shape=self.Phi_shape, name='hidden')
        self.theta = Dense(3, 'linear', name='theta')
        self.u_theta = PDESolver(self.Phi_shape)
        
    # forward pass starting with inputs 
    def call(self, Phi, training=True):
        self.u_theta.set_Phi_batch(Phi)
        x = self.dense(Phi)
        x = self.theta(x)
        u = self.u_theta(x)
        return u

In [6]:
#tf.config.experimental_run_functions_eagerly(True)
model = SemanticAutoEncoder()
model.compile('adam', 'mse', run_eagerly=True)
#model.summary()
#predict_theta = Model(sae.input, sae.get_layer('theta').output)
fit_model = model.fit(Phi_train, Phi_train, epochs=20, batch_size=25)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20

RuntimeError: 

*** -------------------------------------------------------------------------
*** DOLFIN encountered an error. If you are not able to resolve this issue
*** using the information listed below, you can ask for help at
***
***     fenics-support@googlegroups.com
***
*** Remember to include the error message listed below and, if possible,
*** include a *minimal* running example to reproduce the error.
***
*** -------------------------------------------------------------------------
*** Error:   Unable to successfully call PETSc function 'VecCreate'.
*** Reason:  PETSc error code is: 1 ((null)).
*** Where:   This error was encountered inside /usr/local/miniconda/conda-bld/fenics-pkgs_1575559281723/work/dolfin/dolfin/la/PETScVector.cpp.
*** Process: 0
*** 
*** DOLFIN version: 2019.1.0
*** Git changeset:  
*** -------------------------------------------------------------------------


In [None]:
model.summary()

In [None]:
model.evaluate(Phi_test, Phi_test, batch_size=25)

In [None]:
u_test = model.predict(Phi_test[0:10])

In [None]:
x = dataset.domain()
def plot_pred(idx):
    fig, ax = plt.subplots()
    ax.plot(x, u_test[idx], lw=3, label=r'$u$')
    ax.plot(x, Phi_test[idx], lw=3, label=r'$\Phi$')
    ax.legend()
    plt.show()
    plt.close()

In [None]:
plot_pred(0)

In [None]:
plot_pred(1)

In [None]:
plot_pred(2)

In [None]:
plot_pred(3)