In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import numpy as np
from tqdm import tqdm
import cv2
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, random_split  # Added random_split here
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
low_dir = "/kaggle/input/low-ass"
high_dir = "/kaggle/input/high-ass"

In [None]:

# ====================================================
# 1. Dataset Preparation
# ====================================================
class LowLightDataset(Dataset):
    def __init__(self, low_dir, high_dir, transform=None):
        self.low_dir = low_dir
        self.high_dir = high_dir
        self.low_images = sorted(os.listdir(low_dir))
        self.high_images = sorted(os.listdir(high_dir))
        self.transform = transform

    def __len__(self):
        return len(self.low_images)

    def __getitem__(self, idx):
        low_img_path = os.path.join(self.low_dir, self.low_images[idx])
        high_img_path = os.path.join(self.high_dir, self.high_images[idx])
        
        low_img = Image.open(low_img_path).convert('RGB')
        high_img = Image.open(high_img_path).convert('RGB')
        
        if self.transform:
            low_img = self.transform(low_img)
            high_img = self.transform(high_img)
        
        return low_img, high_img


In [None]:

# ====================================================
# 2. Model Architecture 
# ====================================================
class DenoisingAutoencoder(nn.Module):
    def __init__(self):
        super(DenoisingAutoencoder, self).__init__()
        
        # Encoder
        self.conv1_1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv1_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        
        self.conv2_1 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv2_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        
        self.conv3_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv3_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2)
        
        self.conv4_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.conv4_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.pool4 = nn.MaxPool2d(kernel_size=2)
        
        self.conv5_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
        self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        
        # Decoder
        self.upv6 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv6_1 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1)
        self.conv6_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        
        self.upv7 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv7_1 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
        self.conv7_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        
        self.upv8 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv8_1 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
        self.conv8_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        
        self.upv9 = nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.conv9_1 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)
        self.conv9_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        
        # Output
        self.conv10_1 = nn.Conv2d(32, 3, kernel_size=1, stride=1)
    
    def forward(self, x):
        # Encoder
        conv1 = self.lrelu(self.conv1_1(x))
        conv1 = self.lrelu(self.conv1_2(conv1))
        pool1 = self.pool1(conv1)
        
        conv2 = self.lrelu(self.conv2_1(pool1))
        conv2 = self.lrelu(self.conv2_2(conv2))
        pool2 = self.pool1(conv2)
        
        conv3 = self.lrelu(self.conv3_1(pool2))
        conv3 = self.lrelu(self.conv3_2(conv3))
        pool3 = self.pool1(conv3)
        
        conv4 = self.lrelu(self.conv4_1(pool3))
        conv4 = self.lrelu(self.conv4_2(conv4))
        pool4 = self.pool1(conv4)
        
        conv5 = self.lrelu(self.conv5_1(pool4))
        conv5 = self.lrelu(self.conv5_2(conv5))
        
        # Decoder with skip connections
        up6 = self.upv6(conv5)
        up6 = torch.cat([up6, conv4], dim=1)
        conv6 = self.lrelu(self.conv6_1(up6))
        conv6 = self.lrelu(self.conv6_2(conv6))
        
        up7 = self.upv7(conv6)
        up7 = torch.cat([up7, conv3], dim=1)
        conv7 = self.lrelu(self.conv7_1(up7))
        conv7 = self.lrelu(self.conv7_2(conv7))
        
        up8 = self.upv8(conv7)
        up8 = torch.cat([up8, conv2], dim=1)
        conv8 = self.lrelu(self.conv8_1(up8))
        conv8 = self.lrelu(self.conv8_2(conv8))
        
        up9 = self.upv9(conv8)
        up9 = torch.cat([up9, conv1], dim=1)
        conv9 = self.lrelu(self.conv9_1(up9))
        conv9 = self.lrelu(self.conv9_2(conv9))
        
        # Output
        conv10 = self.conv10_1(conv9)
        out = torch.sigmoid(conv10)  # Normalize to [0, 1]
        return out

    def lrelu(self, x):
        return torch.max(0.2 * x, x)


In [None]:

# ====================================================
# 3. Training Setup
# ====================================================
# Hyperparameters
BATCH_SIZE = 4
EPOCHS = 50
LR = 0.0001

# Transformations
transform = transforms.Compose([
    transforms.Resize((512, 512)),  # Original paper uses 512x512
    transforms.ToTensor(),
])

# Full dataset
full_dataset = LowLightDataset("/kaggle/input/low-ass", "/kaggle/input/high-ass", transform=transform)

# Split (Option 1)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Dataloaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

# Initialize model, loss, optimizer
model = DenoisingAutoencoder().to(device)
criterion = nn.L1Loss()  # MAE loss as in the paper
optimizer = optim.Adam(model.parameters(), lr=LR)



In [None]:

# ====================================================
# 4. Training Loop
# ====================================================
def train_model(model, train_loader, val_loader, criterion, optimizer, epochs):
    best_loss = float('inf')
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        
        # Training phase
        for low, high in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            low, high = low.to(device), high.to(device)
            
            optimizer.zero_grad()
            outputs = model(low)
            loss = criterion(outputs, high)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for low, high in val_loader:
                low, high = low.to(device), high.to(device)
                outputs = model(low)
                val_loss += criterion(outputs, high).item()
        
        # Print statistics
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
        
        # Save best model
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), "best_model.pth")
    
    return model


In [None]:
# Start training
model = train_model(model, train_loader, val_loader, criterion, optimizer, EPOCHS)


In [None]:
def evaluate_metrics(model, val_loader):
    model.eval()
    total_psnr = 0.0
    total_ssim = 0.0
    with torch.no_grad():
        for low, high in tqdm(val_loader, desc="Evaluating"):
            low, high = low.to(device), high.to(device)
            outputs = model(low).cpu().numpy()
            high = high.cpu().numpy()
            for i in range(outputs.shape[0]):
                pred_img = np.clip(outputs[i].transpose(1, 2, 0), 0, 1)
                true_img = np.clip(high[i].transpose(1, 2, 0), 0, 1)
                total_psnr += psnr(true_img, pred_img, data_range=1.0)
                total_ssim += ssim(true_img, pred_img, data_range=1.0, channel_axis=2)
    
    avg_psnr = total_psnr / len(val_loader.dataset)
    avg_ssim = total_ssim / len(val_loader.dataset)
    print(f"\n📊 Evaluation Results — PSNR: {avg_psnr:.2f} dB | SSIM: {avg_ssim:.4f}")


In [None]:
model.load_state_dict(torch.load("best_model.pth"))
model.to(device)


In [None]:
evaluate_metrics(model, val_loader)


In [None]:

# ====================================================
# 5. Inference and Save Enhanced Images
# ====================================================
def enhance_and_save(model, input_dir, output_dir, transform):
    os.makedirs(output_dir, exist_ok=True)
    model.eval()
    
    for img_name in tqdm(os.listdir(input_dir)):
        img_path = os.path.join(input_dir, img_name)
        img = Image.open(img_path).convert('RGB')
        img_tensor = transform(img).unsqueeze(0).to(device)
        
        with torch.no_grad():
            enhanced_tensor = model(img_tensor)
        
        enhanced_img = transforms.ToPILImage()(enhanced_tensor.squeeze().cpu())
        enhanced_img.save(os.path.join(output_dir, img_name))

# Load best model
model.load_state_dict(torch.load("best_model.pth"))
model.to(device)

# Enhance test images
enhance_and_save(model, low_dir, "output/enhanced/", transform)