In [23]:
import torch
from _utils_examples import fmnist_data, get_logpath
from backpack import extend, extensions

from cockpit import Cockpit, CockpitPlotter
from cockpit.utils.configuration import configuration
from cockpit.utils import schedules
from PIL import Image
from IPython.display import display, clear_output

from cockpit import quantities
from cockpit.utils.configuration import quantities_cls_for_configuration

In [2]:
def locate_json_log(testproblem, optimizer_class):
    """Locate json logfile."""
    RUN_DIR = os.path.join(HEREDIR, "results", testproblem, optimizer_class.__name__)
    RUN_PATTERN = os.path.join(RUN_DIR, "*/*__log.json")
    RUN_MATCH = glob.glob(RUN_PATTERN)
    assert len(RUN_MATCH) == 1, f"Found no or multiple files: {RUN_MATCH}"
    return RUN_MATCH[0]


In [3]:
testproblem = "mnist_logreg"
optimizer_class = torch.optim.SGD

In [12]:
# Build Fashion-MNIST classifier
data = fmnist_data(batch_size=32)
model = extend(torch.nn.Sequential(torch.nn.Flatten(), torch.nn.Linear(784, 10)))
loss_fn = extend(torch.nn.CrossEntropyLoss(reduction="mean"))
individual_loss_fn = torch.nn.CrossEntropyLoss(reduction="none")

# Create SGD Optimizer
opt = torch.optim.SGD(model.parameters(), lr=1e-2)

# Create Cockpit and a plotter
# Customize the tracked quantities and their tracking schedule
quantities = [
    quantities.GradNorm(schedules.linear(interval=1)),
    quantities.Distance(schedules.linear(interval=1)),
    quantities.UpdateSize(schedules.linear(interval=1)),
    quantities.HessMaxEV(schedules.linear(interval=3)),
    quantities.GradHist1d(schedules.linear(interval=10), bins=10),
]
cockpit = Cockpit(model.parameters(), quantities=quantities)
plotter = CockpitPlotter()

In [13]:
type(data)

torch.utils.data.dataloader.DataLoader

In [14]:
# Main training loop
max_steps, global_step = 50, 0
for inputs, labels in iter(data):
    opt.zero_grad()

    # forward pass
    outputs = model(inputs)
    loss = loss_fn(outputs, labels)
    losses = individual_loss_fn(outputs, labels)

    # backward pass
    with cockpit(
        global_step,
        extensions.DiagHessian(),  # Other BackPACK quantities can be computed as well
        info={
            "batch_size": inputs.shape[0],
            "individual_losses": losses,
            "loss": loss,
            "optimizer": opt,
        },
    ):
        loss.backward(create_graph=cockpit.create_graph(global_step))

    # optimizer step
    opt.step()
    global_step += 1

    if global_step % 10 == 0:
        plotter.plot(
            cockpit,
            savedir=get_logpath(),
            show_plot=False,
            save_plot=True,
            savename_append=str(global_step),
        )

    if global_step >= max_steps:
        break

# Write Cockpit to json file.
cockpit.write(get_logpath())

# Plot results from file
plotter.plot(
    get_logpath(),
    savedir=get_logpath(),
    show_plot=False,
    save_plot=True,
    savename_append="_final",
)

  Variable._execution_engine.run_backward(


[cockpit|plot] Saving figure in /data/github/deep_learning/pytorch/logfiles/cockpit_output/cockpit__primary__10.png
[cockpit|plot] Saving figure in /data/github/deep_learning/pytorch/logfiles/cockpit_output/cockpit__primary__20.png
[cockpit|plot] Saving figure in /data/github/deep_learning/pytorch/logfiles/cockpit_output/cockpit__primary__30.png
[cockpit|plot] Saving figure in /data/github/deep_learning/pytorch/logfiles/cockpit_output/cockpit__primary__40.png
[cockpit|plot] Saving figure in /data/github/deep_learning/pytorch/logfiles/cockpit_output/cockpit__primary__50.png
[cockpit] writing output to /data/github/deep_learning/pytorch/logfiles/cockpit_output.json
[cockpit|plot] Saving figure in /data/github/deep_learning/pytorch/logfiles/cockpit_output/cockpit__primary___final.png


In [16]:
print(get_logpath())

/data/github/deep_learning/pytorch/logfiles/cockpit_output


In [17]:
logpath = get_logpath()
savedir = os.path.dirname(logpath)

plotter = CockpitPlotter()

# regenerate plots
plotter._read_tracking_results(logpath)
track_events = list(plotter.tracking_data["iteration"])

frame_paths = []

In [20]:
plotter.tracking_data.keys()

Index(['iteration', 'Distance', 'GradHist1d', 'GradNorm', 'HessMaxEV',
       'UpdateSize'],
      dtype='object')

In [21]:
track_events

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49]

In [24]:
for idx, global_step in enumerate(track_events):
    print(f"Plotting {idx:05d}/{len(track_events):05d}")

    plotter.plot(
        logpath,
        show_plot=False,
        save_plot=False,
        block=False,
        show_log_iter=True,
        discard=global_step,
    )
    this_frame_path = os.path.join(savedir, f"animation_frame_{idx:05d}.png")
    plotter.fig.savefig(this_frame_path)
    frame_paths.append(this_frame_path)

frame, *frames = [Image.open(f) for f in frame_paths]

animation_savepath = os.path.join(savedir, "showcase.gif")

# Collect images and create Animation
frame.save(
    fp=animation_savepath,
    format="GIF",
    append_images=frames,
    save_all=True,
    duration=200,
    loop=0,
)

Plotting 00000/00050
Plotting 00001/00050
Plotting 00002/00050
Plotting 00003/00050
Plotting 00004/00050
Plotting 00005/00050
Plotting 00006/00050
Plotting 00007/00050
Plotting 00008/00050
Plotting 00009/00050
Plotting 00010/00050
Plotting 00011/00050
Plotting 00012/00050
Plotting 00013/00050
Plotting 00014/00050
Plotting 00015/00050
Plotting 00016/00050
Plotting 00017/00050
Plotting 00018/00050
Plotting 00019/00050
Plotting 00020/00050
Plotting 00021/00050
Plotting 00022/00050
Plotting 00023/00050
Plotting 00024/00050
Plotting 00025/00050
Plotting 00026/00050
Plotting 00027/00050
Plotting 00028/00050
Plotting 00029/00050
Plotting 00030/00050
Plotting 00031/00050
Plotting 00032/00050
Plotting 00033/00050
Plotting 00034/00050
Plotting 00035/00050
Plotting 00036/00050
Plotting 00037/00050
Plotting 00038/00050
Plotting 00039/00050
Plotting 00040/00050
Plotting 00041/00050
Plotting 00042/00050
Plotting 00043/00050
Plotting 00044/00050
Plotting 00045/00050
Plotting 00046/00050
Plotting 0004