In [108]:
import torch
import torchvision
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image
from torchvision.utils import save_image
import torchvision.datasets as datasets
import os

# Path
if not os.path.exists('./dc_img'):
    os.mkdir('./dc_img')
    
# To save images
def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 1, 28, 28)
    return x

# Parameters
learning_rate = 0.01
num_epochs = 10
batch_size = 32

# To transform to tensor and padding to 32*32
transforms = torchvision.transforms.Compose([torchvision.transforms.Pad((2, 2), 0), torchvision.transforms.ToTensor()])

# Dataset for training, validation and test sets as tensors
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms)
mnist_trainset, mnist_validationset = torch.utils.data.random_split(mnist_trainset, [50000, 10000])

# Data loader for train, test and validation sets
trainloader = torch.utils.data.DataLoader(mnist_trainset, batch_size=batch_size, num_workers=2, shuffle=True)
testloader = torch.utils.data.DataLoader(mnist_testset, batch_size=batch_size, num_workers=2, shuffle=True)
validationloader = torch.utils.data.DataLoader(mnist_validationset, batch_size=batch_size, num_workers=2, shuffle=True)

# Compute loss for test and validation
def Loss(dataLoader):
    sum_loss = 0
    for i, data in enumerate(dataloader):
        x, label = data
        inputs, labels = data
        outputs = model(inputs)
        
        # Error catching for partial batches
        dim = outputs.shape
        if dim[0] != batch_size:
            continue

        model.zero_grad()
        
#         inputs = inputs.view(batch_size, -1)
        loss = loss_func(outputs, inputs)
        
        loss.backward()
        optimizer.step()
        
        sum_loss = sum_loss + loss
    
    return(sum_loss)

class autoencoder(nn.Module):

    # Defining the network
    def __init__(self):
        super(autoencoder, self).__init__()
        
        # Defining the layer sizes
        self.conv1 = nn.Conv2d(1, 16, 5, stride=1)
        self.conv2 = nn.Conv2d(16, 8, 5, stride=1)
        self.conv3 = nn.Conv2d(8, 4, 2, stride=1)
        
        self.pool = nn.MaxPool2d(2, stride=2, return_indices=True)
        
        self.deconv1 = nn.ConvTranspose2d(4, 8, 2, stride=1)
        self.deconv2 = nn.ConvTranspose2d(8, 16, 5, stride=1)
        self.deconv3 = nn.ConvTranspose2d(16, 1, 5, stride=1)
        
        self.unpool = nn.MaxUnpool2d(2, stride=2)

    # Forward propagation
    def forward(self, x):
        
        # Encoder
        x = self.conv1(x)
        x, indices1 = self.pool(x)
        x = self.conv2(x)
        x, indices2 = self.pool(x)
        x = self.conv3(x)
        x, indices3 = self.pool(x)
        
        # Decoder
        x = self.unpool(x, indices3)
        x = self.deconv1(x)
        x = self.unpool(x, indices2)
        x = self.deconv2(x)
        x = self.unpool(x, indices1)
        x = self.deconv3(x)
        return x

# Getting our model
model = autoencoder()

# Defining optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss_func = nn.MSELoss()

# Starting Training
for epoch in range (0, num_epochs):
    epoch_loss = 0
    for i, data in enumerate(trainloader):
        inputs, labels = data
        outputs = model(inputs)
        
        # Error catching for partial batches
        dim = outputs.shape
        if dim[0] != batch_size:
            continue

        model.zero_grad()
        
        inputs = inputs.squeeze(1).view(batch_size, -1)
        outputs = outputs.view(batch_size, -1).float()
        
        loss = loss_func(outputs, inputs)
        
        loss.backward()
        optimizer.step()
        
        epoch_loss = epoch_loss + loss

        
    # Saving images
    if (epoch == 0 or epoch == 3 or epoch == 5):
        for i in range(1, 6):
            input_img, _ = mnist_trainset[i]
            print(input_img.shape)
            [z,r,c] = input_img.shape
            input_img = np.reshape(input_img, (batch_size, z, r, c))
            output_img = model(input_img)

            # Save input image just once
            if epoch == 1:
                pic = to_img(input_img.cpu().data)
                save_image(pic, './dc_img/input_image_{}.png'.format(i))

            # Save output image at epochs 1, 3, 5
            pic = to_img(output_img.cpu().data)
            save_image(pic, './dc_img/output_image_{}_{}.png'.format(i, epoch))


torch.Size([1, 32, 32])


ValueError: cannot reshape array of size 1024 into shape (32,1,32,32)