In [None]:
import numpy as np
import torch
from tqdm.notebook import trange

from constants import hue_range, num_epochs, num_models, sidelength, size_range
from dataset import generate_dataset
from grid import make_standard_grid
from image import get_images
from model import VAE
from training import train
from util import compress_floats, expand_floats, get_device, onnx_export_to_files, stringify_coords
from vaewidgets import model_comparison

In [None]:
trainset_coords, valset_coords, trainset, valset = generate_dataset(
    size_range=size_range,
    hue_range=hue_range,
    valset_size_range=(0.6, 0.9),
    valset_hue_range=(0.4, 0.7),
    num_samples=2000,
)

In [None]:
trainset.shape, valset.shape

In [None]:
trainset_coords_x, trainset_coords_y = stringify_coords(trainset_coords)
valset_coords_x, valset_coords_y = stringify_coords(valset_coords)
trainset_images_bytes = trainset.numpy().tobytes()
valset_images_bytes = valset.numpy().tobytes()
with open("trainset_coords.json", "w") as f:
    f.write(f'{{"x": [{trainset_coords_x}], "y": [{trainset_coords_y}]}}')
with open("valset_coords.json", "w") as f:
    f.write(f'{{"x": [{valset_coords_x}], "y": [{valset_coords_y}]}}')
with open("trainset_images.bin", "wb") as f:
    f.write(trainset_images_bytes)
with open("valset_images.bin", "wb") as f:
    f.write(valset_images_bytes)

In [None]:
device = get_device()

In [None]:
standard_grid = make_standard_grid(size_range, hue_range)
imgs = get_images(sidelength, [tuple(pair) for pair in standard_grid.reshape(-1, 2).tolist()])
grid_x = (torch.from_numpy(imgs).float() / 255.0).to(device)

In [None]:
losses = []
processed_grids = []
vae = VAE(2).to(device)
batch_size_train = 256
batch_size_val = 64
for i in trange(num_models):
    filename = f"vae_{i}.pth"
    train_losses, val_losses, model_processed_grids = train(
        device,
        trainset,
        valset,
        filename,
        num_epochs,
        batch_size_train,
        batch_size_val,
        grid_x,
        True,
    )
    processed_grids.append(model_processed_grids)
    losses.append((train_losses, val_losses))

In [None]:
loss_data = np.array(losses).astype(np.float32)

In [None]:
loss_data.tofile("losses.bin")

In [None]:
assert np.array_equal(
    loss_data, np.fromfile("losses.bin", dtype=np.float32).reshape(num_models, 2, num_epochs)
)

In [None]:
all_processed_grids = np.array(processed_grids).reshape(num_models, num_epochs, 10, 10, 2)
with open("grids.bin", "wb") as f:
    f.write(compress_floats(all_processed_grids))

In [None]:
with open("grids.bin", "rb") as f:
    grids_reconstructed = expand_floats(f.read()).reshape(num_models, num_epochs, 10, 10, 2)

In [None]:
assert np.allclose(all_processed_grids, grids_reconstructed, rtol=1e-01, atol=1e-01)

In [None]:
all_processed_grids.shape

In [None]:
with open("losses.bin", "rb") as losses_f, open("grids.bin", "rb") as grids_f:
    losses_bytes = losses_f.read()
    grids_bytes = grids_f.read()

model_comparison(losses_bytes, grids_bytes)

In [None]:
trainset_coords[1]

In [None]:
trainset.dtype

In [None]:
for i in range(9):
    vae = VAE(2)
    vae.load_state_dict(torch.load(f"vae_{i}.pth"))
    vae.eval()
    onnx_export_to_files(vae.encoder, vae.decoder, f"vae_{i}_encoder.onnx", f"vae_{i}_decoder.onnx")