Tests a 2D image autoencoder.

In [None]:

import torch
import matplotlib.pyplot as plt
from tqdm import tqdm

from state_encoder_3d.models import CompNeRFImageEncoder, CNNImageDecoder
from state_encoder_3d.dataset import MNISTDataset

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")
print(f"Using device {device}")

In [None]:
def plot_output_ground_truth(model_output, ground_truth, resolution):
    fig, axes = plt.subplots(1, 2, figsize=(12, 6), squeeze=False)
    axes[0, 0].imshow(model_output.cpu().view(*resolution).detach().numpy())
    axes[0, 0].set_title("Trained MLP")
    axes[0, 1].imshow(ground_truth.cpu().view(*resolution).detach().numpy())
    axes[0, 1].set_title("Ground Truth")
       
    for i in range(2):
        axes[0, i].set_axis_off()

    plt.show()

In [None]:
# Use this as the dataloader when overfitting on a single image
def image_generator():
    dataset = MNISTDataset()
    while True:
        yield dataset[0]
data_generator = image_generator()

dataloader = torch.utils.data.DataLoader(MNISTDataset(), batch_size=512, shuffle=True)

In [None]:
latent_dim = 64
encoder = CompNeRFImageEncoder(out_ch=latent_dim, in_ch=1, resnet_out_dim=2048).to(device)
# NOTE: num_up is determined by the image resolution as we need to upsample to that resolution
decoder = CNNImageDecoder(in_ch=latent_dim, hidden_ch=128, out_ch=1, num_up=6).to(device)

In [None]:
# Loss
img2mse = lambda x, y: torch.mean((x - y) ** 2)

lr = 1e-4
encoder_optim = torch.optim.Adam(encoder.parameters(), lr=lr)
decoder_optim = torch.optim.Adam(decoder.parameters(), lr=lr)

In [None]:
num_steps = 5001
steps_til_summary = 500
for step in tqdm(range(num_steps)):
    image = next(iter(dataloader)).to(device)

    latent = encoder(image)
    predicted_image = decoder(latent)

    loss = img2mse(image, predicted_image)

    encoder_optim.zero_grad()
    decoder_optim.zero_grad()
    loss.backward()
    encoder_optim.step()
    decoder_optim.step()

    # Every so often, we want to show what our model has learned.
    # It would be boring otherwise!
    if not step % steps_til_summary:
        print(f"Step {step}: loss = {float(loss.detach().cpu()):.5f}")

        plot_output_ground_truth(predicted_image[0], image[0], resolution=(64, 64))
