In [1]:
import torch
device = torch.device('mps')

In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
from denoising_autoencoder import DenoisingAutoencoder, train

model = DenoisingAutoencoder()
train(model, train_loader, epochs=10, learning_rate=0.01, device=device)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

dataiter = iter(test_loader)
images, _ = next(dataiter)

n = 5
images = images[:n]
images = images.to(device)

def denormalize(images):
    return images * 0.5 + 0.5

noisy_images = model.add_noise(images)
model.eval()
with torch.no_grad():
    denoised_images = model(noisy_images)

images = denormalize(images).cpu().numpy()
noisy_images = denormalize(noisy_images).cpu().numpy()
denoised_images = denormalize(denoised_images).cpu().numpy()

# Transpose images from (N, C, H, W) to (N, H, W, C) for plotting
images = np.transpose(images, (0, 2, 3, 1))
noisy_images = np.transpose(noisy_images, (0, 2, 3, 1))
denoised_images = np.transpose(denoised_images, (0, 2, 3, 1))

plt.figure(figsize=(12, 4*n))
for index in range(n):
    # Original
    plt.subplot(n, 3, index * 3 + 1)
    plt.imshow(np.clip(images[index], 0., 1.))
    plt.axis('off')
    if index == 0:
        plt.title("Original")
    
    # Noisy
    plt.subplot(n, 3, index * 3 + 2)
    plt.imshow(np.clip(noisy_images[index], 0., 1.))
    plt.axis('off')
    if index == 0:
        plt.title("Noisy")
    
    # Denoised
    plt.subplot(n, 3, index * 3 + 3)
    plt.imshow(np.clip(denoised_images[index], 0., 1.))
    plt.axis('off')
    if index == 0:
        plt.title("Denoised")

plt.tight_layout()
plt.show()