In [11]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torch.utils.data import random_split
from torchvision.transforms import ToTensor
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2

In [12]:
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [13]:
# Define paths to image folders
train_low_folder = "/home/dgreen/sem2/dl/finalProject/LOL-v2/LOL-v2/Real_captured/Train/Low"  # Path to folder containing low-resolution training images
train_high_folder = "/home/dgreen/sem2/dl/finalProject/LOL-v2/LOL-v2/Real_captured/Train/Normal"  # Path to folder containing high-resolution training images
eval_low_folder = "/home/dgreen/sem2/dl/finalProject/LOL-v2/LOL-v2/Real_captured/Test/Low"  # Path to folder containing low-resolution evaluation images
eval_high_folder = "/home/dgreen/sem2/dl/finalProject/LOL-v2/LOL-v2/Real_captured/Test/Normal"  # Path to folder containing high-resolution evaluation images
train_low_synt_folder = "/home/dgreen/sem2/dl/finalProject/LOL-v2/LOL-v2/Synthetic/Train/Low"
train_high_synt_folder = "/home/dgreen/sem2/dl/finalProject/LOL-v2/LOL-v2/Synthetic/Train/Normal"
eval_low_synt_folder = "/home/dgreen/sem2/dl/finalProject/LOL-v2/LOL-v2/Synthetic/Test/Low"
eval_high_synt_folder = "/home/dgreen/sem2/dl/finalProject/LOL-v2/LOL-v2/Synthetic/Test/Normal"

In [14]:
train_low_images_list = []
train_high_images_list = []
eval_low_images_list = []
eval_high_images_list = []

train_low_synt_images_list = []
train_high_synt_images_list = []
eval_low_synt_images_list = []
eval_high_synt_images_list = []

for filename in os.listdir(train_low_folder):
    if "png" in filename.strip().split(".")[-1]:
        train_low_images_list.append(filename)

for filename in os.listdir(train_high_folder):
    if "png" in filename.strip().split(".")[-1]:
        train_high_images_list.append(filename)

for filename in os.listdir(eval_low_folder):
    if "png" in filename.strip().split(".")[-1]:
        eval_low_images_list.append(filename)

for filename in os.listdir(eval_high_folder):
    if "png" in filename.strip().split(".")[-1]:
        eval_high_images_list.append(filename)

for filename in os.listdir(train_low_synt_folder):
    if "png" in filename.strip().split(".")[-1]:
        train_low_synt_images_list.append(filename)

for filename in os.listdir(train_high_synt_folder):
    if "png" in filename.strip().split(".")[-1]:
        train_high_synt_images_list.append(filename)

for filename in os.listdir(eval_low_synt_folder):
    if "png" in filename.strip().split(".")[-1]:
        eval_low_synt_images_list.append(filename)

for filename in os.listdir(eval_high_synt_folder):
    if "png" in filename.strip().split(".")[-1]:
        eval_high_synt_images_list.append(filename)

In [15]:
train_low_images_list = sorted(train_low_images_list)
train_high_images_list = sorted(train_high_images_list)
eval_low_images_list = sorted(eval_low_images_list)
eval_high_images_list = sorted(eval_high_images_list)

train_low_synt_images_list = sorted(train_low_synt_images_list)
train_high_synt_images_list = sorted(train_high_synt_images_list)
eval_low_synt_images_list = sorted(eval_low_synt_images_list)
eval_high_synt_images_list = sorted(eval_high_synt_images_list)

In [16]:
# Load low-resolution and high-resolution images using OpenCV
train_low_images = [cv2.imread(os.path.join(train_low_folder, filename)) for filename in train_low_images_list]
train_high_images = [cv2.imread(os.path.join(train_high_folder, filename)) for filename in train_high_images_list]
eval_low_images = [cv2.imread(os.path.join(eval_low_folder, filename)) for filename in eval_low_images_list]
eval_high_images = [cv2.imread(os.path.join(eval_high_folder, filename)) for filename in eval_high_images_list]

train_low_images_fld_list = [os.path.join(train_low_folder, filename) for filename in train_low_images_list]
train_high_images_fld_list = [os.path.join(train_high_folder, filename) for filename in train_high_images_list]
eval_low_images_fld_list = [os.path.join(eval_low_folder, filename) for filename in eval_low_images_list]
eval_high_images_fld_list = [os.path.join(eval_high_folder, filename) for filename in eval_high_images_list]

train_low_synt_images = [cv2.imread(os.path.join(train_low_synt_folder, filename)) for filename in train_low_synt_images_list]
train_high_synt_images = [cv2.imread(os.path.join(train_high_synt_folder, filename)) for filename in train_high_synt_images_list]
eval_low_synt_images = [cv2.imread(os.path.join(eval_low_synt_folder, filename)) for filename in eval_low_synt_images_list]
eval_high_synt_images = [cv2.imread(os.path.join(eval_high_synt_folder, filename)) for filename in eval_high_synt_images_list]

train_low_synt_images_fld_list = [os.path.join(train_low_synt_folder, filename) for filename in train_low_synt_images_list]
train_high_synt_images_fld_list = [os.path.join(train_high_synt_folder, filename) for filename in train_high_synt_images_list]
eval_low_synt_images_fld_list = [os.path.join(eval_low_synt_folder, filename) for filename in eval_low_synt_images_list]
eval_high_synt_images_fld_list = [os.path.join(eval_high_synt_folder, filename) for filename in eval_high_synt_images_list]

train_low_images = train_low_images + train_low_synt_images
train_high_images = train_high_images + train_high_synt_images
eval_low_images = eval_low_images + eval_low_synt_images
eval_high_images = eval_high_images + eval_high_synt_images

train_low_images_fld_list = train_low_images_fld_list + train_low_synt_images_fld_list
train_high_images_fld_list = train_high_images_fld_list + train_high_synt_images_fld_list
eval_low_images_fld_list = eval_low_images_fld_list + eval_low_synt_images_fld_list
eval_high_images_fld_list = eval_high_images_fld_list + eval_high_synt_images_fld_list

In [17]:
import os
import cv2
from PIL import Image

class CustomDataset(Dataset):
    def __init__(self, low_res_images, high_res_images, transform=None, resize=(384, 384)):
        self.low_res_images = low_res_images
        self.high_res_images = high_res_images
        self.transform = transform
        self.resize = resize

    def __len__(self):
        return len(self.low_res_images)

    def __getitem__(self, idx):
        try:
            low_res_image = Image.open(self.low_res_images[idx]).convert('RGB')
            high_res_image = Image.open(self.high_res_images[idx]).convert('RGB')

            # Resize images
            low_res_image = low_res_image.resize(self.resize, Image.BICUBIC)
            high_res_image = high_res_image.resize(self.resize, Image.BICUBIC)
            
            if self.transform:
                low_res_image = self.transform(low_res_image)
                high_res_image = self.transform(high_res_image)
            
            return low_res_image, high_res_image
        except Exception as e:
            print(f"Error loading image at index {idx}: {e}")
            return None, None

# Define transformation for preprocessing
transform = transforms.Compose([
    transforms.ToTensor()
])

# Load datasets
train_dataset = CustomDataset(train_low_images_fld_list, train_high_images_fld_list, transform=transform)
eval_dataset = CustomDataset(eval_low_images_fld_list, eval_high_images_fld_list, transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(eval_dataset, batch_size=1, shuffle=False)


In [8]:
# Define the model architecture
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.batchnorm = nn.BatchNorm2d(channels)
        self.activation = nn.ReLU()

    def forward(self, x):
        residual = x
        x = self.conv1(x)
        x = self.batchnorm(x)
        x = self.activation(x)
        x = self.conv1(x)
        x = self.batchnorm(x)
        x += residual
        x = self.activation(x)
        return x

class DiffusionModel(nn.Module):
    def __init__(self):
        super(DiffusionModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.batchnorm1 = nn.BatchNorm2d(64)
        self.activation = nn.ReLU()
        self.residual_blocks = nn.ModuleList([ResidualBlock(64) for _ in range(8)])
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.batchnorm2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 3, kernel_size=3, padding=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.conv1(x)
        x = self.batchnorm1(x)
        x = self.activation(x)
        for block in self.residual_blocks:
            x = block(x)
        x = self.conv2(x)
        x = self.batchnorm2(x)
        x = self.activation(x)
        x = self.conv3(x)
        x = self.sigmoid(x)
        return x

In [9]:
# import torch
# from torch import nn, optim
# from torch.utils.data import DataLoader
# from torchvision import transforms
# from PIL import Image
# import os

# # -----------------------------------
# # 1. Generator and Residual Block
# # -----------------------------------
# class ResidualBlock(nn.Module):
#     def __init__(self, in_features):
#         super(ResidualBlock, self).__init__()
#         self.block = nn.Sequential(
#             nn.ReflectionPad2d(1),
#             nn.Conv2d(in_features, in_features, 3),
#             nn.InstanceNorm2d(in_features),
#             nn.ReLU(inplace=True),
#             nn.ReflectionPad2d(1),
#             nn.Conv2d(in_features, in_features, 3),
#             nn.InstanceNorm2d(in_features)
#         )

#     def forward(self, x):
#         return x + self.block(x)

# class Generator(nn.Module):
#     def __init__(self, input_channels, output_channels, n_residual_blocks=9):
#         super(Generator, self).__init__()
#         model = [nn.ReflectionPad2d(3),
#                  nn.Conv2d(input_channels, 64, 7),
#                  nn.InstanceNorm2d(64),
#                  nn.ReLU(inplace=True)]
#         in_features = 64
#         out_features = in_features * 2
#         for _ in range(2):
#             model.extend([nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
#                           nn.InstanceNorm2d(out_features),
#                           nn.ReLU(inplace=True)])
#             in_features = out_features
#             out_features *= 2
#         for _ in range(n_residual_blocks):
#             model.append(ResidualBlock(in_features))
#         out_features = in_features // 2
#         for _ in range(2):
#             model.extend([nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
#                           nn.InstanceNorm2d(out_features),
#                           nn.ReLU(inplace=True)])
#             in_features = out_features
#             out_features = in_features // 2
#         model.extend([nn.ReflectionPad2d(3),
#                       nn.Conv2d(64, output_channels, 7),
#                       nn.Tanh()])
#         self.model = nn.Sequential(*model)

#     def forward(self, x):
#         return self.model(x)

# # -----------------------------------
# # 2. Discriminator
# # -----------------------------------
# class Discriminator(nn.Module):
#     def __init__(self, input_channels):
#         super(Discriminator, self).__init__()
#         self.model = nn.Sequential(
#             nn.Conv2d(input_channels, 64, 4, stride=2, padding=1),
#             nn.LeakyReLU(0.2, inplace=True),
#             nn.Conv2d(64, 128, 4, stride=2, padding=1),
#             nn.InstanceNorm2d(128),
#             nn.LeakyReLU(0.2, inplace=True),
#             nn.Conv2d(128, 256, 4, stride=2, padding=1),
#             nn.InstanceNorm2d(256),
#             nn.LeakyReLU(0.2, inplace=True),
#             nn.Conv2d(256, 512, 4, padding=1),
#             nn.InstanceNorm2d(512),
#             nn.LeakyReLU(0.2, inplace=True),
#             nn.Conv2d(512, 1, 4, padding=1)
#         )

#     def forward(self, x):
#         return self.model(x)

# # -----------------------------------
# # 3. Initialize Models
# # -----------------------------------
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# G_AB = Generator(3, 3).to(device)
# G_BA = Generator(3, 3).to(device)
# D_A = Discriminator(3).to(device)
# D_B = Discriminator(3).to(device)
# print('Models moved to GPU.')

# # -----------------------------------
# # 4. Losses and Optimizers
# # -----------------------------------
# criterion_GAN = nn.MSELoss()
# criterion_cycle = nn.L1Loss()
# criterion_identity = nn.L1Loss()

# optimizer_G = optim.Adam(list(G_AB.parameters()) + list(G_BA.parameters()), lr=0.0002, betas=(0.5, 0.999))
# optimizer_D_A = optim.Adam(D_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
# optimizer_D_B = optim.Adam(D_B.parameters(), lr=0.0002, betas=(0.5, 0.999))


Models moved to GPU.


In [10]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
from skimage import exposure

# Define PSNR calculation function
def psnr(img1, img2):
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    return 20 * torch.log10(1.0 / torch.sqrt(mse))


# Load the models
model1 = DiffusionModel().to(device)
model1.load_state_dict(torch.load('diffusion_model.pth'))
# model2 = DiffusionModel().to(device)
# model2.load_state_dict(torch.load('diffusion_model_100.pth'))
# model3 = DiffusionModel().to(device)
# model3.load_state_dict(torch.load('diffusion_model_actual.pth'))
# model4 = Generator(3, 3)
# model4.load_state_dict(torch.load('G_AB_epoch_100.pth'))
# model5 = Generator(3, 3)
# model5.load_state_dict(torch.load('G_AB_200epoch_200.pth'))

# Load the test dataset


# Set models to evaluation mode
model1.eval()
model2.eval()
model3.eval()
model4.eval()
model5.eval()

def adjust_color_balance(image):
    # Example: Apply a simple color balance adjustment by increasing the red channel
    image[:,:,:] *= 1.25  # Increase red channel

    # Clip values to ensure they are in the valid range [0, 1]
    image = np.clip(image, 0, 1)
    
    return image

# Iterate through the test dataset
for i, (low_img, high_img) in enumerate(test_loader):
    with torch.no_grad():
        # Forward pass through the models
        output1 = model1(low_img)
        output2 = model2(low_img)
        output3 = model3(low_img)
        output4 = model4(low_img)
        output5 = model5(low_img)
        # Calculate PSNR for each output
        psnr1 = psnr(output1, high_img)
        psnr2 = psnr(output2, high_img)
        psnr3 = psnr(output3, high_img)
        psnr4 = psnr(output4, high_img)
        psnr5 = psnr(output5, high_img)

        # Print PSNR scores
        print(f"PSNR for Model 1 on image {i+1}: {psnr1.item()} dB")
        print(f"PSNR for Model 2 on image {i+1}: {psnr2.item()} dB")
        print(f"PSNR for Model 3 on image {i+1}: {psnr3.item()} dB")
        print(f"PSNR for Model 4 on image {i+1}: {psnr4.item()} dB")
        print(f"PSNR for Model 5 on image {i+1}: {psnr5.item()} dB")


        normalized_output1 = (output1 - output1.min()) / (output1.max() - output1.min())
        output_image1_np = normalized_output1.squeeze(0).permute(1, 2, 0).cpu().numpy()
        output_image1_np_adjusted = adjust_color_balance(output_image1_np)
        output_image1_np_adjusted = exposure.equalize_hist(output_image1_np_adjusted)

        normalized_output2 = (output2 - output2.min()) / (output2.max() - output2.min())
        output_image2_np = normalized_output2.squeeze(0).permute(1, 2, 0).cpu().numpy()
        output_image2_np_adjusted = adjust_color_balance(output_image2_np)
        output_image2_np_adjusted = exposure.equalize_hist(output_image2_np_adjusted)

        normalized_output3 = (output3 - output3.min()) / (output3.max() - output3.min())
        output_image3_np = normalized_output3.squeeze(0).permute(1, 2, 0).cpu().numpy()
        output_image3_np_adjusted = adjust_color_balance(output_image3_np)
        output_image3_np_adjusted = exposure.equalize_hist(output_image3_np_adjusted)

        normalized_output4 = (output4 - output4.min()) / (output4.max() - output4.min())
        output_image4_np = normalized_output4.squeeze(0).permute(1, 2, 0).cpu().numpy()
        output_image4_np_adjusted = adjust_color_balance(output_image4_np)
        output_image4_np_adjusted = exposure.equalize_hist(output_image4_np_adjusted)
        
        normalized_output5 = (output5 - output5.min()) / (output5.max() - output5.min())
        output_image5_np = normalized_output5.squeeze(0).permute(1, 2, 0).cpu().numpy()
        output_image5_np_adjusted = adjust_color_balance(output_image5_np)
        output_image5_np_adjusted = exposure.equalize_hist(output_image5_np_adjusted)

        # Plot the images
        fig, axes = plt.subplots(2, 5, figsize=(20, 8))

        # Display the first row of images
        axes[0, 0].imshow(high_img)
        axes[0, 0].set_title('High Resolution')
        axes[0, 0].axis('off')

        axes[0, 1].imshow(low_img)
        axes[0, 1].set_title('Low Resolution')
        axes[0, 1].axis('off')

        axes[0, 2].imshow(output1)
        axes[0, 2].set_title('diffusion_model')
        axes[0, 2].axis('off')

        # If you have additional images, continue displaying them in the same row
        # For example:
        axes[0, 3].imshow(output2)
        axes[0, 3].set_title('diffusion_model_100')
        axes[0, 3].axis('off')

        axes[0, 4].imshow(output3)
        axes[0, 4].set_title('diffusion_model_actual')
        axes[0, 4].axis('off')

        axes[0, 5].imshow(output4)
        axes[0, 5].set_title('G_AB_epoch_100')
        axes[0, 5].axis('off')

        axes[0, 5].imshow(output5)
        axes[0, 5].set_title('G_AB_200epoch_200')
        axes[0, 5].axis('off')

        plt.tight_layout()
        plt.show()
        
        if i == 10:  # Change 10 to the number of images you want to display
            break



RuntimeError: Error(s) in loading state_dict for DiffusionModel:
	Missing key(s) in state_dict: "conv1.weight", "conv1.bias", "batchnorm1.weight", "batchnorm1.bias", "batchnorm1.running_mean", "batchnorm1.running_var", "residual_blocks.0.conv1.weight", "residual_blocks.0.conv1.bias", "residual_blocks.0.batchnorm.weight", "residual_blocks.0.batchnorm.bias", "residual_blocks.0.batchnorm.running_mean", "residual_blocks.0.batchnorm.running_var", "residual_blocks.1.conv1.weight", "residual_blocks.1.conv1.bias", "residual_blocks.1.batchnorm.weight", "residual_blocks.1.batchnorm.bias", "residual_blocks.1.batchnorm.running_mean", "residual_blocks.1.batchnorm.running_var", "residual_blocks.2.conv1.weight", "residual_blocks.2.conv1.bias", "residual_blocks.2.batchnorm.weight", "residual_blocks.2.batchnorm.bias", "residual_blocks.2.batchnorm.running_mean", "residual_blocks.2.batchnorm.running_var", "residual_blocks.3.conv1.weight", "residual_blocks.3.conv1.bias", "residual_blocks.3.batchnorm.weight", "residual_blocks.3.batchnorm.bias", "residual_blocks.3.batchnorm.running_mean", "residual_blocks.3.batchnorm.running_var", "residual_blocks.4.conv1.weight", "residual_blocks.4.conv1.bias", "residual_blocks.4.batchnorm.weight", "residual_blocks.4.batchnorm.bias", "residual_blocks.4.batchnorm.running_mean", "residual_blocks.4.batchnorm.running_var", "residual_blocks.5.conv1.weight", "residual_blocks.5.conv1.bias", "residual_blocks.5.batchnorm.weight", "residual_blocks.5.batchnorm.bias", "residual_blocks.5.batchnorm.running_mean", "residual_blocks.5.batchnorm.running_var", "residual_blocks.6.conv1.weight", "residual_blocks.6.conv1.bias", "residual_blocks.6.batchnorm.weight", "residual_blocks.6.batchnorm.bias", "residual_blocks.6.batchnorm.running_mean", "residual_blocks.6.batchnorm.running_var", "residual_blocks.7.conv1.weight", "residual_blocks.7.conv1.bias", "residual_blocks.7.batchnorm.weight", "residual_blocks.7.batchnorm.bias", "residual_blocks.7.batchnorm.running_mean", "residual_blocks.7.batchnorm.running_var", "conv2.weight", "conv2.bias", "batchnorm2.weight", "batchnorm2.bias", "batchnorm2.running_mean", "batchnorm2.running_var", "conv3.weight", "conv3.bias". 
	Unexpected key(s) in state_dict: "encoder_conv1.weight", "encoder_conv1.bias", "encoder_conv2.weight", "encoder_conv2.bias", "decoder_conv1.weight", "decoder_conv1.bias", "decoder_conv2.weight", "decoder_conv2.bias", "decoder_conv3.weight", "decoder_conv3.bias". 