In [None]:
# Change directory to the root so that relative path loads work correctly
import os

try:
    os.chdir(os.path.join(os.getcwd(), ".."))
    print(os.getcwd())
except:
    pass

In [None]:
import sys

import numpy as np
import torch

from experiments.A_proof_of_constraint.main import build_model_and_optimizer, get_data
from experiments.A_proof_of_constraint.visualize import (
    plot_constraints,
    plot_loss,
    plot_model_predictions,
    plot_time,
)

In [None]:
def get_model_predictions(checkpoints):
    # Retrieve the data and equation of the first checkpoint
    train_dl, test_dl, parameterization, equation = get_data(
        checkpoints[0]["configuration"], return_equation=True
    )
    # Get the models
    models = list()
    for checkpoint in checkpoints:
        model, opt = build_model_and_optimizer(checkpoint["configuration"])
        model.load_state_dict(checkpoint["model_state_dict"])
        models.append(model)
    # Get the predictions
    inputs = list()
    outputs = list()
    is_training = list()
    predictions = [list() for _ in models]
    for xb, yb in train_dl:
        inputs.extend(xb.squeeze().detach().numpy())
        outputs.extend(yb.squeeze().detach().numpy())
        is_training.extend([True for _ in range(len(xb.squeeze()))])
        for i, model in enumerate(models):
            model.eval()
            predictions[i].extend(model(xb).squeeze().detach().numpy())
    for xb, yb in test_dl:
        inputs.extend(xb.squeeze().detach().numpy())
        outputs.extend(yb.squeeze().detach().numpy())
        is_training.extend([False for _ in range(len(xb.squeeze()))])
        for i, model in enumerate(models):
            model.eval()
            predictions[i].extend(model(xb).squeeze().detach().numpy())
    # sort
    idxs = np.argsort(inputs)
    inputs = np.array(inputs)[idxs]
    outputs = np.array(outputs)[idxs]
    is_training = np.array(is_training)[idxs]
    predictions = np.array(predictions)[:, idxs]
    return (inputs, outputs, is_training), predictions

In [None]:
# Files to load
experiment_name = "test"
save_directory = f"/global/u1/g/gelijerg/Projects/pyinsulate/results/{experiment_name}/"
load_directory = "/global/u1/g/gelijerg/Projects/pyinsulate/results/checkpoints/"
files_and_names = [
    ("proof-of-constraint_2019-07-31-13-29-01_00002.pth", "Average"),
    ("proof-of-constraint_2019-07-31-13-29-26_00002.pth", "Batchwise"),
    ("proof-of-constraint_2019-07-31-15-27-51_00005.pth", "Unconstrained (5)"),
    #     ('proof-of-constraint_2019-07-31-15-27-51_00010.pth', "Unconstrained (10)"),
    #     ('proof-of-constraint_2019-07-31-15-27-51_00015.pth', "Unconstrained (15)"),
    #     ('proof-of-constraint_2019-07-31-15-27-51_00020.pth', "Unconstrained (20)"),
    ("proof-of-constraint_2019-07-31-15-27-51_00025.pth", "Unconstrained (25)"),
    #     ('proof-of-constraint_2019-07-31-15-27-51_00030.pth', "Unconstrained (30)"),
    #     ('proof-of-constraint_2019-07-31-15-27-51_00035.pth', "Unconstrained (35)"),
    #     ('proof-of-constraint_2019-07-31-15-27-51_00040.pth', "Unconstrained (40)"),
    #     ('proof-of-constraint_2019-07-31-15-27-51_00045.pth', "Unconstrained (45)"),
    ("proof-of-constraint_2019-07-31-15-27-51_00050.pth", "Unconstrained (50)"),
]
files = [x[0] for x in files_and_names]
model_names = [x[1] for x in files_and_names]

In [None]:
# Load files
checkpoints = [torch.load(f"{load_directory}/{f}") for f in files]
filenames = [os.path.splitext(f)[0] for f in files]
# Make sure directory to save exists
os.makedirs(save_directory, exist_ok=True)

In [None]:
# Do some plotting
all_monitors = [checkpoint["monitors"] for checkpoint in checkpoints]
training_monitors = [x[0] for x in all_monitors]
evalutation_train_monitors = [x[1] for x in all_monitors]
evaluation_test_monitors = [x[2] for x in all_monitors]

time_keys = [
    key
    for key in training_monitors[0].time_keys
    if ("multipliers" in key or "total" in key or "compute" in key)
    and "error" not in key
]
fig = plot_time(
    training_monitors,
    model_names,
    f"compute-time",
    time_keys=time_keys,
    log=True,
    directory=save_directory,
)


fig = plot_loss(
    [x[1:] for x in all_monitors],
    model_names,
    f"constrained-loss",
    directory=save_directory,
    constrained=True,
)
fig = plot_constraints(
    [x[1:] for x in all_monitors], model_names, f"constraint", directory=save_directory
)
fig = plot_loss(
    [x[1:] for x in all_monitors], model_names, f"loss", directory=save_directory
)

data, predictions = get_model_predictions(checkpoints)
fig = plot_model_predictions(
    data, predictions, model_names, f"predictions", directory=save_directory
)

# for checkpoint, filename in zip(checkpoints, filenames):

#     training_monitor, evaluation_train_monitor, evaluation_test_monitor = checkpoint[
#         "monitors"
#     ]

#     print(checkpoint["configuration"])
#     #     print(training_monitor)
#     #     print(evaluation_test_monitor.mean_loss)
#     plot_loss(
#         [evaluation_train_monitor, evaluation_test_monitor],
#         ["Training", "Testing"],
#         f"training-loss_{filename}",
#     )
#     plot_loss(
#         [evaluation_train_monitor, evaluation_test_monitor],
#         ["Training", "Testing"],
#         f"training-constrained-loss_{filename}",
#         title="Constrained losses",
#         constrained=True,
#     )
#     plot_constraints(
#         [evaluation_train_monitor, evaluation_test_monitor],
#         ["Training", "Testing"],
#         f"training-constraint_{filename}",
#     )me}",
#     )