In [None]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import matplotlib.pyplot as plt

from dataset import ReconstructionDataset
from encoder import Encoder
from decoder import IMDecoder
from utils import create_coord_map, initialize_enc_dec

In [None]:
# Training parameters and device
DEV = torch.device("mps")
BATCH_SIZE = 64
Z_DIM = 32
LR = 0.001

# We can train progressively with increasing resolutions, but MNIST is simple enough that we can directly train at 28x28
PROGRESSIVE_TRAINING_RESOLUTIONS = [28]
TRAINING_EPOCHS_PER_RESOLUTION = [i * 2 for i in PROGRESSIVE_TRAINING_RESOLUTIONS]

dataset = ReconstructionDataset()
enc, dec = initialize_enc_dec(Z_DIM, DEV)
opt = optim.Adam(list(enc.parameters()) + list(dec.parameters()), lr = LR)
crit = nn.MSELoss()

In [None]:
# Main training loop
for resolution, epochs in zip(PROGRESSIVE_TRAINING_RESOLUTIONS, TRAINING_EPOCHS_PER_RESOLUTION):
    # Setting target resolution for dataset object and creating a new loader
    dataset.set_target_resolution(resolution)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    # Same coordinate map is used for all predictions of a resolution, so init outside loop
    coord_map = create_coord_map(resolution).to(DEV)
    for e in range(epochs):
        loop = tqdm(loader, total=len(loader), position=0)
        loop.set_description_str(f"Resolution: {resolution}x{resolution} | Epoch: {e}")
        for input_img, target_img in loop:
            opt.zero_grad()
            input_img, target_img = input_img.to(DEV), target_img.to(DEV)
            target_img = target_img.round()
            feature_vector = enc(input_img)
            predicted_img = dec(feature_vector, coord_map.unsqueeze(0).repeat(feature_vector.shape[0], 1, 1))
            loss = crit(predicted_img, target_img.view(target_img.shape[0], -1))
            loss.backward()
            opt.step()
            loop.set_postfix(loss = loss.item())

        # Output a single example each epoch for sanity check
        with torch.no_grad():
            sample_input = dataset[0][0].unsqueeze(0)
            features = enc(sample_input.to(DEV))
            inference_coord_map = create_coord_map(resolution)
            output = dec(features, inference_coord_map.unsqueeze(0).to(DEV)).view(resolution, resolution)

            fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(6, 2))
            ax = ax.flatten()
            ax[0].imshow(sample_input[0][0])
            ax[0].axis(False)
            ax[1].imshow(output.detach().cpu())
            ax[1].axis(False)
            plt.show()