In [None]:
#plotting median time per epoch vs number of cells for different fe-orders
# fig, axes = plt.subplots(3,2, figsize=(10, 7))

import matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
 
# Add the latex scripts
from matplotlib import rc
from cycler import cycler
from matplotlib.ticker import LogLocator, LogFormatterSciNotation, LogFormatter, MaxNLocator
from matplotlib.ticker import NullFormatter
plt.rcParams["text.usetex"] = True
plt.rcParams["font.family"] = "serif"
size_font = 16
show_l_infinity = False
# plt.rcParams["font.serif"] = ["Computer Modern"]
plt.rcParams['xtick.labelsize'] = size_font
plt.rcParams['axes.titlesize'] = size_font
plt.rcParams['axes.labelsize'] = size_font
 
plt.rcParams['legend.fontsize'] = size_font
plt.rcParams['ytick.labelsize'] = size_font
color_cycler = cycler(color=['darkblue', '#d62728', '#2ca02c', '#ff7f0e', '#bcbd22', '#8c564b', '#17becf', '#9467bd', '#e377c2', '#7f7f7f'])
marker_cycler = cycler(marker=['o', 's',  'v', 'p', '*', 'h', 'H', '^', '<', '>'])

plt.rcParams['axes.prop_cycle'] = color_cycler + marker_cycler


# Set up the plot
fig, axes = plt.subplots(3, 2, figsize=(8, 6))
N_cells = [1, 4, 16, 25, 64, 256, 400]
y_ticks = [0.01, 0.1]
#3size_font = 12

def format_subplot(ax, median_times, title,remove_xticks=False):
    for median_time in median_times:
        ax.plot(N_cells[:len(median_time)], median_time)
    
    max_time = max(max(t for t in times if t is not None) for times in median_times)
    min_time = min(min(t for t in times if t is not None) for times in median_times)
    
    ax.set_yscale('log')   
    ax.set_xscale('log')
    ax.set_title(title)
    ax.set_ylim([min_time - 0.1 * min_time, max_time + 0.1 * max_time])
    ax.yaxis.set_major_formatter(NullFormatter())
    ax.yaxis.set_minor_formatter(NullFormatter())
    ax.set_yticks([min_time - 0.2 * min_time, max_time + 0.2 * max_time])
    ax.set_yticklabels([f"{tick:.0e}" for tick in [min_time - 0.2 * min_time, max_time + 0.2 * max_time]])
    ax.set_xticks(N_cells)
    ax.xaxis.set_major_formatter(LogFormatter(base=2.0, labelOnlyBase=False))
    ax.grid(True, which='both', linestyle='--', linewidth=0.5)
    if remove_xticks:
        ax.set_xticklabels([])

# Subplot (a)
median_times_a = [
    [0.00937, 0.0097, 0.0109, 0.00986, 0.0109, 0.0113, 0.0166],
    [0.0095, 0.00981, 0.0109, 0.0105, 0.0105, 0.0125, 0.0184],
    [0.0107, 0.0103, 0.00971, 0.00989, 0.00995, 0.016, 0.0231],
    [0.0104, 0.0102, 0.00988, 0.0108, 0.0105, 0.0223, 0.0313],
    [0.0104, 0.0102, 0.00988, 0.0108, 0.0105, 0.0223, 0.0313]
]
format_subplot(axes[0, 0], median_times_a, '(a)', remove_xticks=True)

# Subplot (b)
median_times_b = [
    [0.0098, 0.0109, 0.0101, 0.0103, 0.0108, 0.0378, 0.0565],
    [0.00961, 0.0097, 0.01, 0.0104, 0.0112, 0.0395, 0.0583],
    [0.00949, 0.00976, 0.0105, 0.0104, 0.0132, 0.0438, 0.0659],
    [0.00976, 0.00965, 0.0106, 0.0108, 0.0153, 0.0517, None],
    [0.00953, 0.00974, 0.0106, 0.0115, 0.0187, None, None]
]
format_subplot(axes[0, 1], median_times_b, '(b)', remove_xticks=True)

# Subplot (c)
median_times_c = [
    [0.00975, 0.0101, 0.0108, 0.0168, 0.015, 0.0504, 0.101],
    [0.0105, 0.0111, 0.0102, 0.0168, 0.015, 0.0507, 0.101],
    [0.0098, 0.0104, 0.0104, 0.0167, 0.0161, 0.0565, 0.11],
    [0.00984, 0.0105, 0.0109, 0.0177, 0.0199, 0.0669, 0.125],
    [0.0098, 0.0104, 0.0104, 0.021, 0.0246, 0.0788, None]
]
format_subplot(axes[1, 0], median_times_c, '(c)', remove_xticks=True)

# Subplot (d)
median_times_d = [
    [0.01, 0.0104, 0.0155, 0.0239, 0.0509, 0.246, 0.372],
    [0.01, 0.00985, 0.0156, 0.025, 0.05, 0.24, 0.369],
    [0.00971, 0.0103, 0.0158, 0.0246, 0.051, None, None],
    [0.0102, 0.0109, 0.0167, 0.0265, 0.0548, None, None],
    [0.0103, 0.01, 0.0193, 0.0294, 0.06, None, None]
]
format_subplot(axes[1, 1], median_times_d, '(d)', remove_xticks=True)

# Subplot (e)
median_times_e = [
    [0.0105, 0.0107, 0.0107, 0.0105, 0.0107, 0.0113, 0.0111],
    [0.0111, 0.0108, 0.0131, 0.0129, 0.0115, 0.0124, 0.0123],
    [0.0116, 0.0116, 0.015, 0.0149, 0.013, 0.0158, 0.0161],
    [0.0134, 0.0129, 0.0185, 0.0179, 0.0157, 0.0223, 0.0222],
    [0.0151, 0.0142, 0.0241, 0.0239, 0.0181, 0.0277, 0.0292]
]
format_subplot(axes[2, 0], median_times_e, '(e)')
axes[2, 0].set_xlabel("Number of elements " + r'$\texttt{(N\_elem)}$')
axes[2, 0].set_xticks(N_cells)

# Subplot (f)
median_times_f = [
    [0.0505, 0.0505, 0.0506, 0.0507, 0.0509, 0.0504, 0.0495],
    [0.0502, 0.0512, 0.0537, 0.0537, 0.0502, 0.0508, 0.0512],
    [0.0518, 0.0538, 0.0542, 0.0537, 0.051, 0.0565, 0.0583],
    [0.0524, 0.0534, 0.0566, 0.0565, 0.0548, 0.0667, 0.0702],
    [0.0541, 0.0551, 0.0616, 0.0612, 0.0598, 0.0789, 0.086]
]
format_subplot(axes[2, 1], median_times_f, '(f)')
axes[2, 1].set_xlabel("Number of elements " + r'$\texttt{(N\_elem)}$')
axes[2, 1].set_xticks(N_cells)

# Common legend
# fe_order_labels = ['2', '5', '10', '15', '20']
# lines, labels = axes[0, 0].get_legend_handles_labels()
# fig.legend(lines, labels, loc='lower center', title='FE Order', bbox_to_anchor=(0.5, 1.0), ncol=len(fe_order_labels), title_fontsize=size_font, fontsize=size_font)

fe_order_labels = ['2', '5', '10', '15', '20']
lines, labels = ax1.get_legend_handles_labels()  # Get the handles and labels from the first subplot
fig.legend(lines, labels, loc='lower center', title='FE Order', bbox_to_anchor=(0.5, 1.0), ncol=len(fe_order_labels), title_fontsize = size_font, fontsize = size_font)
#plt.legend()

# Common y-axis label
fig.text(-0.02, 0.5, 'Median Time per Epoch (s)', ha='center', va='center', rotation='vertical', fontsize=16)



# Adjust layout
plt.tight_layout(pad=0.4, w_pad=0.8, h_pad=1.0)
plt.savefig(f"/home/jovita/Projects/EdgeVPINNs/plots_paper/timing_vs_cells.png", dpi=300)
plt.show()

