In [1]:
import wandb
import pandas as pd
import numpy as np

api = wandb.Api(overrides={'base-url': "https://rosewandb.ucsd.edu"})
runs = api.runs("cht028/Inference-new")

In [2]:
def create_summary_table(runs, row_name, multi_column_name, column_name, metric):
    # Initialize a dictionary for storing the lowest metric values
    lowest_metrics = {}
    run_mapping = {}

    # Find the run with the lowest metric for each combination
    for run in runs:
        try:
            name = run.config[row_name]
            dataset = run.config[multi_column_name]
            window_size = run.config[column_name] if column_name else None
            metric_value = float(run.summaryMetrics[metric])
            metric_value = np.round(metric_value, 3)
        except KeyError:
            continue

        key = (name, dataset) if column_name is None else (name, dataset, window_size)

        if key not in lowest_metrics:
            lowest_metrics[key] = metric_value
            run_mapping[key] = run

    # Prepare the data structure for DataFrame creation
    index = sorted(set(key[0] for key in run_mapping.keys()))
    if column_name:
        datasets = sorted(set(key[1] for key in run_mapping.keys()))
        window_sizes = sorted(set(key[2] for key in run_mapping.keys()))
        multi_columns = pd.MultiIndex.from_product([datasets, window_sizes], names=[multi_column_name, column_name])
    else:
        multi_columns = sorted(set(key[1] for key in run_mapping.keys()))

    # Create the DataFrame
    results_df = pd.DataFrame(index=index, columns=multi_columns)

    # Populate the DataFrame
    for key, run in run_mapping.items():
        if column_name:
            results_df.at[key[0], (key[1], key[2])] = run.summaryMetrics[metric]
        else:
            results_df.at[key[0], key[1]] = run.summaryMetrics[metric]

    # Drop rows and columns with all NaN values
    results_df.dropna(axis=0, how='all', inplace=True)
    results_df.dropna(axis=1, how='all', inplace=True)

    return results_df

In [10]:
results_df = create_summary_table(runs, 'model', 'dataset', 'window_size', 'Meteor Scores')
results_df.to_excel('results.xlsx')
results_df

dataset,climate,climate,climate,climate
window_size,1-1,2-2,3-3,4-4
input_copy,0.375136,0.369451,0.365435,0.363035
llama8b,0.357917,0.345632,0.329292,
llama8b-InContext,0.341697,0.35187,0.328018,
llama8b-InContext-mixed,0.315343,0.314049,0.300341,
llama8b-mixed,0.332154,0.323894,0.327684,
nlinear,,,,
nlinear_textEmbedding,,,,


In [10]:
def postprocess_df(df, rows_to_maintain, columns_to_remove, precision):
    # Define a lambda to round and format the float to the desired precision
    format_float = lambda x: f"{x:.{precision}f}" if isinstance(x, float) else x
    
    # Remove specified rows and columns
    rows_to_remove = [x for x in df.index if x not in rows_to_maintain]
    df = df.drop(index=rows_to_remove)
    df = df.drop(columns=columns_to_remove, axis=1)

    # Maintain the specified rows
    # df = df[[x for x in df.columns if x in rows_to_maintain]]
    
    # Convert all values to the specified precision
    df = df.applymap(lambda x: format_float(round(float(x), precision)) if pd.notnull(x) and isinstance(x, (int, float)) else x)
    
    # Define a function to apply bold styling to the minimum value in each column
    def highlight_min(s):
        s = pd.to_numeric(s, errors='coerce')
        is_min = s == s.min()
        return ['font-weight: bold' if v else '' for v in is_min]
    
    # Apply the styling with the Styler object
    styled_df = df.style.apply(highlight_min, axis=0)
    
    return styled_df, df


# Example usage:
# rows_to_remove = ['llama', 'm2zeroshot'] # replace with your actual row values to remove
row_to_maintain = ['nlinear', 'nlinear_embedding', 'mixed-mixed-dc/finetune', 'text-text-dc/finetune', 'mixed-mixed/finetune', 'text-text/finetune']
# columns_to_remove = [('Yelp', 24), ('Mimic', 14), ('Climate', 30), ('Climate', 14)] # replace with your actual column values to remove
columns_to_remove = []
precision = 3

processed_df, df = postprocess_df(results_df, row_to_maintain, columns_to_remove, precision)
processed_df

  df = df.applymap(lambda x: format_float(round(float(x), precision)) if pd.notnull(x) and isinstance(x, (int, float)) else x)


dataset,climate,climate,climate,gas,gas,gas,medical,medical,medical
window_size,1,2,3,1,2,3,1,2,3
mixed-mixed-dc/finetune,4.642,4.368,4.483,,,,,,
mixed-mixed/finetune,,,,,,,6.821,6.816,6.529
nlinear,4.217,4.544,4.705,0.107,0.143,0.147,5.233,4.8,4.697
nlinear_embedding,5.799,5.03,5.47,0.153,0.16,0.174,5.462,4.826,4.934
text-text-dc/finetune,,,,,,,,,
text-text/finetune,,,,,,,,,
