In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch import linalg as LA
import torch.optim as optim
from torchvision import datasets, transforms
from types import SimpleNamespace
import matplotlib.pyplot as plt
import numpy as np
from src import AutoEncoder, SmoothSailing, kappa

In [7]:
args = SimpleNamespace(batch_size=32, test_batch_size=1000, epochs=10,
                       lr=0.0001, momentum=0.5, seed=1, log_interval=100, 
                            beta_end=0.1,
                            beta_mid=0.005,
                            end_layer_size=256,
                            mid_layer_size=32,
                            noise_level=0.5)
torch.manual_seed(args.seed)
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')

In [8]:
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                     transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,))
                        ])),
    batch_size=args.batch_size, shuffle=True, **kwargs)
    
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)

In [24]:
model = AutoEncoder(end_layer_size=args.end_layer_size, mid_layer_size=args.mid_layer_size).to(device)
model_reg = AutoEncoder(end_layer_size=args.end_layer_size, mid_layer_size=args.mid_layer_size).to(device)

In [25]:
m = torch.load('denoise_model.pt', map_location=torch.device('cpu'))
m_reg = torch.load('denoise_model_reg.pt', map_location=torch.device('cpu'))

In [27]:
model.load_state_dict(m)
model_reg.load_state_dict(m_reg)

<All keys matched successfully>

In [39]:
# plot denoised images
plt.figure(figsize=(60,10))

for i in range(5):
    data, train_labels = next(iter(train_loader))
    noisy_data = data[i][0] + args.noise_level * torch.randn(data[i][0].shape)
    output = model(noisy_data.view(-1, 28*28)).view(28, 28).cpu().detach().numpy()
    output_reg = model_reg(noisy_data.view(-1, 28*28)).view(28, 28).cpu().detach().numpy()

    plt.subplot(3, 18, i+1)
    plt.imshow(noisy_data.cpu().numpy(), cmap='gray')
    plt.axis('off')

    plt.subplot(3, 18, i+19)
    plt.imshow(output, cmap='gray')
    plt.axis('off')

    plt.subplot(3, 18, i+37)
    plt.imshow(output_reg, cmap='gray')
    plt.axis('off')

plt.subplots_adjust(wspace=0, hspace=0)
plt.tight_layout()
#plt.savefig('MNIST_denoising_softy.png')
plt.show()


	beta 0: 8.60%
	beta 0.001: 7.76%
	beta 0.01: 7.83%
	beta 0.1: 8.66%
	beta 1: 8.94%
	beta 0: 16.96%
	beta 0.001: 15.60%
	beta 0.01: 15.65%
	beta 0.1: 17.41%
	beta 1: 17.98%
	beta 0: 25.23%
	beta 0.001: 23.29%
	beta 0.01: 23.34%
	beta 0.1: 25.90%
	beta 1: 27.00%
	beta 0: 33.80%
	beta 0.001: 31.36%
	beta 0.01: 31.46%
	beta 0.1: 34.60%
	beta 1: 36.06%
	beta 0: 42.26%
	beta 0.001: 39.08%
	beta 0.01: 39.24%
	beta 0.1: 43.32%
	beta 1: 44.93%
	beta 0: 50.69%
	beta 0.001: 46.87%
	beta 0.01: 47.16%
	beta 0.1: 51.91%
	beta 1: 53.96%
	beta 0: 59.28%
	beta 0.001: 54.93%
	beta 0.01: 55.27%
	beta 0.1: 60.66%
	beta 1: 62.93%
	beta 0: 67.62%
	beta 0.001: 62.78%
	beta 0.01: 63.20%
	beta 0.1: 69.28%
	beta 1: 71.73%
	beta 0: 75.82%
	beta 0.001: 70.45%
	beta 0.01: 71.06%
	beta 0.1: 77.74%
	beta 1: 80.51%
	beta 0: 84.24%
	beta 0.001: 77.94%
	beta 0.01: 78.83%
	beta 0.1: 86.29%
	beta 1: 89.35%
