In [None]:
%load_ext autoreload
%autoreload 2
import matplotlib
#matplotlib.use('agg')
import matplotlib.pyplot as plt
from fenics import *
from multiphenics import *
import numpy as np
from braininversion.IOHandling import (read_mesh_from_h5, write_to_xdmf, 
                                       xdmf_to_unstructuredGrid, read_xdmf_timeseries)
from braininversion.PlottingHelper import (plot_pressures_and_forces_timeslice, 
                                           plot_pressures_and_forces_cross_section,
                                           extract_cross_section, style_dict)
from braininversion.SourceExpression import get_source_expression
from matplotlib.backends.backend_agg import FigureCanvasAgg
import yaml
import pyvista as pv
from pathlib import Path
import imageio
from cmocean import cm
import warnings
warnings.filterwarnings("ignore")

In [None]:
try:
    mesh_file = snakemake.input["subdomain_file"]
    sim_file = snakemake.input["sim_results"]
    sim_config_file = snakemake.input["sim_config_file"]
except NameError:
    mesh = "MRIExampleSegmentation_Ncoarse"
    sim_name = "stdBrainSim"
    mesh_file = f"../meshes/{mesh}/{mesh}.xdmf"
    sim_file = f"../results/{mesh}_{sim_name}/results.xdmf"
    sim_config_file = f"../results/{mesh}_{sim_name}/config.yml"


mesh_name = f"{Path(mesh_file).parent}/{Path(mesh_file).stem}"
mesh_config_file = f"{mesh_name}_config.yml"

boundary_file = f"{mesh_name}_boundaries.xdmf"
label_boundary_file = f"{mesh_name}_label_boundaries.xdmf"
label_file = f"{mesh_name}_labels.xdmf"

sim_file_old = f"../results/{mesh}_{sim_name}/results_old.xdmf"


with open(sim_config_file) as conf_file:
    sim_config = yaml.load(conf_file, Loader=yaml.FullLoader)
    
with open(mesh_config_file) as conf_file:
    mesh_config = yaml.load(conf_file, Loader=yaml.FullLoader)

T = sim_config["T"]
num_steps = sim_config["num_steps"]

mmHg2Pa = 132.32
porous_id = 1
dt = T/num_steps
times = np.linspace(0, T, num_steps + 1)
mesh_grid = pv.read(label_file)

infile_mesh = XDMFFile(mesh_file)
mesh = Mesh()
infile_mesh.read(mesh)
gdim = mesh.geometric_dimension()
subdomain_marker = MeshFunction("size_t", mesh, gdim)
infile_mesh.read(subdomain_marker)#, "subdomains"
infile_mesh.close()

In [None]:
%%capture
ventricular_system = ["lateral_ventricles", "foramina", "aqueduct", "third_ventricle", "fourth_ventricle",
                      "median_aperture"]

def scale_grid(grid, fac):
    for name, pdata in grid.point_arrays.items():
        grid.point_arrays[name] *=fac
        
def sum_grids(grid1, grid2):
    for name, pdata in grid1.point_arrays.items():
        grid1.point_arrays[name] += grid2.point_arrays[name]


"""
if np.isclose(idx%1, 0.0):
            data = extract_data(mg, var, parts, int(idx))
        else:
            data = extract_data(mg, var, parts, int(np.floor(idx)))
            data2 = extract_data(mg, var, parts, int(np.ceil(idx)))
            sum_grids(data, data2)
            scale_grid(data, 0.5)
"""

def compute_stat(stat, mg, var, parts, idx):
    data = extract_data(mg, var, parts, int(idx))
    return stat(data[var])

def compute_glob_stat(stat, mg, var, parts, indices):
    return stat([compute_stat(stat, mg, var, parts, idx) for idx in indices])

def extract_and_interpolate(mg, var, parts, idx):
    if np.isclose(idx%1, 0.0):
        return extract_data(mg, var, parts, int(idx))
    
    floor =  int(np.floor(idx))
    ceil =  int(np.ceil(idx))
    data1 = extract_data(mg, var, parts, floor)
    data2 = extract_data(mg, var, parts, ceil)
    
    scale_grid(data1, idx - floor)
    scale_grid(data2, ceil - idx)
    sum_grids(data1, data2)
    return data1

def extract_data(mg, var, parts, idx):
    mg = mg.copy()
    # read data
    grid = xdmf_to_unstructuredGrid(sim_file_old, variables=[var], idx=[idx])

    # add new data to mesh
    for name, data in grid.point_arrays.items():
        print()
        mg.point_arrays[name] = grid.point_arrays[name]
    #filter parts:
    dom_meshes= []
    for dom in mesh_config["domains"]:
        if dom["name"] in parts:
            #dom_meshes[dom["name"]] = mg.threshold([dom["id"],dom["id"]], scalars="subdomains")
            dom_meshes.append(mg.threshold([dom["id"],dom["id"]],
                                                  scalars="subdomains"))
    merged = dom_meshes[0].merge(dom_meshes[1:])
    return merged


def plot_partial_3D(mg, scenes, idx, cpos, interactive=False):
    if interactive:
        p = pv.PlotterITK()
    else:
        p = pv.Plotter(notebook=False)
    max_val = -np.inf
    min_val = np.inf
    data_dict = {}
    for i, scene in enumerate(scenes):
        var = scene["var"]
        parts = scene["mesh_parts"]
        data = extract_and_interpolate(mg, var, parts, idx)
        if "clip" in scene.keys():
            try:
                data = data.clip(*scene["clip"])
            except:
                data = data.clip(**scene["clip"])
        data_dict[var] = data
        if "arrow" in scene.keys():
            continue
        max_val = max(max_val, data[var].max())
        min_val = min(min_val, data[var].min())
    
    for i, scene in enumerate(scenes):
        var = scene["var"]
        parts = scene["mesh_parts"]
        data =  data_dict[var]
        if "warp" in scene.keys():
            data = data.warp_by_vector(var, scene["warp_fac"])
        if interactive:
            options = scene["interactive"]
        else:
            options = scene["static"]
            if "clim" not in options.keys() or options["clim"] is None:
                options["clim"] = (min_val, max_val)
        if "arrow" in scene.keys():
            vec_scale = scene["vec_scale"]
            arrows = data.glyph(scale=var, factor=vec_scale, orient=var, clamping=True, rng=[0., 2])
            p.add_mesh(arrows,**options)#, lighting=False) #stitle=f"{var} Magnitude",
        else:
            p.add_mesh(data, scalars=var,**options)
    #camera position, focal point, and view up.
    p.camera_position = cpos
    return p, (min_val, max_val)

In [None]:
source_conf = sim_config["source_data"]
source_expr = get_source_expression(source_conf, mesh, subdomain_marker, porous_id, times)
source_series = []
for t in times:
    source_expr.t=t;
    source_series.append(source_expr(Point([0]*gdim)) )
source_series

In [None]:
cpos_close = [(0.15, 0.1, -0.01), (-0.00, 0.000, -0.00), (0.0, 0.0, 1.0)]


def plot_source_rgb_raw(source_series,times, t, size, dpi):
    fig = plt.Figure(figsize=size, dpi=dpi)
    canvas = FigureCanvasAgg(fig)

    # Do some plotting here
    ax = fig.add_subplot(111)
    ax.plot(times, source_series)
    ax.axvline(t, color="red")
    ax.set_xlabel("t in s")
    ax.set_ylabel("g in 1/s")
    # Retrieve a view on the renderer buffer
    canvas.draw()
    buf = canvas.buffer_rgba()
    # convert to a NumPy array
    data = np.asarray(buf)
    return data[:,:,:3]


# create video
def create_movie(path, times, plot_generator, fps=10, interpolate_frames=1):
    frames = []
    interp = 1
    for i,t in enumerate(times):
        for k in range(interpolate_frames):
            p,_= plot_generator(i + k/interpolate_frames, False)
            p.show()
            img = p.screenshot(transparent_background=True, return_img=True, window_size=None)
            p.close()
            size = (4,3)
            dpi = 70
            miniplot = plot_source_rgb_raw(source_series,times, times[i] + k/interpolate_frames*dt, 
                                           size, dpi)
            x,y,z = miniplot.shape
            img[:x,:y,:] = miniplot
            frames.append(img)
            if i==len(times) - 1:
                break

    mwriter = imageio.get_writer(path, fps=fps)
    for frame in frames:
        mwriter.append_data(frame)
    mwriter.close()

In [None]:
p,_ = plot_partial_3D(mesh_grid, [{"var":"u", "mesh_parts": ventricular_system, "vec_scale":3, "arrow":True,
                                 "interactive" : {"color":"red", "opacity":1},
                                 "static" : {"color":"red", "opacity":1} },
                               {"var":"pF", "mesh_parts": ventricular_system,
                                 "interactive" : { "opacity":0.3},
                                 "static" : {"color":"red", "opacity":1} }],
                                5, cpos_close, interactive=True)
p.show()

In [None]:
p,_ = plot_partial_3D(mesh_grid, [{"var":"d", "mesh_parts": ["parenchyma"], "vec_scale":20, "clip":[(1,0,0)],
                                 "interactive" : {"color":"red", "opacity":1},
                                 "static" : {"color":"red", "opacity":1} },
                               {"var":"pF", "mesh_parts": ventricular_system,
                                 "interactive" : { "opacity":0.3},
                                 "static" : {"color":"red", "opacity":1} }],
                                5, cpos_close, interactive=True)
p.show()

In [None]:
cpos_far = [(0.3, 0.2, -0.05), (-0.00, 0.000, -0.02), (0.0, 0.0, 1.0)]
sargs = dict(title_font_size=20,label_font_size=16,shadow=True,n_labels=3,
             italic=True,font_family="arial", height=0.4, vertical=True, position_y=0.05)
sargs_u = {"position_x":0.95}
sargs_u.update(sargs)
sargs_pF = {"position_x":0.05}
sargs_pF.update(sargs)
p_clim = [-20,20]

func = lambda idx, interactive: plot_partial_3D(mesh_grid, 
                                    [{"var":"u", "mesh_parts": ["csf"], "vec_scale":1, "clip":[(1,0,0)],
                                      "arrow":True,
                                     "interactive" : {"color":"blue", "opacity":1},
                                     "static" : {"cmap":cm.speed, "opacity":1, "stitle":" mag(u)",
                                     "scalar_bar_args":sargs_u}},
                                   {"var":"pF", "mesh_parts": ventricular_system,
                                     "interactive" : {"color":"white", "opacity":0.5},
                                     "static" : {"cmap":"coolwarm", "opacity":0.8,
                                                "scalar_bar_args":sargs_pF, "clim":p_clim}}],
                                    idx, cpos_far, interactive=interactive)
p,_ = func(2, False)
p.show()


In [None]:
%%capture
create_movie("plots/movie.gif", times, func, fps=10, interpolate_frames=5)

In [None]:
func = lambda idx, interactive: plot_partial_3D(mesh_grid, 
                                    [{"var":"d", "mesh_parts": ["parenchyma"], "warp":True, "clip":[(1,0,0)],
                                      "warp_fac":1e1, "vec_scale":1e1,
                                     "interactive" : {"color":"blue", "opacity":1},
                                     "static" : {"cmap":cm.speed, "opacity":1, "stitle":" mag(d)",
                                     "scalar_bar_args":sargs_u}}],
                                    idx, cpos_far, interactive=interactive)
p,_ = func(10, False)
p.show()

In [None]:
%%capture
#create_movie("plots/test_warp.mp4", times, func, fps=2, interpolate_frames=1)

In [None]:
sargs_phi = sargs_pF.copy()
sargs_phi["position_x"] = 0.95
func = lambda idx, clip, cpos, clim: plot_partial_3D(mesh_grid, 
                        [{"var":"phi", "mesh_parts": ["parenchyma"], "clip":clip,
                         "static" : {"cmap":"coolwarm","scalar_bar_args":sargs_phi, "clim":clim}},
                         {"var":"d", "mesh_parts": ["parenchyma"], "clip":clip, "arrow":True,"vec_scale":20,
                         "static" : {"color":"green","scalar_bar_args":sargs_phi}},
                        {"var":"pF", "mesh_parts": ventricular_system + ["csf"], "clip":clip,
                         "static" : {"cmap":"coolwarm", "scalar_bar_args":sargs_pF, "clim":clim}},
                        {"var":"u", "mesh_parts": ventricular_system + ["csf"], "clip":clip,
                         "arrow":True,"vec_scale":2,
                         "static" : {"color":"red", "scalar_bar_args":sargs_pF}}],
                         idx, cpos, interactive=False)

In [None]:
i = 1
#indices = [2,4,6,8,10,12,14,16,18]
indices = [5,10]

origin = (0,0,0.001)
clips = ["y","x","z"]
dist = 0.4
cpos = [[(0, dist, 0), (0, 0, 0), (0, 0, 1)],
        [(dist,0, 0), (0, 0, 0), (0, 0, 1)],
        [(0, 0, dist*1.3), (0, 0, 0), (0, 1, 0)]
        ]
nind = len(indices)
size = 20
idx = 10
clim = None
p,_ = func(idx, {"normal":clips[i], "origin":origin}, cpos[i], clim)
p.show()

In [None]:
%%capture

max_phi = compute_glob_stat(max, mesh_grid, "phi", ["parenchyma"], range(num_steps))
max_pF = compute_glob_stat(max, mesh_grid,"pF", ventricular_system + ["csf"], range(num_steps))
min_phi = compute_glob_stat(max, mesh_grid, "phi", ["parenchyma"], range(num_steps))
min_pF = compute_glob_stat(max, mesh_grid,"pF", ventricular_system + ["csf"], range(num_steps))


fig, axes = plt.subplots(nind, 3, figsize=(3*size, size*nind))
for j, idx in enumerate(indices):
    clim=None
    for i,c in enumerate(clips):
        #max_pF = compute_stat(max, mesh_grid, "pF", ventricular_system + ["csf"], idx)
        #min_pF = compute_stat(min, mesh_grid, "pF", ventricular_system + ["csf"], idx)
        #max_phi = compute_stat(max, mesh_grid, "phi", ["parenchyma"], idx)
        #min_phi = compute_stat(min, mesh_grid, "phi", ["parenchyma"], idx)
        #clim = (min(min_pF, min_phi), max(max_pF, max_phi))
        #clim=(-20,20)
        p, clim1 = func(idx, {"normal":c, "origin":origin}, cpos[i], clim)
        if i==0:
            clim=clim1
        p.show()
        img = p.screenshot(transparent_background=True, return_img=True, window_size=None)
        axes[j,i].imshow(img)
        axes[j,i].set_title(f"t = {times[idx]}")
plt.savefig("plots/pressure_evolution.pdf")