In [None]:
# Change directory to the root so that relative path loads work correctly
import os

try:
    os.chdir(os.path.join(os.getcwd(), ".."))
    print(os.getcwd())
except:
    pass

In [None]:
from datetime import datetime

import torch

from experiments.A_proof_of_constraint.experiment_definition import dictionary_product
from experiments.A_proof_of_constraint.main import run_experiment

In [None]:
# Saving utilities
def get_savefile():
    base_name = "proof-of-constraint"
    time_string = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    savefile = f"{base_name}_{time_string}.pth"
    return savefile


def save_out(
    summary, savefile, directory="/global/u1/g/gelijerg/Projects/pyinsulate/results"
):
    full_file = f"{directory}/{savefile}"
    print(f"Saving to file {full_file}")
    torch.save(summary, full_file)

In [None]:
base_configuration = {
    "training_sampling": "uniform",
    "num_points": 1000,
    "num_training": 500,
    "batch_size": 100,
    "model_size": [20, 20, 20],
    "method": "average",
    "learning_rate": 1e-3,  # 1e-2 works for unconstrained well
}

configuration_revisions = list(
    dictionary_product(
        **{
            #     'method': ["unconstrained", "average", "batchwise", "no-loss"],
            "method": ["average"],
            #     'learning_rate': [1e-2, 1e-3, 1e-4, 1e-5],
            #             "training_sampling": ["start", "uniform", "random"],
        }
    )
)

configurations = list()
for revision in configuration_revisions:
    configurations.append(base_configuration.copy())
    configurations[-1].update(revision)
num_epochs = 2

In [None]:
# Run experiment
all_savefiles = list()
final_checkpoints = list()
for configuration in configurations:
    savefile = get_savefile()
    all_savefiles.append(savefile)
    print(f"Running proof of constraint with savefile {savefile}")
    checkpoint_save_file_base = os.path.splitext(savefile)[0]
    final_checkpoints.append(f"{checkpoint_save_file_base}_{num_epochs}.pth")
    final_result = run_experiment(
        num_epochs,
        log=print,
        save_directory="results/checkpoints",
        save_file=checkpoint_save_file_base,
        save_interval=100,
        **configuration,
    )
    print(f"Completed run with savefile {savefile}")
    # Save out
    configuration, (trainer, train_evaluator, test_evaluator), (
        training_monitor,
        evaluation_train_monitor,
        evaluation_test_monitor,
    ) = final_result

    save_out(
        {
            "configuration": configuration,
            "training_monitor": training_monitor,
            "evaluation_train_monitor": evaluation_train_monitor,
            "evaluation_test_monitor": evaluation_test_monitor,
        },
        savefile=savefile,
    )
print(f"\nFiles were saved to {all_savefiles}")
print(f"\nCheckpoints were saved to {final_checkpoints}")

In [None]:
print("done!")