In [None]:
# Change directory to the root so that relative path loads work correctly
import os

try:
    os.chdir(os.path.join(os.getcwd(), ".."))
    print(os.getcwd())
except:
    pass

In [None]:
import glob
import sys

import matplotlib.pyplot as plt
import numpy as np
import torch

from experiments.C_nonlinear_projection.main import build_model_and_optimizer, get_data
from experiments.C_nonlinear_projection.visualize import (  # plot_model_predictions,
    plot_epoch_wise,
    plot_epoch_wise_distribution,
    plot_model_predictions,
    plot_time,
    retrieve_object,
)

In [None]:
def convert_name_to_filename(model_name):
    filename = model_name.replace(" ", "-").lower()
    filename = (
        filename.replace("(", "").replace(")", "").replace(":", "").replace(",", "")
    )
    return filename

In [None]:
def get_model_name(checkpoint):
    config = checkpoint["configuration"]
    weight = config["regularization_weight"]
    weight_string = (
        "Unconstrained" if weight == 0.0 else f"Soft Constrained ({weight:g})"
    )
    #     model_act = config["model_act"]
    epoch = checkpoint["epoch"]
    return f"{weight_string} Epoch {epoch}"


def get_group_name(checkpoint):
    config = checkpoint["configuration"]
    weight = config["regularization_weight"]
    weight_string = (
        "Unconstrained" if weight == 0.0 else f"Soft Constrained ({weight:g})"
    )
    return weight_string

In [None]:
def get_special_model_name(checkpoint, filename):
    return get_model_name(checkpoint)

In [None]:
# Files to load
experiment_name = "C_nonlinear_projection"
save_directory = f"/global/u1/g/gelijerg/Projects/pyinsulate/results/{experiment_name}/"
load_directory = os.path.expandvars(f"$SCRATCH/results/checkpoints/{experiment_name}")
checkpoint_patterns = ["nonlinear-projection_2019-08-16-12-06-02_000?0.pth"]

In [None]:
# Load files
files = list()
for pattern in checkpoint_patterns:
    files.extend(glob.glob(f"{load_directory}/{pattern}"))
files.sort()
print(files)
checkpoints = [torch.load(f, map_location=torch.device("cpu")) for f in files]
# model_names = [get_model_name(checkpoint) for checkpoint in checkpoints]
model_names = [
    get_special_model_name(checkpoint, filename)
    for checkpoint, filename in zip(checkpoints, files)
]
# Make sure directory to save exists
os.makedirs(save_directory, exist_ok=True)

In [None]:
# Do some plotting
max_epoch = max([checkpoint["epoch"] for checkpoint in checkpoints])
final_checkpoints = [
    (checkpoint, model_name)
    for (checkpoint, model_name) in zip(checkpoints, model_names)
    if checkpoint["epoch"] == max_epoch
]

tasks = [("Final Models", final_checkpoints)]

checkpoint_groups = dict()
for f, checkpoint, model_name in zip(files, checkpoints, model_names):
    key = os.path.basename(f[: f.rfind("_")])
    if key not in checkpoint_groups:
        checkpoint_groups[key] = list()
    checkpoint_groups[key].append((checkpoint, model_name))
for group_key, checkpoint_name_pairs in checkpoint_groups.items():
    name = get_group_name(checkpoint_name_pairs[0][0])
    tasks.append((name, checkpoint_name_pairs))

for task_name, task in tasks:
    print(task_name)
    if len(task) == 0:
        print(f"Nothing for task {task_name}")
        continue
    task_checkpoints = [x[0] for x in task]
    task_model_names = [x[1] for x in task]
    task_monitors = [checkpoint["monitors"] for checkpoint in task_checkpoints]
    time_keys = set()
    for monitors in task_monitors:
        time_keys.update([key for key in monitors[0].timing_keys])
    time_keys = list(time_keys)
    task_filename = convert_name_to_filename(task_name)

    if "final" in task_filename:
        # TRAINING
        fig = plot_time(
            [monitors[0] for monitors in task_monitors],
            task_model_names,
            f"{task_filename}_compute-time",
            time_keys=time_keys,
            log=True,
            directory=save_directory,
        )
        fig = plot_epoch_wise(
            [monitors[0] for monitors in task_monitors],
            task_model_names,
            [retrieve_object(monitors[0], "total_loss") for monitors in task_monitors],
            f"{task_filename}_training-total-loss",
            title="Training Total Loss",
            ylabel="Average loss",
            log=True,
            directory=save_directory,
        )
        fig = plot_epoch_wise(
            [monitors[0] for monitors in task_monitors],
            task_model_names,
            [retrieve_object(monitors[0], "mean_loss") for monitors in task_monitors],
            f"{task_filename}_training-loss",
            title="Training Data Loss",
            ylabel="Average loss",
            log=True,
            directory=save_directory,
        )
        fig = plot_epoch_wise(
            [monitors[0] for monitors in task_monitors],
            task_model_names,
            [
                retrieve_object(monitors[0], "constraints_error")
                for monitors in task_monitors
            ],
            f"{task_filename}_training-constraints-error",
            title="Training Constraint Error",
            ylabel="Average constraint error",
            log=True,
            directory=save_directory,
        )
        fig = plot_epoch_wise_distribution(
            [monitors[0] for monitors in task_monitors],
            task_model_names,
            [
                retrieve_object(monitors[0], "constraints", absolute_value=False)
                for monitors in task_monitors
            ],
            f"{task_filename}_training-constraint-distribution",
            title="Training Distribution of Constraint Residual",
            ylabel="Constraint value",
            log=False,
            directory=save_directory,
        )
        fig = plot_epoch_wise_distribution(
            [monitors[0] for monitors in task_monitors],
            task_model_names,
            [
                retrieve_object(monitors[0], "constraints", absolute_value=True)
                for monitors in task_monitors
            ],
            f"{task_filename}_training-constraint-distribution-magnitude",
            title="Training Distribution of Magnitude of Constraint Residual",
            ylabel="Magnitude of constraint value",
            log=True,
            directory=save_directory,
        )

        # Inference
        monitors_to_plot = list()
        names_to_plot = list()
        data_to_plot = list()
        colors = list()
        line_styles = list()
        for i, (monitors, name) in enumerate(zip(task_monitors, task_model_names)):
            monitors_to_plot.extend([monitors[2], monitors[2]])
            names_to_plot.extend([f"{name} (Unprojected)", f"{name} (Projected)"])
            data_to_plot.extend(
                [
                    retrieve_object(monitors[2], "mean_loss", original=True),
                    retrieve_object(monitors[2], "mean_loss", final=True),
                ]
            )
            colors.extend([i, i])
            line_styles.extend([":", "--"])
        fig = plot_epoch_wise(
            monitors_to_plot,
            names_to_plot,
            data_to_plot,
            f"{task_filename}_testing-loss",
            colors=colors,
            line_styles=line_styles,
            title="Inference Data Loss",
            ylabel="Average loss",
            log=True,
            directory=save_directory,
        )

        monitors_to_plot = list()
        names_to_plot = list()
        data_to_plot = list()
        colors = list()
        line_styles = list()
        for i, (monitors, name) in enumerate(zip(task_monitors, task_model_names)):
            monitors_to_plot.extend([monitors[2], monitors[2]])
            names_to_plot.extend([f"{name} (Unprojected)", f"{name} (Projected)"])
            data_to_plot.extend(
                [
                    retrieve_object(monitors[2], "constraints_error", original=True),
                    retrieve_object(monitors[2], "constraints_error", final=True),
                ]
            )
            colors.extend([i, i])
            line_styles.extend([":", "--"])
        fig = plot_epoch_wise(
            monitors_to_plot,
            names_to_plot,
            data_to_plot,
            f"{task_filename}_testing-constraints-error",
            colors=colors,
            line_styles=line_styles,
            title="Inference Constraint Error",
            ylabel="Average constraint error",
            log=True,
            directory=save_directory,
        )
        # Maybe put the constraint distributions here, but I think it might be too much
        inputs = None
        outputs = None
        prediction_sets = list()
        names_to_plots = list()
        colors = list()
        line_styles = list()
        for i, (monitors, name) in enumerate(zip(task_monitors, task_model_names)):
            # Assumption: all models have the same test set
            if inputs is None:
                inputs = monitors[2].inputs
            if outputs is None:
                outputs = monitors[2].outputs
            prediction_sets.extend(
                [monitors[2].original_out[-1], monitors[2].final_out[-1]]
            )
            names_to_plot.extend([f"{name} (Unprojected)", f"{name} (Projected)"])
            colors.extend([i, i])
            line_styles.extend([":", "--"])
        fig = plot_model_predictions(
            inputs,
            outputs,
            prediction_sets,
            names_to_plot,
            f"{task_filename}_predictions",
            colors=colors,
            line_styles=line_styles,
            title="Model Predictions",
            directory=save_directory,
        )

        # Model-wise
        for monitors, model_name in zip(task_monitors, task_model_names):
            model_filename = convert_name_to_filename(model_name)

            fig = plot_epoch_wise_distribution(
                [monitors[0]],
                [model_name],
                [retrieve_object(monitors[0], "model_parameters", gradients=False)],
                f"{task_filename}_{model_filename}_parameter-distribution",
                title=f"Distribution of Parameter Values",
                ylabel="Parameter values",
                log=False,
                directory=save_directory,
            )
            fig = plot_epoch_wise_distribution(
                [monitors[0]],
                [model_name],
                [retrieve_object(monitors[0], "model_parameters", gradients=True)],
                f"{task_filename}_{model_filename}_gradient-distribution",
                title=f"Distribution of Parameter Gradients",
                ylabel="Parameter gradients",
                log=False,
                directory=save_directory,
            )
            fig = plot_epoch_wise_distribution(
                [monitors[2]],
                [model_name],
                [retrieve_object(monitors[2], "model_parameters", differences=True)],
                f"{task_filename}_{model_filename}_parameter-differences",
                title=f"Distribution of Parameter Differences During Inference",
                ylabel="Parameter differences",
                log=False,
                directory=save_directory,
            )
            # Plot process of projection
            inputs = monitors[2].inputs
            outputs = monitors[2].outputs
            prediction_sets = [monitors[2].original_out[-1]]
            names_to_plot = ["Unprojected"]
            line_styles = [":"]
            for i in range(len(monitors[2].all_out[-1])):
                prediction_sets.append(monitors[2].all_out[-1][i])
                names_to_plot.append(f"Iteration {i}")
                line_styles.append("--")
            prediction_sets.append(monitors[2].final_out[-1])
            names_to_plot.append("Projected")
            line_styles.append("-")
            colors = None
            fig = plot_model_predictions(
                inputs,
                outputs,
                prediction_sets,
                names_to_plot,
                f"{task_filename}_{model_filename}_projections",
                colors=colors,
                line_styles=line_styles,
                title=f"{model_name} Model Predictions",
                directory=save_directory,
            )
            # Plot predictions over training
            num_epochs = len(monitors[2].epoch)
            extra_idxs = np.power(1.5, np.arange(num_epochs)).astype(int)
            extra_idxs = extra_idxs[extra_idxs < num_epochs]
            limited_idxs = np.unique(np.r_[0, extra_idxs, num_epochs - 1])
            prediction_sets = list()
            names_to_plot = list()
            colors = list()
            line_styles = list()
            for i, idx in enumerate(limited_idxs):
                prediction_sets.extend(
                    [monitors[2].original_out[idx], monitors[2].final_out[idx]]
                )
                names_to_plot.extend(
                    [f"Epoch {idx+1} (Unprojected)", f"Epoch {idx+1} (Projected)"]
                )
                colors.extend([i, i])
                line_styles.extend([":", "--"])
            fig = plot_model_predictions(
                inputs,
                outputs,
                prediction_sets,
                names_to_plot,
                f"{task_filename}_{model_filename}_predictions",
                colors=colors,
                line_styles=line_styles,
                title=f"{model_name} Model Predictions",
                directory=save_directory,
            )

    else:  # this is probably a checkpoint group for a single run

        print(f"Nothing implemented for {task_name}")

#         fig = plot_constraints_distribution(
#             [[monitors[1]] for monitors in task_monitors],
#             task_model_names,
#             f"{task_filename}_constraint-distribution",
#             title="Distribution of Constraint Residual",
#             ylabel="Constraint value",
#             log=False,
#             directory=save_directory,
#             absolute_value=False,
#         )
#         fig = plot_constraints_distribution(
#             [[monitors[1]] for monitors in task_monitors],
#             task_model_names,
#             f"{task_filename}_constraint-distribution-magnitude",
#             title="Distribution of Magnitude of Constraint Residual",
#             ylabel="Magnitude of constraint value",
#             log=True,
#             directory=save_directory,
#             absolute_value=True,
#         )

#         inputs, outputs, predictions = get_model_predictions(task_checkpoints)
#         fig = plot_model_predictions(
#             inputs,
#             outputs,
#             predictions,
#             task_model_names,
#             f"{task_filename}_predictions",
#             title="Model Predictions",
#             directory=save_directory,
#         )

#     if (
#         "approximation" in task_filename
#         or "reduction" in task_filename
#         or "constrained" in task_filename
#     ):

#         idxs = np.argsort([checkpoint["epoch"] for checkpoint in task_checkpoints])
#         task_checkpoints = np.array(task_checkpoints)[idxs]
#         task_model_names = np.array(task_model_names)[idxs]
#         task_monitors = np.array(task_monitors)[idxs]

#         fig = plot_loss(
#             [monitors[1:] for monitors in task_monitors[-1:]],
#             [task_name],
#             f"{task_filename}_constrained-loss",
#             title=f"Total Loss for {task_name}",
#             directory=save_directory,
#             constrained=True,
#         )
#         fig = plot_loss(
#             [monitors[1:] for monitors in task_monitors[-1:]],
#             [task_name],
#             f"{task_filename}_loss",
#             title=f"Data Loss for {task_name}",
#             directory=save_directory,
#         )
#         fig = plot_constraints_distribution(
#             [monitors[1:] for monitors in task_monitors[-1:]],
#             [task_name],
#             f"{task_filename}_constraint-distribution",
#             title=f"Distribution of Constraint Residual for {task_name}",
#             ylabel="Constraint value",
#             log=False,
#             directory=save_directory,
#             absolute_value=False,
#         )
#         fig = plot_constraints_distribution(
#             [monitors[1:] for monitors in task_monitors[-1:]],
#             [task_name],
#             f"{task_filename}_constraint-distribution-magnitude",
#             title=f"Distribution of Magnitude of Constraint Residual for {task_name}",
#             ylabel="Magnitude of constraint value",
#             log=True,
#             directory=save_directory,
#             absolute_value=True,
#         )
#         fig = plot_parameters_distribution(
#             [monitors[:1] for monitors in task_monitors[-1:]],
#             [task_name],
#             f"{task_filename}_parameter-distribution",
#             title=f"Distribution of Parameter Values for {task_name}",
#             ylabel="Parameter values",
#             log=False,
#             directory=save_directory,
#             gradients=False,
#         )
#         fig = plot_parameters_distribution(
#             [monitors[:1] for monitors in task_monitors[-1:]],
#             [task_name],
#             f"{task_filename}_gradient-distribution",
#             title=f"Distribution of Parameter Gradients for {task_name}",
#             ylabel="Parameter gradients",
#             log=False,
#             directory=save_directory,
#             gradients=True,
#         )

#         #         for idx, diagnostic_name in enumerate(
#         #             ["LHS", "RHS", r"$\nabla_{\mathrm{inputs}}(\mathrm{outputs})$"]
#         #         ):
#         #             fig = plot_constraints_diagnostics(
#         #                 [monitors[1:] for monitors in task_monitors[-1:]],
#         #                 [task_name],
#         #                 f"{task_filename}_constraint-diagnostics{idx}",
#         #                 diagnostics_index=idx,
#         #                 title=f"{task_name} {diagnostic_name}",
#         #                 ylabel=f"Average {diagnostic_name}",
#         #                 directory=save_directory,
#         #             )

#         extra_idxs = np.power(1.5, np.arange(len(task_checkpoints))).astype(int)
#         extra_idxs = extra_idxs[extra_idxs < len(task_checkpoints)]
#         limited_idxs = np.unique(np.r_[0, extra_idxs, len(task_checkpoints) - 1])
#         limited_checkpoints = task_checkpoints[limited_idxs]

#         inputs, outputs, predictions = get_model_predictions(limited_checkpoints)
#         fig = plot_model_predictions(
#             inputs,
#             outputs,
#             predictions,
#             [f'Epoch {checkpoint["epoch"]}' for checkpoint in limited_checkpoints],
#             f"{task_filename}_predictions",
#             title=task_name,
#             directory=save_directory,
#         )

# # close all those figures
# plt.close("all")