In [18]:
import torch
from _utils_examples import fmnist_data, get_logpath
from backpack import extend, extensions

from cockpit import Cockpit, CockpitPlotter
from cockpit.utils.configuration import configuration
from cockpit.utils import schedules
from PIL import Image
from IPython.display import display, clear_output

from cockpit import quantities
from cockpit.utils.configuration import quantities_cls_for_configuration

In [19]:
# Build Fashion-MNIST classifier
data = fmnist_data(batch_size=32)
model = extend(torch.nn.Sequential(torch.nn.Flatten(), torch.nn.Linear(784, 10)))
loss_fn = extend(torch.nn.CrossEntropyLoss(reduction="mean"))
individual_loss_fn = torch.nn.CrossEntropyLoss(reduction="none")

# Create SGD Optimizer
opt = torch.optim.SGD(model.parameters(), lr=1e-2)

# Create Cockpit and a plotter
cockpit = Cockpit(model.parameters(), quantities=configuration("full"))
plotter = CockpitPlotter()

In [22]:
# Main training loop
max_steps, global_step = 50, 0
for inputs, labels in iter(data):
    opt.zero_grad()

    # forward pass
    outputs = model(inputs)
    loss = loss_fn(outputs, labels)
    losses = individual_loss_fn(outputs, labels)

    # backward pass
    with cockpit(
        global_step,
        extensions.DiagHessian(),  # Other BackPACK quantities can be computed as well
        info={
            "batch_size": inputs.shape[0],
            "individual_losses": losses,
            "loss": loss,
            "optimizer": opt,
        },
    ):
        loss.backward(create_graph=cockpit.create_graph(global_step))

    # optimizer step
    opt.step()
    global_step += 1

    if global_step % 10 == 0:
        plotter.plot(
            cockpit,
            savedir=get_logpath(),
            show_plot=False,
            save_plot=True,
            savename_append="epoch__"+str(global_step),
        )

    if global_step >= max_steps:
        break

# Write Cockpit to json file.
cockpit.write(get_logpath())

# Plot results from file
plotter.plot(
    get_logpath(),
    savedir=get_logpath(),
    show_plot=False,
    save_plot=True,
    savename_append="_final",
)

[cockpit|plot] Saving figure in /data/github/deep_learning/pytorch/logfiles/cockpit_output/cockpit__primary__epoch__10.png
[cockpit|plot] Saving figure in /data/github/deep_learning/pytorch/logfiles/cockpit_output/cockpit__primary__epoch__20.png
[cockpit|plot] Saving figure in /data/github/deep_learning/pytorch/logfiles/cockpit_output/cockpit__primary__epoch__30.png
[cockpit|plot] Saving figure in /data/github/deep_learning/pytorch/logfiles/cockpit_output/cockpit__primary__epoch__40.png
[cockpit|plot] Saving figure in /data/github/deep_learning/pytorch/logfiles/cockpit_output/cockpit__primary__epoch__50.png
[cockpit] writing output to /data/github/deep_learning/pytorch/logfiles/cockpit_output.json
[cockpit|plot] Saving figure in /data/github/deep_learning/pytorch/logfiles/cockpit_output/cockpit__primary___final.png


In [36]:
import glob
logpath = get_logpath() + '/cockpit'
screen = 'primary'
pattern = os.path.splitext(logpath)[0] + f"__{screen}__epoch__*.png"

frame_paths = sorted(glob.glob(pattern))
frame, *frames = [Image.open(f) for f in frame_paths]

In [38]:
frames

[<PIL.PngImagePlugin.PngImageFile image mode=RGBA size=2160x1080 at 0x7FB7868E5A60>,
 <PIL.PngImagePlugin.PngImageFile image mode=RGBA size=2160x1080 at 0x7FB786D85BB0>,
 <PIL.PngImagePlugin.PngImageFile image mode=RGBA size=2160x1080 at 0x7FB7866ADCA0>,
 <PIL.PngImagePlugin.PngImageFile image mode=RGBA size=2160x1080 at 0x7FB7866ADEE0>]

In [33]:
logpath

'/data/github/deep_learning/pytorch/logfiles/cockpit_output'

In [35]:
os.path.splitext(logpath)[0] + f"__{screen}__epoch__*.png"

'/data/github/deep_learning/pytorch/logfiles/cockpit_output__primary__epoch__*.png'

In [39]:
plotter.build_animation(get_logpath() + '/cockpit')

[cockpit|animate] Saving GIF in /data/github/deep_learning/pytorch/logfiles/cockpit_output/cockpit__primary.gif
