In [None]:
# Noise2Self Single-Shot RGB Image Denoising
# This notebook implements the Noise2Self algorithm for denoising a single 512x512 RGB image
# Based on: "Noise2Self: Blind Denoising by Self-Supervision" (Batson & Royer, 2019)



import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import random

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

class Masker:
    """
    Noise2Self masker that randomly masks pixels during training
    """
    def __init__(self, mask_ratio=0.5):
        self.mask_ratio = mask_ratio

    def mask(self, images, step):
        """
        Randomly mask pixels in the input images
        Args:
            images: Input tensor of shape (B, C, H, W)
            step: Current training step (used for reproducible masking)
        Returns:
            masked_input: Input with some pixels set to 0
            mask: Binary mask indicating which pixels to predict
        """
        # Set random seed based on step for reproducible masking
        torch.manual_seed(step)

        B, C, H, W = images.shape
        mask = torch.rand(B, 1, H, W, device=images.device) < self.mask_ratio

        # Create masked input by setting masked pixels to 0
        masked_input = images.clone()
        masked_input = masked_input * (~mask).float()

        # Expand mask to match number of channels
        mask = mask.expand(-1, C, -1, -1)

        return masked_input, mask

class UNet(nn.Module):
    """
    Simple U-Net architecture for image denoising
    """
    def __init__(self, in_channels=3, out_channels=3):
        super(UNet, self).__init__()

        # Encoder
        self.enc1 = self.conv_block(in_channels, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)

        # Bottleneck
        self.bottleneck = self.conv_block(512, 1024)

        # Decoder
        self.upconv4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = self.conv_block(1024, 512)

        self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = self.conv_block(512, 256)

        self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = self.conv_block(256, 128)

        self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = self.conv_block(128, 64)

        # Output layer
        self.out_conv = nn.Conv2d(64, out_channels, 1)

        # Pooling
        self.pool = nn.MaxPool2d(2)

    def conv_block(self, in_channels, out_channels):
        """Convolutional block with two conv layers"""
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))

        # Bottleneck
        b = self.bottleneck(self.pool(e4))

        # Decoder
        d4 = self.upconv4(b)
        d4 = torch.cat([d4, e4], dim=1)
        d4 = self.dec4(d4)

        d3 = self.upconv3(d4)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)

        d2 = self.upconv2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)

        d1 = self.upconv1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)

        # Output
        out = self.out_conv(d1)
        return out

class SingleImageDataset(Dataset):
    """Dataset for single image training"""
    def __init__(self, image_tensor, num_samples=1000):
        # Ensure image has batch dimension
        if len(image_tensor.shape) == 3:
            self.image = image_tensor.unsqueeze(0)  # Add batch dimension
        else:
            self.image = image_tensor
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.image.squeeze(0), self.image.squeeze(0)  # Return without batch dim for DataLoader

def load_and_preprocess_image(image_path):
    """Load and preprocess RGB image to 512x512"""
    # If no image path provided, create a synthetic noisy image for demo
    if image_path is None:
        print("No image provided. Creating synthetic noisy image for demonstration...")
        # Create a synthetic clean image (gradient + patterns)
        x = np.linspace(0, 1, 512)
        y = np.linspace(0, 1, 512)
        X, Y = np.meshgrid(x, y)

        # Create RGB channels with different patterns
        clean_r = 0.5 + 0.3 * np.sin(10 * X) * np.cos(10 * Y)
        clean_g = 0.5 + 0.3 * np.cos(8 * X) * np.sin(12 * Y)
        clean_b = 0.5 + 0.3 * np.sin(6 * X + 2) * np.cos(8 * Y + 1)

        clean_image = np.stack([clean_r, clean_g, clean_b], axis=2)
        clean_image = np.clip(clean_image, 0, 1)

        # Add Gaussian noise
        noise_std = 0.1
        noisy_image = clean_image + np.random.normal(0, noise_std, clean_image.shape)
        noisy_image = np.clip(noisy_image, 0, 1)

        return torch.tensor(noisy_image, dtype=torch.float32).permute(2, 0, 1), clean_image

    else:
        # Load real image
        image = Image.open(image_path).convert('RGB')

        # Resize to 512x512
        transform = transforms.Compose([
            transforms.Resize((512, 512)),
            transforms.ToTensor()
        ])

        image_tensor = transform(image)
        return image_tensor, None

def train_noise2self(noisy_image, num_epochs=50, learning_rate=0.001):
    """Train Noise2Self model on single noisy image"""

    # Create model
    model = UNet(in_channels=3, out_channels=3).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()
    masker = Masker(mask_ratio=0.5)

    # Create dataset and dataloader
    dataset = SingleImageDataset(noisy_image, num_samples=500)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

    # Training loop
    model.train()
    losses = []

    print("Starting Noise2Self training...")
    for epoch in tqdm(range(num_epochs)):
        epoch_loss = 0

        for step, (batch_images, _) in enumerate(dataloader):
            # Ensure batch_images has correct shape [B, C, H, W]
            if len(batch_images.shape) == 3:
                batch_images = batch_images.unsqueeze(0)

            batch_images = batch_images.to(device)

            # Apply masking
            masked_input, mask = masker.mask(batch_images, epoch * len(dataloader) + step)

            # Forward pass
            optimizer.zero_grad()
            output = model(masked_input)

            # Compute loss only on masked pixels
            loss = criterion(output * mask.float(), batch_images * mask.float())

            # Backward pass
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(dataloader)
        losses.append(avg_loss)

        # Print progress every 200 epochs
        if epoch % 200 == 0:
            print(f"Epoch {epoch}/{num_epochs}, Loss: {avg_loss:.6f}")

    return model, losses

def denoise_image(model, noisy_image):
    """Apply trained model to denoise the image"""
    model.eval()
    with torch.no_grad():
        noisy_input = noisy_image.unsqueeze(0).to(device)
        denoised = model(noisy_input)
        denoised = torch.clamp(denoised, 0, 1)
        return denoised.squeeze(0).cpu()

def visualize_results(noisy_image, denoised_image, clean_image=None, losses=None):
    """Visualize the denoising results"""

    if clean_image is not None:
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))

        # Convert tensors to numpy for visualization
        noisy_np = noisy_image.permute(1, 2, 0).numpy()
        denoised_np = denoised_image.permute(1, 2, 0).numpy()

        # Top row: Full images
        axes[0, 0].imshow(noisy_np)
        axes[0, 0].set_title('Noisy Image')
        axes[0, 0].axis('off')

        axes[0, 1].imshow(denoised_np)
        axes[0, 1].set_title('Denoised Image (Noise2Self)')
        axes[0, 1].axis('off')

        axes[0, 2].imshow(clean_image)
        axes[0, 2].set_title('Clean Image (Ground Truth)')
        axes[0, 2].axis('off')

        # Bottom row: Crops for better detail view
        crop_size = 128
        start_x, start_y = 192, 192  # Center crop

        noisy_crop = noisy_np[start_y:start_y+crop_size, start_x:start_x+crop_size]
        denoised_crop = denoised_np[start_y:start_y+crop_size, start_x:start_x+crop_size]
        clean_crop = clean_image[start_y:start_y+crop_size, start_x:start_x+crop_size]

        axes[1, 0].imshow(noisy_crop)
        axes[1, 0].set_title('Noisy (Crop)')
        axes[1, 0].axis('off')

        axes[1, 1].imshow(denoised_crop)
        axes[1, 1].set_title('Denoised (Crop)')
        axes[1, 1].axis('off')

        axes[1, 2].imshow(clean_crop)
        axes[1, 2].set_title('Clean (Crop)')
        axes[1, 2].axis('off')

    else:
        fig, axes = plt.subplots(1, 2, figsize=(12, 6))

        noisy_np = noisy_image.permute(1, 2, 0).numpy()
        denoised_np = denoised_image.permute(1, 2, 0).numpy()

        axes[0].imshow(noisy_np)
        axes[0].set_title('Noisy Input Image')
        axes[0].axis('off')

        axes[1].imshow(denoised_np)
        axes[1].set_title('Denoised Output (Noise2Self)')
        axes[1].axis('off')

    plt.tight_layout()
    plt.show()

    # Plot training loss
    if losses is not None:
        plt.figure(figsize=(10, 6))
        plt.plot(losses)
        plt.title('Training Loss (Noise2Self)')
        plt.xlabel('Epoch')
        plt.ylabel('MSE Loss')
        plt.grid(True)
        plt.show()

def calculate_metrics(clean_image, denoised_image):
    """Calculate PSNR and MSE metrics"""
    if clean_image is None:
        return None, None

    clean_tensor = torch.tensor(clean_image, dtype=torch.float32).permute(2, 0, 1)

    mse = torch.mean((clean_tensor - denoised_image) ** 2).item()
    psnr = 20 * torch.log10(1.0 / torch.sqrt(torch.tensor(mse))).item()

    return psnr, mse

# Main execution
def main():
    print("=== Noise2Self Single-Shot RGB Image Denoising ===\n")

    # Option 1: Use your own image (uncomment and provide path)
    image_path = "036.png"
    noisy_image, clean_image = load_and_preprocess_image(image_path)



    print(f"Image shape: {noisy_image.shape}")
    print(f"Image dtype: {noisy_image.dtype}")
    print(f"Image range: [{noisy_image.min():.3f}, {noisy_image.max():.3f}]")

    # Train Noise2Self model
    model, losses = train_noise2self(noisy_image, num_epochs=20, learning_rate=0.001)

    # Denoise the image
    print("\nApplying denoising...")
    denoised_image = denoise_image(model, noisy_image)

    # Calculate metrics if clean image is available
    if clean_image is not None:
        psnr, mse = calculate_metrics(clean_image, denoised_image)
        print(f"\nDenoising Metrics:")
        print(f"PSNR: {psnr:.2f} dB")
        print(f"MSE: {mse:.6f}")

    # Visualize results
    print("\nVisualizing results...")
    visualize_results(noisy_image, denoised_image, clean_image, losses)

    print("\n=== Denoising Complete! ===")

    return model, noisy_image, denoised_image

# Run the main function
if __name__ == "__main__":
    model, noisy_image, denoised_image = main()

# Instructions for using your own image:
print("""
To use your own noisy RGB image:

1. Upload your image to Colab:
   - Click on Files tab (folder icon) in the left sidebar
   - Click Upload and select your image
   - Copy the file path

2. Modify the code:
   - Uncomment the lines in main() function:
     # image_path = "/path/to/your/noisy/image.jpg"
     # noisy_image, clean_image = load_and_preprocess_image(image_path)
   - Replace the path with your uploaded image path
   - Comment out the synthetic image lines:
     # image_path = None
     # noisy_image, clean_image = load_and_preprocess_image(image_path)

3. Re-run the notebook

The algorithm will automatically:
- Resize your image to 512x512
- Train a U-Net using Noise2Self methodology
- Output the denoised image

Key parameters you can adjust:
- num_epochs: More epochs = better denoising but longer training
- learning_rate: Higher = faster training but might be unstable
- mask_ratio in Masker: Fraction of pixels to mask during training
""")

Using device: cuda
=== Noise2Self Single-Shot RGB Image Denoising ===

Image shape: torch.Size([3, 512, 512])
Image dtype: torch.float32
Image range: [0.000, 1.000]
Starting Noise2Self training...


  5%|▌         | 1/20 [01:51<35:10, 111.07s/it]

Epoch 0/20, Loss: 0.002282


 80%|████████  | 16/20 [29:19<07:07, 106.82s/it]

In [None]:
out_np = (denoised_image.clamp(0, 1).permute(1, 2, 0).numpy() * 255).astype(np.uint8)
Image.fromarray(out_np).save('denoised_output.png')
print("Saved denoised image to denoised_output.png")

Saved denoised image to denoised_output.png
