In [None]:
# Change directory to the root so that relative path loads work correctly
import os

try:
    os.chdir(os.path.join(os.getcwd(), ".."))
    print(os.getcwd())
except:
    pass

In [None]:
import glob
import sys

import matplotlib.pyplot as plt
import numpy as np
import torch

from experiments.A_proof_of_constraint.main import build_model_and_optimizer, get_data
from experiments.A_proof_of_constraint.visualize import (
    plot_constraints,
    plot_loss,
    plot_model_predictions,
    plot_time,
)

In [None]:
def get_model_predictions(checkpoints):
    # Retrieve the data and equation of the first checkpoint
    train_dl, test_dl, parameterization, equation = get_data(
        checkpoints[0]["configuration"], return_equation=True
    )
    # Get the models
    models = list()
    for checkpoint in checkpoints:
        model, opt = build_model_and_optimizer(checkpoint["configuration"])
        model.load_state_dict(checkpoint["model_state_dict"])
        models.append(model)
    # Get the predictions
    inputs = list()
    outputs = list()
    is_training = list()
    predictions = [list() for _ in models]
    for xb, yb in train_dl:
        inputs.extend(xb.squeeze().detach().numpy())
        outputs.extend(yb.squeeze().detach().numpy())
        is_training.extend([True for _ in range(len(xb.squeeze()))])
        for i, model in enumerate(models):
            model.eval()
            predictions[i].extend(model(xb).squeeze().detach().numpy())
    for xb, yb in test_dl:
        inputs.extend(xb.squeeze().detach().numpy())
        outputs.extend(yb.squeeze().detach().numpy())
        is_training.extend([False for _ in range(len(xb.squeeze()))])
        for i, model in enumerate(models):
            model.eval()
            predictions[i].extend(model(xb).squeeze().detach().numpy())
    # sort
    idxs = np.argsort(inputs)
    inputs = np.array(inputs)[idxs]
    outputs = np.array(outputs)[idxs]
    is_training = np.array(is_training)[idxs]
    predictions = np.array(predictions)[:, idxs]
    return (inputs, outputs, is_training), predictions

In [None]:
def convert_name_to_filename(model_name):
    filename = model_name.replace(" ", "-").lower()
    return filename

In [None]:
def get_model_name(checkpoint):
    config = checkpoint["configuration"]
    sampling = config["training_sampling"]
    method = config["method"]
    model_act = config["model_act"]
    epoch = checkpoint["epoch"]
    task = "Ext" if sampling == "start" else "Int"
    return f"{method} {str(model_act)[:-2]} ({task} {epoch})"

In [None]:
# Files to load
experiment_name = "A_proof_of_concept"
save_directory = f"/global/u1/g/gelijerg/Projects/pyinsulate/results/{experiment_name}/"
# load_directory = os.path.expandvars(
#     "$SCRATCH/clones/20190731-160459/pyinsulate/results/checkpoints/"
# )
# checkpoint_pattern = "*.pth"
load_directory = os.path.expandvars("results/checkpoints/")
checkpoint_patterns = ["proof-of-constraint_2019-08-02-15-34-*_0???0.pth"]

In [None]:
# Load files
files = list()
for pattern in checkpoint_patterns:
    files.extend(glob.glob(f"{load_directory}/{pattern}"))
checkpoints = [torch.load(f) for f in files]
model_names = [get_model_name(checkpoint) for checkpoint in checkpoints]
# Make sure directory to save exists
os.makedirs(save_directory, exist_ok=True)

In [None]:
# Do some plotting
final_checkpoints = [
    (checkpoint, model_name)
    for (checkpoint, model_name) in zip(checkpoints, model_names)
    if checkpoint["epoch"] == 50
]
approximation_checkpoints = [
    (checkpoint, model_name)
    for (checkpoint, model_name) in zip(checkpoints, model_names)
    if checkpoint["configuration"]["method"] == "approximate"
]
reduction_checkpoints = [
    (checkpoint, model_name)
    for (checkpoint, model_name) in zip(checkpoints, model_names)
    if checkpoint["configuration"]["method"] == "reduction"
]
unconstrained_checkpoints = [
    (checkpoint, model_name)
    for (checkpoint, model_name) in zip(checkpoints, model_names)
    if checkpoint["configuration"]["method"] == "unconstrained"
]
constrained_checkpoints = [
    (checkpoint, model_name)
    for (checkpoint, model_name) in zip(checkpoints, model_names)
    if checkpoint["configuration"]["method"] == "constrained"
]


def interpolation_task(checkpoints_list):
    return [
        (checkpoints, name) for (checkpoints, name) in checkpoints_list if "Int" in name
    ]


def extrapolation_task(checkpoints_list):
    return [
        (checkpoints, name) for (checkpoints, name) in checkpoints_list if "Ext" in name
    ]


tasks = [
    ("Final Epoch Interpolation", interpolation_task(final_checkpoints)),
    ("Final Epoch Extrapolation", extrapolation_task(final_checkpoints)),
    ("Approximation Interpolation", interpolation_task(approximation_checkpoints)),
    ("Approximation Extrapolation", extrapolation_task(approximation_checkpoints)),
    ("Reduction Interpolation", interpolation_task(reduction_checkpoints)),
    ("Reduction Extrapolation", extrapolation_task(reduction_checkpoints)),
    ("Unconstrained Interpolation", interpolation_task(unconstrained_checkpoints)),
    ("Unconstrained Extrapolation", extrapolation_task(unconstrained_checkpoints)),
    ("Constrained Interpolation", interpolation_task(constrained_checkpoints)),
    ("Constrained Extrapolation", extrapolation_task(constrained_checkpoints)),
]

for task_name, task in tasks:
    print(task_name)
    if len(task) == 0:
        print(f"Nothing for task {task_name}")
        continue
    task_checkpoints = [x[0] for x in task]
    task_model_names = [x[1] for x in task]
    task_monitors = [checkpoint["monitors"] for checkpoint in task_checkpoints]
    time_keys = [
        key
        for key in task_monitors[0][0].time_keys
        if ("multipliers" in key or "total" in key or "compute" in key)
        and "error" not in key
    ]
    task_filename = convert_name_to_filename(task_name)

    if "final" in task_filename:
        fig = plot_time(
            [monitors[0] for monitors in task_monitors],
            task_model_names,
            f"{task_filename}_compute-time",
            time_keys=time_keys,
            log=True,
            directory=save_directory,
        )

        fig = plot_loss(
            [monitors[1:] for monitors in task_monitors],
            task_model_names,
            f"{task_filename}_constrained-loss",
            title="Constrained Loss",
            directory=save_directory,
            constrained=True,
        )
        fig = plot_constraints(
            [monitors[1:] for monitors in task_monitors],
            task_model_names,
            f"{task_filename}_constraint",
            title="Constraint Magnitude",
            directory=save_directory,
        )
        fig = plot_loss(
            [monitors[1:] for monitors in task_monitors],
            task_model_names,
            f"{task_filename}_loss",
            title="Loss",
            directory=save_directory,
        )

        data, predictions = get_model_predictions(task_checkpoints)
        fig = plot_model_predictions(
            data,
            predictions,
            task_model_names,
            f"{task_filename}_predictions",
            title="Model Predictions",
            directory=save_directory,
        )

    if (
        "approximation" in task_filename
        or "reduction" in task_filename
        or "constrained" in task_filename
    ):

        idxs = np.argsort([checkpoint["epoch"] for checkpoint in task_checkpoints])
        task_checkpoints = np.array(task_checkpoints)[idxs]
        task_model_names = np.array(task_model_names)[idxs]
        task_monitors = np.array(task_monitors)[idxs]

        fig = plot_loss(
            [monitors[1:] for monitors in task_monitors[-1:]],
            [task_name],
            f"{task_filename}_constrained-loss",
            title=f"{task_name} Constrained Loss",
            directory=save_directory,
            constrained=True,
        )
        fig = plot_constraints(
            [monitors[1:] for monitors in task_monitors[-1:]],
            [task_name],
            f"{task_filename}_constraint",
            title=f"{task_name} Constraint Magnitude",
            directory=save_directory,
        )
        fig = plot_loss(
            [monitors[1:] for monitors in task_monitors[-1:]],
            [task_name],
            f"{task_filename}_loss",
            title=f"{task_name} Loss",
            directory=save_directory,
        )

        extra_idxs = np.power(1.5, np.arange(len(task_checkpoints))).astype(int)
        extra_idxs = extra_idxs[extra_idxs < len(task_checkpoints)]
        limited_idxs = np.unique(np.r_[0, extra_idxs, len(task_checkpoints) - 1])
        limited_checkpoints = task_checkpoints[limited_idxs]
        data, predictions = get_model_predictions(limited_checkpoints)
        fig = plot_model_predictions(
            data,
            predictions,
            [f'Epoch {checkpoint["epoch"]}' for checkpoint in limited_checkpoints],
            f"{task_filename}_predictions",
            title=task_name,
            directory=save_directory,
        )

    # # close all those figures
    # plt.close("all")