In [None]:
from ml_aos.lightning import WaveNetSystem, DonutLoader
from ml_aos.dataloader import Donuts
import matplotlib.pyplot as plt
import torch
import numpy as np
from ml_aos.utils import get_root, convert_zernikes


In [None]:
root = get_root()


In [None]:
for train_donuts in DonutLoader(batch_size=20, shuffle=False).train_dataloader():
    break

for val_donuts in DonutLoader(batch_size=20, shuffle=False).val_dataloader():
    break

for test_donuts in DonutLoader(batch_size=20, shuffle=False).test_dataloader():
    break


In [None]:
def plot_model_predictions(versions: list, donut_set: str, ymax=0.5):
    fig, axes = plt.subplots(5, 4, figsize=(10, 6), dpi=120, constrained_layout=True)

    donuts = {"train": train_donuts, "val": val_donuts, "test": test_donuts}[donut_set]

    for i, ax in enumerate(axes.flatten()):
        ax.set(
            xticks=np.arange(4, 23, 4),
            xlim=(3.5, 22.5),
            ylim=(-ymax, ymax),
        )
        ax.axhline(0, c="silver", alpha=0.5, lw=1)
        ax.plot(
            np.arange(4, 23),
            convert_zernikes(donuts["zernikes"][i]),
            c="k",
            lw=1,
            ls="--",
            zorder=10,
            alpha=0.6,
        )
    for ax in axes[:-1].flatten():
        ax.set(xticklabels=[])
    for ax in axes[-1]:
        ax.set(xlabel="Noll index")
    for ax in axes[:, 1:].flatten():
        ax.set(yticklabels=[])
    for ax in axes[:, 0]:
        ax.set(ylabel='$\Delta$PSF (")')

    fig.suptitle(donut_set, fontsize=18)

    for v in np.atleast_1d(versions):
        # load the model
        ckpt_dir = root / "lightning_logs" / f"version_{v}" / "checkpoints"
        assert ckpt_dir.exists(), f"directory {ckpt_dir} does not exist."
        ckpt = list(ckpt_dir.glob("*"))[0]
        model = WaveNetSystem.load_from_checkpoint(ckpt)

        # predict zernikes
        zk_pred, _ = model.predict_step(donuts, None)

        # loop over axes and plot
        for zk, ax in zip(zk_pred, axes.flatten()):
            ax.plot(np.arange(4, 23), convert_zernikes(zk.detach()))


In [None]:
plot_model_predictions([0], "train")


In [None]:
plot_model_predictions([0], "val")


In [None]:
plot_model_predictions([0], "test")


Now let's test the exported model and the `model.forward` interface:

In [None]:
# load test data without transformations applied
test0 = Donuts("test", transform=False)


In [None]:
# load the exported model
mlFile = "/astro/store/epyc/users/jfc20/ml-aos/models/v0_2023-06-19_09:41:19.pt"
model = torch.jit.load(mlFile)
model.eval()


In [None]:
fig, axes = plt.subplots(5, 4, figsize=(10, 6), dpi=120, constrained_layout=True)

for i, ax in enumerate(axes.flatten()):
    donut = test0[i]
    zk_true = donut["zernikes"]
    with torch.no_grad():
        zk_pred = model(
            donut["image"][None, ...],
            donut["field_x"][None, ...] * 180 / torch.pi,
            donut["field_y"][None, ...] * 180 / torch.pi,
            donut["intrafocal"][None, ...],
            donut["band"][None, ...],
        )
    zk_pred = zk_pred.detach().cpu().squeeze()

    ax.plot(np.arange(4, 23), convert_zernikes(zk_true), c="k", ls="--")
    ax.plot(np.arange(4, 23), convert_zernikes(zk_pred / 1000))
    ax.axhline(0, c="silver", alpha=0.5, lw=1)
    ax.set(xticks=np.arange(4, 23, 4), ylim=(-0.5, 0.5))

for ax in axes[:-1].flatten():
    ax.set(xticklabels=[])
for ax in axes[-1]:
    ax.set(xlabel="Noll index")
for ax in axes[:, 1:].flatten():
    ax.set(yticklabels=[])
for ax in axes[:, 0]:
    ax.set(ylabel='$\Delta$PSF (")')
