In [None]:
#default_exp solution

In [None]:
#exporti
import numpy as np

from dl4to.plotting import plot_scalar_field, pyvista_plot_scalar_field

## Code

In [None]:
#exporti
def add_corners(solution):
    solution._θ[:,0,0,0] = 1.
    solution._θ[:,0,0,-1] = 1.
    solution._θ[:,0,-1,0] = 1.
    solution._θ[:,-1,0,0] = 1.
    solution._θ[:,-1,-1,-1] = 1.
    solution._θ[:,-1,-1,0] = 1.
    solution._θ[:,0,-1,-1] = 1.
    solution._θ[:,-1,0,-1] = 1.
    solution._θ[:,0,0,2] = 1.
    solution._θ[:,0,2,0] = 1.
    solution._θ[:,2,0,0] = 1.
    solution._θ[:,2,2,2] = 1.
    solution._θ[:,2,2,0] = 1.
    solution._θ[:,0,2,2] = 1.
    solution._θ[:,2,0,2] = 1.
    return solution

In [None]:
#exporti
class PlottingForSolution:
    @staticmethod
    def __call__(
        solution, 
        binary=False,
        solve_pde=True,
        normalize_σ_vm=True,
        threshold=0.,
        display=True, 
        file_path=None, 
        camera_position=(0,.1,.12), 
        show_design_space=False,
        use_pyvista=False,
        window_size=(800,800),
        smooth_iters=0,
        show_colorbar=True,
        show_axislabels=False, 
        show_ticklabels=False, 
        export_png=False,
    ):

        file_path_ = file_path
        if use_pyvista and smooth_iters > 0:
            solution = add_corners(solution)
        θ = solution.get_θ(binary=binary).cpu().detach().numpy()
        if use_pyvista:
            plotting_kwargs = {'problem': solution.problem,
                               'display': display,
                               'camera_position': camera_position,
                               'show_design_space': show_design_space,
                               'window_size': window_size,
                               'smooth_iters': smooth_iters}
        else:
            plotting_kwargs = {'problem': solution.problem,
                               'display': display,
                               'camera_position': camera_position,
                               'show_design_space': show_design_space,
                               'show_colorbar': show_colorbar,
                               'show_axislabels': show_axislabels,
                               'show_ticklabels':show_ticklabels,
                               'export_png':export_png}

        plotting_data_dict ={
            'data': [None],
            'title': ['Density distribution θ', 'Normed displacements |u|'],
            'file_path_suffix': ['density', 
                                 'displacement', 
                                 'stress']
        }

        if solve_pde and (solution.pde_solver == None):
            print("Cannot plot PDE solution because no PDE solver is attached to solution.problem.")
            solve_pde = False

        if solve_pde:
            u, σ, σ_vm = solution.solve_pde(p=1., binary=binary)
            u_norm = np.linalg.norm(u.cpu().detach().numpy(), axis=0)
            plotting_data_dict['data'].append(u_norm)
            σ_vm_ = σ_vm.cpu().detach().numpy()
            if normalize_σ_vm:
                plotting_data_dict['data'].append(σ_vm_ / solution.problem.σ_ys)
                plotting_data_dict['title'].append('Normalized von Mises stresses σ_vm/σ_ys')
            else:
                plotting_data_dict['data'].append(σ_vm_)
                plotting_data_dict['title'].append('Von Mises stresses σ_vm')


        for i in range(len(plotting_data_dict['data'])):
            if use_pyvista:
                pyvista_plot_scalar_field(
                    scalar_field=θ,
                    data=plotting_data_dict['data'][i],
                    threshold=threshold,
                    title=plotting_data_dict['title'][i],
                    **plotting_kwargs
                )
            else:
                if file_path != None:
                    file_path_ = f"{file_path}_{plotting_data_dict['file_path_suffix'][i]}"
                plot_scalar_field(
                    scalar_field=θ,
                    data=plotting_data_dict['data'][i],
                    threshold=threshold,
                    title=plotting_data_dict['title'][i],
                    file_path=file_path_,
                    **plotting_kwargs
                )