In [1]:
import wandb
import os
import os.path as path
import pandas as pd
api = wandb.Api()
#
PATH = "../logs/CIFAR10-Ensemble"

folders = os.listdir(PATH)
folders.remove("wandb")
folders.remove("moments")

In [2]:
# group runs:
for r in api.runs("PyTorch-StudioGAN-src"):
    if r.group != r.name.split("-train")[0]:
        r.group = r.name.split("-train")[0]
        r.update()

In [3]:
# move all non kept experiments to tmp folder

trash_path = path.join(PATH, "..", "CIFAR10-Ensemble-trash")

if "CIFAR10-Ensemble-trash" not in os.listdir(path.join(PATH, "..")):
    os.mkdir(trash_path)
for folder in folders:
    if folder not in os.listdir(trash_path):
        os.mkdir(path.join(trash_path, folder))
    
kept_runs = [r.name for r in api.runs("PyTorch-StudioGAN-src")]

for folder in folders:
    for experiment in os.listdir(path.join(PATH, folder)):
        if path.splitext(experiment)[0] not in kept_runs:
            os.rename(path.join(PATH, folder, experiment), path.join(trash_path, folder, experiment))


In [4]:
# old TinyImage-runs:
for r in api.runs("PyTorch-StudioGAN-src"):
    if "TinyImageNet" in r.name:
        r.name = r.name.replace("TinyImageNet", "Tiny_ImageNet")
        r.update()

# Aggregate Statistics

In [7]:
import numpy as np
import wandb

def get_filtered_runs(project_name, group_name):
    api = wandb.Api()
    runs = api.runs(
        path=project_name,
        filters={
            "group": group_name,
            "host": {"$regex": "^clara"},
            "tags": {"$nin": ["old"]},
            #"state": "finished",
            "$or": [
                {"state": "finished"},
                {"tags": "almost_finished"}],
        }
    )
    return runs

def compute_metric_stats(runs, metric_name: str, use_max: bool = True, round_digits: int = 4) -> dict:
    """Compute statistics across WandB runs for a given metric."""
    last_values, peak_values = [], []
    
    for run in runs:
        #history = run.scan_history(keys=[metric_name])
        #values = [row[metric_name] for row in history if metric_name in row]
        history = run.history(samples=run.lastHistoryStep+1, keys=[metric_name])
        values = history[metric_name].values
        last_values.append(values[-1])
        peak_values.append(max(values) if use_max else min(values))
        
    if not last_values:
        return {}
    
    last_arr = np.array(last_values)
    opt_peak = round(max(peak_values) if use_max else min(peak_values), round_digits)
    opt_mean = round(np.mean(peak_values), round_digits)
    opt_std = round(np.std(peak_values), round_digits)
    opt_se = round(opt_std / np.sqrt(len(peak_values)), round_digits)
    last_mean = round(np.mean(last_arr), round_digits)
    last_std = round(np.std(last_arr), round_digits)
    last_se = round(last_std / np.sqrt(len(last_arr)), round_digits)
    
    return {
        f"opt_{'max' if use_max else 'min'}_value": opt_peak,
        "opt_mean": opt_mean,
        "opt_std": opt_std,
        "opt_se ": opt_se,
        #"last_mean": last_mean,
        #"last_std": last_std,
        #"last_se": last_se
    }

In [None]:
METRICS = ["FID score", "IS score"] #"valid/variational_performance"


PROJECT = "PyTorch-StudioGAN-src" 
CIFAR10_DCGAN = ["CIFAR10-DCGAN-ens-1", "CIFAR10-DCGAN-ens-2-ew3", "CIFAR10-DCGAN-ens-3-ew3", "CIFAR10-DCGAN-ens-5-ew3", "CIFAR10-DCGAN-ens-10-ew3"]
CIFAR100_DCGAN = ["CIFAR100-DCGAN-ens-1", "CIFAR100-DCGAN-ens-2-ew", "CIFAR100-DCGAN-ens-3-ew", "CIFAR100-DCGAN-ens-5-ew", "CIFAR100-DCGAN-ens-10-ew"]
CIFAR10_WGANGP = ["CIFAR10-WGAN-GP-ens-1", "CIFAR10-WGAN-GP-ens-2", "CIFAR10-WGAN-GP-ens-3-ew", "CIFAR10-WGAN-GP-ens-5-ew", "CIFAR10-WGAN-GP-ens-10-ew"]
TImageNet_WGANGP = ["Tiny_ImageNet-WGAN-GP-ens-1", "Tiny_ImageNet-WGAN-GP-ens-2-ew", "Tiny_ImageNet-WGAN-GP-ens-3-ew", "Tiny_ImageNet-WGAN-GP-ens-5-ew", "Tiny_ImageNet-WGAN-GP-ens-10-ew"]

CIFAR10_DCGAN_weightings = ["CIFAR10-DCGAN-ens-5-ew3", "CIFAR10-DCGAN-ens-5-normal", "CIFAR10-DCGAN-ens-5-uniform", "CIFAR10-DCGAN-ens-5-bernoulli2", "CIFAR10-DCGAN-ens-5-bernoulli_split", "CIFAR10-DCGAN-ens-5-ew3-gn", "CIFAR10-DCGAN-ens-5-bernoulli-gn", "CIFAR10-DCGAN-ens-5-soft-logits"]
CIFAR100_DCGAN_weightings = ["CIFAR100-DCGAN-ens-5-ew", "CIFAR100-DCGAN-ens-5-normal", "CIFAR100-DCGAN-ens-5-uniform", "CIFAR100-DCGAN-ens-5-bernoulli", "CIFAR100-DCGAN-ens-5-bernoulli_split", "CIFAR100-DCGAN-ens-5-ew-gn", "CIFAR100-DCGAN-ens-5-bernoulli-gn", "CIFAR100-DCGAN-ens-5-soft-logits"]
CIFAR10_WGANGP_weightings =  ["CIFAR10-WGAN-GP-ens-5-ew", "CIFAR10-WGAN-GP-ens-5-normal", "CIFAR10-WGAN-GP-ens-5-uniform", "CIFAR10-WGAN-GP-ens-5-bernoulli", "CIFAR10-WGAN-GP-ens-5-bernoulli_split", "CIFAR10-WGAN-GP-ens-5-ew-gn", "CIFAR10-WGAN-GP-ens-5-bernoulli-gn"]

STATISTIC_TEST = ["CIFAR10-DCGAN-ens-1", "CIFAR10-DCGAN-ens-5-ew3", "CIFAR100-DCGAN-ens-1", "CIFAR100-DCGAN-ens-5-ew", "CIFAR10-WGAN-GP-ens-1", "CIFAR10-WGAN-GP-ens-5-ew", "Tiny_ImageNet-WGAN-GP-ens-1", "Tiny_ImageNet-WGAN-GP-ens-5-ew"]
PRIOR_WORK_COMPARISON = CIFAR10_DCGAN + CIFAR10_WGANGP + CIFAR100_DCGAN

#GROUPS =  CIFAR10_DCGAN + CIFAR100_DCGAN + CIFAR10_WGANGP + TImageNet_WGANGP
#GROUPS = CIFAR10_DCGAN_weightings + CIFAR100_DCGAN_weightings + CIFAR10_WGANGP_weightings
GROUPS = PRIOR_WORK_COMPARISON



for group in GROUPS:
    runs = get_filtered_runs(PROJECT, group)
    print("GROUP: ", group, "({})".format(len(runs)))
    for r in runs:
        print("\t", r.name, r.lastHistoryStep)
    
    for metric in METRICS:
        print("")
        print("METRIC: ", metric)
        metrics = compute_metric_stats(runs, metric, round_digits=3, use_max = "FID" not in metric)
        for key, value in metrics.items():
            if "FID" in metric:
                value = f"{value:.1f}"#.lstrip('0') if value != 0 else '.000'
            else:
                value = f"{value:.2f}"#.lstrip('0') if value != 0 else '.000'
            print(f"{key}:\t\t{value}")
    
    print("")
    print("")

GROUP:  CIFAR10-DCGAN-ens-1 (5)
	 CIFAR10-DCGAN-ens-1-train-2024_05_08_18_13_01 200000
	 CIFAR10-DCGAN-ens-1-train-2024_05_10_17_41_30 200000
	 CIFAR10-DCGAN-ens-1-train-2024_05_14_18_34_08 200000
	 CIFAR10-DCGAN-ens-1-train-2024_05_14_18_35_16 200000
	 CIFAR10-DCGAN-ens-1-train-2024_05_15_12_58_28 200000

METRIC:  FID score
opt_min_value:		39.3
opt_mean:		42.9
opt_std:		2.7
opt_se :		1.2

METRIC:  IS score
opt_max_value:		6.988
opt_mean:		6.645
opt_std:		0.244
opt_se :		0.109


GROUP:  CIFAR10-DCGAN-ens-2-ew3 (5)
	 CIFAR10-DCGAN-ens-2-ew3-train-2024_06_06_15_21_25 200000
	 CIFAR10-DCGAN-ens-2-ew3-train-2024_06_07_11_06_13 200000
	 CIFAR10-DCGAN-ens-2-ew3-train-2024_06_07_11_40_01 200000
	 CIFAR10-DCGAN-ens-2-ew3-train-2024_06_07_11_40_23 200000
	 CIFAR10-DCGAN-ens-2-ew3-train-2024_06_07_11_40_27 200000

METRIC:  FID score
opt_min_value:		29.2
opt_mean:		32.6
opt_std:		2.8
opt_se :		1.3

METRIC:  IS score
opt_max_value:		7.357
opt_mean:		7.160
opt_std:		0.221
opt_se :		0.099


GROUP:  

In [None]:
# Download Metric as csv:

METRIC = "FID score"
GROUPS = ["CIFAR10-DCGAN-ens-1", "CIFAR10-DCGAN-ens-2-ew3", "CIFAR10-DCGAN-ens-3-ew3", "CIFAR10-DCGAN-ens-5-ew3", "CIFAR10-DCGAN-ens-10-ew3"]

PROJECT = "PyTorch-StudioGAN-src" 

run_list = []
history_list =[]
metric_df = pd.DataFrame()

#retrieve all runs
for group in GROUPS:
    runs = get_filtered_runs(PROJECT, group)
    for run in runs:
        run_list.append(run)

#retrieve all histories:
for run in run_list:
    history = run.history(samples=run.lastHistoryStep+1, keys= [METRIC])
    metric_df[run.name] = history[METRIC]


metric_df.index = (metric_df.index+1) * 2000

metric_df.to_csv("02/{}_{}.csv".format(GROUPS[0], METRIC))

metric_df.shape

In [18]:
metric_df

Unnamed: 0,CIFAR10-DCGAN-ens-1-train-2024_05_08_18_13_01,CIFAR10-DCGAN-ens-1-train-2024_05_10_17_41_30,CIFAR10-DCGAN-ens-1-train-2024_05_14_18_34_08,CIFAR10-DCGAN-ens-1-train-2024_05_14_18_35_16,CIFAR10-DCGAN-ens-1-train-2024_05_15_12_58_28,CIFAR10-DCGAN-ens-2-ew3-train-2024_06_06_15_21_25,CIFAR10-DCGAN-ens-2-ew3-train-2024_06_07_11_06_13,CIFAR10-DCGAN-ens-2-ew3-train-2024_06_07_11_40_01,CIFAR10-DCGAN-ens-2-ew3-train-2024_06_07_11_40_23,CIFAR10-DCGAN-ens-2-ew3-train-2024_06_07_11_40_27,...,CIFAR10-DCGAN-ens-5-ew3-train-2024_06_06_15_10_31,CIFAR10-DCGAN-ens-5-ew3-train-2024_06_06_15_21_01,CIFAR10-DCGAN-ens-5-ew3-train-2024_06_06_15_21_10,CIFAR10-DCGAN-ens-5-ew3-train-2024_06_06_15_21_34,CIFAR10-DCGAN-ens-5-ew3-train-2024_06_07_11_06_13,CIFAR10-DCGAN-ens-10-ew3-train-2024_06_06_15_10_31,CIFAR10-DCGAN-ens-10-ew3-train-2024_06_06_15_21_27,CIFAR10-DCGAN-ens-10-ew3-train-2024_06_06_15_21_34,CIFAR10-DCGAN-ens-10-ew3-train-2024_06_07_11_06_13,CIFAR10-DCGAN-ens-10-ew3-train-2024_06_07_20_38_07
0,102.480068,111.678438,103.170108,124.281023,139.023868,118.827773,88.389067,94.086444,102.807745,84.923108,...,92.390380,99.659494,88.037551,100.229986,84.826133,88.566955,89.883304,100.575085,80.811464,71.684488
2000,87.644259,77.551991,76.628881,80.042925,81.322032,81.798428,90.309244,64.203283,69.853090,59.784288,...,66.713225,88.328413,74.038965,70.682366,72.177672,57.712161,58.772100,75.221776,66.745850,56.060295
4000,73.882321,64.842894,68.829652,72.283578,74.078261,65.422611,64.662059,69.153628,57.316098,55.108173,...,62.193450,67.090907,51.699995,56.625845,56.627351,56.030074,53.836681,63.866022,61.148583,50.217165
6000,61.385311,66.129480,60.755017,60.959751,87.009582,65.975151,51.819082,51.740919,49.635588,63.535634,...,58.026170,58.155622,48.405428,59.262422,52.868863,42.130314,45.433722,57.897897,55.800221,49.076546
8000,57.919498,57.331446,60.253031,63.381205,60.784397,62.762812,59.455806,49.090874,52.566044,58.281898,...,53.617615,56.918616,52.873727,49.382249,54.761568,42.175410,46.753459,54.565713,52.790867,41.150798
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
190000,126.140172,126.639605,155.256489,142.804076,129.407830,111.040257,55.752426,63.030322,79.295913,49.376007,...,51.565573,45.832828,45.924838,42.536582,45.084595,41.455043,44.664939,41.561105,41.410136,44.224369
192000,135.273105,139.146250,129.050192,128.812495,126.152163,104.350868,59.746571,49.885633,86.227502,56.136488,...,47.375453,44.387541,41.943671,48.759176,45.473864,46.099812,42.655240,49.610694,41.573231,52.992120
194000,126.726919,137.131073,142.469990,121.876314,129.619378,99.127877,60.124235,48.220591,74.239774,47.725129,...,48.663203,57.710101,47.626539,44.015730,53.444591,46.110876,36.365403,46.006129,35.593099,47.154189
196000,138.613818,136.403206,138.243261,126.731342,125.200507,102.054059,65.132553,56.023354,81.755529,47.440771,...,49.986236,57.792208,49.326090,42.844304,47.948315,46.759536,39.993106,47.824417,37.521211,51.279072


Unnamed: 0,CIFAR10-DCGAN-ens-1-train-2024_05_08_18_13_01,CIFAR10-DCGAN-ens-1-train-2024_05_10_17_41_30,CIFAR10-DCGAN-ens-1-train-2024_05_14_18_34_08,CIFAR10-DCGAN-ens-1-train-2024_05_14_18_35_16,CIFAR10-DCGAN-ens-1-train-2024_05_15_12_58_28,CIFAR10-DCGAN-ens-2-ew3-train-2024_06_06_15_21_25,CIFAR10-DCGAN-ens-2-ew3-train-2024_06_07_11_06_13,CIFAR10-DCGAN-ens-2-ew3-train-2024_06_07_11_40_01,CIFAR10-DCGAN-ens-2-ew3-train-2024_06_07_11_40_23,CIFAR10-DCGAN-ens-2-ew3-train-2024_06_07_11_40_27,...,CIFAR10-DCGAN-ens-5-ew3-train-2024_06_06_15_21_01,CIFAR10-DCGAN-ens-5-ew3-train-2024_06_06_15_21_10,CIFAR10-DCGAN-ens-5-ew3-train-2024_06_06_15_21_34,CIFAR10-DCGAN-ens-5-ew3-train-2024_06_07_11_06_13,CIFAR10-DCGAN-ens-10-ew3-train-2024_06_06_15_10_31,CIFAR10-DCGAN-ens-10-ew3-train-2024_06_06_15_21_27,CIFAR10-DCGAN-ens-10-ew3-train-2024_06_06_15_21_34,CIFAR10-DCGAN-ens-10-ew3-train-2024_06_07_11_06_13,CIFAR10-DCGAN-ens-10-ew3-train-2024_06_07_20_38_07,step
0,102.480068,111.678438,103.170108,124.281023,139.023868,118.827773,88.389067,94.086444,102.807745,84.923108,...,99.659494,88.037551,100.229986,84.826133,88.566955,89.883304,100.575085,80.811464,71.684488,20000
1,87.644259,77.551991,76.628881,80.042925,81.322032,81.798428,90.309244,64.203283,69.853090,59.784288,...,88.328413,74.038965,70.682366,72.177672,57.712161,58.772100,75.221776,66.745850,56.060295,40000
2,73.882321,64.842894,68.829652,72.283578,74.078261,65.422611,64.662059,69.153628,57.316098,55.108173,...,67.090907,51.699995,56.625845,56.627351,56.030074,53.836681,63.866022,61.148583,50.217165,60000
3,61.385311,66.129480,60.755017,60.959751,87.009582,65.975151,51.819082,51.740919,49.635588,63.535634,...,58.155622,48.405428,59.262422,52.868863,42.130314,45.433722,57.897897,55.800221,49.076546,80000
4,57.919498,57.331446,60.253031,63.381205,60.784397,62.762812,59.455806,49.090874,52.566044,58.281898,...,56.918616,52.873727,49.382249,54.761568,42.175410,46.753459,54.565713,52.790867,41.150798,100000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,126.140172,126.639605,155.256489,142.804076,129.407830,111.040257,55.752426,63.030322,79.295913,49.376007,...,45.832828,45.924838,42.536582,45.084595,41.455043,44.664939,41.561105,41.410136,44.224369,1920000
96,135.273105,139.146250,129.050192,128.812495,126.152163,104.350868,59.746571,49.885633,86.227502,56.136488,...,44.387541,41.943671,48.759176,45.473864,46.099812,42.655240,49.610694,41.573231,52.992120,1940000
97,126.726919,137.131073,142.469990,121.876314,129.619378,99.127877,60.124235,48.220591,74.239774,47.725129,...,57.710101,47.626539,44.015730,53.444591,46.110876,36.365403,46.006129,35.593099,47.154189,1960000
98,138.613818,136.403206,138.243261,126.731342,125.200507,102.054059,65.132553,56.023354,81.755529,47.440771,...,57.792208,49.326090,42.844304,47.948315,46.759536,39.993106,47.824417,37.521211,51.279072,1980000


In [5]:
api = wandb.Api()
runs = api.runs(PROJECT)
for r in runs:
    test = r
