# Denoising MNIST with U-Net

In this script, we are denoising mnist samples using segmentation. This is just a small usecase of the unet used later in DDPMs. The idea is to segment "written" (white) from "background" (black) data, while training on different noise scales and obtaining the ground truth by thresholding.
* class 0 (background): pixelvalue < 0.5
* class 1 (written digit): pixelvalue >= 0.5

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

transform = torchvision.transforms.transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
dataset = torchvision.datasets.MNIST('./../data', train=True, transform=transform, download=False)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

### Training Loop

In [None]:
from simple_unet import SimpleUNet
import matplotlib.pyplot as plt

denoiser = SimpleUNet(n_channels=1, n_classes=2)
optimizer = torch.optim.Adam(params=denoiser.parameters(), lr=1e-3)

for i, (batch, labels) in enumerate(dataloader):

    ground_truth = torch.round(batch)

    prediction = denoiser(batch)

    if i%100 == 0:
        imgs = torchvision.utils.make_grid(prediction).detach().cpu().numpy()
        fig, axes = plt.subplots(1, 2, figsize=(8, 4))
        for i, (img, title) in enumerate(zip(imgs,['background', 'handwritten'])):
            axes[i].imshow(img)
            axes[i].set_title(title)
            axes[i].axis('off')

    break

    
    

    

