In [None]:
import torch
from tqdm.notebook import trange

from constants import hue_range, sidelength, size_range
from dataset import generate_dataset
from grid import make_standard_grid
from image import get_images
from model import VAE
from training import train
from util import get_device
from vaewidgets import model_comparison

In [None]:
trainset_coords, valset_coords, trainset, valset = generate_dataset(
    size_range=size_range,
    hue_range=hue_range,
    valset_size_range=(0.6, 0.9),
    valset_hue_range=(0.4, 0.7),
    num_samples=2000,
)

In [None]:
trainset.shape, valset.shape

In [None]:
device = get_device()

In [None]:
standard_grid = make_standard_grid(size_range, hue_range)
imgs = get_images(sidelength, [tuple(pair) for pair in standard_grid.reshape(-1, 2).tolist()])
grid_x = (torch.from_numpy(imgs).float() / 255.0).to(device)

In [None]:
losses = []
processed_grids = []
vae = VAE(2).to(device)
num_epochs = 100
batch_size_train = 256
batch_size_val = 64
for i in trange(9):
    filename = f"vae_{i}.pth"
    train_losses, val_losses, model_processed_grids = train(
        device,
        trainset,
        valset,
        filename,
        num_epochs,
        batch_size_train,
        batch_size_val,
        grid_x,
        True,
    )
    processed_grids.append(model_processed_grids)
    losses.append((train_losses, val_losses))

In [None]:
import numpy as np

data = np.array(losses).astype(np.float32)

In [None]:
data.tofile("losses.bin")

In [None]:
data = np.fromfile("losses.bin", dtype=np.float32).reshape(9, num_epochs, 2)

In [None]:
data.shape, data.dtype

In [None]:
from util import compress_floats

all_processed_grids = np.array(processed_grids).reshape(9, num_epochs, 10, 10, 2)
with open("grids.bin", "wb") as f:
    f.write(compress_floats(all_processed_grids))

In [None]:
with open("grids.bin", "rb") as f:
    # Read min and max
    minval, maxval = np.fromfile(f, dtype="<f4", count=2)
    # Read the rest of the data
    all_processed_grids_loaded = np.fromfile(f, dtype=np.uint8).reshape(9, num_epochs, 10, 10, 2)

In [None]:
extent = maxval - minval
all_processed_grids_reconstructed = (
    minval + all_processed_grids_loaded.astype(np.float32) / 255.0 * extent
)

In [None]:
np.allclose(all_processed_grids, all_processed_grids_reconstructed, rtol=1e-01, atol=1e-01)

In [None]:
all_processed_grids.shape

In [None]:
with open("losses.bin", "rb") as losses_f, open("grids.bin", "rb") as grids_f:
    losses_bytes = losses_f.read()
    grids_bytes = grids_f.read()

model_comparison(losses_bytes, grids_bytes)