In [None]:
# Change directory to the root so that relative path loads work correctly
import os

try:
    os.chdir(os.path.join(os.getcwd(), ".."))
    print(os.getcwd())
except:
    pass

In [None]:
import glob
import sys

import matplotlib.pyplot as plt
import numpy as np
import torch

from experiments.A_proof_of_constraint.main import build_model_and_optimizer, get_data
from experiments.A_proof_of_constraint.visualize import (
    plot_constraints,
    plot_loss,
    plot_model_predictions,
    plot_time,
)

In [None]:
def get_model_predictions(checkpoints):
    # Retrieve the data and equation of the first checkpoint
    train_dl, test_dl, parameterization, equation = get_data(
        checkpoints[0]["configuration"], return_equation=True
    )
    # Get the models
    models = list()
    for checkpoint in checkpoints:
        model, opt = build_model_and_optimizer(checkpoint["configuration"])
        model.load_state_dict(checkpoint["model_state_dict"])
        models.append(model)
    # Get the predictions
    inputs = list()
    outputs = list()
    is_training = list()
    predictions = [list() for _ in models]
    for xb, yb in train_dl:
        inputs.extend(xb.squeeze().detach().numpy())
        outputs.extend(yb.squeeze().detach().numpy())
        is_training.extend([True for _ in range(len(xb.squeeze()))])
        for i, model in enumerate(models):
            model.eval()
            predictions[i].extend(model(xb).squeeze().detach().numpy())
    for xb, yb in test_dl:
        inputs.extend(xb.squeeze().detach().numpy())
        outputs.extend(yb.squeeze().detach().numpy())
        is_training.extend([False for _ in range(len(xb.squeeze()))])
        for i, model in enumerate(models):
            model.eval()
            predictions[i].extend(model(xb).squeeze().detach().numpy())
    # sort
    idxs = np.argsort(inputs)
    inputs = np.array(inputs)[idxs]
    outputs = np.array(outputs)[idxs]
    is_training = np.array(is_training)[idxs]
    predictions = np.array(predictions)[:, idxs]
    return (inputs, outputs, is_training), predictions

In [None]:
def convert_name_to_filename(model_name):
    filename = model_name.replace(" ", "-").lower()
    return filename

In [None]:
def get_model_name(checkpoint):
    config = checkpoint["configuration"]
    sampling = config["training_sampling"]
    method = config["method"]
    model_act = config["model_act"]
    epoch = checkpoint["epoch"]
    task = "Ext" if sampling == "start" else "Int"
    return f"{method} {str(model_act)[:-2]} ({task} {epoch})"

In [None]:
def get_special_model_name(checkpoint, filename):
    epoch = checkpoint["epoch"]
    suffix = "Before" if epoch <= 200 else "After"
    if "2019-08-05-15-07" in filename:
        return f"Unconstrained --> Constrained ({suffix})"
    elif "2019-08-05-15-05" in filename:
        return f"Unconstrained --> Reduction ({suffix})"
    elif "2019-08-05-14-38" in filename:
        return f"Unconstrained --> No-loss ({suffix})"
    else:
        return get_model_name(checkpoint)

In [None]:
# Files to load
experiment_name = "A_proof_of_concept"
save_directory = f"/global/u1/g/gelijerg/Projects/pyinsulate/results/{experiment_name}/"
# load_directory = os.path.expandvars(
#     "$SCRATCH/clones/20190731-160459/pyinsulate/results/checkpoints/"
# )
# checkpoint_pattern = "*.pth"
load_directory = os.path.expandvars("results/checkpoints/")
# checkpoint_patterns = ["proof-of-constraint_2019-08-02-15-34-*_0???0.pth"] # 50 epoch test of frequency = 1
# checkpoint_patterns = ["proof-of-constraint_2019-08-05-08-5[3|4]-*_0???0.pth"] # 100 epoch test of frequency = 5, amplitude = 0.04
# checkpoint_patterns = [ "proof-of-constraint_2019-08-05-08-59-*_0???0.pth", "proof-of-constraint_2019-08-05-09-00-*_0???0.pth"] # 100 epoch test of frequency = 0.2, amplitude = 0.04
# checkpoint_patterns = [ "proof-of-constraint_2019-08-05-09-0[4|5|7]-*_0???0.pth"] # 300 epoch test of frequency = 5
# checkpoint_patterns = [ "proof-of-constraint_2019-08-05-09-1[4|5]-*_0???0.pth"] # 100 epoch test of frequency = 5 after PDE modification
# checkpoint_patterns = ["proof-of-constraint_2019-08-05-09-59-*_0??[0|5]0.pth", "proof-of-constraint_2019-08-05-10-0?-*_0??[0|5]0.pth"]
# checkpoint_patterns = ["proof-of-constraint_2019-08-05-11-29-*_0???0.pth"] # clamp = 5
# checkpoint_patterns = ["proof-of-constraint_2019-08-05-11-32-*_0???0.pth"]  # clamp = 1
# checkpoint_patterns = ["proof-of-constraint_2019-08-05-11-3[8|9]-*_0???0.pth"]  # revised data, clamp = 1
# checkpoint_patterns = ["proof-of-constraint_2019-08-05-11-42-*_0???0.pth"]  # revised data, clamp = None
# checkpoint_patterns = ["proof-of-constraint_2019-08-05-11-4[8|9]-*_0???0.pth"]  # revised data, clamp = "log"
# checkpoint_patterns = ["proof-of-constraint_2019-08-05-11-53-*_0???0.pth"]  # revised data, clamp = x : 1+x
# checkpoint_patterns = ["proof-of-constraint_2019-08-05-11-5[8|9]-*_0???0.pth"]  # revised data, Huber(6)
# checkpoint_patterns = ["proof-of-constraint_2019-08-05-12-0[1|2|6]-*_0???0.pth"]  # revised data, Huber(6), 500 epochs
# checkpoint_patterns = ["proof-of-constraint_2019-08-05-12-1[5|6]-*_0???0.pth", "proof-of-constraint_2019-08-05-12-20-*_0???0.pth"]  # revised data, Huber(6), tanh
# checkpoint_patterns = ["proof-of-constraint_2019-08-05-12-3[1|2]-*_0???0.pth"]  # revised data, Huber(6), 100 epochs
checkpoint_patterns = ["proof-of-constraint_2019-08-05-12-4[1|2]-*_0???0.pth"]
checkpoint_patterns = ["proof-of-constraint_2019-08-05-12-5[3|4]-*_0???0.pth"]
checkpoint_patterns = ["proof-of-constraint_2019-08-05-14-23-28_0???0.pth"]
checkpoint_patterns = [
    #                         'proof-of-constraint_2019-08-05-14-38-*_0???0.pth', # 200 epochs unconstrained followed by 200 epochs no-loss
    #                       'proof-of-constraint_2019-08-05-15-05-*_0???0.pth', # 200 epochs unconstrained followed by 200 epochs reduction
    #                       'proof-of-constraint_2019-08-05-15-07-*_0???0.pth', # 200 epochs unconstrained followed by 200 epochs constrained
    "proof-of-constraint_2019-08-05-15-38-*_0???0.pth",  # 400 epochs unconstrained
    "proof-of-constraint_2019-08-05-15-39-*_0???0.pth",  # 400 epochs reduction
]

In [None]:
# Load files
files = list()
for pattern in checkpoint_patterns:
    files.extend(glob.glob(f"{load_directory}/{pattern}"))
files.sort()
checkpoints = [torch.load(f) for f in files]
# model_names = [get_model_name(checkpoint) for checkpoint in checkpoints]
model_names = [
    get_special_model_name(checkpoint, filename)
    for checkpoint, filename in zip(checkpoints, files)
]
# Make sure directory to save exists
os.makedirs(save_directory, exist_ok=True)

In [None]:
# Do some plotting
max_epoch = max([checkpoint["epoch"] for checkpoint in checkpoints])
final_checkpoints = [
    (checkpoint, model_name)
    for (checkpoint, model_name) in zip(checkpoints, model_names)
    if checkpoint["epoch"] == max_epoch or checkpoint["epoch"] == 190
]
approximation_checkpoints = [
    (checkpoint, model_name)
    for (checkpoint, model_name) in zip(checkpoints, model_names)
    if checkpoint["configuration"]["method"] == "approximate"
]
reduction_checkpoints = [
    (checkpoint, model_name)
    for (checkpoint, model_name) in zip(checkpoints, model_names)
    if checkpoint["configuration"]["method"] == "reduction"
]
unconstrained_checkpoints = [
    (checkpoint, model_name)
    for (checkpoint, model_name) in zip(checkpoints, model_names)
    if checkpoint["configuration"]["method"] == "unconstrained"
]
constrained_checkpoints = [
    (checkpoint, model_name)
    for (checkpoint, model_name) in zip(checkpoints, model_names)
    if checkpoint["configuration"]["method"] == "constrained"
]
soft_constrained_checkpoints = [
    (checkpoint, model_name)
    for (checkpoint, model_name) in zip(checkpoints, model_names)
    if checkpoint["configuration"]["method"] == "soft-constrained"
]


tasks = [
    ("Final Models", final_checkpoints),
    #     ("Approximation Interpolation", interpolation_task(approximation_checkpoints)),
    #     ("Approximation Extrapolation", extrapolation_task(approximation_checkpoints)),
    #     ("Reduction Interpolation", interpolation_task(reduction_checkpoints)),
    #     ("Reduction Extrapolation", extrapolation_task(reduction_checkpoints)),
    #     ("Unconstrained Interpolation", interpolation_task(unconstrained_checkpoints)),
    #     ("Unconstrained Extrapolation", extrapolation_task(unconstrained_checkpoints)),
    #     ("Soft-Constrained Interpolation", interpolation_task(soft_constrained_checkpoints)),
    #     ("Soft-Constrained Extrapolation", extrapolation_task(soft_constrained_checkpoints)),
    #     ("Constrained Interpolation", interpolation_task(constrained_checkpoints)),
    #     ("Constrained Extrapolation", extrapolation_task(constrained_checkpoints)),
]

for task_name, task in tasks:
    print(task_name)
    if len(task) == 0:
        print(f"Nothing for task {task_name}")
        continue
    task_checkpoints = [x[0] for x in task]
    task_model_names = [x[1] for x in task]
    task_monitors = [checkpoint["monitors"] for checkpoint in task_checkpoints]
    time_keys = set()
    for monitors in task_monitors:
        time_keys.update(
            [
                key
                for key in monitors[0].timing[0][0].keys()
                if ("multipliers" in key or "total" in key or "compute" in key)
                and "error" not in key
                and "recomputed" not in key
            ]
        )
    time_keys = list(time_keys)
    task_filename = convert_name_to_filename(task_name)

    if "final" in task_filename:
        fig = plot_time(
            [monitors[0] for monitors in task_monitors],
            task_model_names,
            f"{task_filename}_compute-time",
            time_keys=time_keys,
            log=True,
            directory=save_directory,
        )

        fig = plot_loss(
            [(monitors[1], None) for monitors in task_monitors],
            task_model_names,
            f"{task_filename}_constrained-loss",
            title="Constrained Loss",
            log=True,
            directory=save_directory,
            constrained=True,
        )
        fig = plot_constraints(
            [(monitors[1], None) for monitors in task_monitors],
            task_model_names,
            f"{task_filename}_constraint",
            title="Constraint Magnitude",
            log=True,
            directory=save_directory,
        )
        fig = plot_loss(
            [(monitors[1], None) for monitors in task_monitors],
            task_model_names,
            f"{task_filename}_loss",
            title="Loss",
            log=True,
            directory=save_directory,
        )

        data, predictions = get_model_predictions(task_checkpoints)
        fig = plot_model_predictions(
            data,
            predictions,
            task_model_names,
            f"{task_filename}_predictions",
            title="Model Predictions",
            directory=save_directory,
        )

    if (
        "approximation" in task_filename
        or "reduction" in task_filename
        or "constrained" in task_filename
    ):

        idxs = np.argsort([checkpoint["epoch"] for checkpoint in task_checkpoints])
        task_checkpoints = np.array(task_checkpoints)[idxs]
        task_model_names = np.array(task_model_names)[idxs]
        task_monitors = np.array(task_monitors)[idxs]

        fig = plot_loss(
            [monitors[1:] for monitors in task_monitors[-1:]],
            [task_name],
            f"{task_filename}_constrained-loss",
            title=f"{task_name} Constrained Loss",
            directory=save_directory,
            constrained=True,
        )
        fig = plot_constraints(
            [monitors[1:] for monitors in task_monitors[-1:]],
            [task_name],
            f"{task_filename}_constraint",
            title=f"{task_name} Constraint Magnitude",
            directory=save_directory,
        )
        fig = plot_loss(
            [monitors[1:] for monitors in task_monitors[-1:]],
            [task_name],
            f"{task_filename}_loss",
            title=f"{task_name} Loss",
            directory=save_directory,
        )

        extra_idxs = np.power(1.5, np.arange(len(task_checkpoints))).astype(int)
        extra_idxs = extra_idxs[extra_idxs < len(task_checkpoints)]
        limited_idxs = np.unique(np.r_[0, extra_idxs, len(task_checkpoints) - 1])
        limited_checkpoints = task_checkpoints[limited_idxs]
        data, predictions = get_model_predictions(limited_checkpoints)
        fig = plot_model_predictions(
            data,
            predictions,
            [f'Epoch {checkpoint["epoch"]}' for checkpoint in limited_checkpoints],
            f"{task_filename}_predictions",
            title=task_name,
            directory=save_directory,
        )

    # # close all those figures
    # plt.close("all")