In [None]:
%load_ext autoreload
%autoreload 2
%env CUBLAS_WORKSPACE_CONFIG=:4096:8
from GraPL import hyperparameter_profiles, segment_bsds
import glob
import tqdm
import matplotlib.pyplot as plt
from GraPL.evaluate import bsds_score_directory
import numpy as np
import json
import os

In [None]:
def run_experiment(base_params, changes, num_trials=2):
    results_base_dir = f'experiment_results/graph_weights'
    aggregate_scores = {}
    for change in changes:
        change_name = "local"
        if "use_fully_connected" in change.keys():
            change_name = "fully_connected"
        elif "use_color_distance_weights" in change.keys():
            change_name = "color_distance"
        elif "use_embeddings" in change.keys():
            change_name = "dino"
        print(f'Running experiment with graph={change_name}...')
        results_dir = f'{results_base_dir}/{change_name}'
        trials = []
        mean_over_trials = {}
        with tqdm.tqdm(total=num_trials * 200) as progress_bar:
            for trial_num in range(num_trials):
                params = base_params.copy()
                for key, value in change.items():
                    params[key] = value
                params['seed'] = trial_num
                trial_results_dir = f'{results_dir}/{trial_num}'
                trial_scores = segment_bsds(results_dir=trial_results_dir, progress_bar=progress_bar, **params)
                image_ids = list(trial_scores.keys())
                metrics = trial_scores[image_ids[0]].keys()
                trial_scores = {metric: np.mean([trial_scores[id][metric] for id in image_ids]) for metric in metrics}
                trials.append(trial_scores)
        for metric in trials[0]:
            mean_over_trials[metric] = np.mean([trial[metric] for trial in trials])
        aggregate_scores[change_name] = mean_over_trials
        print(f'Scores for change={change_name}: {mean_over_trials["accuracy"]}')
    return aggregate_scores

base_params = hyperparameter_profiles["best_miou"]
base_params["use_fully_connected"] = False
base_params["use_color_distance_weights"] = False
base_params["use_embeddings"] = False

changes = [{}, {"use_fully_connected": True}, {"use_color_distance_weights": True}, {"use_embeddings": True}]

aggregate_scores = run_experiment(base_params, changes, num_trials=10)