In [None]:
import json
import os
import subprocess

import ipywidgets as widgets  # type: ignore
import matplotlib.pyplot as plt
import numpy as np
import onnxruntime as ort  # type: ignore
import torch

from constants import hue_range, num_epochs, num_models, sidelength, size_range, z_range
from grid import make_standard_grid
from image import get_images
from model import VAE
from util import expand_floats, get_device, read_as_base64
from vaewidgets import GridViewer, evolution, mapping, model_comparison

plt.ioff();

This notebook demonstrates functional equivalence between some Python and JS code, in particular PyTorch and ONNX models.

Here's some code to draw a grid:

In [None]:
px = 1 / plt.rcParams["figure.dpi"]  # Pixel in inches


def connect_rows(grid: np.ndarray) -> None:
    for i in range(10):
        for j in range(9):
            plt.plot(
                [grid[i][j][0], grid[i][j + 1][0]],
                [grid[i][j][1], grid[i][j + 1][1]],
                "k-",
                linewidth=0.2,
            )


def connect_columns(grid: np.ndarray) -> None:
    for j in range(10):
        for i in range(9):
            plt.plot(
                [grid[i][j][0], grid[i + 1][j][0]],
                [grid[i][j][1], grid[i + 1][j][1]],
                "k-",
                linewidth=0.2,
            )


def show_grid(
    out: widgets.Output, xlim: tuple[float, float], ylim: tuple[float, float], grid: np.ndarray
) -> None:
    with out:
        plt.subplots(figsize=(250 * px, 250 * px))
        plt.xlim(xlim)
        plt.ylim(ylim)
        connect_rows(grid)
        connect_columns(grid)
        plt.show()

Let's create the standard grid in Python.

In [None]:
standard_grid = make_standard_grid(size_range, hue_range)

And in JS via Node:

In [None]:
js_code = """
import { makeStandardGrid } from '../widgets/dist/grid.js';
import { sizeRange, hueRange } from '../widgets/dist/constants.js';
const grid = makeStandardGrid(sizeRange, hueRange);
console.log(JSON.stringify(grid));
"""

result = subprocess.run(
    ["node", "-e", js_code],
    capture_output=True,
    text=True,
    env={
        **os.environ,
        "FORCE_COLOR": "0",
        "NO_COLOR": "1",
    },
    check=True,
)

js_grid = np.array(json.loads(result.stdout)).round(3)

Check that they are equal:

In [None]:
assert np.all(standard_grid == js_grid)

Display in matplotlib as well as in our frontend code:

In [None]:
out = widgets.Output()
show_grid(out, size_range, hue_range, standard_grid)
widgets.HBox([out, GridViewer(size_range, hue_range, standard_grid.tolist())])

Generate image data:

In [None]:
imgs = get_images(sidelength, [tuple(pair) for pair in standard_grid.reshape(-1, 2).tolist()])
x = torch.from_numpy(imgs).float() / 255.0

Import the PyTorch model:

In [None]:
device = get_device()
vae = VAE(2).to(device)
vae.load_state_dict(torch.load("vae_0.pth"))
vae.eval();

Encode the images in PyTorch:

In [None]:
with torch.no_grad():
    mu, logvar = vae.encoder(x.to(device))
converted_grid = mu.view(10, 10, 2).cpu().numpy()

Display the encodings of the images generated by the parameters from the grid:

In [None]:
out = widgets.Output()
show_grid(out, z_range, z_range, converted_grid)
widgets.HBox([out, GridViewer(z_range, z_range, converted_grid.tolist())])

Run the ONNX encoder in Python (the result should look the same as above):

In [None]:
ort_sess = ort.InferenceSession("vae_0_encoder.onnx")
mu_onnx, mu_logvar = ort_sess.run(None, {"image": x.numpy()})
converted_grid_onnx = mu_onnx.reshape(10, 10, 2)
out = widgets.Output()
show_grid(out, z_range, z_range, converted_grid_onnx)
out

Display the mapping widget, which runs the ONNX model in JS (should show the same shape of the grid in z space again):

In [None]:
mapping(read_as_base64("vae_0_encoder.onnx"), read_as_base64("vae_0_decoder.onnx"))

Load losses and grid evolution data from training:

In [None]:
with open("losses.bin", "rb") as f:
    losses_bytes = f.read()

with open("grids.bin", "rb") as f:
    grids_bytes = f.read()

The grid shape should appear again here:

In [None]:
model_comparison(losses_bytes, grids_bytes)

And here:

In [None]:
grids = expand_floats(grids_bytes).reshape(num_models, num_epochs, 10, 10, 2)
loss_data = np.fromfile("losses.bin", dtype=np.float32).reshape(num_models, 2, num_epochs)

train_losses = loss_data[0, 0]
val_losses = loss_data[0, 1]
grid_data = grids[0]
evolution(train_losses, val_losses, grids)