In [138]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os


if not os.path.exists('./img'):
    os.mkdir('./img')

In [139]:
class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 256),
            nn.ReLU(True),
            nn.Linear(256, 64),
            nn.ReLU(True))
        self.decoder = nn.Sequential(
            nn.Linear(64, 256),
            nn.ReLU(True),
            nn.Linear(256, 28 * 28),
            nn.Sigmoid())

    def forward(self, x):
        encoded = self.encoder(x)
        output = self.decoder(encoded)
        return output
    
def add_noise(img,sigma):
    noise = torch.normal(0,std=torch.ones_like(img)) * sigma
    return img + noise



In [140]:
with np.load('denoising-challenge-01-data.npz') as fh:
    training_images_clean = torch.tensor(fh['training_images_clean']).view(20000,28*28)
    validation_images_noisy = torch.tensor(fh['validation_images_noisy']).view(2000,28*28)
    validation_images_clean = torch.tensor(fh['validation_images_clean']).view(2000,28*28)
    test_images_noisy = torch.tensor(fh['test_images_noisy']).view(2000,28*28)
    

# TRAINING DATA: CLEAN
# 1. INDEX: IMAGE SERIAL NUMBER (20000)
# 2. INDEX: COLOR CHANNEL (1)
# 3/4. INDEX: PIXEL VALUE (28 x 28)
"""print(training_images_clean.shape, training_images_clean.dtype)"""
# VALIDATION DATA: CLEAN + NOISY
"""print(validation_images_clean.shape, validation_images_clean.dtype)
print(validation_images_noisy.shape, validation_images_noisy.dtype)"""
# TEST DATA: NOISY
"""print(test_images_noisy.shape, test_images_noisy.dtype)"""

#Daten ausgeben
#n=(0,19999)
"""n=19
test=training_images_clean"""
#n=(0,1999)
"""n=5
test=validation_images_noisy"""
#n=(0,1999)
"""n=1999
test=validation_images_clean"""
#n=(0,1999)
"""n=1999
test=test_images_noisy"""
"""plt.figure
plt.imshow(test[n].view(28,28),cmap='gray')
plt.show()"""

# TRAIN MODEL ON training_images_clean
num_epochs = 120
batch_size = 128
learning_rate = 1e-3

standard_deviation=(validation_images_noisy-validation_images_clean).std(dim=1)
sigma=standard_deviation.sum()/standard_deviation.shape[0]

autoencoder = AutoEncoder()
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=learning_rate, weight_decay=1e-5)
loss_func = nn.MSELoss()

train_loader = DataLoader(dataset=training_images_clean, batch_size=batch_size, shuffle=True)

for epoch in range(num_epochs):
    for data in train_loader:
        img = data
        noisy_img = add_noise(img,sigma)
        # ===================forward=====================
        output = autoencoder(noisy_img)
        loss = loss_func(output,img)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # ====================log========================

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, loss.data[0]))
    if epoch % 10 == 0:
        save_in = img.view(output.size(0), 1, 28, 28)
        save_out = output.view(output.size(0), 1, 28, 28)
        save_image(save_in, './img/image_{}_in.png'.format(epoch))
        save_image(save_out, './img/image_{}_out.png'.format(epoch))

torch.save(autoencoder.state_dict(), './autoencoder.pth')
print('Finished Training')
# CHECK YOUR MODEL USING (validation_images_clean, validation_images_noisy)
with torch.no_grad():
    output=autoencoder(validation_images_noisy)
    fehler = ((((output-validation_images_clean)**2).sum())**(1/2))/2000
    print('====> validation: Average loss: {:.4f}'.format(
        fehler))
save_in = validation_images_clean.view(validation_images_clean.size(0), 1, 28, 28)
save_out = output.view(output.shape[0],1,28,28)
save_image(save_in, './img/image_val_in.png')
save_image(save_out, './img/image_val_out.png')
# DENOISE IMAGES (test_images_clean) USING test_images_noisy
with torch.no_grad():
    test_images_clean=autoencoder(test_images_noisy)
save_in = test_images_noisy.view(test_images_noisy.shape[0], 1, 28, 28)
save_out = test_images_clean.view(test_images_clean.shape[0],1,28,28)
save_image(save_in, './img/image_denoised_in.png')
save_image(save_out, './img/image_denoised_out.png')
test_images_clean=test_images_clean.view(test_images_clean.shape[0],1,28,28).numpy()
# MAKE SURE THAT YOU HAVE THE RIGHT FORMAT
assert test_images_clean.ndim == 4
assert test_images_clean.shape[0] == 2000
assert test_images_clean.shape[1] == 1
assert test_images_clean.shape[2] == 28
assert test_images_clean.shape[3] == 28

# AND SAVE EXACTLY AS SHOWN BELOW
np.save('test_images_clean.npy', test_images_clean)



====> Epoch: 0 Average loss: 0.0397
====> Epoch: 1 Average loss: 0.0264
====> Epoch: 2 Average loss: 0.0233
====> Epoch: 3 Average loss: 0.0194
====> Epoch: 4 Average loss: 0.0173
====> Epoch: 5 Average loss: 0.0145
====> Epoch: 6 Average loss: 0.0145
====> Epoch: 7 Average loss: 0.0129
====> Epoch: 8 Average loss: 0.0117
====> Epoch: 9 Average loss: 0.0132
====> Epoch: 10 Average loss: 0.0118
====> Epoch: 11 Average loss: 0.0139
====> Epoch: 12 Average loss: 0.0140
====> Epoch: 13 Average loss: 0.0127
====> Epoch: 14 Average loss: 0.0105
====> Epoch: 15 Average loss: 0.0117
====> Epoch: 16 Average loss: 0.0111
====> Epoch: 17 Average loss: 0.0107
====> Epoch: 18 Average loss: 0.0108
====> Epoch: 19 Average loss: 0.0105
====> Epoch: 20 Average loss: 0.0102
====> Epoch: 21 Average loss: 0.0099
====> Epoch: 22 Average loss: 0.0110
====> Epoch: 23 Average loss: 0.0108
====> Epoch: 24 Average loss: 0.0114
====> Epoch: 25 Average loss: 0.0092
====> Epoch: 26 Average loss: 0.0096
====> Epoch



====> validation: Average loss: 0.0081


In [149]:
with torch.no_grad():
    output=autoencoder(validation_images_noisy)
    fehler = ((((output-validation_images_clean)**2).sum())**(1/2))/2000
    print('====> validation: Average loss: {:.4f}'.format(
        fehler))

====> validation: Average loss: 0.0564


In [146]:
4**(1/2)

2.0