In [None]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.tri as tri
import meshio

# Add the latex scripts
from matplotlib import rc
from cycler import cycler
from matplotlib.ticker import LogLocator, LogFormatterSciNotation, LogFormatter, MaxNLocator
from matplotlib.ticker import NullFormatter
plt.rcParams["text.usetex"] = True
plt.rcParams["font.family"] = "serif"
size_font = 16
show_l_infinity = False
# plt.rcParams["font.serif"] = ["Computer Modern"]
plt.rcParams['xtick.labelsize'] = size_font
plt.rcParams['axes.titlesize'] = size_font
plt.rcParams['axes.labelsize'] = size_font
 
plt.rcParams['legend.fontsize'] = size_font
plt.rcParams['ytick.labelsize'] = size_font
color_cycler = cycler(color=['darkblue', '#d62728', '#2ca02c', '#ff7f0e', '#bcbd22', '#8c564b', '#17becf', '#9467bd', '#e377c2', '#7f7f7f'])
marker_cycler = cycler(marker=['o', 's',  'v', 'p', '*', 'h', 'H', '^', '<', '>'])

plt.rcParams['axes.prop_cycle'] = color_cycler + marker_cycler

def apply_mask(triang, alpha=0.4):
    # Mask triangles with side length bigger than some alpha
    triangles = triang.triangles
    xtri = triang.x[triangles] - np.roll(triang.x[triangles], 1, axis=1)
    ytri = triang.y[triangles] - np.roll(triang.y[triangles], 1, axis=1)
    maxi = np.max(np.sqrt(xtri**2 + ytri**2), axis=1)
    triang.set_mask(maxi > alpha)

def plot_tricontourf_subplot(txt_file, mesh_file, ax, title, alpha=0.06):
    """
    This function reads x, y coordinates from the mesh file and solution from txt file,
    and plots tricontourf on the provided axis.
    """
    # Extract coordinates from mesh file
    mesh = meshio.read(mesh_file)
    x = mesh.points[:, 0]
    y = mesh.points[:, 1]

    # Read solution from txt file where each row is a solution value
    solution = np.loadtxt(txt_file)

    # Create a triangulation from the mesh points
    triang = tri.Triangulation(x, y)

    # Apply mask to remove unwanted triangles
    apply_mask(triang, alpha)

    # Plot using tricontourf with the triangulation and the interpolated solution values
    contour = ax.tricontourf(triang, solution, levels=100, cmap='jet')

    # Add title
    ax.set_title(title)

    # Add colorbar to each subplot
    cbar = plt.colorbar(contour, ax=ax)
    cbar.ax.tick_params(labelsize=8)

    # Set scientific notation for ticks
    ax.ticklabel_format(style='scientific', axis='both', scilimits=(0, 0))

def plot_all_subplots(mesh_file, exact_solution_file, pred_solution_file, error_file, output_path, filename):
    """
    This function creates a 2x3 subplot with the exact solution, predicted solution, and error plots.
    """
    # Create a figure with 2 rows and 3 columns
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))

    # Plot each of the solutions on a separate subplot
    plot_tricontourf_subplot(exact_solution_file, mesh_file, axes[0, 0], '(a)')
    plot_tricontourf_subplot(pred_solution_file, mesh_file, axes[0, 1], '(b)')
    plot_tricontourf_subplot(error_file, mesh_file, axes[0, 2], '(c)')

    # You can add more plots in the second row, for example:
    plot_tricontourf_subplot(exact_solution_file, mesh_file, axes[1, 0], '(d)')
    plot_tricontourf_subplot(pred_solution_file, mesh_file, axes[1, 1], '(e)')
    plot_tricontourf_subplot(error_file, mesh_file, axes[1, 2], '(f)')

    # Adjust layout
    plt.tight_layout()

    # Save the figure
    plt.savefig(f"{output_path}/{filename}.png", dpi=300)

    # Show the plot
    plt.show()

# Example usage:
plot_all_subplots(
    '/home/jovita/Projects/EdgeVPINNs/meshes/gear.mesh',
    '/home/jovita/Projects/EdgeVPINNs/output/gear_test_seeding/base/105/exact.txt',
    '/home/jovita/Projects/EdgeVPINNs/output/gear_test_seeding/Tucker/16_105/prediction.txt',
    '/home/jovita/Projects/EdgeVPINNs/output/gear_test_seeding/Tucker/16_105/error.txt',
    '/home/jovita/Projects/EdgeVPINNs/plots_paper',
    'forward_gear_cd2d_decomposed'
)
