In [None]:
# Change directory to the root so that relative path loads work correctly
import os

try:
    os.chdir(os.path.join(os.getcwd(), ".."))
    print(os.getcwd())
except:
    pass

In [None]:
import glob
import sys

import matplotlib.pyplot as plt
import numpy as np
import torch

from experiments.A_proof_of_constraint.main import build_model_and_optimizer, get_data
from experiments.A_proof_of_constraint.visualize import (
    plot_constraints,
    plot_loss,
    plot_model_predictions,
    plot_time,
)

In [None]:
def get_model_predictions(checkpoints):
    # Retrieve the testing data of these parameterizations
    parameterizations = {
        "amplitudes": [1.0],
        "frequencies": [1.0],
        "phases": [0.0],
        "num_points": 500,
        "sampling": "uniform",
    }

    # Retrieve the data and equation of the first checkpoint
    train_dl, test_dl = get_data(checkpoints[0]["configuration"])
    # Get the models
    models = list()
    inputs = list()
    outputs = list()
    predictions = list()
    empty = True  # for retrieving a copy of the data
    for checkpoint in checkpoints:
        model, opt = build_model_and_optimizer(checkpoint["configuration"])
        model.load_state_dict(checkpoint["model_state_dict"])
        models.append(model)

        modified_configuration = checkpoint["configuration"].copy()
        modified_configuration["testing_parameterizations"] = parameterizations
        __, test_dl = get_data(modified_configuration)

        preds = list()
        for (xb, params), yb in test_dl:
            if empty:
                inputs.extend(xb.squeeze().detach().numpy())
                outputs.extend(yb.squeeze().detach().numpy())
            for i, model in enumerate(models):
                model.eval()
                preds.extend(model(xb, params).squeeze().detach().numpy())
        empty = False
        predictions.append(preds)
    # sort
    idxs = np.argsort(inputs)
    inputs = np.array(inputs)[idxs]
    outputs = np.array(outputs)[idxs]
    predictions = np.array(predictions)[:, idxs]
    return inputs, outputs, predictions

In [None]:
def convert_name_to_filename(model_name):
    filename = model_name.replace(" ", "-").lower()
    return filename

In [None]:
def get_model_name(checkpoint):
    config = checkpoint["configuration"]
    method = config["method"]
    model_act = config["model_act"]
    epoch = checkpoint["epoch"]
    return f"{method} {str(model_act)[:-2]} ({epoch})"

In [None]:
def get_special_model_name(checkpoint, filename):
    epoch = checkpoint["epoch"]
    suffix = "Before" if epoch <= 200 else "After"
    if "2019-08-05-15-07" in filename:
        return f"Unconstrained --> Constrained ({suffix})"
    elif "2019-08-05-15-05" in filename:
        return f"Unconstrained --> Reduction ({suffix})"
    elif "2019-08-05-14-38" in filename:
        return f"Unconstrained --> No-loss ({suffix})"
    else:
        return get_model_name(checkpoint)

In [None]:
# Files to load
experiment_name = "A_proof_of_concept"
save_directory = f"/global/u1/g/gelijerg/Projects/pyinsulate/results/{experiment_name}/"
load_directory = os.path.expandvars("results/checkpoints/")
checkpoint_patterns = ["proof-of-constraint_2019-08-06-11-37-43_00400.pth"]

In [None]:
# Load files
files = list()
for pattern in checkpoint_patterns:
    files.extend(glob.glob(f"{load_directory}/{pattern}"))
files.sort()
checkpoints = [torch.load(f) for f in files]
# model_names = [get_model_name(checkpoint) for checkpoint in checkpoints]
model_names = [
    get_special_model_name(checkpoint, filename)
    for checkpoint, filename in zip(checkpoints, files)
]
# Make sure directory to save exists
os.makedirs(save_directory, exist_ok=True)

In [None]:
# Do some plotting
max_epoch = max([checkpoint["epoch"] for checkpoint in checkpoints])
final_checkpoints = [
    (checkpoint, model_name)
    for (checkpoint, model_name) in zip(checkpoints, model_names)
    if checkpoint["epoch"] == max_epoch or checkpoint["epoch"] == 190
]
approximation_checkpoints = [
    (checkpoint, model_name)
    for (checkpoint, model_name) in zip(checkpoints, model_names)
    if checkpoint["configuration"]["method"] == "approximate"
]
reduction_checkpoints = [
    (checkpoint, model_name)
    for (checkpoint, model_name) in zip(checkpoints, model_names)
    if checkpoint["configuration"]["method"] == "reduction"
]
unconstrained_checkpoints = [
    (checkpoint, model_name)
    for (checkpoint, model_name) in zip(checkpoints, model_names)
    if checkpoint["configuration"]["method"] == "unconstrained"
]
constrained_checkpoints = [
    (checkpoint, model_name)
    for (checkpoint, model_name) in zip(checkpoints, model_names)
    if checkpoint["configuration"]["method"] == "constrained"
]
soft_constrained_checkpoints = [
    (checkpoint, model_name)
    for (checkpoint, model_name) in zip(checkpoints, model_names)
    if checkpoint["configuration"]["method"] == "soft-constrained"
]


tasks = [("Final Models", final_checkpoints)]

for task_name, task in tasks:
    print(task_name)
    if len(task) == 0:
        print(f"Nothing for task {task_name}")
        continue
    task_checkpoints = [x[0] for x in task]
    task_model_names = [x[1] for x in task]
    task_monitors = [checkpoint["monitors"] for checkpoint in task_checkpoints]
    time_keys = set()
    for monitors in task_monitors:
        time_keys.update(
            [
                key
                for key in monitors[0].timing[0][0].keys()
                if ("multipliers" in key or "total" in key or "compute" in key)
                and "error" not in key
                and "recomputed" not in key
            ]
        )
    time_keys = list(time_keys)
    task_filename = convert_name_to_filename(task_name)

    if "final" in task_filename:
        fig = plot_time(
            [monitors[0] for monitors in task_monitors],
            task_model_names,
            f"{task_filename}_compute-time",
            time_keys=time_keys,
            log=True,
            directory=save_directory,
        )

        fig = plot_loss(
            [(monitors[1], None) for monitors in task_monitors],
            task_model_names,
            f"{task_filename}_constrained-loss",
            title="Constrained Loss",
            log=True,
            directory=save_directory,
            constrained=True,
        )
        fig = plot_constraints(
            [(monitors[1], None) for monitors in task_monitors],
            task_model_names,
            f"{task_filename}_constraint",
            title="Constraint Magnitude",
            log=True,
            directory=save_directory,
        )
        fig = plot_loss(
            [(monitors[1], None) for monitors in task_monitors],
            task_model_names,
            f"{task_filename}_loss",
            title="Loss",
            log=True,
            directory=save_directory,
        )

        inputs, outputs, predictions = get_model_predictions(task_checkpoints)
        fig = plot_model_predictions(
            inputs,
            outputs,
            predictions,
            task_model_names,
            f"{task_filename}_predictions",
            title="Model Predictions",
            directory=save_directory,
        )

    if (
        "approximation" in task_filename
        or "reduction" in task_filename
        or "constrained" in task_filename
    ):

        idxs = np.argsort([checkpoint["epoch"] for checkpoint in task_checkpoints])
        task_checkpoints = np.array(task_checkpoints)[idxs]
        task_model_names = np.array(task_model_names)[idxs]
        task_monitors = np.array(task_monitors)[idxs]

        fig = plot_loss(
            [monitors[1:] for monitors in task_monitors[-1:]],
            [task_name],
            f"{task_filename}_constrained-loss",
            title=f"{task_name} Constrained Loss",
            directory=save_directory,
            constrained=True,
        )
        fig = plot_constraints(
            [monitors[1:] for monitors in task_monitors[-1:]],
            [task_name],
            f"{task_filename}_constraint",
            title=f"{task_name} Constraint Magnitude",
            directory=save_directory,
        )
        fig = plot_loss(
            [monitors[1:] for monitors in task_monitors[-1:]],
            [task_name],
            f"{task_filename}_loss",
            title=f"{task_name} Loss",
            directory=save_directory,
        )

        extra_idxs = np.power(1.5, np.arange(len(task_checkpoints))).astype(int)
        extra_idxs = extra_idxs[extra_idxs < len(task_checkpoints)]
        limited_idxs = np.unique(np.r_[0, extra_idxs, len(task_checkpoints) - 1])
        limited_checkpoints = task_checkpoints[limited_idxs]

        inputs, outputs, predictions = get_model_predictions(limited_checkpoints)
        fig = plot_model_predictions(
            inputs,
            outputs,
            predictions,
            [f'Epoch {checkpoint["epoch"]}' for checkpoint in limited_checkpoints],
            f"{task_filename}_predictions",
            title=task_name,
            directory=save_directory,
        )

    # # close all those figures
    # plt.close("all")