In [None]:
# --.. ..- .-.. .-.. --- --.. ..- .-.. .-.. --- --.. ..- .-.. .-.. ---
# Z3ST: An open-source FEniCSx framework for thermo-mechanical analysis
# Author: Giovanni Zullo
# Version: 0.1.0 (2025)
# --.. ..- .-.. .-.. --- --.. ..- .-.. .-.. --- --.. ..- .-.. .-.. ---

In [27]:
import pyvista as pv
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import yaml

from analytical_cases import *

# Interactive widgets
import ipywidgets as widgets
from ipywidgets import interact, Layout, interactive
from IPython.display import display

# Avoid warnings for divisions by zero (e.g., r=0)
np.seterr(divide="ignore", invalid="ignore")

{'divide': 'ignore', 'over': 'warn', 'under': 'ignore', 'invalid': 'ignore'}

In [28]:
def interactive_analysis_plot(filename="output/fields.vtu", analytical_case=None, tol=1e-5):
    """
    A generalized interactive plot tool for visualizing scalar, vector, and tensor fields
    (Temperature, Displacement, Heat Flux, Stress, Strain) from a .vtu file.
    """
    print(f"Loading data from: {filename}")
    try:
        grid = pv.read(filename)
    except FileNotFoundError:
        print(f"[ERROR] File not found: {filename}")
        return

    # --- Master Configuration Dictionary ---
    # Defines how to handle each field type
    CONFIG = {
        "Stress": {
            "field_name_point": "Stress (points)",
            "field_name_cell": "Stress (cells)",
            "nature": "tensor",
            "var_prefix": "sig",
            "plot_unit": "Pa",
            "plot_symbol": "σ",
            "analytical_funcs": {
                "σ_rr (analytical)": lambda r: lame_solutions(r, a=0.02, b=0.03, pi=1e6, po=1e7)[0],
                "σ_θθ (analytical)": lambda r: lame_solutions(r, a=0.02, b=0.03, pi=1e6, po=1e7)[1],
                "σ_zz (analytical, plane strain)": lambda r: lame_solutions(
                    r, a=0.02, b=0.03, pi=1e6, po=1e7
                )[2],
                "σ_zz (analytical, plane stress)": lambda r: lame_solutions(
                    r, a=0.02, b=0.03, pi=1e6, po=1e7
                )[3],
            },
        },
        "Strain": {
            "field_name_point": "Strain (points)",
            "field_name_cell": "Strain (cells)",
            "nature": "tensor",
            "var_prefix": "eps",
            "plot_unit": "-",
            "plot_symbol": "ε",
            "analytical_funcs": {
                "ε_rr (analytical)": lambda r: radial_strain(r, a=0.02, b=0.03, pi=1e6, po=1e7),
                "ε_θθ (analytical)": lambda r: radial_displacement(
                    r, a=0.02, b=0.03, pi=1e6, po=1e7
                )
                / r,
            },
        },
        "Displacement": {
            "field_name_point": "Displacement",
            "field_name_cell": None,  # Points only
            "nature": "vector",
            "var_prefix": "u",
            "plot_unit": "m",
            "plot_symbol": "u",
            "analytical_funcs": {
                "u_r (analytical)": lambda r: radial_displacement(
                    r, a=0.02, b=0.03, pi=1e6, po=1e7, nu=0.3, E=2e11
                ),
            },
        },
        "Heat flux": {
            "field_name_point": None,
            "field_name_cell": "Heat Flux",  # Cells only
            "nature": "vector",
            "var_prefix": "q",
            "plot_unit": "W/m^2",
            "plot_symbol": "q",
        },
        "Temperature": {
            "field_name_point": "Temperature",
            "field_name_cell": None,  # Points only
            "nature": "scalar",
            "var_prefix": "T",
            "plot_unit": "K",
            "plot_symbol": "T",
            # 'analytical_funcs': {
            #         'T (analytical)': lambda r: 300 + (1000 / (8 * 2.5)) * 4 * r * (1 - r) # delta T = q' / 8k
            # }
        },
    }

    # Filter available analysis types based on what exists in the file
    available_analysis_types = [
        name
        for name, cfg in CONFIG.items()
        if (cfg["field_name_point"] and cfg["field_name_point"] in grid.point_data)
        or (cfg["field_name_cell"] and cfg["field_name_cell"] in grid.cell_data)
    ]

    if not available_analysis_types:
        print(
            "[ERROR] No recognized fields (Stress, Strain, Temperature, Heat flux, Displacement) found in file."
        )
        return

    def setup_analysis(analysis_type):
        cfg = CONFIG[analysis_type]
        prefix = cfg["var_prefix"]
        nature = cfg["nature"]

        # --- Data loading and setup based on configuration ---

        # 1. Check availability and define display options
        has_points = cfg["field_name_point"] and cfg["field_name_point"] in grid.point_data
        has_cells = cfg["field_name_cell"] and cfg["field_name_cell"] in grid.cell_data

        display_options = []
        if has_points and has_cells:
            display_options.append("Compare points and cells")
        if has_points:
            display_options.append("Points")
        if has_cells:
            display_options.append("Cells")

        if not display_options:
            print(f"[ERROR] Data configuration mismatch for {analysis_type}.")
            return

        # 2. Load and reshape data
        point_coords, data_points, unique_point_z = None, None, np.array([])
        if has_points:
            point_coords = grid.points
            data = grid.point_data[cfg["field_name_point"]]
            if nature == "tensor":
                data_points = data.reshape((-1, 3, 3))
            elif nature == "vector":
                data_points = data.reshape((-1, 3))
            else:
                data_points = data  # Scalar
            unique_point_z = np.unique(np.round(point_coords[:, 2], 6))

        cell_coords, data_cells, unique_cell_z = None, None, np.array([])
        if has_cells:
            cell_coords = grid.cell_centers().points
            data = grid.cell_data[cfg["field_name_cell"]]
            if nature == "tensor":
                data_cells = data.reshape((-1, 3, 3))
            elif nature == "vector":
                data_cells = data.reshape((-1, 3))
            else:
                data_cells = data  # Scalar
            unique_cell_z = np.unique(np.round(cell_coords[:, 2], 6))

        slice_z_levels = unique_cell_z if has_cells else unique_point_z
        if slice_z_levels.size == 0:
            return

        print(f"\n--- {analysis_type} analysis initialized ({nature.capitalize()}) ---")

        def _process_slice(coords, data_full, z_level, plot_mode, data_type_source):
            mask = np.abs(coords[:, 2] - z_level) < tol
            if not np.any(mask):
                return pd.DataFrame()

            sliced_data = data_full[mask]
            df = pd.DataFrame(
                {
                    "x": coords[mask, 0],
                    "y": coords[mask, 1],
                    "r": np.sqrt(coords[mask, 0] ** 2 + coords[mask, 1] ** 2),
                }
            )

            # --- GENERALIZED DATA PROCESSING ---
            theta = np.arctan2(df["y"], df["x"])
            c, s = np.cos(theta), np.sin(theta)

            if nature == "scalar":
                df[f"{prefix}"] = sliced_data

            elif nature == "vector":
                df[f"{prefix}_x"] = sliced_data[:, 0]
                df[f"{prefix}_y"] = sliced_data[:, 1]
                df[f"{prefix}_z"] = sliced_data[:, 2]
                if plot_mode == "Cylindrical":
                    # Vector transformation to cylindrical (Ur = Ux*cos + Uy*sin)
                    df[f"{prefix}_r"] = df[f"{prefix}_x"] * c + df[f"{prefix}_y"] * s
                    df[f"{prefix}_t"] = -df[f"{prefix}_x"] * s + df[f"{prefix}_y"] * c

            elif nature == "tensor":
                df[f"{prefix}_xx"] = sliced_data[:, 0, 0]
                df[f"{prefix}_yy"] = sliced_data[:, 1, 1]
                df[f"{prefix}_xy"] = sliced_data[:, 0, 1]
                df[f"{prefix}_zz"] = sliced_data[:, 2, 2]

                if plot_mode == "Cylindrical":
                    c2, s2, sc = c**2, s**2, s * c
                    df[f"{prefix}_rr"] = (
                        df[f"{prefix}_xx"] * c2
                        + df[f"{prefix}_yy"] * s2
                        + 2 * df[f"{prefix}_xy"] * sc
                    )
                    df[f"{prefix}_tt"] = (
                        df[f"{prefix}_xx"] * s2
                        + df[f"{prefix}_yy"] * c2
                        - 2 * df[f"{prefix}_xy"] * sc
                    )

            # --- Grouping/Slicing for Plotting ---
            if plot_mode == "Cylindrical":
                df["r_group"] = df["r"].round(4)
                return df.groupby("r_group").mean()
            elif plot_mode == "Rectangular":
                df_line = (
                    df[np.abs(df["y"]) < tol].sort_values(by="x")
                    if data_type_source == "point"
                    else df[df["y"].abs() < df["y"].abs().min() + tol].sort_values(by="x")
                )
                return df_line.reset_index(drop=True)
            return pd.DataFrame()

        def update_plot(display_mode, plot_mode, z_slice):
            print("-" * 50)
            print(
                f"Analyzing: {analysis_type} | Mode: '{display_mode}', Plot: '{plot_mode}', Z-Slice: {z_slice:.4f}"
            )
            df_points, df_cells = pd.DataFrame(), pd.DataFrame()

            # Point data Processing (handles layer averaging if needed)
            if display_mode in ["Compare points and cells", "Points"] and has_points:
                if has_cells and unique_point_z.size > 1:  # Average bounding layers
                    idx = np.searchsorted(unique_point_z, z_slice)
                    if 0 < idx < len(unique_point_z):
                        z_b, z_a = unique_point_z[idx - 1], unique_point_z[idx]
                        df_b = _process_slice(point_coords, data_points, z_b, plot_mode, "point")
                        df_a = _process_slice(point_coords, data_points, z_a, plot_mode, "point")
                        if not df_b.empty and not df_a.empty:
                            df_points = pd.concat([df_b, df_a]).groupby(level=0).mean()
                else:  # Direct slice
                    df_points = _process_slice(
                        point_coords, data_points, z_slice, plot_mode, "point"
                    )

            # Cell data Processing (direct slice)
            if display_mode in ["Compare points and cells", "Cells"] and has_cells:
                df_cells = _process_slice(cell_coords, data_cells, z_slice, plot_mode, "cell")

            fig, ax = plt.subplots(figsize=(12, 8))

            def draw_plots(df, data_source):
                if df.empty:
                    return
                p_style = (
                    {"marker": "o", "ms": 6, "ls": "-"}
                    if data_source == "points"
                    else {"marker": "x", "ms": 7, "mew": 2, "ls": "--"}
                )
                x_axis = "r" if plot_mode == "Cylindrical" else "x"
                suffix = f" ({data_source})"

                if nature == "scalar":
                    ax.plot(
                        df[x_axis],
                        df[f"{prefix}"],
                        **p_style,
                        label=f"{cfg['plot_symbol']}{suffix}",
                    )

                elif nature == "vector":
                    if plot_mode == "Cylindrical":
                        ax.plot(
                            df["r"],
                            df[f"{prefix}_r"],
                            **p_style,
                            label=f"{cfg['plot_symbol']}_r{suffix}",
                        )
                        ax.plot(
                            df["r"],
                            df[f"{prefix}_t"],
                            **p_style,
                            label=f"{cfg['plot_symbol']}_θ{suffix}",
                        )
                    else:  # Rectangular
                        ax.plot(
                            df["x"],
                            df[f"{prefix}_x"],
                            **p_style,
                            label=f"{cfg['plot_symbol']}_x{suffix}",
                        )
                        ax.plot(
                            df["x"],
                            df[f"{prefix}_y"],
                            **p_style,
                            label=f"{cfg['plot_symbol']}_y{suffix}",
                        )

                elif nature == "tensor":
                    if plot_mode == "Cylindrical":
                        ax.plot(
                            df["r"],
                            df[f"{prefix}_rr"],
                            **p_style,
                            label=f"{cfg['plot_symbol']}_rr{suffix}",
                        )
                        ax.plot(
                            df["r"],
                            df[f"{prefix}_tt"],
                            **p_style,
                            label=f"{cfg['plot_symbol']}_θθ{suffix}",
                        )
                        ax.plot(
                            df["r"],
                            df[f"{prefix}_zz"],
                            **p_style,
                            label=f"{cfg['plot_symbol']}_zz{suffix}",
                        )
                    else:  # Rectangular
                        ax.plot(
                            df["x"],
                            df[f"{prefix}_xx"],
                            **p_style,
                            label=f"{cfg['plot_symbol']}_xx{suffix}",
                        )
                        ax.plot(
                            df["x"],
                            df[f"{prefix}_yy"],
                            **p_style,
                            label=f"{cfg['plot_symbol']}_yy{suffix}",
                        )
                        ax.plot(
                            df["r"],
                            df[f"{prefix}_zz"],
                            **p_style,
                            label=f"{cfg['plot_symbol']}_zz{suffix}",
                        )

            if not df_points.empty:
                draw_plots(df_points, "points")
            if not df_cells.empty:
                draw_plots(df_cells, "cells")

            # Plot analytical solutions
            if analytical_case and "analytical_funcs" in cfg:
                all_r = [df["r"] for df in [df_points, df_cells] if not df.empty]
                if all_r:
                    r_an = np.linspace(
                        np.min(np.concatenate(all_r)), np.max(np.concatenate(all_r)), 200
                    )
                    for label, func in cfg["analytical_funcs"].items():
                        ax.plot(r_an, func(r_an), "k--", lw=2, label=label)

            ax.set_title(f"{analysis_type} ({plot_mode}) at Z ≈ {z_slice:.4f}")
            ax.set_xlabel("Radius r (m)" if plot_mode == "Cylindrical" else "X-coordinate (m)")
            ax.set_ylabel(f"{analysis_type} {cfg['plot_symbol']} ({cfg['plot_unit']})")
            ax.grid(True, linestyle=":")
            ax.legend()
            plt.tight_layout()
            plt.show()

        # --- Inner widgets ---
        interact(
            update_plot,
            display_mode=widgets.Dropdown(
                options=display_options,
                value=display_options[0],
                description="Display mode:",
                style={"description_width": "initial"},
            ),
            plot_mode=widgets.Dropdown(
                options=["Cylindrical", "Rectangular"],
                value="Cylindrical",
                description="Plot mode:",
                style={"description_width": "initial"},
            ),
            z_slice=widgets.Dropdown(
                options=slice_z_levels,
                value=slice_z_levels[len(slice_z_levels) // 2],
                description="Select Z-layer:",
                style={"description_width": "initial"},
                layout=Layout(width="50%"),
            ),
        )

    # --- Top-level widget ---
    interact(
        setup_analysis,
        analysis_type=widgets.Dropdown(
            options=available_analysis_types,
            value=available_analysis_types[0],
            description="Analysis Type:",
            style={"description_width": "initial"},
        ),
    )


interactive_analysis_plot(filename="output/fields.vtu", analytical_case=True)

Loading data from: output/fields.vtu


interactive(children=(Dropdown(description='Analysis Type:', options=('Stress', 'Strain', 'Displacement', 'Hea…