In [None]:
import os
import matplotlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
matplotlib.use("agg")

from exp_spec_info import *
from plot_info import *
from select_data import *

In [None]:
# Processed pickle paths
extended_data_path = "C:\\Users\\dosre\\dev\\thesis-data\\extended_data.pkl"
median_data_path = "C:\\Users\\dosre\\dev\\thesis-data\\median_data.pkl"

# Plot output directory
plot_output_root = "C:\\Users\\dosre\\dev\\thesis-data\\plots"
os.makedirs(plot_output_root, exist_ok=True)
basic_traj_plots_dir = os.path.join(plot_output_root, "plots_basic_traj")
os.makedirs(basic_traj_plots_dir, exist_ok=True)

In [None]:
# Load data
extended_data = pd.read_pickle(extended_data_path)
median_data = pd.read_pickle(median_data_path)

##### Plot Basic Trajectories for All Experiments

In [None]:
for setup in SETUPS:
    os.makedirs(os.path.join(basic_traj_plots_dir, setup), exist_ok=True)
    for matrix in SETUP_MATRIX_MAPPING[setup]:
        for restart_param in RESTART_PARAMS:

            fig, axs = plt.subplots(3, 1, figsize=(8, 8), height_ratios=[0.35, 0.35, 0.3])
            ax1, ax2, ax3 = axs

            extended_sub_data = df_sel_setup_matrix_restart(
                extended_data, setup, matrix, restart_param
            )
            median_sub_data = df_sel_setup_matrix_restart(
                median_data, setup, matrix, restart_param
            )

            for fp_solver in FP_SOLVERS:
                solver_exp_iteration_data = extended_sub_data[
                    extended_sub_data["solver"] == fp_solver
                ]
                plot_exp_iters_conv_traj(
                    ax1,
                    solver_exp_iteration_data,
                    N_EXPERIMENT_ITERATIONS,
                    fp_solver,
                    SOLVER_CLR_DICT[fp_solver]                    
                )

            for gmres_m_solver in GMRES_M_SOLVERS:
                solver_exp_iteration_data = extended_sub_data[
                    extended_sub_data["solver"] == gmres_m_solver
                ]
                plot_exp_iters_conv_traj(
                    ax2,
                    solver_exp_iteration_data,
                    N_EXPERIMENT_ITERATIONS,
                    gmres_m_solver,
                    SOLVER_CLR_DICT[gmres_m_solver]                    
                )

            for ax in [ax1, ax2]:
                ax.set_xlim(
                    0,
                    np.nanmax(np.hstack([0, extended_sub_data["inner_iters"].to_numpy()]))
                )
                ax.set_ylabel("$\\|b-Ax_i\\|_2/\\|b-Ax_0\\|_2$")
                ax.legend()
                ax.grid()
            ax1.set_xticklabels([])
            ax2.set_xlabel("Inner Iteration")

            try:
                table_data = median_sub_data.drop(
                    [
                        "setup", "matrix", "restart_param", "med_outer_iter",
                        "med_rel_res_frac_err", "med_rel_time"
                    ],
                    axis=1
                )
                table_data["med_inner_iter"] = (
                    table_data["med_inner_iter"].apply(lambda elem: int(elem))
                )
                table_data["med_rel_res"] = (
                    table_data["med_rel_res"].apply(lambda elem: f"{elem:.2g}")
                )
                table_data = table_data.set_index("solver")
                table_data = table_data.loc(axis=0)[*SOLVER_ID_ORDER]
                table_data = table_data.reset_index()
                plt_table = ax3.table(
                    cellText=table_data.values,
                    colLabels=table_data.columns,
                    loc="center",
                    bbox=[-0.1, 0, 1.2, 1.]
                )
                plt_table.auto_set_font_size(False)
                plt_table.set_fontsize(9)
            except KeyError as e:
                print(f"KeyError: {setup} {matrix} {restart_param}")
            ax3.axis('off')
            
            fig.suptitle(
                f"{matrix} {SETUP_NAME_MAPPING[setup]} GMRES({restart_param})"
            )
            fig.tight_layout()

            plt.savefig(
                os.path.join(
                    basic_traj_plots_dir,
                    setup,
                    f"{setup}_{matrix}_{restart_param:03d}"
                )
            )

            plt.close(fig)