In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.models import vgg16
from PIL import Image
import numpy as np
import os

In [2]:
# UNet architecture
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=3, stride=1, padding=1),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [3]:
# Dataset
class SharpenDataset(Dataset):
    def __init__(self, blurred_dir, sharp_dir, transform=None):
        self.blurred_dir = blurred_dir
        self.sharp_dir = sharp_dir
        self.transform = transform
        self.image_names = os.listdir(blurred_dir)
        self.image_namess = os.listdir(sharp_dir)

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        blurred_path = os.path.join(self.blurred_dir, self.image_names[idx])
        sharp_path = os.path.join(self.sharp_dir, self.image_namess[idx])
        blurred_image = Image.open(blurred_path).convert('RGB')
        sharp_image = Image.open(sharp_path).convert('RGB')
        if self.transform:
            blurred_image = self.transform(blurred_image)
            sharp_image = self.transform(sharp_image)
        return blurred_image, sharp_image

In [4]:
# function to train model
def train_model(model, dataloader, criterion, optimizer, device, epochs=10):
    model.to(device)
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        for blurred, sharp in dataloader:
            blurred, sharp = blurred.to(device), sharp.to(device)
            optimizer.zero_grad()
            output = model(blurred)
            loss = criterion(output, sharp)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss/len(dataloader)}")

In [None]:
# Hyperparameters
batch_size = 16
learning_rate = 1e-4
num_epochs = 60

# Paths
blurred_dir = "/blurred_images"
sharp_dir = "/original_images"

# Transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

# Dataset and DataLoader
dataset = SharpenDataset(blurred_dir, sharp_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Model, Loss, Optimizer
model = UNet()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Train
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_model(model, dataloader, criterion, optimizer, device, num_epochs)

# Save model
torch.save(model.state_dict(), "unet_model.pt")
print("Model saved to unet_model.pt")


In [6]:
# function to load model
def load_model(model_path, device):
    model = UNet()
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    return model

# function to preprocess image
def preprocess_image(image_path, transform, device):
    image = Image.open(image_path).convert('RGB')
    input_image = transform(image).unsqueeze(0).to(device)  # Add batch dimension
    return input_image, image

# function to sharpen image
def sharpen_image(model, input_image):
    with torch.no_grad():
        output = model(input_image)
    output_image = torch.clamp(output.squeeze(0), 0, 1)  # Remove batch dimension and clip values
    return output_image

# function to postprocess image
def postprocess_image(output_tensor):
    output_image = transforms.ToPILImage()(output_tensor.cpu())
    return output_image


In [7]:
# function to display test images
import matplotlib.pyplot as plt
def main():
    model_path = "unet_model.pt"  # Trained model path
    input_image_path = "/input_blurred_image.jpg"
    output_image_path = "/output_image.jpg"

    # Device configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Preprocessing transformation
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])

    # Load the model
    model = load_model(model_path, device)

    # Preprocess the input image
    input_image, original_image = preprocess_image(input_image_path, transform, device)

    # Sharpen the image
    output_tensor = sharpen_image(model, input_image)

    # Postprocess the output image
    output_image = postprocess_image(output_tensor)

    # Save and display the result
    output_image.save(output_image_path)

    # Display original and sharpened images
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.title("Original Blurred Image")
    plt.imshow(original_image)
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.title("Sharpened Image")
    plt.imshow(output_image)
    plt.axis("off")

    plt.show()


In [None]:
main()

In [None]:
# calculate metrics
import numpy as np
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import mean_squared_error as mse
from skimage.metrics import peak_signal_noise_ratio as psnr
import matplotlib.pyplot as plt
import cv2

sharp_image = cv2.imread("/original_image.jpg")
blurred_image = cv2.imread("/input_blurred_image.jpg")
output_image = cv2.imread("/output_image.jpg")

# resize images
sharp_image = cv2.resize(sharp_image, (256, 256))
blurred_image = cv2.resize(blurred_image, (256, 256))
output_image = cv2.resize(output_image, (256, 256))

# convert to RGB
sharp_image = cv2.cvtColor(sharp_image, cv2.COLOR_BGR2RGB)
blurred_image = cv2.cvtColor(blurred_image, cv2.COLOR_BGR2RGB)
output_image = cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)

# create figure
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
ax = axes.ravel()

# calculate mse and ssim and psnr
mse_sharp = mse(sharp_image, sharp_image)
ssim_sharp = ssim(sharp_image, sharp_image, channel_axis=-1)
psnr_sharp = psnr(sharp_image, sharp_image)
mse_blurred = mse(sharp_image, blurred_image)
ssim_blurred = ssim(sharp_image, blurred_image, channel_axis=-1)
psnr_blurred = psnr(sharp_image, blurred_image)
mse_output = mse(sharp_image, output_image)
ssim_output = ssim(sharp_image, output_image, channel_axis=-1)
psnr_output = psnr(sharp_image, output_image)

# plot images
ax[0].axis('off')
ax[0].imshow(sharp_image)
ax[0].set_title(f"Sharp Image\nMSE: {mse_sharp}\nSSIM: {ssim_sharp}\nPSNR: {psnr_sharp}")
ax[1].axis('off')
ax[1].imshow(blurred_image)
ax[1].set_title(f"Blurred Image\nMSE: {mse_blurred}\nSSIM: {ssim_blurred}\nPSNR: {psnr_blurred}")
ax[2].axis('off')
ax[2].imshow(output_image)
ax[2].set_title(f"Output Image\nMSE: {mse_output}\nSSIM: {ssim_output}\nPSNR: {psnr_output}")

# show figure
plt.show()

