In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from skimage.metrics import structural_similarity as ssim


In [None]:

# Define a custom dataset class to load noisy and ground truth images
class ColorCorrectionDataset(Dataset):
    def __init__(self, noisy_dir, gt_dir, transform=None):
        self.noisy_images = datasets.ImageFolder(noisy_dir, transform=transform)
        self.gt_images = datasets.ImageFolder(gt_dir, transform=transform)
        assert len(self.noisy_images) == len(self.gt_images), "Number of noisy and GT images must be the same."

    def __getitem__(self, index):
        noisy_img, _ = self.noisy_images[index]
        gt_img, _ = self.gt_images[index]
        return noisy_img, gt_img

    def __len__(self):
        return len(self.noisy_images)


In [None]:

# Define the convolutional neural network architecture
class ColorCorrectionCNN(nn.Module):
    def __init__(self):
        super(ColorCorrectionCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x

# Define SSIM loss function
class SSIMLoss(nn.Module):
    def __init__(self):
        super(SSIMLoss, self).__init__()

    def forward(self, output, target):
        output = output.permute(0, 2, 3, 1).cpu().detach().numpy()
        target = target.permute(0, 2, 3, 1).cpu().detach().numpy()
        ssim_loss = 1 - ssim(output, target, multichannel=True)
        return torch.tensor(ssim_loss, requires_grad=True)


In [None]:

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define transforms
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# Create datasets and dataloaders
train_dataset = ColorCorrectionDataset(noisy_dir='/train/noisy', gt_dir='/train/GT', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataset = ColorCorrectionDataset(noisy_dir='/test/noisy', gt_dir='/test/GT', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# Create the model, loss function, and optimizer
model = ColorCorrectionCNN().to(device)
criterion = SSIMLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

loss_values = []
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, (noisy_img, gt_img) in enumerate(train_loader):
        noisy_img, gt_img = noisy_img.to(device), gt_img.to(device)
        
        optimizer.zero_grad()
        outputs = model(noisy_img)
        loss = criterion(outputs, gt_img)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if (i+1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {running_loss/10:.4f}")
            running_loss = 0.0
    loss_values.append(running_loss)

print("Training finished.")


In [None]:

# Evaluate the model on the test set
model.eval()
total_ssim_loss = 0.0
with torch.no_grad():
    for i, (noisy_img, gt_img) in enumerate(test_loader):
        noisy_img, gt_img = noisy_img.to(device), gt_img.to(device)
        outputs = model(noisy_img)
        loss = criterion(outputs, gt_img)
        total_ssim_loss += loss.item()

print(f"Average SSIM loss on test set: {total_ssim_loss/len(test_loader):.4f}")
import matplotlib.pyplot as plt
import numpy as np


# Visualize the loss over epochs
plt.plot(np.arange(1, num_epochs+1), loss_values)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.show()





In [None]:

# Save the trained model
torch.save(model.state_dict(), 'color_correction_model.pth')
print("Model saved.")
