# Imports

In [None]:
from time import time
import subprocess as sp
import tempfile
import sys
import os
import json

import tabulate as tbl
import numpy as np
import pandas as pd

from subprocess import run

# Definitions

In [None]:
test_filepath = os.path.join(os.getcwd(), 'parallel_3d_calc.py')
assert os.path.isfile(test_filepath)

num_cells_per_axis = 22
mpi_np_values = [1, 12] # axis 0
linear_solver_types = ['lu', 'gmres'] # axis 1

# Measure

In [None]:
solve_time_s = {
    "wall" : np.zeros((len(mpi_np_values), len(linear_solver_types))),
    "rank0" : np.zeros((len(mpi_np_values), len(linear_solver_types)))
}

num_dof = None

with tempfile.TemporaryDirectory() as temp_dir:

    for idx0, mpi_np_value in enumerate(mpi_np_values):
        for idx1, linear_solver_type in enumerate(linear_solver_types):
            
            run_results_filepath = os.path.join(str(temp_dir), f'run_results_{idx0}_{idx1}.json')

            cmd_list = ['mpirun', '-np', str(mpi_np_value), 
            sys.executable, test_filepath,
            f'--num-cells={num_cells_per_axis}', 
            f'--results-filepath={run_results_filepath}',
            f'--linear-solver={linear_solver_type}']

            run_start_time_s = time()
            result = sp.run(cmd_list)

            if result.returncode!=0:
                solve_time_s['rank0'][idx0, idx1] = np.NaN
                solve_time_s['wall'][idx0, idx1] = np.NaN
                continue

            assert result.returncode==0, f"cmd failed: {' '.join(cmd_list)}"
            assert os.path.isfile(run_results_filepath)

            run_solve_time_s = time() - run_start_time_s
            solve_time_s['wall'][idx0, idx1] = run_solve_time_s

            with open(run_results_filepath, 'r') as f:
                run_result = json.load(f)

            num_dof = run_result['dof']
            solve_time_s['rank0'][idx0, idx1] = run_result['solve_time_rank0_s']

# table creation

for name in solve_time_s:

    df = pd.DataFrame(solve_time_s[name],
                    columns=linear_solver_types, index=[f"mpi np={np}" for np in mpi_np_values])

    print(f"solve time (s) - {name} - DOF: {num_dof}")
    print(tbl.tabulate(df, headers='keys', tablefmt='psql'))
    print("\n")