In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import MNIST, CIFAR10
from torchvision.utils import make_grid, save_image

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import os
from tqdm.notebook import tqdm
import random

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [6]:
# CGNet model implementation
class CGBlock(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels):
        super(CGBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(mid_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        identity = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        if out.shape == identity.shape:
            out += identity
            
        out = self.relu(out)
        
        return out

class CGNet(nn.Module):
    def __init__(self, in_channels=1, num_blocks=4, base_channels=64):
        super(CGNet, self).__init__()
        
        # Initial convolution
        self.init_conv = nn.Conv2d(in_channels, base_channels, kernel_size=3, padding=1)
        
        # CG blocks
        self.blocks = nn.ModuleList()
        for i in range(num_blocks):
            self.blocks.append(CGBlock(base_channels, base_channels, base_channels))
        
        # Final convolution to reconstruct the image
        self.final_conv = nn.Conv2d(base_channels, in_channels, kernel_size=3, padding=1)
        
    def forward(self, x):
        # Store the input for residual connection
        identity = x
        
        # Initial feature extraction
        x = self.init_conv(x)
        
        # Process through CG blocks
        for block in self.blocks:
            x = block(x)
        
        # Final reconstruction
        x = self.final_conv(x)
        
        # Residual connection with input
        x = x + identity
        
        return x

In [7]:
# Dataset class for noisy and clean image pairs
class DenoisingDataset(Dataset):
    def __init__(self, dataset, noise_level=0.1):
        self.dataset = dataset
        self.noise_level = noise_level
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        
        # Create clean image (target)
        clean = img.clone()
        
        # Add noise to create noisy image (input)
        noise = torch.randn_like(img) * self.noise_level
        noisy = img + noise
        noisy = torch.clamp(noisy, 0, 1)
        
        return noisy, clean, label

# Function to load dataset
def load_dataset(dataset_name='mnist', batch_size=64, noise_level=0.2):
    if dataset_name.lower() == 'mnist':
        transform = transforms.Compose([
            transforms.ToTensor()
        ])
        train_dataset = MNIST('./data', train=True, download=True, transform=transform)
        test_dataset = MNIST('./data', train=False, download=True, transform=transform)
        channels = 1
        
    elif dataset_name.lower() == 'cifar10':
        transform = transforms.Compose([
            transforms.ToTensor()
        ])
        train_dataset = CIFAR10('./data', train=True, download=True, transform=transform)
        test_dataset = CIFAR10('./data', train=False, download=True, transform=transform)
        channels = 3
        
    else:
        raise ValueError(f"Dataset {dataset_name} not supported")
    
    # Create noisy datasets
    train_dataset = DenoisingDataset(train_dataset, noise_level=noise_level)
    test_dataset = DenoisingDataset(test_dataset, noise_level=noise_level)
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    return train_loader, test_loader, channels

In [8]:
# Training function
def train(model, train_loader, optimizer, epoch, log_interval=100):
    model.train()
    losses = []
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
    for batch_idx, (noisy, clean, _) in enumerate(pbar):
        noisy, clean = noisy.to(device), clean.to(device)
        
        optimizer.zero_grad()
        output = model(noisy)
        
        # MSE Loss
        loss = F.mse_loss(output, clean)
        
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
        
        if batch_idx % log_interval == 0:
            pbar.set_postfix(loss=f"{loss.item():.6f}")
    
    avg_loss = sum(losses) / len(losses)
    print(f"Epoch {epoch}: Average loss = {avg_loss:.6f}")
    return avg_loss

# Testing function
def test(model, test_loader):
    model.eval()
    test_loss = 0
    
    with torch.no_grad():
        for noisy, clean, _ in test_loader:
            noisy, clean = noisy.to(device), clean.to(device)
            output = model(noisy)
            test_loss += F.mse_loss(output, clean, reduction='sum').item()
    
    test_loss /= len(test_loader.dataset)
    print(f"Test loss: {test_loss:.6f}")
    return test_loss

In [9]:
# Visualization functions
def show_images(noisy_images, denoised_images, clean_images, num_images=5):
    fig, axes = plt.subplots(3, num_images, figsize=(15, 9))
    
    for i in range(num_images):
        # Noisy image
        if noisy_images[i].shape[0] == 1:  # Grayscale
            axes[0, i].imshow(noisy_images[i].squeeze(0).cpu().numpy(), cmap='gray')
        else:  # RGB
            axes[0, i].imshow(noisy_images[i].permute(1, 2, 0).cpu().numpy())
        axes[0, i].set_title("Noisy")
        axes[0, i].axis('off')
        
        # Denoised image
        if denoised_images[i].shape[0] == 1:  # Grayscale
            axes[1, i].imshow(denoised_images[i].squeeze(0).cpu().numpy(), cmap='gray')
        else:  # RGB
            axes[1, i].imshow(denoised_images[i].permute(1, 2, 0).cpu().numpy())
        axes[1, i].set_title("Denoised")
        axes[1, i].axis('off')
        
        # Original clean image
        if clean_images[i].shape[0] == 1:  # Grayscale
            axes[2, i].imshow(clean_images[i].squeeze(0).cpu().numpy(), cmap='gray')
        else:  # RGB
            axes[2, i].imshow(clean_images[i].permute(1, 2, 0).cpu().numpy())
        axes[2, i].set_title("Clean")
        axes[2, i].axis('off')
    
    plt.tight_layout()
    plt.show()

def visualize_results(model, test_loader, num_images=5):
    model.eval()
    
    # Get a batch
    batch = next(iter(test_loader))
    noisy_images, clean_images, _ = batch
    
    # Select a subset of images
    noisy_subset = noisy_images[:num_images].to(device)
    clean_subset = clean_images[:num_images]
    
    # Generate denoised images
    with torch.no_grad():
        denoised_subset = model(noisy_subset).cpu()
    
    # Show the images
    show_images(noisy_subset.cpu(), denoised_subset, clean_subset)

# Function to calculate PSNR
def calculate_psnr(img1, img2):
    mse = torch.mean((img1 - img2) ** 2)
    return 20 * torch.log10(1.0 / torch.sqrt(mse))

In [10]:
# Set up the experiment
dataset_name = 'mnist'  # 'mnist' or 'cifar10'
batch_size = 64
noise_level = 0.2
num_epochs = 10
learning_rate = 0.001

# Load dataset
train_loader, test_loader, channels = load_dataset(dataset_name, batch_size, noise_level)

# Create model
model = CGNet(in_channels=channels, num_blocks=4, base_channels=64).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
train_losses = []
test_losses = []

print(f"Starting training on {dataset_name} dataset with noise level {noise_level}")

for epoch in range(1, num_epochs + 1):
    train_loss = train(model, train_loader, optimizer, epoch)
    test_loss = test(model, test_loader)
    
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    
    # Visualize intermediate results every few epochs
    if epoch % 2 == 0 or epoch == num_epochs:
        visualize_results(model, test_loader)

Starting training on mnist dataset with noise level 0.2


Epoch 1:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 1: Average loss = 0.012433
Test loss: 2.532931


Epoch 2:   0%|          | 0/938 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# Plot training and test loss curves
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs + 1), train_losses, label='Training Loss')
plt.plot(range(1, num_epochs + 1), test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss (MSE)')
plt.title('Training and Test Loss')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# Evaluate the model on test set with PSNR metric
def evaluate_psnr(model, test_loader):
    model.eval()
    psnr_values = []
    
    with torch.no_grad():
        for noisy, clean, _ in test_loader:
            noisy, clean = noisy.to(device), clean.to(device)
            denoised = model(noisy)
            
            # Calculate PSNR for each image in the batch
            for i in range(clean.size(0)):
                psnr = calculate_psnr(denoised[i], clean[i])
                psnr_values.append(psnr.item())
    
    avg_psnr = sum(psnr_values) / len(psnr_values)
    print(f"Average PSNR: {avg_psnr:.2f} dB")
    return avg_psnr

# Evaluate the model
avg_psnr = evaluate_psnr(model, test_loader)

In [None]:
# Save the trained model
torch.save(model.state_dict(), f'cgnet_{dataset_name}_noise{noise_level}.pth')

# Generate more visualization examples
def generate_comparison_grid(model, test_loader, num_samples=8):
    model.eval()
    
    # Get samples from test set
    samples = []
    for noisy, clean, _ in test_loader:
        samples.append((noisy, clean))
        if len(samples) * batch_size >= num_samples:
            break
    
    # Select random samples
    all_noisy = torch.cat([batch[0] for batch in samples])
    all_clean = torch.cat([batch[1] for batch in samples])
    
    indices = torch.randperm(len(all_noisy))[:num_samples]
    selected_noisy = all_noisy[indices].to(device)
    selected_clean = all_clean[indices]
    
    # Generate denoised outputs
    with torch.no_grad():
        denoised = model(selected_noisy).cpu()
    
    # Create a grid of all images
    noisy_grid = make_grid(selected_noisy.cpu(), nrow=num_samples, normalize=True, padding=2)
    denoised_grid = make_grid(denoised, nrow=num_samples, normalize=True, padding=2)
    clean_grid = make_grid(selected_clean, nrow=num_samples, normalize=True, padding=2)
    
    # Stack the grids vertically
    final_grid = torch.cat([noisy_grid, denoised_grid, clean_grid], dim=1)
    
    # Convert to image and display
    plt.figure(figsize=(15, 8))
    if channels == 1:
        plt.imshow(final_grid.permute(1, 2, 0).squeeze(-1), cmap='gray')
    else:
        plt.imshow(final_grid.permute(1, 2, 0))
    plt.title("Comparison: Noisy (top) → Denoised (middle) → Clean (bottom)")
    plt.axis('off')
    plt.tight_layout()
    plt.show()
    
    # Save the comparison grid
    save_image(final_grid, f'comparison_grid_{dataset_name}.png')

# Generate and display the comparison grid
generate_comparison_grid(model, test_loader)

In [None]:
# Try with different noise levels
def test_different_noise_levels(model, dataset_name='mnist', noise_levels=[0.1, 0.2, 0.3, 0.5]):
    results = []
    
    for noise_level in noise_levels:
        print(f"\nTesting noise level: {noise_level}")
        _, test_loader, _ = load_dataset(dataset_name, batch_size=16, noise_level=noise_level)
        
        # Evaluate PSNR
        psnr = evaluate_psnr(model, test_loader)
        results.append((noise_level, psnr))
        
        # Visualize some examples
        visualize_results(model, test_loader, num_images=3)
    
    # Plot PSNR vs Noise Level
    noise_levels, psnrs = zip(*results)
    plt.figure(figsize=(8, 5))
    plt.plot(noise_levels, psnrs, 'o-')
    plt.xlabel('Noise Level')
    plt.ylabel('PSNR (dB)')
    plt.title('Denoising Performance vs Noise Level')
    plt.grid(True)
    plt.show()

# Test the model with different noise levels
test_different_noise_levels(model)

In [None]:
# Load and try with a custom image
def denoise_custom_image(model, image_path, noise_level=0.2):
    # Load the image
    img = Image.open(image_path).convert('L' if channels == 1 else 'RGB')
    
    # Prepare transforms
    transform = transforms.Compose([
        transforms.Resize((28, 28) if dataset_name == 'mnist' else (32, 32)),
        transforms.ToTensor()
    ])
    
    # Transform the image
    clean_tensor = transform(img).unsqueeze(0)
    
    # Add noise
    noise = torch.randn_like(clean_tensor) * noise_level
    noisy_tensor = clean_tensor + noise
    noisy_tensor = torch.clamp(noisy_tensor, 0, 1)
    
    # Denoise
    model.eval()
    with torch.no_grad():
        denoised_tensor = model(noisy_tensor.to(device)).cpu()
    
    # Display the results
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    if channels == 1:
        axes[0].imshow(noisy_tensor.squeeze().numpy(), cmap='gray')
        axes[1].imshow(denoised_tensor.squeeze().numpy(), cmap='gray')
        axes[2].imshow(clean_tensor.squeeze().numpy(), cmap='gray')
    else:
        axes[0].imshow(noisy_tensor.squeeze().permute(1, 2, 0).numpy())
        axes[1].imshow(denoised_tensor.squeeze().permute(1, 2, 0).numpy())
        axes[2].imshow(clean_tensor.squeeze().permute(1, 2, 0).numpy())
    
    axes[0].set_title('Noisy Image')
    axes[1].set_title('Denoised Image')
    axes[2].set_title('Original Image')
    
    for ax in axes:
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

# You can uncomment and run this if you have a custom image
# denoise_custom_image(model, 'path/to/your/image.jpg', noise_level=0.2)

In [None]:
# Conclusion and final model summary
print("Model Summary:")
print(model)

print("\nTraining Summary:")
print(f"Dataset: {dataset_name}")
print(f"Noise Level: {noise_level}")
print(f"Epochs: {num_epochs}")
print(f"Final Train Loss: {train_losses[-1]:.6f}")
print(f"Final Test Loss: {test_losses[-1]:.6f}")
print(f"Average PSNR: {avg_psnr:.2f} dB")