In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import os
import numpy as np
from PIL import Image
import time
import cv2

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Test Dataset
class TestDenoisingDataset(Dataset):
    def __init__(self, test_dir, transform=None):
        self.test_dir = test_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(test_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
        self.resize_transform = transforms.Resize((256, 256), antialias=True)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = os.path.join(self.test_dir, self.image_files[idx])
        image = Image.open(image_path).convert("RGB")
        image = self.resize_transform(image)
        if self.transform:
            image = self.transform(image)
        return image, self.image_files[idx]

# Neural Network Base Class
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(self, x):
        raise NotImplementedError

# NN Regressor Base Class
class NNRegressor(NeuralNetwork):
    def __init__(self):
        super(NNRegressor, self).__init__()
        self.mse = nn.MSELoss()

    def criterion(self, y, d):
        return self.mse(y, d)

# DUDnCNN Model
class DUDnCNN(NNRegressor):
    def __init__(self, D, C=64):
        super(DUDnCNN, self).__init__()
        self.D = D
        k = [0] + [i for i in range(D//2)] + [D//2]*(D//2+1)
        l = [0]*(D//2+1) + [i for i in range(D+1-(D//2+1))] + [D//2]
        holes = [2**(kl[0]-kl[1])-1 for kl in zip(k, l)]
        dilations = [i+1 for i in holes]
        
        self.conv = nn.ModuleList()
        self.conv.append(nn.Conv2d(3, C, 3, padding=dilations[0], dilation=dilations[0]))
        self.conv.extend([nn.Conv2d(C, C, 3, padding=dilations[i+1], dilation=dilations[i+1]) for i in range(D)])
        self.conv.append(nn.Conv2d(C, 3, 3, padding=dilations[-1], dilation=dilations[-1]))
        
        self.bn = nn.ModuleList()
        self.bn.extend([nn.BatchNorm2d(C) for _ in range(D)])

    def forward(self, x):
        D = self.D
        h = nn.functional.relu(self.conv[0](x))
        h_buff = []
        for i in range(D//2 - 1):
            h = nn.functional.relu(self.bn[i](self.conv[i+1](h)))
            h_buff.append(h)
        for i in range(D//2 - 1, D//2 + 1):
            h = nn.functional.relu(self.bn[i](self.conv[i+1](h)))
        for i in range(D//2 + 1, D):
            j = i - (D//2 + 1) + 1
            h = nn.functional.relu(self.bn[i](self.conv[i+1]((h + h_buff[-j]) / 1.41421356237)))  # Precomputed sqrt(2)
        y = self.conv[D+1](h) + x
        return y

# Optimized transformation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Simplified Test-Time Augmentation (TTA) function
def apply_tta(net, image):
    augmentations = [
        image,                          # Original
        torch.flip(image, [2]),         # Horizontal flip only (reduced for less blurring)
    ]
    
    denoised_outputs = []
    with torch.no_grad():
        for aug_img in augmentations:
            aug_img = aug_img.unsqueeze(0).to(device, non_blocking=True)
            with torch.amp.autocast('cuda'):
                denoised = net(aug_img)
            # Reverse the augmentation
            if aug_img is augmentations[1]:
                denoised = torch.flip(denoised, [3])
            denoised_outputs.append(denoised.squeeze(0))
    
    # Average the outputs
    return torch.mean(torch.stack(denoised_outputs), dim=0)

# Revised Post-processing function
def post_process(denoised_img):
    img = np.moveaxis(denoised_img, 0, -1)  # CHW to HWC
    img = (img * 255).astype(np.uint8)
    
    # Lighter bilateral filtering to avoid over-smoothing
    img = cv2.bilateralFilter(img, d=5, sigmaColor=25, sigmaSpace=25)
    
    # Adjusted sharpening for subtle detail enhancement
    kernel = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]])  # Stronger center weight
    img = cv2.filter2D(img, -1, kernel)
    
    return img

def test_model(checkpoint_path, test_dir, output_dir="/kaggle/working/denoised_images", batch_size=4):
    # Initialize model
    net = DUDnCNN(D=6, C=64).to(device)
    net.eval()
    
    # Load checkpoint with weights_only=True
    try:
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
        net.load_state_dict(checkpoint['Net'])
    except Exception as e:
        print(f"Error loading checkpoint: {e}")
        return
    
    # Optimized DataLoader
    test_dataset = TestDenoisingDataset(test_dir=test_dir, transform=transform)
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        prefetch_factor=2
    )
    
    os.makedirs(output_dir, exist_ok=True)
    
    total_time = 0
    num_images = 0
    
    denorm_factor = torch.tensor([0.5], device=device)
    denorm_bias = torch.tensor([0.5], device=device)
    
    for noisy_images, filenames in test_loader:
        start_time = time.time()
        noisy_images = noisy_images.to(device, non_blocking=True)
        
        # Apply simplified TTA and denoise
        denoised_images = torch.stack([apply_tta(net, img) for img in noisy_images])
        
        # Denormalize
        denoised_images = (denoised_images * denorm_factor) + denorm_bias
        denoised_images = torch.clamp(denoised_images, 0, 1)
        
        # Process original sizes and save
        original_sizes = [Image.open(os.path.join(test_dataset.test_dir, fname)).size for fname in filenames]
        denoised_batch = denoised_images.cpu().numpy()
        
        for i, (filename, orig_size) in enumerate(zip(filenames, original_sizes)):
            denoised_img = post_process(denoised_batch[i])
            output_path = os.path.join(output_dir, filename)
            
            img = Image.fromarray(denoised_img)
            img = img.resize(orig_size, Image.Resampling.LANCZOS)
            img.save(output_path, 'PNG', compress_level=0)
            print(f"Saved denoised image: {output_path}")
        
        end_time = time.time()
        total_time += (end_time - start_time)
        num_images += len(filenames)
        
        # Clean up memory
        del noisy_images, denoised_images, denoised_batch
        torch.cuda.empty_cache()
    
    avg_runtime = total_time / num_images if num_images > 0 else 0
    print(f"Average runtime per image: {avg_runtime:.4f} seconds")
    print("Denoising complete! Output directory:", output_dir)

if __name__ == '__main__':
    checkpoint_path = "/kaggle/input/checkpoint/checkpoint.pth.tar"
    test_dir = "/kaggle/input/test-data"
    test_model(checkpoint_path, test_dir)

Saved denoised image: /kaggle/working/denoised_images/0000089.png
Saved denoised image: /kaggle/working/denoised_images/0000014.png
Saved denoised image: /kaggle/working/denoised_images/0000026.png
Saved denoised image: /kaggle/working/denoised_images/0968.png
Saved denoised image: /kaggle/working/denoised_images/0947.png
Saved denoised image: /kaggle/working/denoised_images/1000.png
Saved denoised image: /kaggle/working/denoised_images/0950.png
Saved denoised image: /kaggle/working/denoised_images/0971.png
Saved denoised image: /kaggle/working/denoised_images/0940.png
Saved denoised image: /kaggle/working/denoised_images/0000056.png
Saved denoised image: /kaggle/working/denoised_images/0000019.png
Saved denoised image: /kaggle/working/denoised_images/0000030.png
Saved denoised image: /kaggle/working/denoised_images/0000021.png
Saved denoised image: /kaggle/working/denoised_images/0000070.png
Saved denoised image: /kaggle/working/denoised_images/0000099.png
Saved denoised image: /kaggl