In [None]:
import cv2
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from torch.utils.data import Dataset, TensorDataset
from torch.utils.data import DataLoader

from google.colab import drive
drive.mount('/content/drive')

In [None]:
orig_imgs = '/content/drive/MyDrive/small_orig_imgs'
blur_imgs = '/content/drive/MyDrive/small_blur_imgs'

In [None]:
# load the image and convert into
# numpy array
def create_array(folder_dir):
  arr = []
  oslist = sorted(os.listdir(orig_imgs))
  for image in oslist:
    img = Image.open(folder_dir + "/" + image)
    img = img.resize((64,128))
    data = np.asarray(img)
    arr.append(data)
  return np.array(arr)

In [None]:
x = create_array(orig_imgs)
y = create_array(blur_imgs)

tensor_x = torch.Tensor(x).permute(0, 3, 1, 2)
tensor_y = torch.Tensor(y).permute(0, 3, 1, 2)

my_dataset = TensorDataset(tensor_x,tensor_y)
my_dataloader = DataLoader(my_dataset)

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.model = nn.Sequential(
                  #input_shape=(3, 512,1024)
                  nn.Conv2d(3, 10, kernel_size=3, stride=2, padding=1),
                  nn.ReLU(),
                  nn.Conv2d(10, 25, kernel_size=3, stride=2, padding=1),
                  nn.ReLU(),
                  nn.Conv2d(25, 50, kernel_size=3, stride=2, padding=1)
                )
        self.linear = nn.Linear(8*16*50, ro)
        # 64 -> 32 -> 16 -> 8
        # 128 -> 64 -> 32 -> 16
    def forward(self, x):
        x = self.model(x)
        x = torch.flatten(x, start_dim=1)
        x = self.linear(x)
        return x

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.linear = nn.Linear(128, 8*16*3)
        self.model = nn.Sequential(
                  nn.ConvTranspose2d(3, 50, kernel_size=4, stride=2, padding=1),
                  nn.ReLU(),
                  nn.ConvTranspose2d(50, 25, kernel_size=4, stride=2, padding=1),
                  nn.ReLU(),
                  nn.ConvTranspose2d(25, 10, kernel_size=4, stride=2, padding=1),
                  nn.ReLU(),
                  nn.ConvTranspose2d(10, 3, kernel_size=3, padding=1)
                )
    def forward(self, x):
        x = self.linear(x)
        x = torch.reshape(x, (-1,3,16,8))
        x = self.model(x)
        return x

In [None]:
model1 = Encoder()
summary(model1, input_size = (3, 128, 64), batch_size=-1)

In [None]:
model2 = Decoder()
summary(model2, input_size = (1,128), batch_size=-1)

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
    def forward(self, x):
        return self.decoder(self.encoder(x))

In [None]:
#model = AutoEncoder()
model = AutoEncoder().to(device)

# Validation using MSE Loss function
loss_function = torch.nn.MSELoss()
 
# Using an Adam Optimizer with lr = 0.1
optimizer = torch.optim.Adam(model.parameters(),
                             lr = 1e-1,
                             weight_decay = 1e-8)

In [None]:
epochs = 20
losses = []
for epoch in range(epochs):
    print("epoch no:", epoch)
    for (image, _) in my_dataloader:

      # Output of Autoencoder
      deblurred_image = model(image)
       
      # Calculating the loss function
      loss = loss_function(deblurred_image, image)
       
      # The gradients are set to zero,
      # the gradient is computed and stored.
      # .step() performs parameter update
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()
       
      # Storing the losses in a list for plotting
      losses.append(loss)
 
# Defining the Plot Style
plt.xlabel('Iterations')
plt.ylabel('Loss')
 
for i in range(len(losses)):
  losses[i] = losses[i]
# Plotting the last 100 values
plt.plot(losses[-100:])

In [None]:
for i, item in enumerate(image):
  plt.imshow(item[0])
 
for i, item in enumerate(deblurred_image):
  plt.imshow(item[0])