In [38]:
import json
from itertools import groupby

In [51]:
# Load the json files
with open('results/two_week/station_metrics.json', 'r') as file:
    two_week_data = json.load(file)

with open('results/four_week/station_metrics.json', 'r') as file:
    four_week_data = json.load(file)

In [52]:
def extract_metrics(station_data, model_name):
    """
    Given the data for a station, extract the metrics for a specific model.
    """
    metrics = {}
    for metric in ['rmse', 'mae', 'mbe']:
        metrics[f'{metric}_mean'] = '{:.2f}'.format(station_data[f'{model_name}_{metric}_mean'])
        metrics[f'{metric}_std'] = '{:.2f}'.format(station_data[f'{model_name}_{metric}_std'])
    return metrics

def fill_table(station_name, two_week_data, four_week_data):
    models = {
        'mlp': 'MLP',
        'baseline': 'Richard et al.',
        'gapt_naive': 'GapT (naive)',
        'gapt': 'GapT'
    }
    table_data = []

    # For each sequence, 2 weeks and 4 weeks
    for sequence, data in [('2 weeks', two_week_data), ('4 weeks', four_week_data)]:
        for model_key, model_name in models.items():
            metrics = extract_metrics(data[station_name], model_key)
            table_data.append((station_name, sequence, model_name, 
                               metrics['rmse_mean'], metrics['rmse_std'], 
                               metrics['mae_mean'], metrics['mae_std'], 
                               metrics['mbe_mean'], metrics['mbe_std']))

    return table_data

stations = list(four_week_data.keys())
full_table_data = []

for station in stations:
    full_table_data.extend(fill_table(station, two_week_data, four_week_data))

# Create grouped data for highlighting
grouped_data = {}
for data in full_table_data:
    station = data[0]
    sequence = data[1]
    grouped_data.setdefault(station, {}).setdefault(sequence, []).append(data)

# Calculate the lowest values for highlighting
lowest_values = {}
for station, sequences in grouped_data.items():
    lowest_values[station] = {}
    for sequence, entries in sequences.items():
        min_rmse_mean = min([float(entry[3]) for entry in entries])
        min_rmse_std = min([float(entry[4]) for entry in entries])
        
        min_mae_mean = min([float(entry[5]) for entry in entries])
        min_mae_std = min([float(entry[6]) for entry in entries])
        
        # Determine the MBE mean closest to zero
        mbe_means = [float(entry[7]) for entry in entries]
        closest_mbe_mean = min(mbe_means, key=lambda x: abs(x))
        min_mbe_std = min([float(entry[8]) for entry in entries])
        
        lowest_values[station][sequence] = (min_rmse_mean, min_rmse_std, min_mae_mean, min_mae_std, closest_mbe_mean, min_mbe_std)

with open('results/station_metrics.tex', 'w') as out_file:
    out_file.write('\\begin{longtable}{cllcccccc}\n')
    out_file.write('\t\\caption{Results.} \\\\\n') # Caption for the first page
    out_file.write('\t\\hline\n')
    out_file.write('\tStation & Sequence & Model & \\multicolumn{2}{c}{RMSE} & \\multicolumn{2}{c}{MAE} & \\multicolumn{2}{c}{MBE} \\\\\n')
    out_file.write('\t& & & mean & std & mean & std & mean & std \\\\\n')
    out_file.write('\t\\hline\n')
    out_file.write('\t\\endfirsthead \n') # End the first header
    
    out_file.write('\t\\caption{Results (continued).} \\\\\n') # Caption for continued pages
    out_file.write('\t\\hline\n')
    out_file.write('\tStation & Sequence & Model & \\multicolumn{2}{c}{RMSE} & \\multicolumn{2}{c}{MAE} & \\multicolumn{2}{c}{MBE} \\\\\n')
    out_file.write('\t& & & mean & std & mean & std & mean & std \\\\\n')
    out_file.write('\t\\hline\n')
    out_file.write('\t\\endhead \n')
    
    
    prev_station = None
    model_count = 0
    sequence_count = 0
    station_count = 0

    for station, sequences in grouped_data.items():
        station_count += 1

        if prev_station is not None:
            out_file.write('\t\\hline\n')

        out_file.write('\t\\multirow{8}{*}{' + station + '}\n')
        prev_station = station
        for sequence, entries in sequences.items():
            for entry in entries:
                _, _, model, rmse_mean, rmse_std, mae_mean, mae_std, mbe_mean, mbe_std = entry
                min_values = lowest_values[station][sequence]

                out_file.write('\t& \\multirow{4}{*}{' + sequence + '}' if model_count == 0 else '\t&')
                out_file.write(' & {}'.format(model))

                for idx, val in enumerate([rmse_mean, rmse_std, mae_mean, mae_std, mbe_mean, mbe_std]):
                    if float(val) == min_values[idx]:
                        out_file.write(' & \\textbf{' + val + '}')
                    else:
                        out_file.write(' & ' + val)
                out_file.write(' \\\\\n')
                model_count += 1
                if model_count == 4:
                    model_count = 0
                    sequence_count += 1
                    if sequence_count == 1:
                        # Add cline between 2 weeks and 4 weeks sequences
                        out_file.write('\t\\cline{2-9}\n')
        sequence_count = 0

        # Add pagebreak after every 5 stations
        if station_count % 5 == 0:
            out_file.write('\t\\hline\n')
            out_file.write('\t\\pagebreak\n')


    out_file.write('\t\\hline\n')
    out_file.write('\t\\caption{Results.}\n')
    out_file.write('\t\\label{tab:my_label}\n')
    out_file.write('\\end{longtable}\n')

In [61]:
def compute_average(data, metric_prefixes, model_names):
    """
    Compute the average value for given metric prefixes (e.g., 'baseline_rmse') across all stations.
    """
    avg_values = {}
    num_stations = len(data)

    for model in model_names:
        for prefix in metric_prefixes:
            mean_key = f"{model}_{prefix}_mean"
            std_key = f"{model}_{prefix}_std"
            avg_values[mean_key] = sum([station_data[mean_key] for station_data in data.values()]) / num_stations
            avg_values[std_key] = sum([station_data[std_key] for station_data in data.values()]) / num_stations

    return avg_values

def bold_best(data_avg, metric, model_names, metric_suffix='mean', closest_to_zero=False):
    """
    Return the model name which should be bolded for a given metric.
    """
    values = [data_avg[model + f'_{metric}_{metric_suffix}'] for model in model_names]
    if closest_to_zero:
        best_index = min(range(len(values)), key=lambda i: abs(values[i]))
    else:
        best_index = values.index(min(values))
    return model_names[best_index]

# Metric prefixes and model names based on provided data
metric_prefixes = ["rmse", "mae", "mbe"]
metric_suffixes = ["mean", "std"]
model_names = ["mlp", "baseline", "gapt_naive", "gapt"]

# Determine which values should be bolded
bold_metrics_two_week = {}
bold_metrics_four_week = {}
for metric in metric_prefixes:
    for suffix in metric_suffixes:
        closest_to_zero = True if metric == 'mbe' and suffix == 'mean' else False
        bold_metrics_two_week[f"{metric}_{suffix}"] = bold_best(two_week_avg, metric, model_names, suffix, closest_to_zero)
        bold_metrics_four_week[f"{metric}_{suffix}"] = bold_best(four_week_avg, metric, model_names, suffix, closest_to_zero)

# Output LaTeX table
with open('results/average_metrics.tex', 'w') as out_file:
    out_file.write("\\begin{table}\n")
    out_file.write("\t\\caption{Results.}\n")
    out_file.write("\t\\begin{tabular}{clccccccc}\n")
    out_file.write("\t\\hline\n")
    out_file.write("Sequence & Model & \\# Params & \\multicolumn{2}{c}{RMSE} & \\multicolumn{2}{c}{MAE} & \\multicolumn{2}{c}{MBE} \\\\ \n")
    out_file.write("\t &  &  & mean & std & mean & std & mean & std \\\\ \n")
    out_file.write("\t\\hline\n")

    for data_avg, sequence_name, bold_metrics in [(two_week_avg, "2 weeks", bold_metrics_two_week), (four_week_avg, "4 weeks", bold_metrics_four_week)]:
        # Flag to ensure that the multirow command is only printed once for each sequence
        print_multirow = True
        for model in model_names:
            if model == "mlp":
                model_display_name = "MLP"
                params = "867k"
            elif model == "baseline":
                model_display_name = "Richard et al."
                params = "804k"
            elif model == "gapt_naive":
                model_display_name = "GapT (naive)"
                params = "820k"
            else:
                model_display_name = "GapT"
                params = "818k"

            rmse_val = f"\\textbf{{{data_avg[model + '_rmse_mean']:.2f}}}" if model == bold_metrics['rmse_mean'] else f"{data_avg[model + '_rmse_mean']:.2f}"
            mae_val = f"\\textbf{{{data_avg[model + '_mae_mean']:.2f}}}" if model == bold_metrics['mae_mean'] else f"{data_avg[model + '_mae_mean']:.2f}"
            mbe_val = f"\\textbf{{{data_avg[model + '_mbe_mean']:.2f}}}" if model == bold_metrics['mbe_std'] else f"{data_avg[model + '_mbe_mean']:.2f}"
            
            rmse_std_val = f"\\textbf{{{data_avg[model + '_rmse_std']:.2f}}}" if model == bold_metrics[f'rmse_std'] else f"{data_avg[model + '_rmse_std']:.2f}"
            mae_std_val = f"\\textbf{{{data_avg[model + '_mae_std']:.2f}}}" if model == bold_metrics[f'mae_std'] else f"{data_avg[model + '_mae_std']:.2f}"
            mbe_std_val = f"\\textbf{{{data_avg[model + '_mbe_std']:.2f}}}" if model == bold_metrics[f'mbe_std'] else f"{data_avg[model + '_mbe_std']:.2f}"

            if print_multirow:
                out_file.write(f"\t\\multirow{4}{{*}}{{{sequence_name}}}\t")
                print_multirow = False
            else:
                out_file.write("\t")
            
            out_file.write(f"& {model_display_name} & {params} & {rmse_val} & {rmse_std_val} & {mae_val} & {mae_std_val} & {mbe_val} & {mbe_std_val} \\\\ \n")
            
        out_file.write("\t\\hline\n")
    
    out_file.write("\t\\end{tabular}\n")
    out_file.write("\t\\label{tab:results}\n")
    out_file.write("\\end{table}\n")