## importing libraries

In [1]:
import os
import time
from IPython.display import clear_output
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow, show, figure
from PIL import Image
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import RandomCrop, ToTensor, Compose, RandomHorizontalFlip, RandomVerticalFlip, ToPILImage
from denoising_dataset import DenoisingDataset
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import transforms
from skimage import color, metrics

## import data sets

In [2]:
TRAIN_DATA_PATH = "data/train/"
VALIDATION_DATA_PATH = "data/val/"

### Define model architecture

In [3]:
class DnCNN(nn.Module):
    def __init__(self, depth=7, n_channels=16, image_channels=3, use_bnorm=True, kernel_size=3):
        super(DnCNN, self).__init__()
        kernel_size = 3
        padding = 1
        layers = []

        layers.append(nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True))
        layers.append(nn.ReLU(inplace=True))
        for _ in range(depth-2):
            layers.append(nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=False))
            layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum = 0.95))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=False))
        self.dncnn = nn.Sequential(*layers)
        self._initialize_weights()

    def forward(self, x):
        out = self.dncnn(x)
        return x - out

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.orthogonal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

### Set the model, dataset, loss function and optimizer

In [4]:
model = DnCNN()
dataset = DenoisingDataset(TRAIN_DATA_PATH)   #from denoising_datset.py file
criterion = nn.MSELoss(reduction="sum")
optimizer = Adam(model.parameters(), lr=0.01)

FileNotFoundError: [WinError 3] The system cannot find the path specified: 'data/train/'

###  model training


In [None]:
dataset = DenoisingDataset(TRAIN_DATA_PATH)
test_loader = DataLoader(dataset, batch_size=1, shuffle=False)

NUM_EPOCHS = 50
BATCH_SIZE = 16

# Learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)


In [None]:
for epoch_id in range(NUM_EPOCHS):
    model.train()
    total_loss = 0.0  # Initialize total loss for the epoch

    for iter_id, (input_images, target_images) in enumerate(test_loader):
        predicted_images = model(input_images)
        loss = criterion(predicted_images, target_images)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        print("\rEpoch {} Iteration {} Loss {}".format(epoch_id, iter_id, loss.item() / BATCH_SIZE), end="")

    # Validation
    model.eval()
    total_psnr = 0.0
    total_ssim = 0.0

    with torch.no_grad():
        for data in test_loader:
            inputs, targets = data
            outputs = model(inputs)

            # Convert tensors to numpy arrays
            img_true = transforms.ToPILImage()(targets.squeeze(0))
            img_pred = transforms.ToPILImage()(outputs.squeeze(0))

            # Convert to grayscale if needed
            img_true_gray = color.rgb2gray(img_true)
            img_pred_gray = color.rgb2gray(img_pred)

            psnr = metrics.peak_signal_noise_ratio(img_true_gray, img_pred_gray)
            ssim = metrics.structural_similarity(img_true_gray, img_pred_gray, data_range=img_true_gray.max() - img_true_gray.min())


            total_psnr += psnr
            total_ssim += ssim

    average_psnr = total_psnr / len(test_loader)
    average_ssim = total_ssim / len(test_loader)

    print("\n Avg. PSNR: {:.2f}, Avg. SSIM: {:.4f}".format(average_psnr, average_ssim))

    # Learning rate scheduling
    lr_scheduler.step()

    # Early stopping condition (you may need to customize this)
    if epoch_id > 10 and total_loss / len(test_loader) < 0.001:
        print("Early stopping as the loss is not improving.")
        break


# model evaluation

In [None]:
loss.item() / BATCH_SIZE

In [None]:
average_psnr

In [None]:
average_ssim

## testing the model with a Image

In [None]:
from IPython.display import display, HTML, clear_output
from matplotlib.pyplot import figure, imshow, show

def visualize_validation(model):
    model.eval()
    image_path = os.path.join(VALIDATION_DATA_PATH, "pier.png")
    original_image = Image.open(image_path)
    original_image = np.array(original_image).astype("float32") / 255.
    model_input = torch.from_numpy(original_image).permute(2, 0, 1).unsqueeze(0)

    with torch.no_grad():
        result = model(model_input)

    result_image = result[0].clamp(0, 1).permute(1, 2, 0).numpy()
    result_image = (result_image * 255).astype("uint8")
    
    stacked_images = np.zeros((original_image.shape[0], original_image.shape[1] * 2, original_image.shape[2]), dtype="uint8")
    stacked_images[:, :original_image.shape[1]] = (original_image * 255).astype("uint8")
    stacked_images[:, original_image.shape[1]:] = result_image
    
    # Display images with labels
    clear_output(wait=True)
    fig, ax = plt.subplots(figsize=(28, 28))
    ax.imshow(stacked_images)
    plt.axis('off')
    plt.show()


In [None]:
visualize_validation(model)