In [11]:
import wandb
import os
import os.path as path
import numpy as np
api = wandb.Api()


In [12]:
for r in api.runs("CTGAN-ctgan"):
    if "_2024_" in r.name:
        if r.group != r.name.split("_2024_")[0]:
            r.group = r.name.split("_2024_")[0]
            r.update()
            
for r in api.runs("CTGAN"):
    if "_2024_" in r.name:
        if r.group != r.name.split("_2024_")[0]:
            r.group = r.name.split("_2024_")[0]
            r.update()

# Aggregate Statistics

In [13]:
def get_filtered_runs(project_name, group_name):
    api = wandb.Api()
    runs = api.runs(
        path=project_name,
        filters={
            "group": group_name,
            "state": "finished",
            "host": {"$regex": "^paul"}
        }
    )
    return runs

def compute_metric_stats(runs, metric_name: str, use_max: bool = True, round_digits: int = 4) -> dict:
    """Compute statistics across WandB runs for a given metric."""
    last_values, peak_values = [], []
    
    for run in runs:
        history = run.scan_history(keys=[metric_name])
        values = [row[metric_name] for row in history if metric_name in row]
        last_values.append(values[-1])
        peak_values.append(max(values) if use_max else min(values))
        
    if not last_values:
        return {}
    
    last_arr = np.array(last_values)
    peak = round(max(peak_values) if use_max else min(peak_values), round_digits)
    mean = round(np.mean(last_arr), round_digits)
    std = round(np.std(last_arr), round_digits)
    se = round(std / np.sqrt(len(last_arr)), round_digits)
    
    return {
        f"{'max' if use_max else 'min'}_value": peak,
        "last_mean": mean,
        "last_std": std,
        "last_se": se
    }

In [15]:
METRICS = ["valid/Column_Shape", "valid/Column_Pair_Trend"] #"valid/variational_performance"


PROJECT = "CTGAN" 
GROUPS_ADULTS = ["1ew", "2ew", "3ew", "5ew", "10ew"]
GROUPS_CANCER = ["cancer_1ew", "cancer_2ew", "cancer_3ew", "cancer_5ew", "cancer_10ew"]
GROUPS_SUPERSTORE = ["superstore_1ew", "superstore_2ew", "superstore_3ew", "superstore_5ew", "superstore_10ew"]

GROUPS = GROUPS_ADULTS

for group in GROUPS:
    print("GROUP: ", group)
    runs = get_filtered_runs(PROJECT, group)
    for r in runs:
        print("\t", r.name)
    
    for metric in METRICS:
        print("")
        print("METRIC: ", metric)
        metrics = compute_metric_stats(runs, metric, round_digits=3)
        for key, value in metrics.items():
            value = f"{value:.3f}".lstrip('0') if value != 0 else '.000'
            print(f"{key}:\t\t{value}")
    
    print("")
    print("")


GROUP:  1ew
	 1ew_2024_09_17_20_21_43
	 1ew_2024_09_17_20_27_45
	 1ew_2024_09_17_20_27_47
	 1ew_2024_09_18_13_26_59
	 1ew_2024_09_18_13_27_08

METRIC:  valid/Column_Shape
max_value:		.895
last_mean:		.869
last_std:		.009
last_se:		.004

METRIC:  valid/Column_Pair_Trend




max_value:		.881
last_mean:		.840
last_std:		.016
last_se:		.007


GROUP:  2ew
	 2ew_2024_09_17_20_21_43
	 2ew_2024_09_17_20_27_45
	 2ew_2024_09_17_20_27_47
	 2ew_2024_09_18_13_26_59
	 2ew_2024_09_18_13_27_08

METRIC:  valid/Column_Shape
max_value:		.903
last_mean:		.866
last_std:		.016
last_se:		.007

METRIC:  valid/Column_Pair_Trend
max_value:		.895
last_mean:		.838
last_std:		.028
last_se:		.013


GROUP:  3ew
	 3ew_2024_09_17_20_21_43
	 3ew_2024_09_17_20_27_45
	 3ew_2024_09_17_20_27_47
	 3ew_2024_09_18_13_26_59
	 3ew_2024_09_18_13_27_07

METRIC:  valid/Column_Shape
max_value:		.899
last_mean:		.864
last_std:		.008
last_se:		.004

METRIC:  valid/Column_Pair_Trend
max_value:		.900
last_mean:		.836
last_std:		.014
last_se:		.006


GROUP:  5ew
	 5ew_2024_09_17_20_21_43
	 5ew_2024_09_17_20_27_45
	 5ew_2024_09_17_20_27_47
	 5ew_2024_09_18_13_26_59
	 5ew_2024_09_18_13_27_07

METRIC:  valid/Column_Shape
max_value:		.902
last_mean:		.872
last_std:		.012
last_se:		.005

METRIC:  valid/Column_