In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torchvision.transforms.functional as TF
from PIL import Image
import os


In [2]:
# Dataset class for sharp and blurred images
class ImageDataset(Dataset):
    def __init__(self, blurred_dir, sharp_dir, transform=None):
        self.blurred_dir = blurred_dir
        self.sharp_dir = sharp_dir
        self.transform = transform
        self.filenames = os.listdir(blurred_dir)
        self.filenamess = os.listdir(sharp_dir)

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        blurred_path = os.path.join(self.blurred_dir, self.filenames[idx])
        sharp_path = os.path.join(self.sharp_dir, self.filenamess[idx])
        blurred_img = Image.open(blurred_path).convert("RGB")
        sharp_img = Image.open(sharp_path).convert("RGB")
        if self.transform:
            blurred_img = self.transform(blurred_img)
            sharp_img = self.transform(sharp_img)
        return blurred_img, sharp_img
    



In [3]:
# define the CNN architecture
class ImageSharpeningCNN(nn.Module):
    def __init__(self):
        super(ImageSharpeningCNN, self).__init__()
        # encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )
        # decoder
        self.decoder = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )
    # forward
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [4]:
# define training parameters, transofmation and dataloader

# training parameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ImageSharpeningCNN().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# dataset and dataloader
blurred_dir = "/blurred_images"
sharp_dir = "/original_images"
dataset = ImageDataset(blurred_dir, sharp_dir, transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

In [None]:
# train the model
epochs = 60
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for blurred_imgs, sharp_imgs in dataloader:
        blurred_imgs = blurred_imgs.to(device)
        sharp_imgs = sharp_imgs.to(device)
        
        # forward pass
        outputs = model(blurred_imgs)
        loss = criterion(outputs, sharp_imgs)
        
        # backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {running_loss/len(dataloader)}")

# save the model
torch.save(model.state_dict(), "/model_path_directory.pth")
print("Model saved")

In [None]:
import matplotlib.pyplot as plt
# inference
# load blurred image
def load_image(image_path, transform):
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0)
    return image

# visualize the deblurred image and the original image
def show_images(input_image, output_image):
    # convert to numpy
    input_image = input_image.squeeze(0).permute(1, 2, 0).cpu().numpy()
    output_image = output_image.squeeze(0).permute(1, 2, 0).cpu().numpy()
    # convert back to image format
    input_image = Image.fromarray((input_image * 255).astype("uint8"))
    output_image = Image.fromarray((output_image * 255).astype("uint8"))

    # show the images
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.title("Input Image")
    plt.axis("off")
    plt.imshow(input_image)
    plt.subplot(1, 2, 2)
    plt.title("Output Image")
    plt.axis("off")
    plt.imshow(output_image) 
    plt.show()   

# load the model
model = ImageSharpeningCNN().to(device)
model.load_state_dict(torch.load("/model_path_directory.pth"))
model.eval()

# preprocess the input image
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# path to the blulrred input image
input_image_path = "/input_image.jpg"
blurred_img = load_image(input_image_path, transform).to(device)

# pass the blurred image through the model
with torch.no_grad():
    sharpened_img = model(blurred_img)

# show the deblurred image and the original image
show_images(blurred_img.cpu(), sharpened_img.cpu())


In [None]:
# calculate metrics
import numpy as np
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import mean_squared_error as mse
from skimage.metrics import peak_signal_noise_ratio as psnr
import matplotlib.pyplot as plt
import cv2

sharp_image = cv2.imread("/original_image.jpg")
blurred_image = cv2.imread("/input_image.jpg")
output_image = cv2.imread("/output_image.jpg")

# resize images
sharp_image = cv2.resize(sharp_image, (256, 256))
blurred_image = cv2.resize(blurred_image, (256, 256))
output_image = cv2.resize(output_image, (256, 256))

# convert to RGB
sharp_image = cv2.cvtColor(sharp_image, cv2.COLOR_BGR2RGB)
blurred_image = cv2.cvtColor(blurred_image, cv2.COLOR_BGR2RGB)
output_image = cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)

# create figure
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
ax = axes.ravel()

# calculate mse and ssim and psnr
mse_sharp = mse(sharp_image, sharp_image)
ssim_sharp = ssim(sharp_image, sharp_image, channel_axis=-1)
psnr_sharp = psnr(sharp_image, sharp_image)
mse_blurred = mse(sharp_image, blurred_image)
ssim_blurred = ssim(sharp_image, blurred_image, channel_axis=-1)
psnr_blurred = psnr(sharp_image, blurred_image)
mse_output = mse(sharp_image, output_image)
ssim_output = ssim(sharp_image, output_image, channel_axis=-1)
psnr_output = psnr(sharp_image, output_image)

# plot images
ax[0].axis('off')
ax[0].imshow(sharp_image)
ax[0].set_title(f"Sharp Image\nMSE: {mse_sharp}\nSSIM: {ssim_sharp}\nPSNR: {psnr_sharp}")
ax[1].axis('off')
ax[1].imshow(blurred_image)
ax[1].set_title(f"Blurred Image\nMSE: {mse_blurred}\nSSIM: {ssim_blurred}\nPSNR: {psnr_blurred}")
ax[2].axis('off')
ax[2].imshow(output_image)
ax[2].set_title(f"Output Image\nMSE: {mse_output}\nSSIM: {ssim_output}\nPSNR: {psnr_output}")

# show figure
plt.show()

