In [None]:
from fenics import *
import matplotlib.pyplot as plt
import numpy as np
from fenics_adjoint import *
import moola
from mshr import *
from braininversion.DarcySolver import solve_darcy
from braininversion.Optimization import optimize_biot_force
from braininversion.PlottingHelper import (plot_pressures_and_forces_timeslice, 
                            plot_pressures_and_forces_cross_section,
                            extract_cross_section, style_dict)

# time stepping
T = 1.2          # final time
num_steps = 12    # number of time steps
dt = T/ num_steps
times = np.linspace(dt, T, num_steps)

# material parameter
kappa = 1e-17       # permeability 15*(1e-9)**2
visc = 0.8*1e-3     # viscocity 
K = kappa/visc      # hydraulic conductivity
c = 2*1e-8         # storage coefficent
alpha = 1.0         # Biot-Willis coefficient

# Biot material parameters
E = 1500.0          # Young modulus
nu = 0.479         # Poisson ratio

material_parameter = dict()
material_parameter["c"] = c
material_parameter["K"] = K
material_parameter["lmbda"] = nu*E/((1.0-2.0*nu)*(1.0+nu)) 
material_parameter["mu"] = E/(2.0*(1.0+nu))
material_parameter["alpha"] = alpha
mmHg2Pa = 132.32

# create mesh and mark boundaries
N = 12 # resolution
brain_radius = 0.1 
ventricle_radius = brain_radius/3
stem_length = brain_radius*1.4
brain = Circle(Point(0,0), brain_radius)
ventricle = Circle(Point(0,0), ventricle_radius)
stem = Rectangle(Point(-brain_radius/4, - stem_length),Point(brain_radius/4, -ventricle_radius) )
brain = brain - ventricle +stem
mesh = Mesh(generate_mesh(brain, N))

ventricle = CompiledSubDomain("on_boundary && (x[0]*x[0] + x[1]*x[1] < R*R*0.95)",
                              R =brain_radius )
skull = CompiledSubDomain("on_boundary && (x[0]*x[0] + x[1]*x[1] >= R*R*0.95 )",
                          R = brain_radius)
stem = CompiledSubDomain("on_boundary && x[1] < -R", R=brain_radius)

boundary_marker = MeshFunction("size_t", mesh, mesh.topology().dim()-1, value=0)
skull.mark(boundary_marker, 1)
ventricle.mark(boundary_marker, 2)
stem.mark(boundary_marker, 3)
x_coords = np.linspace(ventricle_radius, brain_radius, 20)
slice_points = [Point(x, 0.0) for x in x_coords]

A = 2*mmHg2Pa
f = 1
p_obs = Expression("A*sin(2*pi*f*t)", A=A, f=f, t=0, degree=2)


In [None]:
# Dirichlet BC
n = FacetNormal(mesh)

boundary_conditions_u = {1:{"Neumann":n*p_obs},  #skull
                         2:{"Neumann":n*p_obs}, # ventricle
                         3:{"Dirichlet":Constant((0.0, 0.0))}} # stem

boundary_conditions_p = {1:{"Neumann":Constant(0.0)},
                         2:{"Neumann":Constant(0.0)},
                         3:{"Neumann":Constant(0.0)}}

minimization_target = {"dx": { "everywhere": lambda x: (x - p_obs)**2}}

res = optimize_biot_force(mesh, material_parameter, times, minimization_target,
                         boundary_marker, boundary_conditions_p,
                         boundary_marker, boundary_conditions_u,
                         opt_solver="moola_bfgs")
opt_ctrl, opt_solution, initial_solution = res

In [None]:
p_opt = [s.split()[2] for s in opt_solution]
pT_opt = [s.split()[1] for s in opt_solution]

u_opt = [s.split()[0] for s in opt_solution]
p_init = [s.split()[2] for s in initial_solution]
u_init = [s.split()[0] for s in initial_solution]

In [None]:
def extract_spatial_total(solution):
    return np.array([assemble(s*dx) for s in solution])

def extract_displaced_volume(displacement):
    mesh = displacement[0].function_space().mesh()
    n = FacetNormal(mesh)
    ds = Measure("ds", domain=mesh)
    return np.array([assemble(inner(u,n)*ds) for u in displacement])

def extract_total_outflow(solution, K):
    mesh = solution[0].function_space().mesh()
    n = FacetNormal(mesh)
    ds = Measure("ds", domain=mesh)
    return np.array([assemble(inner(K*grad(p),n)*ds) for p in solution])

#total_source = extract_spatial_total(opt_ctrl)
total_displaced = extract_displaced_volume(u_opt)
total_diff_displaced = np.diff(total_displaced, prepend=0)
#total_outlow = extract_total_outflow(p_opt, K)
#dx = Measure("dx", domain=mesh)
#total_area = assemble(Constant(1.0)*dx)
#mean_source = total_source/total_area

In [None]:
plt.figure(figsize=(9,7))
#plt.plot(times, total_source, ".-", label="total mass source")
plt.plot(times, total_displaced, ".-", label="total displaced volume")
plt.plot(times, total_diff_displaced, ".-", label="change of total displaced volume")
#plt.plot(times, total_outlow, ".-", label="total outflow")
plt.legend()
plt.grid()
plt.xlabel("t in s")
plt.ylabel("flow in m^2/s")

In [None]:
V = FunctionSpace(mesh, "CG", 1)
displ = [s.split()[0] for s in opt_solution]
total_pressure = [s.split()[1] for s in opt_solution]
pressure = [s.split()[2] for s in opt_solution]
lmbda_div_u = [project(material_parameter["lmbda"]*div(u), V) for u in displ]


#x_coords = np.linspace(0.0, 0.1, 100)
#slice_points = [Point(x, 0.0) for x in x_coords]

In [None]:
pressures = {"negative_total_pressure" : extract_cross_section(total_pressure, slice_points)/mmHg2Pa,
             "fluid_pressure" : extract_cross_section(pressure, slice_points)/mmHg2Pa,
             "lambda_div_u"   : extract_cross_section(lmbda_div_u, slice_points)/mmHg2Pa,}


forces = {"displacement [m]": extract_cross_section(displ, slice_points),
          "f_opt [N]": extract_cross_section(opt_ctrl, slice_points)}

style_dict["negative_total_pressure"] = {"ls":":", "lw":3, "color":"firebrick"}
style_dict["fluid_pressure"] = {"ls":":", "lw":3, "color":"orange"}
style_dict["displacement [m]"] = {"ls":"-.", "lw":3, "color":"green"}
style_dict["f_opt [N]"] = {"ls":"-.", "lw":3, "color":"blue"}


In [None]:
for i in [2,4,6,8]: 
    plot_pressures_and_forces_cross_section(pressures, forces, i, x_coords)
    plt.suptitle(f"t = {times[i]:.3f} s")

In [None]:
for i in [2,4,6,8]: 
    plot_pressures_and_forces_timeslice(pressures, forces, i, times)

In [None]:
for s in opt_solution:
    u, pT, p = s.split()
    plt.subplots(1,3, figsize=(15,7))
    plt.subplot(1,3,1)
    c = plot(u)
    plt.colorbar(c)
    plt.subplot(1,3,2)
    c = plot(pT, vmax=A,vmin=-A)
    plt.colorbar(c)
    plt.subplot(1,3,3)
    c = plot(p,  vmax=A,vmin=-A)
    plt.colorbar(c)