In [1]:
import os
import cv2
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

# Define a function to resize images while loading
def load_and_resize_images(root_folder, num_images, image_size=(256, 256)):
    images = []
    filenames = os.listdir(root_folder)[:num_images]
    for filename in filenames:
        image_path = os.path.join(root_folder, filename)
        img = cv2.imread(image_path)
        img = cv2.resize(img, image_size)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
        images.append(img)
    return images

# Load and resize images from the dataset folders
num_images = 1000
dataset_root = '..//Underwater Image Enhancement//EUVP'  # Update with your dataset root folder
raw_image_folder = os.path.join(dataset_root, 'raw')
clear_image_folder = os.path.join(dataset_root, 'gt')

raw_images = load_and_resize_images(raw_image_folder, 1000)
clear_images = load_and_resize_images(clear_image_folder, 1000)

# Split data into training and validation sets
train_raw_images, val_raw_images, train_clear_images, val_clear_images = train_test_split(
    raw_images, clear_images, test_size=int(515 / 6128 * 1000), random_state=42)

# Define the normalization and transformation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Apply transformations to training and validation images
train_preprocessed_raw_images = [transform(img) for img in train_raw_images]
train_preprocessed_clear_images = [transform(img) for img in train_clear_images]

val_preprocessed_raw_images = [transform(img) for img in val_raw_images]
val_preprocessed_clear_images = [transform(img) for img in val_clear_images]

# Create training and validation datasets
train_dataset = torch.utils.data.TensorDataset(torch.stack(train_preprocessed_raw_images),
                                               torch.stack(train_preprocessed_clear_images))

val_dataset = torch.utils.data.TensorDataset(torch.stack(val_preprocessed_raw_images),
                                             torch.stack(val_preprocessed_clear_images))

# Create data loaders for training and validation
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)


In [34]:
from torchvision.datasets import ImageFolder

# Define transformations for the test data (resize, normalize, etc.)
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize images if needed
    transforms.ToTensor(),          # Convert images to tensors
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize if needed
])

# Create a dataset from the test folder
test_dataset = ImageFolder(root="..//Underwater Image Enhancement//data", transform=transform)

# Create a DataLoader for the test dataset
batch_size = 32  # Choose an appropriate batch size
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [5]:
# import torch
# from torchvision import transforms

# # Define the normalization and transformation
# transform = transforms.Compose([
#     transforms.ToTensor(),  # Convert images to PyTorch tensors
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize using ImageNet mean and std
# ])

# # Apply the transformation to each image in raw_image_data and clear_image_data
# preprocessed_raw_images = [transform(img) for img in subset_raw_image_data]
# preprocessed_clear_images = [transform(img) for img in subset_clear_image_data]

In [7]:
# from torch.utils.data import DataLoader, TensorDataset

# # Convert preprocessed images to PyTorch tensors
# raw_images_tensor = torch.stack(preprocessed_raw_images)
# clear_images_tensor = torch.stack(preprocessed_clear_images)

# # Create a TensorDataset from preprocessed raw and clear images
# dataset = torch.utils.data.TensorDataset(raw_images_tensor, clear_images_tensor)

# # Define batch size for training
# batch_size = 16

# # Create data loader for training
# train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [52]:
import torch.nn as nn

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.relu(self.conv1(x))
        out = self.dropout(out)
        out = self.relu(self.conv2(out))
        out = self.dropout(out)
        return out

class ShallowUWNet(nn.Module):
    def __init__(self):
        super(ShallowUWNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        
        self.conv_block1 = ConvBlock(64, 64)
        self.conv_block2 = ConvBlock(64, 64)
        self.conv_block3 = ConvBlock(64, 64)
        
        self.final_conv = nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)

        # Additional 1x1 convolutions for skip connections
        self.skip_conv1 = nn.Conv2d(3, 64, kernel_size=1)
        self.skip_conv2 = nn.Conv2d(64, 64, kernel_size=1)
        self.skip_conv3 = nn.Conv2d(64, 64, kernel_size=1)

    def forward(self, x):
        out = self.conv1(x)
        out = self.relu(out)
        
        residual1 = self.conv_block1(out)
        out = out + residual1
        
        residual2 = self.conv_block2(out)
        out = out + residual2
        
        residual3 = self.conv_block3(out)
        out = out + residual3
        
        out = self.final_conv(out)
        return out

In [53]:
import torchvision.models as models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the VGG19 network without the final classification layer
class VGGFeatureExtractor(nn.Module):
    def __init__(self):
        super(VGGFeatureExtractor, self).__init__()
        vgg = models.vgg19(pretrained=True)
        self.vgg_features = vgg.features[:29]  # Use features until before the last maxpooling layer

    def forward(self, x):
        return self.vgg_features(x)

# Initialize the VGG feature extractor
vgg_extractor = VGGFeatureExtractor().to(device)
vgg_extractor.eval()  # Set to evaluation mode (no gradient computation)

# Define the loss functions
criterion_mse = nn.MSELoss()

# Function to compute VGG perceptual loss
def perceptual_loss(vgg, x, y):
    features_x = vgg(x)
    features_y = vgg(y)
    return criterion_mse(features_x, features_y)

# Combined Loss function
def total_loss(outputs, targets):
    mse_loss = criterion_mse(outputs, targets)
    vgg_loss = perceptual_loss(vgg_extractor, outputs, targets)
    total_loss = mse_loss + vgg_loss
    return total_loss

In [18]:
model = ShallowUWNet()

In [None]:
import torch.optim as optim
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set the model, criterion, optimizer, and other parameters
model = ShallowUWNet().to(device)  # Assuming ShallowUWNet is the model defined earlier
criterion = total_loss  # Use the defined combined loss function
optimizer = optim.Adam(model.parameters(), lr=0.0002)
batch_size = 1  # Set batch size to 1

# Set dropout layers to 0.2
for module in model.modules():
    if isinstance(module, nn.Dropout):
        module.p = 0.2

# Create data loaders for training and validation with batch size 1
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

# Training loop
num_epochs = 10  # Set the number of epochs
train_losses = []
val_losses = []
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 100 == 99:  # Print every 100 batches
            print(f"Epoch [{epoch+1}/{num_epochs}] Batch [{batch_idx+1}/{len(train_loader)}] Loss: {running_loss / 100:.4f}")
            running_loss = 0.0

    # Validation after each epoch (if needed)
    model.eval()
    with torch.no_grad():
        val_loss = 0.0
        for val_inputs, val_targets in val_loader:
            val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)
            val_outputs = model(val_inputs)
            val_loss += criterion(val_outputs, val_targets).item()

        print(f"Epoch [{epoch+1}/{num_epochs}] Validation Loss: {val_loss / len(val_loader):.4f}")

        # Calculate average validation loss
        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)

    # Track training loss after each epoch
    train_losses.append(running_loss / len(train_loader))

# Plot losses across epochs
plt.figure(figsize=(8, 5))
plt.plot(range(1, num_epochs + 1), train_losses, label='Training Loss', marker='o')
plt.plot(range(1, num_epochs + 1), val_losses, label='Validation Loss', marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Losses')
plt.legend()
plt.grid(True)
plt.show()

torch.save(model.state_dict(), 'trained_model.pth')

In [33]:
import torch
import os
from PIL import Image


def get_image_list(raw_image_path, clear_image_path, is_train):
    image_list = []
    raw_image_list = [raw_image_path + i for i in os.listdir(raw_image_path)]
    if is_train:
        for raw_image in raw_image_list:
            image_file = raw_image.split('/')[-1]
            image_list.append([raw_image, os.path.join(clear_image_path + image_file), image_file])
    else:
        for raw_image in raw_image_list:
            image_file = raw_image.split('/')[-1]
            image_list.append([raw_image, None, image_file])
    return image_list


class UWNetDataSet(torch.utils.data.Dataset):
    def __init__(self, raw_image_path, clear_image_path, transform, is_train=False):
        self.raw_image_path = raw_image_path
        self.clear_image_path = clear_image_path
        self.is_train = is_train
        self.image_list = get_image_list(self.raw_image_path, self.clear_image_path, is_train)
        self.transform = transform

    def __getitem__(self, index):
        raw_image, clear_image, image_name = self.image_list[index]
        raw_image = Image.open(raw_image)
        if self.is_train:
            clear_image = Image.open(clear_image)
            return self.transform(raw_image), self.transform(clear_image), "_"
        return self.transform(raw_image), "_", image_name

    def __len__(self):
        return len(self.image_list)

In [31]:
# model.eval()

# # Iterate through the test dataset
# for inputs, targets in test_dataloader:
#     # Move inputs and targets to the same device as the model
#     inputs = inputs.to(device)
#     targets = targets.to(device)

#     # Forward pass (prediction)
#     with torch.no_grad():
#         outputs = model(inputs)

In [None]:
import torch
import torchvision.transforms as transforms
import torchvision.utils

# Define parameters or paths directly (remove argparse/config part)
snapshot_path = '..//Underwater Image Enhancement//trained_model.pth'
test_images_path = "..//Underwater Image Enhancement//data//input//"
output_images_path = '..//Underwater Image Enhancement//data/output//'
batch_size = 1
resize = 256

# Load the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ptmodel = ShallowUWNet()  # Initialize your model instance
ptmodel.load_state_dict(torch.load(snapshot_path,  map_location=device))
# model2 = ShallowUWNet.load_from_checkpoint(snapshot_path)
ptmodel.to(device)
ptmodel.eval()

# Prepare the test dataset and dataloader
transform = transforms.Compose([
    transforms.Resize((resize, resize)),
    transforms.ToTensor()
])
test_dataset = UWNetDataSet(test_images_path, None, transform, False)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Define a simplified test function
@torch.no_grad()
def test(test_dataloader, model, output_images_path):
    model.eval()
    for img, _, name in test_dataloader:
        img = img.to(device)
        generate_img = model(img)
        torchvision.utils.save_image(generate_img, output_images_path + name[0])

# Execute the test function
test(test_dataloader, model, output_images_path)
