In [1]:
import os
import torch
import numpy as np
import torch.nn as nn
from PIL import Image
from torchvision import models
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchsummary import summary
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from skimage.metrics import structural_similarity as ssim

2024-11-02 10:43:19.268718: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /home/gokul/.mujoco/mujoco210/bin:/opt/ros/noetic/lib:/opt/ros/noetic/lib/x86_64-linux-gnu:/usr/lib/nvidia
2024-11-02 10:43:19.268755: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [2]:
class DenoisingDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None, patch_size=None, is_train=True):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform
        self.image_names = os.listdir(image_dir)
        self.patch_size = patch_size
        self.is_train = is_train
        self.num_patches_per_image = (256 // patch_size)**2

    def __len__(self):
        return len(self.image_names)*self.num_patches_per_image

    def extract_patches(self, image, patch_size):
        _, h, w = image.size()
        patches = []

        for i in range(0, h, patch_size):
            for j in range(0, w, patch_size):
                patch = image[:, i:i + patch_size, j:j + patch_size]
                if patch.size(1) == patch_size and patch.size(2) == patch_size:
                    patches.append(patch)

        return patches

    def __getitem__(self, idx):
        img_idx = idx // self.num_patches_per_image
        patch_idx = idx % self.num_patches_per_image

        img_name = self.image_names[img_idx]

        image = Image.open(os.path.join(
            self.image_dir, img_name)).convert('RGB')
        label = Image.open(os.path.join(
            self.label_dir, img_name)).convert('RGB')

        if self.transform:
            image = self.transform(image)
            label = self.transform(label)

        if self.is_train:
            image_patches = self.extract_patches(image, self.patch_size)
            label_patches = self.extract_patches(label, self.patch_size)
            return image_patches[patch_idx], label_patches[patch_idx]

        return image, label

In [4]:
image_size = 256

Transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
])

train_image_dir = '../Project/dataset/data/Train'
train_label_dir = '../Project/dataset/label/Train'
val_image_dir = '../Project/dataset/data/Val'
val_label_dir = '../Project/dataset/label/Val'

trainDataset = DenoisingDataset(
    train_image_dir, train_label_dir, Transform, patch_size=64, is_train=True)
valDataset = DenoisingDataset(
    val_image_dir, val_label_dir, Transform, patch_size=256, is_train=False)

batchSize = 16
trainLoader = DataLoader(trainDataset, batch_size=batchSize, shuffle=True)
valLoader = DataLoader(valDataset, batch_size=batchSize, shuffle=True)

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
print(len(trainLoader))

932


In [6]:
class EAM(nn.Module):
    def __init__(self):
        super(EAM, self).__init__()

        self.conv1 = nn.Conv2d(
            in_channels=64, out_channels=64, kernel_size=3, dilation=1, padding=1)
        self.conv2 = nn.Conv2d(
            in_channels=64, out_channels=64, kernel_size=3, dilation=2, padding=2)

        self.conv3 = nn.Conv2d(
            in_channels=64, out_channels=64, kernel_size=3, dilation=3, padding=3)
        self.conv4 = nn.Conv2d(
            in_channels=64, out_channels=64, kernel_size=3, dilation=4, padding=4)

        self.conv5 = nn.Conv2d(
            in_channels=128, out_channels=64, kernel_size=3, padding=1)

        self.conv6 = nn.Conv2d(
            in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.conv7 = nn.Conv2d(
            in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.conv8 = nn.Conv2d(
            in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.conv9 = nn.Conv2d(
            in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.conv10 = nn.Conv2d(
            in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.conv11 = nn.Conv2d(
            in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.conv12 = nn.Conv2d(
            in_channels=64, out_channels=64, kernel_size=3, padding=1)

        self.norm = nn.LayerNorm([64, 256, 256])

        for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6,
                      self.conv7, self.conv8, self.conv9, self.conv10, self.conv11, self.conv12]:
            nn.init.xavier_uniform_(layer.weight)

    def forward(self, image):
        x1 = F.silu(self.conv1(image))
        x1 = F.silu(self.conv2(x1))

        x2 = F.silu(self.conv3(image))
        x2 = F.silu(self.conv4(x2))

        x1_x2 = torch.cat([x1, x2], dim=1)
        x3 = F.silu(self.conv5(x1_x2))
        add1 = image + x3

        x4 = F.silu(self.conv6(add1))
        x4 = self.conv7(x4)
        add2 = x4 + add1
        add2 = F.silu(add2)

        x5 = F.silu(self.conv8(add2))
        x5 = F.silu(self.conv9(x5))
        x5 = self.conv10(x5)

        add3 = add2 + x5
        add3 = F.silu(add3)

        gap = F.adaptive_avg_pool2d(add3, (1, 1))
        x6 = F.silu(self.conv11(gap))
        x6 = torch.sigmoid(self.conv12(x6))

        mul = x6 * add3
        output = image + mul

        return output


class RIDNet(nn.Module):
    def __init__(self):
        super(RIDNet, self).__init__()

        self.conv1 = nn.Conv2d(
            in_channels=3, out_channels=64, kernel_size=3, padding=1)
        self.eam = nn.Sequential(
            EAM(),
            EAM(),
            EAM(),
            EAM()
        )
        self.conv2 = nn.Conv2d(
            in_channels=64, out_channels=3, kernel_size=3, padding=1)

        nn.init.xavier_uniform_(self.conv1.weight)
        nn.init.xavier_uniform_(self.conv2.weight)

    def forward(self, image):
        x = self.conv1(image)
        x = self.eam(x)
        x = self.conv2(x)

        output = image + x

        return output

In [7]:
os.makedirs('saved_models', exist_ok=True)
writer = SummaryWriter(log_dir='logs')

In [8]:
model = RIDNet()
# model = model.to(device)

X, y = next(iter(valLoader))
output = model(X[:8,:,:,:])
# output = model(X.to(device))
print(output.shape)

# summary(model, (3, 256, 256), device=device)

torch.Size([8, 3, 256, 256])


In [9]:
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
# Define loss function, optimizer, and scheduler
model_path = "/home/keerthi/KLA project/ridnet3_epoch_16patch29oct.pth"

model.load_state_dict(torch.load(model_path))
print("Model Loaded Successfully !")
model.to(device)
lossfn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)

# Training parameters
epochs = 40
trainLosses, valLosses = [], []

# Training Loop
for epoch in range(epochs):
    model.train()
    runningLoss = 0.0

    # Use tqdm for batch loop progress
    with tqdm(enumerate(trainLoader), total=len(trainLoader), desc=f"Epoch {epoch+1}/{epochs}", unit="batch") as pbar:
        for batchIdx, (images, labels) in pbar:
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            yHat = model(images)
            l1_loss = lossfn(yHat, labels)
            total_loss = l1_loss

            # Backpropagation and optimization step
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            # Log progress
            writer.add_scalar('Loss/Train', total_loss.item(),
                              epoch*len(trainLoader)+batchIdx)
            runningLoss += total_loss.item()

            # Update tqdm progress bar with current loss
            pbar.set_postfix({'loss': total_loss.item()})

    # Scheduler step at the end of the epoch
    scheduler.step()

    # Calculate average training loss for the epoch
    avg_train_loss = runningLoss / len(trainLoader)
    trainLosses.append(avg_train_loss)
    writer.add_scalar('Training Loss', avg_train_loss, epoch)

    # Print and log the loss for the epoch
    print(f'Epoch {epoch+1}/{epochs}, Training Loss: {avg_train_loss: .5f}')

    # Save model weights every 5 epochs
    if (epoch + 1) % 1 == 0:
        torch.save(model.state_dict(), f"ridnet3_epoch_16patch29oct.pth")
        print(f"Model weights saved at epoch {epoch+1}")

# Final save of the model weights after training
torch.save(model.state_dict(), "ridnet3_model.pth")
print("Model Weights Saved Successfully")

# Close the writer to ensure everything is written to disk
writer.close()

  model.load_state_dict(torch.load(model_path))


Model Loaded Successfully !


Epoch 1/40: 100%|██████████| 932/932 [15:59<00:00,  1.03s/batch, loss=0.00085] 


Epoch 1/40, Training Loss:  0.00097
Model weights saved at epoch 1


Epoch 2/40: 100%|██████████| 932/932 [16:00<00:00,  1.03s/batch, loss=0.000629]


Epoch 2/40, Training Loss:  0.00097
Model weights saved at epoch 2


Epoch 3/40: 100%|██████████| 932/932 [15:57<00:00,  1.03s/batch, loss=0.00112] 


Epoch 3/40, Training Loss:  0.00097
Model weights saved at epoch 3


Epoch 4/40: 100%|██████████| 932/932 [15:58<00:00,  1.03s/batch, loss=0.00126] 


Epoch 4/40, Training Loss:  0.00097
Model weights saved at epoch 4


Epoch 5/40: 100%|██████████| 932/932 [15:54<00:00,  1.02s/batch, loss=0.000924]


Epoch 5/40, Training Loss:  0.00097
Model weights saved at epoch 5


Epoch 6/40: 100%|██████████| 932/932 [15:56<00:00,  1.03s/batch, loss=0.000641]


Epoch 6/40, Training Loss:  0.00097
Model weights saved at epoch 6


Epoch 7/40: 100%|██████████| 932/932 [15:54<00:00,  1.02s/batch, loss=0.000662]


Epoch 7/40, Training Loss:  0.00097
Model weights saved at epoch 7


Epoch 8/40: 100%|██████████| 932/932 [15:52<00:00,  1.02s/batch, loss=0.000831]


Epoch 8/40, Training Loss:  0.00097
Model weights saved at epoch 8


Epoch 9/40: 100%|██████████| 932/932 [15:51<00:00,  1.02s/batch, loss=0.000658]


Epoch 9/40, Training Loss:  0.00097
Model weights saved at epoch 9


Epoch 10/40: 100%|██████████| 932/932 [15:53<00:00,  1.02s/batch, loss=0.00118] 


Epoch 10/40, Training Loss:  0.00097
Model weights saved at epoch 10


Epoch 11/40: 100%|██████████| 932/932 [15:53<00:00,  1.02s/batch, loss=0.00125] 


Epoch 11/40, Training Loss:  0.00097
Model weights saved at epoch 11


Epoch 12/40: 100%|██████████| 932/932 [15:52<00:00,  1.02s/batch, loss=0.000963]


Epoch 12/40, Training Loss:  0.00096
Model weights saved at epoch 12


Epoch 13/40: 100%|██████████| 932/932 [15:54<00:00,  1.02s/batch, loss=0.00131] 


Epoch 13/40, Training Loss:  0.00096
Model weights saved at epoch 13


Epoch 14/40: 100%|██████████| 932/932 [15:49<00:00,  1.02s/batch, loss=0.00152] 


Epoch 14/40, Training Loss:  0.00096
Model weights saved at epoch 14


Epoch 15/40: 100%|██████████| 932/932 [15:51<00:00,  1.02s/batch, loss=0.00117] 


Epoch 15/40, Training Loss:  0.00096
Model weights saved at epoch 15


Epoch 16/40: 100%|██████████| 932/932 [15:49<00:00,  1.02s/batch, loss=0.00133] 


Epoch 16/40, Training Loss:  0.00096
Model weights saved at epoch 16


Epoch 17/40: 100%|██████████| 932/932 [15:50<00:00,  1.02s/batch, loss=0.000686]


Epoch 17/40, Training Loss:  0.00096
Model weights saved at epoch 17


Epoch 18/40: 100%|██████████| 932/932 [15:50<00:00,  1.02s/batch, loss=0.000795]


Epoch 18/40, Training Loss:  0.00096
Model weights saved at epoch 18


Epoch 19/40: 100%|██████████| 932/932 [15:50<00:00,  1.02s/batch, loss=0.00126] 


Epoch 19/40, Training Loss:  0.00096
Model weights saved at epoch 19


Epoch 20/40: 100%|██████████| 932/932 [15:51<00:00,  1.02s/batch, loss=0.000722]


Epoch 20/40, Training Loss:  0.00096
Model weights saved at epoch 20


Epoch 21/40: 100%|██████████| 932/932 [15:52<00:00,  1.02s/batch, loss=0.000796]


Epoch 21/40, Training Loss:  0.00096
Model weights saved at epoch 21


Epoch 22/40: 100%|██████████| 932/932 [15:51<00:00,  1.02s/batch, loss=0.00104] 


Epoch 22/40, Training Loss:  0.00095
Model weights saved at epoch 22


Epoch 23/40: 100%|██████████| 932/932 [15:47<00:00,  1.02s/batch, loss=0.00106] 


Epoch 23/40, Training Loss:  0.00095
Model weights saved at epoch 23


Epoch 24/40: 100%|██████████| 932/932 [15:49<00:00,  1.02s/batch, loss=0.000581]


Epoch 24/40, Training Loss:  0.00095
Model weights saved at epoch 24


Epoch 25/40: 100%|██████████| 932/932 [15:48<00:00,  1.02s/batch, loss=0.00161] 


Epoch 25/40, Training Loss:  0.00095
Model weights saved at epoch 25


Epoch 26/40: 100%|██████████| 932/932 [15:47<00:00,  1.02s/batch, loss=0.000914]


Epoch 26/40, Training Loss:  0.00095
Model weights saved at epoch 26


Epoch 27/40: 100%|██████████| 932/932 [15:46<00:00,  1.02s/batch, loss=0.00104] 


Epoch 27/40, Training Loss:  0.00095
Model weights saved at epoch 27


Epoch 28/40: 100%|██████████| 932/932 [15:49<00:00,  1.02s/batch, loss=0.00161] 


Epoch 28/40, Training Loss:  0.00095
Model weights saved at epoch 28


Epoch 29/40: 100%|██████████| 932/932 [15:46<00:00,  1.02s/batch, loss=0.000685]


Epoch 29/40, Training Loss:  0.00095
Model weights saved at epoch 29


Epoch 30/40: 100%|██████████| 932/932 [15:51<00:00,  1.02s/batch, loss=0.00117] 


Epoch 30/40, Training Loss:  0.00095
Model weights saved at epoch 30


Epoch 31/40: 100%|██████████| 932/932 [15:47<00:00,  1.02s/batch, loss=0.000989]


Epoch 31/40, Training Loss:  0.00095
Model weights saved at epoch 31


Epoch 32/40: 100%|██████████| 932/932 [15:52<00:00,  1.02s/batch, loss=0.000972]


Epoch 32/40, Training Loss:  0.00095
Model weights saved at epoch 32


Epoch 33/40: 100%|██████████| 932/932 [15:49<00:00,  1.02s/batch, loss=0.000702]


Epoch 33/40, Training Loss:  0.00095
Model weights saved at epoch 33


Epoch 34/40: 100%|██████████| 932/932 [15:52<00:00,  1.02s/batch, loss=0.00104] 


Epoch 34/40, Training Loss:  0.00094
Model weights saved at epoch 34


Epoch 35/40: 100%|██████████| 932/932 [16:03<00:00,  1.03s/batch, loss=0.00108] 


Epoch 35/40, Training Loss:  0.00094
Model weights saved at epoch 35


Epoch 36/40: 100%|██████████| 932/932 [15:58<00:00,  1.03s/batch, loss=0.00147] 


Epoch 36/40, Training Loss:  0.00094
Model weights saved at epoch 36


Epoch 37/40: 100%|██████████| 932/932 [15:56<00:00,  1.03s/batch, loss=0.000994]


Epoch 37/40, Training Loss:  0.00094
Model weights saved at epoch 37


Epoch 38/40: 100%|██████████| 932/932 [15:55<00:00,  1.03s/batch, loss=0.00112] 


Epoch 38/40, Training Loss:  0.00094
Model weights saved at epoch 38


Epoch 39/40: 100%|██████████| 932/932 [16:10<00:00,  1.04s/batch, loss=0.00144] 


Epoch 39/40, Training Loss:  0.00094
Model weights saved at epoch 39


Epoch 40/40: 100%|██████████| 932/932 [15:55<00:00,  1.03s/batch, loss=0.000914]


Epoch 40/40, Training Loss:  0.00094
Model weights saved at epoch 40
Model Weights Saved Successfully


In [9]:
model_path = "../Project/saved_models/ridnet3_epoch_16patch29oct.pth"

model.load_state_dict(torch.load(model_path))
print("Model Loaded Successfully !")
model = model.to('cpu')

  model.load_state_dict(torch.load(model_path))


Model Loaded Successfully !


In [13]:
import torch
from torchmetrics import PeakSignalNoiseRatio
from skimage.metrics import structural_similarity as ssim
import numpy as np

# Initialize PSNR metric
model = model.to('cpu')
psnr_metric = PeakSignalNoiseRatio().to('cpu')

# Accumulators for PSNR and SSIM values
total_psnr = 0
total_ssim = 0
num_batches = 0

model.eval()  # Set model to evaluation mode
with torch.no_grad():  # Disable gradient calculation
    for Xval, yVal in valLoader:
        # Ensure data is on the CPU
        Xval, yVal = Xval.cpu(), yVal.cpu()

        # Get model prediction
        yval_pred = model(Xval).cpu()  # Keep everything on CPU

        # Calculate PSNR for the batch
        total_psnr += psnr_metric(yval_pred, yVal).item()

        # Calculate SSIM for each image in the batch, using win_size=3 for small images
        batch_ssim_values = []
        for i in range(yVal.size(0)):
            # Ensure conversion to NumPy
            pred_img = np.array(yval_pred[i].squeeze().cpu())
            true_img = np.array(yVal[i].squeeze().cpu())
            # Calculate SSIM, ensuring compatibility with data_range and win_size
            ssim_value = ssim(
                true_img, pred_img, data_range=true_img.max() - true_img.min(), win_size=3)
            batch_ssim_values.append(ssim_value)

        # Average SSIM for the batch and accumulate
        batch_ssim = np.mean(batch_ssim_values)
        total_ssim += batch_ssim
        num_batches += 1

# Calculate average PSNR and SSIM across all batches
average_psnr = total_psnr / num_batches
average_ssim = total_ssim / num_batches

print(f"Average PSNR: {average_psnr:.4f}")
print(f"Average SSIM: {average_ssim:.4f}")

Average PSNR: 29.5089
Average SSIM: 0.8579


In [10]:
valLoader = DataLoader(valDataset, batch_size=batchSize, num_workers=0)

In [11]:
import matplotlib.pyplot as plt

# Get a batch from the validation loader and move to CPU
Xval, yVal = next(iter(valLoader))
Xval = Xval.to('cpu')
yVal = yVal.to('cpu')
model = model.to('cpu')

# Run the model on the entire batch
yvalPred = model(Xval)

# Set up a figure with 3 columns to display the images side-by-side
_, ax = plt.subplots(ncols=3, nrows=1, figsize=(15, 5))

# Select the specific sample within the batch (0 to 15)
select_sample = 0  # Adjust this to select a different sample within the batch range

# Display Noisy Image (input)
ax[0].imshow(Xval[select_sample].detach().numpy().transpose(1, 2, 0))
ax[0].set_title("Noisy Image")

# Display Clean Image (target)
ax[1].imshow(yVal[select_sample].detach().numpy().transpose(1, 2, 0))
ax[1].set_title("Clean Image")

# Display Restored Image (model output)
ax[2].imshow(yvalPred[select_sample].detach().numpy().transpose(1, 2, 0))
ax[2].set_title("Restored Image")

# Show the plots
plt.show()

: 

### PSNR

In [13]:
def PSNR(ground_truth, predicted_image):
    mse = F.mse_loss(predicted_image, ground_truth)
    if mse == 0:
        return 100
    max_pixel = 1.0
    psnr = 20 * torch.log10(torch.tensor(max_pixel)/torch.sqrt(mse))
    return psnr.item()


psnrVals = []
for i in range(0, 8):
    psnr = PSNR(output[i, :, :, :], y[i, :, :, :])
    psnrVals.append(psnr)
psnrVals

[7.512714862823486,
 8.647907257080078,
 8.253816604614258,
 8.396913528442383,
 8.402873039245605,
 8.349210739135742,
 7.125802040100098,
 8.017643928527832]

### SSIM

In [21]:
def calculate_ssim(img1, img2):
    img1 = img1.cpu().detach().numpy()
    img2 = img2.cpu().detach().numpy()

    if img1.ndim == 3 and img1.shape[0] == 3:
        img1 = np.moveaxis(img1, 0, -1)
        img2 = np.moveaxis(img2, 0, -1)

    score, _ = ssim(img1, img2, full=True, channel_axis=2,
                    data_range=img1.max() - img1.min())

    return score


calculate_ssim(yvalPred[1, :, :, :], yVal[33, :, :, :])

0.76714975