In [1]:
import wandb
import pandas as pd

api = wandb.Api(overrides={'base-url': "https://rosewandb.ucsd.edu"})
runs = api.runs("cht028/Inference")

In [2]:
print(runs)

<Runs cht028/Inference>


In [3]:
def create_summary_table(runs, row_name, multi_column_name, column_name, metric):
    # Initialize a dictionary for storing the lowest metric values
    lowest_metrics = {}
    run_mapping = {}

    # Find the run with the lowest metric for each combination
    for run in runs:
        try:
            name = run.config[row_name]
            dataset = run.config[multi_column_name]
            window_size = run.config[column_name] if column_name else None
            metric_value = float(run.summaryMetrics[metric])
        except KeyError:
            continue

        key = (name, dataset) if column_name is None else (name, dataset, window_size)

        if key not in lowest_metrics:
            lowest_metrics[key] = metric_value
            run_mapping[key] = run

    # Prepare the data structure for DataFrame creation
    index = sorted(set(key[0] for key in run_mapping.keys()))
    if column_name:
        datasets = sorted(set(key[1] for key in run_mapping.keys()))
        window_sizes = sorted(set(key[2] for key in run_mapping.keys()))
        multi_columns = pd.MultiIndex.from_product([datasets, window_sizes], names=[multi_column_name, column_name])
    else:
        multi_columns = sorted(set(key[1] for key in run_mapping.keys()))

    # Create the DataFrame
    results_df = pd.DataFrame(index=index, columns=multi_columns)

    # Populate the DataFrame
    for key, run in run_mapping.items():
        if column_name:
            results_df.at[key[0], (key[1], key[2])] = run.summaryMetrics[metric]
        else:
            results_df.at[key[0], key[1]] = run.summaryMetrics[metric]

    # Drop rows and columns with all NaN values
    results_df.dropna(axis=0, how='all', inplace=True)
    results_df.dropna(axis=1, how='all', inplace=True)

    return results_df

In [4]:
results_df = create_summary_table(runs, 'case', 'dataset', 'window_size', 'RMSE Scores')
results_df

dataset,Gas,Gas,Gas
window_size,1,2,5
mixed-mixed-fact/finetune,0.045173,0.056874,
mixed-mixed-fact/zeroshot,0.225459,0.104919,
mixed-mixed/finetune,0.046673,,
mixed-mixed/zeroshot,0.133108,0.087971,
text-text-fact/finetune,0.0,0.0,
text-text-fact/zeroshot,0.0,0.0,
text-text/finetune,0.0,,
text-text/zeroshot,0.0,0.0,0.0


In [8]:
def postprocess_df(df, rows_to_remove, columns_to_remove, precision):
    # Define a lambda to round and format the float to the desired precision
    format_float = lambda x: f"{x:.{precision}f}" if isinstance(x, float) else x
    
    # Remove specified rows and columns
    df = df.drop(index=rows_to_remove)
    df = df.drop(columns=columns_to_remove, axis=1)
    
    # Convert all values to the specified precision
    df = df.applymap(lambda x: format_float(round(float(x), precision)) if pd.notnull(x) and isinstance(x, (int, float)) else x)
    
    # Define a function to apply bold styling to the minimum value in each column
    def highlight_min(s):
        s = pd.to_numeric(s, errors='coerce')
        is_min = s == s.min()
        return ['font-weight: bold' if v else '' for v in is_min]
    
    # Apply the styling with the Styler object
    styled_df = df.style.apply(highlight_min, axis=0)
    
    return styled_df, df


# Example usage:
rows_to_remove = ['llama', 'm2zeroshot'] # replace with your actual row values to remove
columns_to_remove = [('Yelp', 24), ('Mimic', 14), ('Climate', 30), ('Climate', 14)] # replace with your actual column values to remove
precision = 3

processed_df, df = postprocess_df(results_df, rows_to_remove, columns_to_remove, precision)
processed_df

KeyError: "['llama', 'm2zeroshot'] not found in axis"

In [11]:
df.to_csv('results.csv')