# Import Pytorch

In [0]:
import torch
from torchvision import datasets, transforms
from torch.autograd import Variable
import torch.nn as nn
from torchvision.utils import save_image
from torch.utils.data import DataLoader
import os

# Import the MNIST dataset

In [0]:
os.makedirs("./data/mnist", exist_ok=True)

epochs = 100
batch_size = 16
learning_rate = 1e-3

transform = transforms.Compose([
    transforms.ToTensor()
])

train_data = datasets.MNIST("./data", transform=transform,
                           download=True)
data_loader = DataLoader(train_data,
                         batch_size=batch_size,
                         shuffle=True)

# Let's design the Autoencoder model

In [0]:
class AutoEncoder(nn.Module):
  def __init__(self):
    super(AutoEncoder, self).__init__()
    
    self.encoder = nn.Sequential(
        nn.Conv2d(1, 16, 3, stride=3, padding=1),
        nn.ReLU(True),
        nn.MaxPool2d(2, stride=2),
        nn.Conv2d(16, 8, 3, stride=2, padding=1),
        nn.ReLU(True),
        nn.MaxPool2d(2, stride=1)
    )
    
    self.decoder = nn.Sequential(
        nn.ConvTranspose2d(8, 16, 3, stride=2),
        nn.ReLU(True),
        nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1),
        nn.ReLU(True),
        nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1),
        nn.Tanh()
    )
  
  def forward(self, x):
    x = self.encoder(x)
    x = self.decoder(x)
    
    return x
  
model = AutoEncoder()

# Checking CUDA


In [0]:
cuda = True if torch.cuda.is_available else False

if cuda:
  model.cuda()

# Define loss function and optimizer

In [0]:
loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
                            weight_decay=1e-5)


# Some code for saving the images we generate

In [0]:
import matplotlib.pyplot as plt
import numpy as np

os.makedirs("./images", exist_ok=True)

# def imshow(img):
#   img = img / 2 + 0.5
#   np_img = img.numpy()
#   print(np_img.shape)
#   plt.imshow(np.transpose(np_img, (1,2,0)), cmap="gray")
#   plt.show()

def to_img(x):
  x = 0.5 * (x + 1)
  x = x.clamp(0,1)
  x = x.view(x.size(0), 1, 28, 28)
  return x

# Modify the dataset by removing a square in the middle

In [0]:
from PIL import Image
from PIL import ImageDraw

im = Image.new("L", (14,14))

def apply_space(img):
  
  tensor = torch.zeros((16,28,28))
  for i in range(img.size(0)):
    t1 = transforms.ToPILImage()
    pil_img = t1(img[i])
    pil_img.paste(im, (7,7))
    t2 = transforms.ToTensor()
    tensor[i,:,:] = t2(pil_img)
#     noise = torch.randn(tensor_img[i].size()) * 0.5
  tensor.unsqueeze_(1)
  return tensor

# Let's train!

In [0]:
for epoch in range(epochs):
  
  for data in data_loader:
    img, _ = data
    
    
    incomplete_img = apply_space(img)
    
    incomplete_img = Variable(incomplete_img).cuda()
    
    img = Variable(img).cuda()
    
    pred = model(incomplete_img)
    loss = loss_function(pred, img)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
  print("Epoch: %d, Loss: %f" % (epoch, loss.item()))
  
  if epoch % 1 == 0:
    original_imgs = to_img(img.cpu().data)
    incomplete_imgs = to_img(incomplete_img.cpu().data)
    pred_imgs = to_img(pred.cpu().data)
    total_images = torch.cat((original_imgs, incomplete_imgs,
                             pred_imgs), 2)
    save_image(total_images, "./images/%d.png" % epoch)
    
    

Epoch: 0, Loss: 0.037023
Epoch: 1, Loss: 0.038475
Epoch: 2, Loss: 0.037703
Epoch: 3, Loss: 0.036824
Epoch: 4, Loss: 0.034728
Epoch: 5, Loss: 0.039180
Epoch: 6, Loss: 0.036615
Epoch: 7, Loss: 0.034779
Epoch: 8, Loss: 0.030224
Epoch: 9, Loss: 0.038075
Epoch: 10, Loss: 0.031634
Epoch: 11, Loss: 0.037840
Epoch: 12, Loss: 0.033648
Epoch: 13, Loss: 0.037401
Epoch: 14, Loss: 0.035007
Epoch: 15, Loss: 0.031566
Epoch: 16, Loss: 0.038816
Epoch: 17, Loss: 0.025228
Epoch: 18, Loss: 0.031177
Epoch: 19, Loss: 0.035323
Epoch: 20, Loss: 0.034319
Epoch: 21, Loss: 0.035218
Epoch: 22, Loss: 0.033967
Epoch: 23, Loss: 0.038811
Epoch: 24, Loss: 0.033302
Epoch: 25, Loss: 0.029803
Epoch: 26, Loss: 0.031040
