In [1]:
# clear_vision_complete.py

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.autograd import grad
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
import torchvision.models as models
import numpy as np
import matplotlib.pyplot as plt
import random
from PIL import Image
import cv2
from tqdm import tqdm
import time
from io import BytesIO
import scipy
from scipy import linalg
import pickle
import json

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

#----------------------------------------------------------------------------------
# Dataset and Data Loading Functions
#----------------------------------------------------------------------------------

# Define transformations - using smaller images for CPU processing
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Custom dataset class for CelebA
class CelebADataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = [os.path.join(root_dir, f) for f in os.listdir(root_dir)
                          if f.endswith('.jpg') or f.endswith('.png')]
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image

# Load the dataset
def load_dataset(root_dir='/kaggle/input/celebahq-resized-256x256/celeba_hq_256', max_images=None):
    try:
        dataset = CelebADataset(root_dir, transform=transform)
        
        # Limit dataset if specified
        if max_images and max_images < len(dataset):
            indices = list(range(len(dataset)))
            random.shuffle(indices)
            indices = indices[:max_images]
            dataset = torch.utils.data.Subset(dataset, indices)
        
        # Split dataset into train and validation sets
        train_size = int(0.9 * len(dataset))
        val_size = len(dataset) - train_size
        train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
        
        # Create data loaders - small batch size for CPU
        batch_size = 16
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
        
        print(f"Dataset loaded: {len(dataset)} images")
        print(f"Training set: {len(train_dataset)} images")
        print(f"Validation set: {len(val_dataset)} images")
        
        return train_loader, val_loader
    except Exception as e:
        print(f"Error loading dataset: {e}")
        print("Please ensure the CelebA dataset is correctly placed in the specified directory")
        return None, None

#----------------------------------------------------------------------------------
# Image Corruption Functions
#----------------------------------------------------------------------------------

# 1. Mask-based corruption (as in the original repo)
def apply_mask(images, mask_size=32, offset_x=None, offset_y=None):
    """Apply a square mask to the image at a specified or random position."""
    batch_size, channels, height, width = images.shape
    corrupted_images = images.clone()
    
    # For each image in the batch
    for i in range(batch_size):
        # Determine mask position (random if not specified)
        if offset_x is None:
            offset_x = random.randint(0, width - mask_size)
        if offset_y is None:
            offset_y = random.randint(0, height - mask_size)
        
        # Apply mask (set to 0)
        corrupted_images[i, :, offset_y:offset_y+mask_size, offset_x:offset_x+mask_size] = 0.0
    
    return corrupted_images

# 2. Gaussian noise
def apply_gaussian_noise(images, mean=0.0, std=0.1):
    """Add Gaussian noise to images."""
    noise = torch.randn_like(images) * std + mean
    corrupted_images = images + noise
    # Clip values to be in [-1, 1] range
    corrupted_images = torch.clamp(corrupted_images, -1, 1)
    return corrupted_images

# 3. Salt and pepper noise
def apply_salt_pepper_noise(images, salt_prob=0.02, pepper_prob=0.02):
    """Add salt and pepper noise to images."""
    corrupted_images = images.clone()
    batch_size, channels, height, width = images.shape
    
    # For each image in the batch
    for i in range(batch_size):
        # Generate salt noise (white pixels)
        salt_mask = torch.rand(channels, height, width, device=images.device) < salt_prob
        corrupted_images[i][salt_mask] = 1.0
        
        # Generate pepper noise (black pixels)
        pepper_mask = torch.rand(channels, height, width, device=images.device) < pepper_prob
        corrupted_images[i][pepper_mask] = -1.0
    
    return corrupted_images

# 4. Blur effects
def apply_gaussian_blur(images, kernel_size=7, sigma=1.5):
    """Apply Gaussian blur to images."""
    batch_size = images.shape[0]
    corrupted_images = images.clone()
    
    # Apply Gaussian blur to each image in the batch
    for i in range(batch_size):
        # Make sure kernel size is odd
        if kernel_size % 2 == 0:
            kernel_size += 1
        
        for c in range(3): # Process each channel
            img = images[i, c].cpu().numpy()
            blurred = cv2.GaussianBlur(img, (kernel_size, kernel_size), sigma)
            corrupted_images[i, c] = torch.from_numpy(blurred).to(images.device)
    
    return corrupted_images

# 5. JPEG compression artifacts
def apply_jpeg_artifacts(images, quality=10):
    """Apply JPEG compression artifacts to images."""
    batch_size = images.shape[0]
    corrupted_images = torch.zeros_like(images)
    
    # Apply JPEG compression to each image in the batch
    for i in range(batch_size):
        # Convert to PIL Image
        img = transforms.ToPILImage()((images[i] + 1) / 2) # Denormalize
        
        # Save as JPEG and load back to apply compression artifacts
        buffer = BytesIO()
        img.save(buffer, format="JPEG", quality=quality)
        buffer.seek(0)
        compressed_img = Image.open(buffer)
        
        # Convert back to tensor and normalize
        compressed_tensor = transforms.ToTensor()(compressed_img) * 2 - 1
        corrupted_images[i] = compressed_tensor
    
    return corrupted_images

# 6. Combined corruption (apply multiple types randomly)
def apply_combined_corruption(images):
    """Apply a random combination of corruption types."""
    corrupted_images = images.clone()
    
    # List of corruption functions
    corruptions = [
        lambda x: apply_mask(x, mask_size=random.randint(20, 40)),
        lambda x: apply_gaussian_noise(x, std=random.uniform(0.05, 0.2)),
        lambda x: apply_salt_pepper_noise(x, salt_prob=random.uniform(0.01, 0.05),
                                       pepper_prob=random.uniform(0.01, 0.05)),
        lambda x: apply_gaussian_blur(x, kernel_size=random.choice([3, 5, 7]),
                                    sigma=random.uniform(0.5, 2.0))
    ]
    
    # Apply 1-3 random corruptions
    num_corruptions = random.randint(1, 3)
    selected_corruptions = random.sample(corruptions, num_corruptions)
    
    for corruption_fn in selected_corruptions:
        corrupted_images = corruption_fn(corrupted_images)
    
    return corrupted_images

#----------------------------------------------------------------------------------
# Model Architecture
#----------------------------------------------------------------------------------

class ResidualBlock(nn.Module):
    """Residual block for better gradient flow and feature learning."""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # Skip connection if dimensions change
        self.skip = nn.Sequential()
        if in_channels != out_channels:
            self.skip = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        residual = x
        out = F.silu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.skip(residual)
        out = F.silu(out)
        return out

class VAE(nn.Module):
    def __init__(self, latent_dim=128):
        super().__init__()
        self.latent_dim = latent_dim
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.SiLU(),
            
            ResidualBlock(32, 32),
            
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.SiLU(),
            
            ResidualBlock(64, 64),
            
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.SiLU(),
            
            ResidualBlock(128, 128),
            
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.SiLU(),
            
            ResidualBlock(256, 256)
        )
        
        # For 128x128 input, feature map size is now 8x8
        self.encoder_output_size = 8
        encoder_flatten_size = 256 * self.encoder_output_size * self.encoder_output_size
        
        # Latent space mapping
        self.fc_mu = nn.Linear(encoder_flatten_size, latent_dim)
        self.fc_logvar = nn.Linear(encoder_flatten_size, latent_dim)
        self.fc_decoder = nn.Linear(latent_dim, encoder_flatten_size)
        
        # Decoder
        self.decoder = nn.Sequential(
            ResidualBlock(256, 256),
            
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.SiLU(),
            
            ResidualBlock(128, 128),
            
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.SiLU(),
            
            ResidualBlock(64, 64),
            
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.SiLU(),
            
            ResidualBlock(32, 32),
            
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )
    
    def encode(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1) # Flatten
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z
    
    def decode(self, z):
        x = self.fc_decoder(z)
        x = x.view(x.size(0), 256, self.encoder_output_size, self.encoder_output_size)
        x = self.decoder(x)
        return x
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        reconstructed = self.decode(z)
        return reconstructed, mu, logvar

#----------------------------------------------------------------------------------
# Loss Function and Training Utilities
#----------------------------------------------------------------------------------

# Enhanced VAE loss function
def vae_loss(reconstructed, original, mu, logvar, kld_weight=0.001):
    """
    VAE loss with balanced reconstruction and KL divergence terms.
    Lower KLD weight for better reconstructions.
    """
    # Reconstruction loss (MSE works better for image quality)
    recon_loss = F.mse_loss(reconstructed, original, reduction='sum') / original.size(0)
    
    # KL Divergence
    kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / original.size(0)
    
    # Total loss
    loss = recon_loss + kld_weight * kld_loss
    
    return loss, recon_loss, kld_loss

# Function to save sample images during training
def save_sample_images(model, data_loader, epoch, output_dir='/kaggle/working/samples'):
    """Save sample reconstructions during training."""
    os.makedirs(output_dir, exist_ok=True)
    model.eval()
    
    # Get a batch of images
    batch = next(iter(data_loader))
    if isinstance(batch, list):
        batch = batch[0]  # For ImageFolder dataset
    original_images = batch.to(device)[:8]  # Use 8 images
    
    # Apply different corruptions
    mask_images = apply_mask(original_images.clone())
    gaussian_images = apply_gaussian_noise(original_images.clone())
    salt_pepper_images = apply_salt_pepper_noise(original_images.clone())
    blur_images = apply_gaussian_blur(original_images.clone())
    jpeg_images = apply_jpeg_artifacts(original_images.clone())
    combined_images = apply_combined_corruption(original_images.clone())
    
    # Get reconstructions
    with torch.no_grad():
        mask_reconstructed, _, _ = model(mask_images)
        gaussian_reconstructed, _, _ = model(gaussian_images)
        salt_pepper_reconstructed, _, _ = model(salt_pepper_images)
        blur_reconstructed, _, _ = model(blur_images)
        jpeg_reconstructed, _, _ = model(jpeg_images)
        combined_reconstructed, _, _ = model(combined_images)
    
    # Create comparison grid: original, corrupted, reconstructed
    rows = [
        torch.cat([original_images, mask_images, mask_reconstructed], dim=0),
        torch.cat([original_images, gaussian_images, gaussian_reconstructed], dim=0),
        torch.cat([original_images, salt_pepper_images, salt_pepper_reconstructed], dim=0),
        torch.cat([original_images, blur_images, blur_reconstructed], dim=0),
        torch.cat([original_images, jpeg_images, jpeg_reconstructed], dim=0),
        torch.cat([original_images, combined_images, combined_reconstructed], dim=0)
    ]
    
    # Save grid
    comparison = torch.cat(rows, dim=0)
    save_image(comparison.cpu() * 0.5 + 0.5, f"{output_dir}/epoch_{epoch}.png",
              nrow=8, padding=2, normalize=False)
    
    print(f"Sample images saved for epoch {epoch}")
    model.train()

#----------------------------------------------------------------------------------
# New LPIPS Implementation
#----------------------------------------------------------------------------------

class LPIPSMetric:
    def __init__(self, device='cuda'):
        self.device = device
        # Try to load LPIPS package
        try:
            import lpips
            self.lpips_model = lpips.LPIPS(net='alex').to(device)
            self.has_lpips = True
            print("LPIPS module loaded successfully")
        except ImportError:
            print("Warning: LPIPS package not found. Using AlexNet features as approximation.")
            print("For better results, install LPIPS with: pip install lpips")
            # Use AlexNet features as approximation for LPIPS
            self.has_lpips = False
            self.alexnet = models.alexnet(pretrained=True).features[:9].eval().to(device)
            for param in self.alexnet.parameters():
                param.requires_grad = False
    
    def calculate_lpips(self, img1, img2):
        """Calculate LPIPS between two images"""
        # Ensure images have the same dimensions
        if img1.shape[2:] != img2.shape[2:]:
            img1 = F.interpolate(img1, size=img2.shape[2:], mode='bilinear', align_corners=False)
        
        if self.has_lpips:
            # Use official LPIPS if available
            with torch.no_grad():
                lpips_dist = self.lpips_model(img1, img2)
                return lpips_dist.mean().item()
        else:
            # Use AlexNet feature distance as approximation
            with torch.no_grad():
                # Make sure images are in range [-1, 1]
                if img1.min() < -1 or img1.max() > 1:
                    img1 = torch.clamp(img1, -1, 1)
                if img2.min() < -1 or img2.max() > 1:
                    img2 = torch.clamp(img2, -1, 1)
                
                # Convert from [-1,1] to [0,1] for AlexNet
                img1 = (img1 + 1) / 2
                img2 = (img2 + 1) / 2
                
                # Extract features
                feat1 = self.alexnet(img1)
                feat2 = self.alexnet(img2)
                
                # Calculate distance
                lpips_dist = torch.mean((feat1 - feat2).pow(2))
                return lpips_dist.item()

#----------------------------------------------------------------------------------
# New FID Score Implementation
#----------------------------------------------------------------------------------

class FIDMetric:
    def __init__(self, device='cuda'):
        self.device = device
        
        # Load Inception v3 model for FID computation
        try:
            self.inception_model = models.inception_v3(pretrained=True, transform_input=False)
            self.inception_model.fc = nn.Identity()  # Remove classification layer
            self.inception_model.eval().to(device)
            for param in self.inception_model.parameters():
                param.requires_grad = False
            self.has_inception = True
            print("Inception model loaded for FID calculation")
        except:
            print("Warning: Could not load Inception model. FID scores will not be accurate.")
            try:
                # Try with weights parameter instead
                self.inception_model = models.inception_v3(weights='DEFAULT', transform_input=False)
                self.inception_model.fc = nn.Identity()  # Remove classification layer
                self.inception_model.eval().to(device)
                for param in self.inception_model.parameters():
                    param.requires_grad = False
                self.has_inception = True
                print("Inception model loaded with updated API")
            except:
                self.has_inception = False
    
    def get_activations(self, images, batch_size=50):
        """Get Inception activations for a batch of images"""
        n_batches = len(images) // batch_size + 1
        act = np.empty((len(images), 2048))
        
        for batch_idx in range(n_batches):
            start_idx = batch_idx * batch_size
            end_idx = min((batch_idx + 1) * batch_size, len(images))
            
            if start_idx >= end_idx:
                break
                
            batch = images[start_idx:end_idx].to(self.device)
            
            # Resize to inception input size
            if batch.shape[2] != 299 or batch.shape[3] != 299:
                batch = F.interpolate(batch, size=(299, 299), mode='bilinear', align_corners=False)
            
            # Normalize from [-1, 1] to [0, 1] to [0, 255]
            if batch.min() < 0:
                batch = (batch + 1) / 2  # [-1,1] -> [0,1]
            
            # Forward pass
            with torch.no_grad():
                pred = self.inception_model(batch)
            
            # Store activations
            act[start_idx:end_idx] = pred.cpu().numpy()
        
        return act
    
    def calculate_activation_statistics(self, images):
        """Calculate statistics of activations for FID"""
        activations = self.get_activations(images)
        mu = np.mean(activations, axis=0)
        sigma = np.cov(activations, rowvar=False)
        return mu, sigma
    
    def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2):
        """Calculate Fréchet distance between two multivariate Gaussians"""
        diff = mu1 - mu2
        
        # Product might be almost singular
        covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False)
        if not np.isfinite(covmean).all():
            offset = np.eye(sigma1.shape[0]) * 1e-6
            covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
        
        # Numerical error might give slight imaginary component
        if np.iscomplexobj(covmean):
            covmean = covmean.real
        
        tr_covmean = np.trace(covmean)
        
        return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
    
    def calculate_fid(self, real_images, generated_images):
        """Calculate FID between real and generated images"""
        if not self.has_inception:
            return -1  # Cannot calculate FID without inception model
        
        # Get statistics for real and generated images
        mu_real, sigma_real = self.calculate_activation_statistics(real_images)
        mu_gen, sigma_gen = self.calculate_activation_statistics(generated_images)
        
        # Calculate FID
        fid_score = self.calculate_frechet_distance(mu_real, sigma_real, mu_gen, sigma_gen)
        return fid_score

#----------------------------------------------------------------------------------
# Evaluation Metrics
#----------------------------------------------------------------------------------

# Basic evaluation metrics
def compute_psnr(img1, img2):
    """Compute Peak Signal-to-Noise Ratio between image tensors."""
    # Convert to range [0, 1] if in [-1, 1]
    if img1.min() < 0:
        img1 = img1 * 0.5 + 0.5
    if img2.min() < 0:
        img2 = img2 * 0.5 + 0.5
    
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    max_pixel = 1.0
    psnr = 20 * torch.log10(max_pixel / torch.sqrt(mse))
    return psnr.item()

def compute_ssim(img1, img2):
    """Compute Structural Similarity Index between image tensors."""
    # Convert to range [0, 1] if in [-1, 1]
    if img1.min() < 0:
        img1 = img1 * 0.5 + 0.5
    if img2.min() < 0:
        img2 = img2 * 0.5 + 0.5
    
    # Constants
    C1 = (0.01 * 1) ** 2
    C2 = (0.03 * 1) ** 2
    
    # Calculate mean, variance, covariance
    mu1 = torch.mean(img1, dim=(2, 3), keepdim=True)
    mu2 = torch.mean(img2, dim=(2, 3), keepdim=True)
    
    sigma1_sq = torch.mean((img1 - mu1) ** 2, dim=(2, 3), keepdim=True)
    sigma2_sq = torch.mean((img2 - mu2) ** 2, dim=(2, 3), keepdim=True)
    sigma12 = torch.mean((img1 - mu1) * (img2 - mu2), dim=(2, 3), keepdim=True)
    
    # SSIM formula
    numerator = (2 * mu1 * mu2 + C1) * (2 * sigma12 + C2)
    denominator = (mu1**2 + mu2**2 + C1) * (sigma1_sq + sigma2_sq + C2)
    ssim = torch.mean(numerator / denominator)
    
    return ssim.item()

# Comprehensive evaluation with all metrics
def evaluate_model(model, data_loader, corruption_types=['mask', 'gaussian', 'salt_pepper', 'blur', 'jpeg', 'combined']):
    """Evaluate model on multiple corruption types with PSNR, SSIM, LPIPS and FID."""
    model.eval()
    results = {}
    
    # Initialize metrics
    lpips_metric = LPIPSMetric(device)
    fid_metric = FIDMetric(device)
    
    # Get a batch of images
    batch = next(iter(data_loader))
    if isinstance(batch, list):
        batch = batch[0]
    original_images = batch.to(device)[:32]  # Use 32 images for evaluation
    
    corruption_functions = {
        'mask': apply_mask,
        'gaussian': apply_gaussian_noise,
        'salt_pepper': apply_salt_pepper_noise,
        'blur': apply_gaussian_blur,
        'jpeg': apply_jpeg_artifacts,
        'combined': apply_combined_corruption
    }
    
    # Create progress bar for corruption types evaluation
    corruption_bar = tqdm(corruption_types, desc="Evaluating corruption types")
    
    for corruption_type in corruption_bar:
        corruption_bar.set_description(f"Evaluating {corruption_type}")
        
        if corruption_type not in corruption_functions:
            print(f"Unknown corruption type: {corruption_type}")
            continue
        
        # Apply corruption
        corruption_fn = corruption_functions[corruption_type]
        corrupted_images = corruption_fn(original_images.clone())
        
        # Get reconstructions
        with torch.no_grad():
            start_time = time.time()
            reconstructed_images, _, _ = model(corrupted_images)
            inference_time = time.time() - start_time
        
        # Calculate metrics - silently, without progress bar
        avg_psnr = 0
        avg_ssim = 0
        lpips_scores = []
        
        # Use regular loop instead of tqdm progress bar to avoid per-percentage updates
        for i in range(len(original_images)):
            psnr = compute_psnr(original_images[i:i+1], reconstructed_images[i:i+1])
            ssim = compute_ssim(original_images[i:i+1], reconstructed_images[i:i+1])
            lpips = lpips_metric.calculate_lpips(original_images[i:i+1], reconstructed_images[i:i+1])
            
            avg_psnr += psnr
            avg_ssim += ssim
            lpips_scores.append(lpips)
        
        avg_psnr /= len(original_images)
        avg_ssim /= len(original_images)
        avg_lpips = sum(lpips_scores) / len(lpips_scores)
        
        # Calculate FID (this is computationally expensive)
        try:
            fid_score = fid_metric.calculate_fid(original_images, reconstructed_images)
        except:
            fid_score = -1
            print(f"Warning: Failed to calculate FID score for {corruption_type}")
        
        results[corruption_type] = {
            'psnr': avg_psnr,
            'ssim': avg_ssim,
            'lpips': avg_lpips,
            'fid': fid_score,
            'inference_time_ms': (inference_time / len(original_images)) * 1000
        }
        
        corruption_bar.set_postfix(psnr=f"{avg_psnr:.2f}dB", ssim=f"{avg_ssim:.4f}", lpips=f"{avg_lpips:.4f}")
    
    model.train()
    return results


#----------------------------------------------------------------------------------
# Training Function
#----------------------------------------------------------------------------------
def train_vae(model, train_loader, val_loader, num_epochs=100, lr=1e-4, kld_weight=0.001, 
              save_dir='/kaggle/working/checkpoints'):
    """Train the VAE model with multiple corruption types."""
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs('/kaggle/working/samples', exist_ok=True)
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
    
    # Track learning rate changes manually
    previous_lr = optimizer.param_groups[0]['lr']
    
    # Lists to store loss values
    train_losses = []
    val_losses = []
    
    # Use different corruption types during training
    corruption_types = ['mask', 'gaussian', 'salt_pepper', 'blur', 'jpeg', 'combined']
    corruption_functions = {
        'mask': apply_mask,
        'gaussian': apply_gaussian_noise,
        'salt_pepper': apply_salt_pepper_noise,
        'blur': apply_gaussian_blur,
        'jpeg': apply_jpeg_artifacts,
        'combined': apply_combined_corruption
    }
    
    print("Starting training...")
    
    # Create overall progress bar for epochs
    try:
        epoch_bar = tqdm(range(num_epochs), desc="Training Progress", position=0)
        
        for epoch in epoch_bar:
            model.train()
            epoch_train_loss = 0
            epoch_recon_loss = 0
            epoch_kld_loss = 0
            batch_count = 0
            
            # Training loop - without detailed per-batch output
            for batch in train_loader:
                # Move batch to device
                if isinstance(batch, list):
                    batch = batch[0]
                original = batch.to(device)
                
                # Apply random corruption
                corruption_type = random.choice(corruption_types)
                corrupted = corruption_functions[corruption_type](original.clone())
                
                # Forward pass
                reconstructed, mu, logvar = model(corrupted)
                
                # Compute loss
                loss, recon_loss, kld_loss = vae_loss(reconstructed, original, mu, logvar, kld_weight)
                
                # Backward pass and optimization
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                # Accumulate losses without printing
                epoch_train_loss += loss.item()
                epoch_recon_loss += recon_loss.item()
                epoch_kld_loss += kld_loss.item()
                batch_count += 1
            
            # Average training loss for this epoch
            avg_train_loss = epoch_train_loss / batch_count
            avg_recon_loss = epoch_recon_loss / batch_count
            avg_kld_loss = epoch_kld_loss / batch_count
            train_losses.append(avg_train_loss)
            
            # Validation - without detailed per-batch output
            model.eval()
            epoch_val_loss = 0
            val_batch_count = 0
            
            with torch.no_grad():
                for batch in val_loader:
                    if isinstance(batch, list):
                        batch = batch[0]
                    original = batch.to(device)
                    
                    # Apply random corruption for validation
                    corruption_type = random.choice(corruption_types)
                    corrupted = corruption_functions[corruption_type](original.clone())
                    
                    # Forward pass
                    reconstructed, mu, logvar = model(corrupted)
                    
                    # Compute loss
                    loss, _, _ = vae_loss(reconstructed, original, mu, logvar, kld_weight)
                    epoch_val_loss += loss.item()
                    val_batch_count += 1
            
            # Average validation loss for this epoch
            avg_val_loss = epoch_val_loss / val_batch_count
            val_losses.append(avg_val_loss)
            
            # Update the main epoch progress bar
            epoch_bar.set_postfix(train_loss=f"{avg_train_loss:.4f}", val_loss=f"{avg_val_loss:.4f}")
            
            # Print epoch summary - ONLY PRINT ONCE PER EPOCH
            print(f"\nEpoch {epoch+1}/{num_epochs}:")
            print(f" Train Loss: {avg_train_loss:.6f} (Recon: {avg_recon_loss:.6f}, KLD: {avg_kld_loss:.6f})")
            print(f" Validation Loss: {avg_val_loss:.6f}")
            print(f" Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
            
            # Save model checkpoint
            if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1:
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'train_loss': avg_train_loss,
                    'val_loss': avg_val_loss,
                    'latent_dim': model.latent_dim
                }, f"{save_dir}/vae_epoch_{epoch+1}.pth")
                print(f"Model checkpoint saved to {save_dir}/vae_epoch_{epoch+1}.pth")
            
            # Generate and save sample images
            if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1:
                save_sample_images(model, val_loader, epoch + 1)
            
            # Evaluate model on validation set (every 5 epochs or at the end)
            if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1:
                print(f"\nRunning evaluation at epoch {epoch+1}...")
                eval_results = evaluate_model(model, val_loader)
                print("\nValidation Metrics:")
                for corruption_type, metrics in eval_results.items():
                    print(f" {corruption_type}: PSNR = {metrics['psnr']:.2f}dB, SSIM = {metrics['ssim']:.4f}, " +
                        f"LPIPS = {metrics['lpips']:.4f}, FID = {metrics['fid']:.2f}, " +
                        f"Inference Time = {metrics['inference_time_ms']:.2f}ms")
                
                # Save metrics to file for later analysis
                metrics_file = f"{save_dir}/metrics_epoch_{epoch+1}.json"
                with open(metrics_file, 'w') as f:
                    json.dump(eval_results, f, indent=2)
            
            # Update learning rate
            scheduler.step(avg_val_loss)
            
            # Check for learning rate changes and log them
            current_lr = optimizer.param_groups[0]['lr']
            if current_lr != previous_lr:
                print(f"\nLearning rate adjusted from {previous_lr:.6f} to {current_lr:.6f}")
                previous_lr = current_lr
            
    except KeyboardInterrupt:
        print("\nTraining interrupted by user.")
    except Exception as e:
        print(f"\nError during training: {e}")
    
    # Plot training and validation loss
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig('/kaggle/working/loss_curve.png')
    plt.close()
    
    print("Training completed!")
    
    # Save the final model in Kaggle output directory
    final_model_path = '/kaggle/working/vae_final_model.pth'
    torch.save({
        'epoch': num_epochs,
        'model_state_dict': model.state_dict(),
        'latent_dim': model.latent_dim
    }, final_model_path)
    print(f"Final model saved to {final_model_path}")
    
    return train_losses, val_losses
 
            
         
    
   
    
    

#----------------------------------------------------------------------------------
# Image Restoration and Visualization
#----------------------------------------------------------------------------------

def restore_image(model, image_path, corruption_type='combined', save_path=None):
    """Restore a single image using the trained VAE model."""
    if isinstance(model, dict):
        # If model is a state dict, create a new model instance
        latent_dim = model.get('latent_dim', 128)
        model_instance = VAE(latent_dim=latent_dim).to(device)
        model_instance.load_state_dict(model['model_state_dict'])
        model = model_instance
        
    model.eval()
    
    # Load and preprocess image
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    
    image = Image.open(image_path).convert('RGB')
    original = transform(image).unsqueeze(0).to(device)
    
    # Apply corruption
    corruption_functions = {
        'mask': apply_mask,
        'gaussian': apply_gaussian_noise,
        'salt_pepper': apply_salt_pepper_noise,
        'blur': apply_gaussian_blur,
        'jpeg': apply_jpeg_artifacts,
        'combined': apply_combined_corruption
    }
    
    if corruption_type not in corruption_functions:
        print(f"Unknown corruption type: {corruption_type}. Using 'combined' instead.")
        corruption_type = 'combined'
    
    corrupted = corruption_functions[corruption_type](original.clone())
    
    # Restore image
    with torch.no_grad():
        restored, _, _ = model(corrupted)
    
    # Convert tensors to images for visualization
    # Convert from [-1, 1] range to [0, 1]
    original_img = original.cpu() * 0.5 + 0.5
    corrupted_img = corrupted.cpu() * 0.5 + 0.5
    restored_img = restored.cpu() * 0.5 + 0.5
    
    # Create comparison grid
    comparison = torch.cat([original_img, corrupted_img, restored_img], dim=0)
    
    # Save or display results
    if save_path:
        save_image(comparison, save_path, nrow=1, padding=5)
        print(f"Restored image saved to {save_path}")
    
    # Calculate metrics
    psnr = compute_psnr(original, restored)
    ssim = compute_ssim(original, restored)
    
    # Initialize LPIPS
    lpips_metric = LPIPSMetric(device)
    lpips_val = lpips_metric.calculate_lpips(original, restored)
    
    print(f"Restoration Metrics:")
    print(f" PSNR: {psnr:.2f}dB")
    print(f" SSIM: {ssim:.4f}")
    print(f" LPIPS: {lpips_val:.4f}")
    
    # For displaying in notebook
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.imshow(original_img.squeeze().permute(1, 2, 0).numpy())
    plt.title("Original")
    plt.axis('off')
    
    plt.subplot(1, 3, 2)
    plt.imshow(corrupted_img.squeeze().permute(1, 2, 0).numpy())
    plt.title(f"Corrupted ({corruption_type})")
    plt.axis('off')
    
    plt.subplot(1, 3, 3)
    plt.imshow(restored_img.squeeze().permute(1, 2, 0).numpy())
    plt.title(f"Restored (PSNR: {psnr:.2f}dB)")
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return original, corrupted, restored, {'psnr': psnr, 'ssim': ssim, 'lpips': lpips_val}

#----------------------------------------------------------------------------------
# Loading Pretrained Model
#----------------------------------------------------------------------------------

def load_pretrained_model(checkpoint_path):
    """Load a pretrained VAE model."""
    if not os.path.exists(checkpoint_path):
        print(f"Checkpoint not found: {checkpoint_path}")
        return None
    
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Get latent dimension from checkpoint or use default
    latent_dim = checkpoint.get('latent_dim', 128)
    
    # Initialize model
    model = VAE(latent_dim=latent_dim).to(device)
    
    # Load model state
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        # Assume it is a direct state dict
        model.load_state_dict(checkpoint)
    
    print(f"Loaded pretrained model from {checkpoint_path}")
    if 'epoch' in checkpoint:
        print(f"Trained for {checkpoint['epoch']} epochs")
    
    return model

#----------------------------------------------------------------------------------
# Main Pipeline
#------------------------------------------------------------------------------


Using device: cuda


In [2]:
def main(max_images=30000, num_epochs=100):
    """Run the complete pipeline: load data, train model, evaluate and test."""
    print("Starting ClearVision Pipeline...")
    
    # Set paths for Kaggle
    data_dir = '/kaggle/input/celebahq-resized-256x256/celeba_hq_256'
    output_dir = '/kaggle/working'
    
    # Create output directories
    os.makedirs(f'{output_dir}/checkpoints', exist_ok=True)
    os.makedirs(f'{output_dir}/samples', exist_ok=True)
    os.makedirs(f'{output_dir}/metrics', exist_ok=True)
    
    # Create a progress bar for the main pipeline steps
    pipeline_steps = ["Loading datasets", "Model initialization", "Training", "Evaluation", "Saving model"]
    pipeline_bar = tqdm(pipeline_steps, desc="ClearVision Pipeline")
    
    # Step 1: Load datasets with limit
    pipeline_bar.set_description("Loading datasets...")
    train_loader, val_loader = load_dataset(data_dir, max_images=max_images)
    
    if train_loader is None or val_loader is None:
        print("Failed to load datasets. Please check the dataset paths.")
        return
    pipeline_bar.update(1)
    
    # Step 2: Initialize model
    pipeline_bar.set_description("Initializing VAE model...")
    latent_dim = 128
    model = VAE(latent_dim=latent_dim).to(device)
    print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
    pipeline_bar.update(1)
    
    # Step 3: Train model
    pipeline_bar.set_description("Training model...")
    train_losses, val_losses = train_vae(model, train_loader, val_loader, 
                                       num_epochs=num_epochs, 
                                       lr=1e-4, 
                                       kld_weight=0.001,
                                       save_dir=f'{output_dir}/checkpoints')
    pipeline_bar.update(1)
    
    # Step 4: Final evaluation on validation set
    pipeline_bar.set_description("Evaluating final model...")
    eval_results = evaluate_model(model, val_loader)
    print("\nFinal Metrics:")
    for corruption_type, metrics in eval_results.items():
        print(f" {corruption_type}: PSNR = {metrics['psnr']:.2f}dB, SSIM = {metrics['ssim']:.4f}, LPIPS = {metrics.get('lpips', 'N/A')}")
        if 'inference_time_ms' in metrics:
            print(f" {corruption_type} Inference Time: {metrics['inference_time_ms']:.2f}ms")
    
    # Save metrics to file
    metrics_file = os.path.join(output_dir, 'metrics', 'final_metrics.json')
    with open(metrics_file, 'w') as f:
        import json
        # Convert values that might not be JSON serializable
        serializable_results = {}
        for k, v in eval_results.items():
            serializable_results[k] = {k2: float(v2) if not isinstance(v2, str) else v2 
                                     for k2, v2 in v.items()}
        json.dump(serializable_results, f, indent=4)
    
    print(f"Metrics saved to {metrics_file}")
    pipeline_bar.update(1)
    
    # Step 5: Save the final model
    pipeline_bar.set_description("Saving final model...")
    final_model_path = os.path.join(output_dir, 'checkpoints', 'vae_final_model.pth')
    torch.save({
        'epoch': num_epochs,
        'model_state_dict': model.state_dict(),
        'latent_dim': latent_dim
    }, final_model_path)
    print(f"Final model saved to {final_model_path}")
    pipeline_bar.update(1)
    
    pipeline_bar.close()
    print("\nClearVision Pipeline completed!")
    
    # Create a sample image with all corruption types for the final model
    create_final_comparison_image(model, val_loader, 
                                output_path=os.path.join(output_dir, 'final_results.png'))
    
    return model, train_loader, val_loader

def create_final_comparison_image(model, val_loader, output_path):
    """Create a final comparison image with all corruption types"""
    model.eval()
    
    # Get a batch of images
    batch = next(iter(val_loader))
    if isinstance(batch, list):
        batch = batch[0]
    original_image = batch[0:1].to(device)  # Use first image
    
    corruption_types = ['mask', 'gaussian', 'salt_pepper', 'blur', 'jpeg', 'combined']
    corruption_functions = {
        'mask': apply_mask,
        'gaussian': apply_gaussian_noise,
        'salt_pepper': apply_salt_pepper_noise,
        'blur': apply_gaussian_blur,
        'jpeg': apply_jpeg_artifacts,
        'combined': apply_combined_corruption
    }
    
    # Create rows for each corruption type
    rows = []
    
    for corruption_type in corruption_types:
        # Apply corruption
        corruption_fn = corruption_functions[corruption_type]
        corrupted = corruption_fn(original_image.clone())
        
        # Generate restored image
        with torch.no_grad():
            restored, _, _ = model(corrupted)
        
        # Add to rows: original, corrupted, restored
        row = torch.cat([original_image, corrupted, restored], dim=0)
        rows.append(row)
    
    # Create final grid
    grid = torch.cat(rows, dim=0)
    save_image(grid * 0.5 + 0.5, output_path, nrow=3)
    print(f"Final comparison image saved to {output_path}")


In [3]:
def test_model(model, val_loader, num_samples=5, output_dir='/kaggle/working/test_outputs'):
    """Test the trained model on the validation set and generate example outputs"""
    os.makedirs(output_dir, exist_ok=True)
    model.eval()
    
    # Get corruption types
    corruption_types = ['mask', 'gaussian', 'salt_pepper', 'blur', 'jpeg', 'combined']
    corruption_functions = {
        'mask': apply_mask,
        'gaussian': apply_gaussian_noise,
        'salt_pepper': apply_salt_pepper_noise,
        'blur': apply_gaussian_blur,
        'jpeg': apply_jpeg_artifacts,
        'combined': apply_combined_corruption
    }
    
    # Evaluate and report metrics for each corruption type
    print("\nModel Testing Results:")
    
    for corruption_type in corruption_types:
        # Get a batch of samples
        batch = next(iter(val_loader))
        if isinstance(batch, list):
            batch = batch[0]
        
        # Select random samples
        indices = torch.randperm(len(batch))[:num_samples]
        original_images = batch[indices].to(device)
        
        # Apply corruption
        corruption_fn = corruption_functions[corruption_type]
        corrupted_images = corruption_fn(original_images)
        
        # Generate restored images
        with torch.no_grad():
            restored_images, _, _ = model(corrupted_images)
        
        # Calculate metrics
        avg_psnr = 0
        avg_ssim = 0
        lpips_scores = []
        
        # Initialize LPIPS metric
        lpips_metric = LPIPSMetric(device)
        
        for i in range(len(original_images)):
            psnr = compute_psnr(original_images[i:i+1], restored_images[i:i+1])
            ssim = compute_ssim(original_images[i:i+1], restored_images[i:i+1])
            lpips = lpips_metric.calculate_lpips(original_images[i:i+1], restored_images[i:i+1])
            
            avg_psnr += psnr
            avg_ssim += ssim
            lpips_scores.append(lpips)
        
        avg_psnr /= len(original_images)
        avg_ssim /= len(original_images)
        avg_lpips = sum(lpips_scores) / len(lpips_scores)
        
        print(f"{corruption_type.capitalize()} Corruption:")
        print(f"  PSNR: {avg_psnr:.2f}dB")
        print(f"  SSIM: {avg_ssim:.4f}")
        print(f"  LPIPS: {avg_lpips:.4f}")
        
        # Save example images
        for i in range(len(original_images)):
            comparison = torch.cat([
                original_images[i:i+1], 
                corrupted_images[i:i+1], 
                restored_images[i:i+1]
            ], dim=0)
            
            save_image(
                comparison * 0.5 + 0.5,
                f"{output_dir}/{corruption_type}_sample_{i+1}.png",
                nrow=3, 
                padding=5
            )
    
    # Create a comprehensive test grid with all corruption types
    create_test_grid(model, val_loader, output_dir)
    
    print(f"\nTest images saved to {output_dir}")
    return

def create_test_grid(model, val_loader, output_dir):
    """Create a comprehensive test grid with all corruption types"""
    model.eval()
    
    # Get a batch of images
    batch = next(iter(val_loader))
    if isinstance(batch, list):
        batch = batch[0]
    
    # Use 2 sample images
    original_images = batch[:2].to(device)
    
    corruption_types = ['mask', 'gaussian', 'salt_pepper', 'blur', 'jpeg', 'combined']
    corruption_functions = {
        'mask': apply_mask,
        'gaussian': apply_gaussian_noise,
        'salt_pepper': apply_salt_pepper_noise,
        'blur': apply_gaussian_blur,
        'jpeg': apply_jpeg_artifacts,
        'combined': apply_combined_corruption
    }
    
    # Process each image
    all_rows = []
    for img_idx in range(len(original_images)):
        image_rows = []
        original_image = original_images[img_idx:img_idx+1]
        
        for corruption_type in corruption_types:
            corruption_fn = corruption_functions[corruption_type]
            corrupted = corruption_fn(original_image.clone())
            
            with torch.no_grad():
                restored, _, _ = model(corrupted)
            
            # Create row: original, corrupted, restored
            image_rows.append(torch.cat([original_image, corrupted, restored], dim=0))
        
        # Combine all corruption types for this image
        all_rows.append(torch.cat(image_rows, dim=0))
    
    # Create final grid
    grid = torch.cat(all_rows, dim=0)
    save_image(grid * 0.5 + 0.5, f"{output_dir}/comprehensive_test.png", nrow=3)


In [4]:
if __name__ == '__main__':
    # Create necessary directories
    os.makedirs('checkpoints', exist_ok=True)
    os.makedirs('samples', exist_ok=True)
    os.makedirs('outputs', exist_ok=True)
    
    # Run the main pipeline
    model, train_loader, val_loader = main()
    
    # Test the model
    test_model(model, val_loader)
    
    print("""
    ## ClearVision Image Restoration Complete
    
    The model has been trained to restore corrupted images using a VAE architecture.
    It can handle multiple corruption types:
    - Masked regions (inpainting)
    - Gaussian noise
    - Salt and pepper noise
    - Blur
    - JPEG compression artifacts
    - Combined corruptions
    
    To test on your own images, use:
    ```
    restore_image(model, 'path/to/your/image.jpg', corruption_type='combined')
    ```
    
    Available corruption types: 'mask', 'gaussian', 'salt_pepper', 'blur', 'jpeg', 'combined'
    """)


Starting ClearVision Pipeline...


Initializing VAE model...:  20%|██        | 1/5 [00:00<00:01,  2.26it/s]

Dataset loaded: 30000 images
Training set: 27000 images
Validation set: 3000 images


Training model...:  40%|████      | 2/5 [00:00<00:01,  2.88it/s]        

Model parameters: 10,828,739
Starting training...


Training Progress:   1%|          | 1/100 [01:48<2:59:02, 108.51s/it, train_loss=5607.4079, val_loss=2788.6650]


Epoch 1/100:
 Train Loss: 5607.407859 (Recon: 5428.766137, KLD: 178641.703507)
 Validation Loss: 2788.664988
 Learning Rate: 0.000100


Training Progress:   2%|▏         | 2/100 [02:50<2:12:25, 81.07s/it, train_loss=2496.5790, val_loss=2116.8730] 


Epoch 2/100:
 Train Loss: 2496.578962 (Recon: 2495.191087, KLD: 1387.874673)
 Validation Loss: 2116.872968
 Learning Rate: 0.000100


Training Progress:   3%|▎         | 3/100 [03:51<1:56:24, 72.00s/it, train_loss=2044.8931, val_loss=1846.3450]


Epoch 3/100:
 Train Loss: 2044.893136 (Recon: 2043.044502, KLD: 1848.635398)
 Validation Loss: 1846.345004
 Learning Rate: 0.000100


Training Progress:   4%|▍         | 4/100 [04:52<1:48:24, 67.76s/it, train_loss=1782.6863, val_loss=1711.0686]


Epoch 4/100:
 Train Loss: 1782.686287 (Recon: 1780.432669, KLD: 2253.617979)
 Validation Loss: 1711.068583
 Learning Rate: 0.000100


Training Progress:   5%|▌         | 5/100 [05:54<1:43:45, 65.53s/it, train_loss=1631.4528, val_loss=1729.6461]


Epoch 5/100:
 Train Loss: 1631.452836 (Recon: 1628.768580, KLD: 2684.255406)
 Validation Loss: 1729.646068
 Learning Rate: 0.000100


Training Progress:   6%|▌         | 6/100 [06:56<1:40:54, 64.41s/it, train_loss=1511.0873, val_loss=1430.4785]


Epoch 6/100:
 Train Loss: 1511.087275 (Recon: 1508.032345, KLD: 3054.929541)
 Validation Loss: 1430.478535
 Learning Rate: 0.000100


Training Progress:   7%|▋         | 7/100 [07:58<1:38:35, 63.60s/it, train_loss=1407.4761, val_loss=1317.8049]


Epoch 7/100:
 Train Loss: 1407.476077 (Recon: 1404.119467, KLD: 3356.609951)
 Validation Loss: 1317.804862
 Learning Rate: 0.000100


Training Progress:   8%|▊         | 8/100 [09:01<1:36:56, 63.23s/it, train_loss=1357.3072, val_loss=1412.2966]


Epoch 8/100:
 Train Loss: 1357.307206 (Recon: 1353.706891, KLD: 3600.316974)
 Validation Loss: 1412.296572
 Learning Rate: 0.000100


Training Progress:   9%|▉         | 9/100 [10:03<1:35:28, 62.95s/it, train_loss=1299.5268, val_loss=1296.2718]


Epoch 9/100:
 Train Loss: 1299.526834 (Recon: 1295.791007, KLD: 3735.827368)
 Validation Loss: 1296.271774
 Learning Rate: 0.000100


Training Progress:   9%|▉         | 9/100 [11:05<1:35:28, 62.95s/it, train_loss=1251.0787, val_loss=1307.5158]


Epoch 10/100:
 Train Loss: 1251.078733 (Recon: 1247.270206, KLD: 3808.528549)
 Validation Loss: 1307.515781
 Learning Rate: 0.000100
Model checkpoint saved to /kaggle/working/checkpoints/vae_epoch_10.pth
Sample images saved for epoch 10

Running evaluation at epoch 10...
For better results, install LPIPS with: pip install lpips


Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth

  0%|          | 0.00/233M [00:00<?, ?B/s][A
  7%|▋         | 16.4M/233M [00:00<00:01, 171MB/s][A
 17%|█▋        | 38.5M/233M [00:00<00:00, 206MB/s][A
 25%|██▍       | 58.2M/233M [00:00<00:00, 194MB/s][A
 35%|███▍      | 81.1M/233M [00:00<00:00, 211MB/s][A
 44%|████▍     | 103M/233M [00:00<00:00, 217MB/s] [A
 54%|█████▎    | 125M/233M [00:00<00:00, 221MB/s][A
 63%|██████▎   | 148M/233M [00:00<00:00, 227MB/s][A
 73%|███████▎  | 171M/233M [00:00<00:00, 231MB/s][A
 83%|████████▎ | 194M/233M [00:00<00:00, 233MB/s][A
100%|██████████| 233M/233M [00:01<00:00, 224MB/s]
Downloading: "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth" to /root/.cache/torch/hub/checkpoints/inception_v3_google-0cc3c7bd.pth

  0%|          | 0.00/104M [00:00<?, ?B/s][A
 14%|█▍        | 14.8M/104M [00:00<00:00, 154MB/s][A
 36%|███▌      | 37.6M/104

Inception model loaded for FID calculation



Evaluating corruption types:   0%|          | 0/6 [00:00<?, ?it/s][A
Evaluating mask:   0%|          | 0/6 [00:00<?, ?it/s]            [A
Evaluating mask:   0%|          | 0/6 [00:06<?, ?it/s, lpips=1.8971, psnr=22.59dB, ssim=0.9486][A
Evaluating mask:  17%|█▋        | 1/6 [00:06<00:33,  6.79s/it, lpips=1.8971, psnr=22.59dB, ssim=0.9486][A
Evaluating gaussian:  17%|█▋        | 1/6 [00:06<00:33,  6.79s/it, lpips=1.8971, psnr=22.59dB, ssim=0.9486][A
Evaluating gaussian:  17%|█▋        | 1/6 [00:13<00:33,  6.79s/it, lpips=1.8075, psnr=23.17dB, ssim=0.9561][A
Evaluating gaussian:  33%|███▎      | 2/6 [00:13<00:26,  6.75s/it, lpips=1.8075, psnr=23.17dB, ssim=0.9561][A
Evaluating salt_pepper:  33%|███▎      | 2/6 [00:13<00:26,  6.75s/it, lpips=1.8075, psnr=23.17dB, ssim=0.9561][A
Evaluating salt_pepper:  33%|███▎      | 2/6 [00:19<00:26,  6.75s/it, lpips=1.8587, psnr=23.15dB, ssim=0.9550][A
Evaluating salt_pepper:  50%|█████     | 3/6 [00:19<00:19,  6.56s/it, lpips=1.8587, psnr=23.


Validation Metrics:
 mask: PSNR = 22.59dB, SSIM = 0.9486, LPIPS = 1.8971, FID = 133.49, Inference Time = 0.38ms
 gaussian: PSNR = 23.17dB, SSIM = 0.9561, LPIPS = 1.8075, FID = 129.99, Inference Time = 0.38ms
 salt_pepper: PSNR = 23.15dB, SSIM = 0.9550, LPIPS = 1.8587, FID = 133.01, Inference Time = 0.24ms
 blur: PSNR = 23.09dB, SSIM = 0.9544, LPIPS = 2.0039, FID = 135.06, Inference Time = 0.23ms
 jpeg: PSNR = 23.00dB, SSIM = 0.9548, LPIPS = 1.8164, FID = 132.37, Inference Time = 0.35ms
 combined: PSNR = 22.81dB, SSIM = 0.9523, LPIPS = 1.9229, FID = 137.06, Inference Time = 0.23ms


Training Progress:  11%|█         | 11/100 [12:52<1:46:58, 72.11s/it, train_loss=1227.3149, val_loss=1198.7621]


Epoch 11/100:
 Train Loss: 1227.314933 (Recon: 1223.447247, KLD: 3867.685020)
 Validation Loss: 1198.762115
 Learning Rate: 0.000100


Training Progress:  12%|█▏        | 12/100 [13:55<1:41:49, 69.43s/it, train_loss=1193.6644, val_loss=1233.2580]


Epoch 12/100:
 Train Loss: 1193.664378 (Recon: 1189.811971, KLD: 3852.407494)
 Validation Loss: 1233.258045
 Learning Rate: 0.000100


Training Progress:  13%|█▎        | 13/100 [14:57<1:37:27, 67.21s/it, train_loss=1173.6214, val_loss=1170.4304]


Epoch 13/100:
 Train Loss: 1173.621351 (Recon: 1169.808948, KLD: 3812.402380)
 Validation Loss: 1170.430425
 Learning Rate: 0.000100


Training Progress:  14%|█▍        | 14/100 [16:00<1:34:18, 65.80s/it, train_loss=1147.7629, val_loss=1274.8564]


Epoch 14/100:
 Train Loss: 1147.762911 (Recon: 1143.933677, KLD: 3829.234574)
 Validation Loss: 1274.856423
 Learning Rate: 0.000100


Training Progress:  15%|█▌        | 15/100 [17:03<1:32:08, 65.05s/it, train_loss=1144.1519, val_loss=1152.2000]


Epoch 15/100:
 Train Loss: 1144.151936 (Recon: 1140.367010, KLD: 3784.925272)
 Validation Loss: 1152.199952
 Learning Rate: 0.000100


Training Progress:  16%|█▌        | 16/100 [18:06<1:30:02, 64.32s/it, train_loss=1108.8514, val_loss=1144.9514]


Epoch 16/100:
 Train Loss: 1108.851379 (Recon: 1105.000047, KLD: 3851.331428)
 Validation Loss: 1144.951417
 Learning Rate: 0.000100


Training Progress:  17%|█▋        | 17/100 [19:08<1:28:09, 63.73s/it, train_loss=1093.0114, val_loss=1130.9746]


Epoch 17/100:
 Train Loss: 1093.011368 (Recon: 1089.182703, KLD: 3828.665497)
 Validation Loss: 1130.974629
 Learning Rate: 0.000100


Training Progress:  18%|█▊        | 18/100 [20:11<1:26:37, 63.38s/it, train_loss=1069.7432, val_loss=1130.4002]


Epoch 18/100:
 Train Loss: 1069.743206 (Recon: 1065.940995, KLD: 3802.210974)
 Validation Loss: 1130.400213
 Learning Rate: 0.000100


Training Progress:  19%|█▉        | 19/100 [21:13<1:25:08, 63.07s/it, train_loss=1057.9335, val_loss=1126.9585]


Epoch 19/100:
 Train Loss: 1057.933488 (Recon: 1054.185641, KLD: 3747.847287)
 Validation Loss: 1126.958545
 Learning Rate: 0.000100


Training Progress:  19%|█▉        | 19/100 [22:15<1:25:08, 63.07s/it, train_loss=1043.5253, val_loss=1169.7076]


Epoch 20/100:
 Train Loss: 1043.525285 (Recon: 1039.959054, KLD: 3566.230578)
 Validation Loss: 1169.707598
 Learning Rate: 0.000100
Model checkpoint saved to /kaggle/working/checkpoints/vae_epoch_20.pth
Sample images saved for epoch 20

Running evaluation at epoch 20...
For better results, install LPIPS with: pip install lpips
Inception model loaded for FID calculation



Evaluating corruption types:   0%|          | 0/6 [00:00<?, ?it/s][A
Evaluating mask:   0%|          | 0/6 [00:00<?, ?it/s]            [A
Evaluating mask:   0%|          | 0/6 [00:06<?, ?it/s, lpips=1.7067, psnr=23.47dB, ssim=0.9584][A
Evaluating mask:  17%|█▋        | 1/6 [00:06<00:31,  6.34s/it, lpips=1.7067, psnr=23.47dB, ssim=0.9584][A
Evaluating gaussian:  17%|█▋        | 1/6 [00:06<00:31,  6.34s/it, lpips=1.7067, psnr=23.47dB, ssim=0.9584][A
Evaluating gaussian:  17%|█▋        | 1/6 [00:12<00:31,  6.34s/it, lpips=1.6986, psnr=23.79dB, ssim=0.9612][A
Evaluating gaussian:  33%|███▎      | 2/6 [00:12<00:25,  6.44s/it, lpips=1.6986, psnr=23.79dB, ssim=0.9612][A
Evaluating salt_pepper:  33%|███▎      | 2/6 [00:12<00:25,  6.44s/it, lpips=1.6986, psnr=23.79dB, ssim=0.9612][A
Evaluating salt_pepper:  33%|███▎      | 2/6 [00:19<00:25,  6.44s/it, lpips=1.7513, psnr=23.76dB, ssim=0.9601][A
Evaluating salt_pepper:  50%|█████     | 3/6 [00:19<00:18,  6.33s/it, lpips=1.7513, psnr=23.


Validation Metrics:
 mask: PSNR = 23.47dB, SSIM = 0.9584, LPIPS = 1.7067, FID = 123.28, Inference Time = 0.34ms
 gaussian: PSNR = 23.79dB, SSIM = 0.9612, LPIPS = 1.6986, FID = 122.65, Inference Time = 0.27ms
 salt_pepper: PSNR = 23.76dB, SSIM = 0.9601, LPIPS = 1.7513, FID = 123.26, Inference Time = 0.24ms
 blur: PSNR = 23.63dB, SSIM = 0.9588, LPIPS = 1.9502, FID = 127.43, Inference Time = 0.23ms
 jpeg: PSNR = 23.59dB, SSIM = 0.9599, LPIPS = 1.7165, FID = 125.63, Inference Time = 0.26ms
 combined: PSNR = 23.85dB, SSIM = 0.9593, LPIPS = 1.7771, FID = 125.53, Inference Time = 0.24ms


Training Progress:  21%|██        | 21/100 [24:00<1:34:18, 71.63s/it, train_loss=1026.1736, val_loss=1112.5511]


Epoch 21/100:
 Train Loss: 1026.173636 (Recon: 1022.563799, KLD: 3609.835974)
 Validation Loss: 1112.551122
 Learning Rate: 0.000100


Training Progress:  22%|██▏       | 22/100 [25:02<1:29:25, 68.79s/it, train_loss=1010.5471, val_loss=1100.6560]


Epoch 22/100:
 Train Loss: 1010.547114 (Recon: 1006.964180, KLD: 3582.934125)
 Validation Loss: 1100.655990
 Learning Rate: 0.000100


Training Progress:  23%|██▎       | 23/100 [26:05<1:25:56, 66.97s/it, train_loss=993.0964, val_loss=1116.1715]


Epoch 23/100:
 Train Loss: 993.096428 (Recon: 989.562336, KLD: 3534.092782)
 Validation Loss: 1116.171520
 Learning Rate: 0.000100


Training Progress:  24%|██▍       | 24/100 [27:08<1:23:15, 65.73s/it, train_loss=977.3512, val_loss=1106.3337]


Epoch 24/100:
 Train Loss: 977.351246 (Recon: 973.867742, KLD: 3483.503845)
 Validation Loss: 1106.333702
 Learning Rate: 0.000100


Training Progress:  25%|██▌       | 25/100 [28:10<1:20:57, 64.77s/it, train_loss=969.1106, val_loss=1129.0477]


Epoch 25/100:
 Train Loss: 969.110594 (Recon: 965.680288, KLD: 3430.305436)
 Validation Loss: 1129.047699
 Learning Rate: 0.000100


Training Progress:  26%|██▌       | 26/100 [29:13<1:18:59, 64.05s/it, train_loss=959.8862, val_loss=1107.8094]


Epoch 26/100:
 Train Loss: 959.886204 (Recon: 956.491980, KLD: 3394.223189)
 Validation Loss: 1107.809439
 Learning Rate: 0.000100


Training Progress:  27%|██▋       | 27/100 [30:15<1:17:21, 63.58s/it, train_loss=945.6759, val_loss=1113.3876]


Epoch 27/100:
 Train Loss: 945.675888 (Recon: 942.324690, KLD: 3351.197859)
 Validation Loss: 1113.387589
 Learning Rate: 0.000100


Training Progress:  28%|██▊       | 28/100 [31:18<1:15:51, 63.22s/it, train_loss=937.0472, val_loss=1105.1816]


Epoch 28/100:
 Train Loss: 937.047163 (Recon: 933.732472, KLD: 3314.690824)
 Validation Loss: 1105.181565
 Learning Rate: 0.000100

Learning rate adjusted from 0.000100 to 0.000050


Training Progress:  29%|██▉       | 29/100 [32:20<1:14:34, 63.02s/it, train_loss=884.1299, val_loss=1059.2368]


Epoch 29/100:
 Train Loss: 884.129946 (Recon: 880.967291, KLD: 3162.654265)
 Validation Loss: 1059.236821
 Learning Rate: 0.000050


Training Progress:  29%|██▉       | 29/100 [33:22<1:14:34, 63.02s/it, train_loss=873.8668, val_loss=1067.0756]


Epoch 30/100:
 Train Loss: 873.866804 (Recon: 870.834612, KLD: 3032.192908)
 Validation Loss: 1067.075556
 Learning Rate: 0.000050
Model checkpoint saved to /kaggle/working/checkpoints/vae_epoch_30.pth
Sample images saved for epoch 30

Running evaluation at epoch 30...
For better results, install LPIPS with: pip install lpips
Inception model loaded for FID calculation



Evaluating corruption types:   0%|          | 0/6 [00:00<?, ?it/s][A
Evaluating mask:   0%|          | 0/6 [00:00<?, ?it/s]            [A
Evaluating mask:   0%|          | 0/6 [00:07<?, ?it/s, lpips=1.4504, psnr=24.12dB, ssim=0.9628][A
Evaluating mask:  17%|█▋        | 1/6 [00:07<00:38,  7.72s/it, lpips=1.4504, psnr=24.12dB, ssim=0.9628][A
Evaluating gaussian:  17%|█▋        | 1/6 [00:07<00:38,  7.72s/it, lpips=1.4504, psnr=24.12dB, ssim=0.9628][A
Evaluating gaussian:  17%|█▋        | 1/6 [00:13<00:38,  7.72s/it, lpips=1.4380, psnr=24.40dB, ssim=0.9647][A
Evaluating gaussian:  33%|███▎      | 2/6 [00:13<00:27,  6.83s/it, lpips=1.4380, psnr=24.40dB, ssim=0.9647][A
Evaluating salt_pepper:  33%|███▎      | 2/6 [00:13<00:27,  6.83s/it, lpips=1.4380, psnr=24.40dB, ssim=0.9647][A
Evaluating salt_pepper:  33%|███▎      | 2/6 [00:20<00:27,  6.83s/it, lpips=1.4696, psnr=24.28dB, ssim=0.9636][A
Evaluating salt_pepper:  50%|█████     | 3/6 [00:20<00:19,  6.62s/it, lpips=1.4696, psnr=24.


Validation Metrics:
 mask: PSNR = 24.12dB, SSIM = 0.9628, LPIPS = 1.4504, FID = 117.69, Inference Time = 0.34ms
 gaussian: PSNR = 24.40dB, SSIM = 0.9647, LPIPS = 1.4380, FID = 114.94, Inference Time = 0.32ms
 salt_pepper: PSNR = 24.28dB, SSIM = 0.9636, LPIPS = 1.4696, FID = 116.82, Inference Time = 0.23ms
 blur: PSNR = 24.31dB, SSIM = 0.9630, LPIPS = 1.6873, FID = 121.86, Inference Time = 0.24ms
 jpeg: PSNR = 24.25dB, SSIM = 0.9637, LPIPS = 1.4493, FID = 117.03, Inference Time = 0.24ms
 combined: PSNR = 23.75dB, SSIM = 0.9582, LPIPS = 1.6319, FID = 120.65, Inference Time = 0.34ms


Training Progress:  31%|███       | 31/100 [35:06<1:21:56, 71.26s/it, train_loss=859.4232, val_loss=1063.2973]


Epoch 31/100:
 Train Loss: 859.423215 (Recon: 856.492334, KLD: 2930.880129)
 Validation Loss: 1063.297294
 Learning Rate: 0.000050


Training Progress:  32%|███▏      | 32/100 [36:08<1:17:32, 68.42s/it, train_loss=856.7144, val_loss=1055.5976]


Epoch 32/100:
 Train Loss: 856.714441 (Recon: 853.839059, KLD: 2875.382304)
 Validation Loss: 1055.597626
 Learning Rate: 0.000050


Training Progress:  33%|███▎      | 33/100 [37:11<1:14:30, 66.73s/it, train_loss=850.6236, val_loss=1125.1387]


Epoch 33/100:
 Train Loss: 850.623568 (Recon: 847.802097, KLD: 2821.470524)
 Validation Loss: 1125.138746
 Learning Rate: 0.000050


Training Progress:  34%|███▍      | 34/100 [38:13<1:12:01, 65.48s/it, train_loss=846.1860, val_loss=1066.6123]


Epoch 34/100:
 Train Loss: 846.186023 (Recon: 843.404943, KLD: 2781.079229)
 Validation Loss: 1066.612257
 Learning Rate: 0.000050


Training Progress:  35%|███▌      | 35/100 [39:16<1:10:01, 64.64s/it, train_loss=843.2329, val_loss=1069.4368]


Epoch 35/100:
 Train Loss: 843.232923 (Recon: 840.500012, KLD: 2732.911138)
 Validation Loss: 1069.436755
 Learning Rate: 0.000050


Training Progress:  36%|███▌      | 36/100 [40:18<1:08:07, 63.87s/it, train_loss=832.0369, val_loss=1072.3171]


Epoch 36/100:
 Train Loss: 832.036919 (Recon: 829.333324, KLD: 2703.594794)
 Validation Loss: 1072.317069
 Learning Rate: 0.000050


Training Progress:  37%|███▋      | 37/100 [41:20<1:06:35, 63.42s/it, train_loss=830.2073, val_loss=1084.9783]


Epoch 37/100:
 Train Loss: 830.207324 (Recon: 827.535574, KLD: 2671.749642)
 Validation Loss: 1084.978261
 Learning Rate: 0.000050


Training Progress:  38%|███▊      | 38/100 [42:22<1:05:05, 62.99s/it, train_loss=825.8330, val_loss=1082.5119]


Epoch 38/100:
 Train Loss: 825.833006 (Recon: 823.191902, KLD: 2641.102444)
 Validation Loss: 1082.511888
 Learning Rate: 0.000050

Learning rate adjusted from 0.000050 to 0.000025


Training Progress:  39%|███▉      | 39/100 [43:25<1:04:01, 62.97s/it, train_loss=800.0963, val_loss=1074.6966]


Epoch 39/100:
 Train Loss: 800.096348 (Recon: 797.514680, KLD: 2581.668073)
 Validation Loss: 1074.696610
 Learning Rate: 0.000025


Training Progress:  39%|███▉      | 39/100 [44:27<1:04:01, 62.97s/it, train_loss=793.0694, val_loss=1072.2231]


Epoch 40/100:
 Train Loss: 793.069437 (Recon: 790.541106, KLD: 2528.330693)
 Validation Loss: 1072.223145
 Learning Rate: 0.000025
Model checkpoint saved to /kaggle/working/checkpoints/vae_epoch_40.pth
Sample images saved for epoch 40

Running evaluation at epoch 40...
For better results, install LPIPS with: pip install lpips
Inception model loaded for FID calculation



Evaluating corruption types:   0%|          | 0/6 [00:00<?, ?it/s][A
Evaluating mask:   0%|          | 0/6 [00:00<?, ?it/s]            [A
Evaluating mask:   0%|          | 0/6 [00:06<?, ?it/s, lpips=1.4902, psnr=23.99dB, ssim=0.9606][A
Evaluating mask:  17%|█▋        | 1/6 [00:06<00:32,  6.53s/it, lpips=1.4902, psnr=23.99dB, ssim=0.9606][A
Evaluating gaussian:  17%|█▋        | 1/6 [00:06<00:32,  6.53s/it, lpips=1.4902, psnr=23.99dB, ssim=0.9606][A
Evaluating gaussian:  17%|█▋        | 1/6 [00:13<00:32,  6.53s/it, lpips=1.4221, psnr=24.41dB, ssim=0.9638][A
Evaluating gaussian:  33%|███▎      | 2/6 [00:13<00:27,  6.75s/it, lpips=1.4221, psnr=24.41dB, ssim=0.9638][A
Evaluating salt_pepper:  33%|███▎      | 2/6 [00:13<00:27,  6.75s/it, lpips=1.4221, psnr=24.41dB, ssim=0.9638][A
Evaluating salt_pepper:  33%|███▎      | 2/6 [00:19<00:27,  6.75s/it, lpips=1.4790, psnr=24.30dB, ssim=0.9626][A
Evaluating salt_pepper:  50%|█████     | 3/6 [00:19<00:19,  6.56s/it, lpips=1.4790, psnr=24.


Validation Metrics:
 mask: PSNR = 23.99dB, SSIM = 0.9606, LPIPS = 1.4902, FID = 116.27, Inference Time = 0.33ms
 gaussian: PSNR = 24.41dB, SSIM = 0.9638, LPIPS = 1.4221, FID = 113.60, Inference Time = 0.35ms
 salt_pepper: PSNR = 24.30dB, SSIM = 0.9626, LPIPS = 1.4790, FID = 114.00, Inference Time = 0.23ms
 blur: PSNR = 24.23dB, SSIM = 0.9615, LPIPS = 1.6770, FID = 119.22, Inference Time = 0.24ms
 jpeg: PSNR = 24.29dB, SSIM = 0.9630, LPIPS = 1.4500, FID = 115.70, Inference Time = 0.24ms
 combined: PSNR = 24.28dB, SSIM = 0.9621, LPIPS = 1.6270, FID = 118.25, Inference Time = 0.24ms


Training Progress:  41%|████      | 41/100 [46:11<1:10:05, 71.27s/it, train_loss=790.9943, val_loss=1077.3381]


Epoch 41/100:
 Train Loss: 790.994319 (Recon: 788.499939, KLD: 2494.381097)
 Validation Loss: 1077.338091
 Learning Rate: 0.000025


Training Progress:  42%|████▏     | 42/100 [47:13<1:06:06, 68.39s/it, train_loss=789.0634, val_loss=1068.2051]


Epoch 42/100:
 Train Loss: 789.063362 (Recon: 786.599149, KLD: 2464.213096)
 Validation Loss: 1068.205126
 Learning Rate: 0.000025


Training Progress:  43%|████▎     | 43/100 [48:14<1:02:57, 66.27s/it, train_loss=780.8257, val_loss=1071.1597]


Epoch 43/100:
 Train Loss: 780.825676 (Recon: 778.384271, KLD: 2441.404194)
 Validation Loss: 1071.159736
 Learning Rate: 0.000025


Training Progress:  44%|████▍     | 44/100 [49:16<1:00:41, 65.02s/it, train_loss=778.8317, val_loss=1067.7319]


Epoch 44/100:
 Train Loss: 778.831724 (Recon: 776.415582, KLD: 2416.141917)
 Validation Loss: 1067.731906
 Learning Rate: 0.000025

Learning rate adjusted from 0.000025 to 0.000013


Training Progress:  45%|████▌     | 45/100 [50:19<58:57, 64.31s/it, train_loss=767.0676, val_loss=1091.3895]  


Epoch 45/100:
 Train Loss: 767.067590 (Recon: 764.678501, KLD: 2389.088922)
 Validation Loss: 1091.389492
 Learning Rate: 0.000013


Training Progress:  46%|████▌     | 46/100 [51:21<57:17, 63.66s/it, train_loss=764.8281, val_loss=1070.1285]


Epoch 46/100:
 Train Loss: 764.828091 (Recon: 762.465959, KLD: 2362.131665)
 Validation Loss: 1070.128485
 Learning Rate: 0.000013


Training Progress:  47%|████▋     | 47/100 [52:23<55:44, 63.10s/it, train_loss=765.0985, val_loss=1071.7205]


Epoch 47/100:
 Train Loss: 765.098530 (Recon: 762.755550, KLD: 2342.979137)
 Validation Loss: 1071.720500
 Learning Rate: 0.000013


Training Progress:  48%|████▊     | 48/100 [53:25<54:22, 62.75s/it, train_loss=760.7555, val_loss=1071.2059]


Epoch 48/100:
 Train Loss: 760.755509 (Recon: 758.428551, KLD: 2326.957928)
 Validation Loss: 1071.205893
 Learning Rate: 0.000013


Training Progress:  49%|████▉     | 49/100 [54:27<53:06, 62.48s/it, train_loss=758.9905, val_loss=1078.5356]


Epoch 49/100:
 Train Loss: 758.990510 (Recon: 756.678814, KLD: 2311.694652)
 Validation Loss: 1078.535586
 Learning Rate: 0.000013


Training Progress:  49%|████▉     | 49/100 [55:29<53:06, 62.48s/it, train_loss=757.7812, val_loss=1074.1720]


Epoch 50/100:
 Train Loss: 757.781160 (Recon: 755.483686, KLD: 2297.472796)
 Validation Loss: 1074.172012
 Learning Rate: 0.000013
Model checkpoint saved to /kaggle/working/checkpoints/vae_epoch_50.pth
Sample images saved for epoch 50

Running evaluation at epoch 50...
For better results, install LPIPS with: pip install lpips
Inception model loaded for FID calculation



Evaluating corruption types:   0%|          | 0/6 [00:00<?, ?it/s][A
Evaluating mask:   0%|          | 0/6 [00:00<?, ?it/s]            [A
Evaluating mask:   0%|          | 0/6 [00:06<?, ?it/s, lpips=1.3610, psnr=24.19dB, ssim=0.9627][A
Evaluating mask:  17%|█▋        | 1/6 [00:06<00:32,  6.43s/it, lpips=1.3610, psnr=24.19dB, ssim=0.9627][A
Evaluating gaussian:  17%|█▋        | 1/6 [00:06<00:32,  6.43s/it, lpips=1.3610, psnr=24.19dB, ssim=0.9627][A
Evaluating gaussian:  17%|█▋        | 1/6 [00:12<00:32,  6.43s/it, lpips=1.3622, psnr=24.45dB, ssim=0.9644][A
Evaluating gaussian:  33%|███▎      | 2/6 [00:12<00:25,  6.41s/it, lpips=1.3622, psnr=24.45dB, ssim=0.9644][A
Evaluating salt_pepper:  33%|███▎      | 2/6 [00:12<00:25,  6.41s/it, lpips=1.3622, psnr=24.45dB, ssim=0.9644][A
Evaluating salt_pepper:  33%|███▎      | 2/6 [00:21<00:25,  6.41s/it, lpips=1.4139, psnr=24.35dB, ssim=0.9633][A
Evaluating salt_pepper:  50%|█████     | 3/6 [00:21<00:22,  7.36s/it, lpips=1.4139, psnr=24.


Validation Metrics:
 mask: PSNR = 24.19dB, SSIM = 0.9627, LPIPS = 1.3610, FID = 113.73, Inference Time = 0.34ms
 gaussian: PSNR = 24.45dB, SSIM = 0.9644, LPIPS = 1.3622, FID = 115.80, Inference Time = 0.29ms
 salt_pepper: PSNR = 24.35dB, SSIM = 0.9633, LPIPS = 1.4139, FID = 118.11, Inference Time = 0.23ms
 blur: PSNR = 24.33dB, SSIM = 0.9626, LPIPS = 1.6099, FID = 121.05, Inference Time = 0.23ms
 jpeg: PSNR = 24.32dB, SSIM = 0.9636, LPIPS = 1.3872, FID = 117.23, Inference Time = 0.35ms
 combined: PSNR = 24.12dB, SSIM = 0.9607, LPIPS = 1.4915, FID = 117.04, Inference Time = 0.23ms

Learning rate adjusted from 0.000013 to 0.000006


Training Progress:  51%|█████     | 51/100 [57:15<58:34, 71.72s/it, train_loss=751.3635, val_loss=1064.4696]  


Epoch 51/100:
 Train Loss: 751.363517 (Recon: 749.079072, KLD: 2284.445110)
 Validation Loss: 1064.469630
 Learning Rate: 0.000006


Training Progress:  52%|█████▏    | 52/100 [58:17<55:00, 68.77s/it, train_loss=752.1818, val_loss=1069.7245]


Epoch 52/100:
 Train Loss: 752.181800 (Recon: 749.906242, KLD: 2275.558417)
 Validation Loss: 1069.724479
 Learning Rate: 0.000006


Training Progress:  53%|█████▎    | 53/100 [59:19<52:17, 66.76s/it, train_loss=748.4470, val_loss=1075.9531]


Epoch 53/100:
 Train Loss: 748.446999 (Recon: 746.180868, KLD: 2266.131500)
 Validation Loss: 1075.953057
 Learning Rate: 0.000006


Training Progress:  54%|█████▍    | 54/100 [1:00:21<49:58, 65.19s/it, train_loss=747.1154, val_loss=1071.0401]


Epoch 54/100:
 Train Loss: 747.115353 (Recon: 744.856952, KLD: 2258.402045)
 Validation Loss: 1071.040072
 Learning Rate: 0.000006


Training Progress:  55%|█████▌    | 55/100 [1:01:22<48:00, 64.02s/it, train_loss=748.4092, val_loss=1075.1270]


Epoch 55/100:
 Train Loss: 748.409225 (Recon: 746.158842, KLD: 2250.382934)
 Validation Loss: 1075.127021
 Learning Rate: 0.000006


Training Progress:  56%|█████▌    | 56/100 [1:02:23<46:21, 63.21s/it, train_loss=746.6336, val_loss=1078.4251]


Epoch 56/100:
 Train Loss: 746.633646 (Recon: 744.391016, KLD: 2242.630627)
 Validation Loss: 1078.425059
 Learning Rate: 0.000006

Learning rate adjusted from 0.000006 to 0.000003


Training Progress:  57%|█████▋    | 57/100 [1:03:25<45:01, 62.82s/it, train_loss=741.8388, val_loss=1065.3149]


Epoch 57/100:
 Train Loss: 741.838825 (Recon: 739.601034, KLD: 2237.790630)
 Validation Loss: 1065.314874
 Learning Rate: 0.000003


Training Progress:  58%|█████▊    | 58/100 [1:04:27<43:45, 62.52s/it, train_loss=745.8267, val_loss=1064.4793]


Epoch 58/100:
 Train Loss: 745.826710 (Recon: 743.592040, KLD: 2234.670260)
 Validation Loss: 1064.479258
 Learning Rate: 0.000003


Training Progress:  59%|█████▉    | 59/100 [1:05:29<42:38, 62.40s/it, train_loss=741.4572, val_loss=1065.6586]


Epoch 59/100:
 Train Loss: 741.457167 (Recon: 739.226481, KLD: 2230.686591)
 Validation Loss: 1065.658561
 Learning Rate: 0.000003


Training Progress:  59%|█████▉    | 59/100 [1:06:31<42:38, 62.40s/it, train_loss=740.6482, val_loss=1071.8182]


Epoch 60/100:
 Train Loss: 740.648163 (Recon: 738.422474, KLD: 2225.688625)
 Validation Loss: 1071.818239
 Learning Rate: 0.000003
Model checkpoint saved to /kaggle/working/checkpoints/vae_epoch_60.pth
Sample images saved for epoch 60

Running evaluation at epoch 60...
For better results, install LPIPS with: pip install lpips
Inception model loaded for FID calculation



Evaluating corruption types:   0%|          | 0/6 [00:00<?, ?it/s][A
Evaluating mask:   0%|          | 0/6 [00:00<?, ?it/s]            [A
Evaluating mask:   0%|          | 0/6 [00:06<?, ?it/s, lpips=1.3668, psnr=24.05dB, ssim=0.9617][A
Evaluating mask:  17%|█▋        | 1/6 [00:06<00:32,  6.42s/it, lpips=1.3668, psnr=24.05dB, ssim=0.9617][A
Evaluating gaussian:  17%|█▋        | 1/6 [00:06<00:32,  6.42s/it, lpips=1.3668, psnr=24.05dB, ssim=0.9617][A
Evaluating gaussian:  17%|█▋        | 1/6 [00:13<00:32,  6.42s/it, lpips=1.3285, psnr=24.44dB, ssim=0.9646][A
Evaluating gaussian:  33%|███▎      | 2/6 [00:13<00:26,  6.52s/it, lpips=1.3285, psnr=24.44dB, ssim=0.9646][A
Evaluating salt_pepper:  33%|███▎      | 2/6 [00:13<00:26,  6.52s/it, lpips=1.3285, psnr=24.44dB, ssim=0.9646][A
Evaluating salt_pepper:  33%|███▎      | 2/6 [00:19<00:26,  6.52s/it, lpips=1.3776, psnr=24.40dB, ssim=0.9639][A
Evaluating salt_pepper:  50%|█████     | 3/6 [00:19<00:19,  6.43s/it, lpips=1.3776, psnr=24.


Validation Metrics:
 mask: PSNR = 24.05dB, SSIM = 0.9617, LPIPS = 1.3668, FID = 114.49, Inference Time = 0.38ms
 gaussian: PSNR = 24.44dB, SSIM = 0.9646, LPIPS = 1.3285, FID = 114.46, Inference Time = 0.32ms
 salt_pepper: PSNR = 24.40dB, SSIM = 0.9639, LPIPS = 1.3776, FID = 116.44, Inference Time = 0.24ms
 blur: PSNR = 24.38dB, SSIM = 0.9633, LPIPS = 1.5559, FID = 119.59, Inference Time = 0.25ms
 jpeg: PSNR = 24.32dB, SSIM = 0.9639, LPIPS = 1.3458, FID = 116.57, Inference Time = 0.24ms
 combined: PSNR = 24.07dB, SSIM = 0.9620, LPIPS = 1.3553, FID = 114.99, Inference Time = 0.29ms


Training Progress:  61%|██████    | 61/100 [1:08:15<46:05, 70.91s/it, train_loss=740.6496, val_loss=1078.1531]


Epoch 61/100:
 Train Loss: 740.649626 (Recon: 738.428188, KLD: 2221.438305)
 Validation Loss: 1078.153092
 Learning Rate: 0.000003


Training Progress:  62%|██████▏   | 62/100 [1:09:18<43:22, 68.48s/it, train_loss=743.1301, val_loss=1080.1300]


Epoch 62/100:
 Train Loss: 743.130091 (Recon: 740.913848, KLD: 2216.243213)
 Validation Loss: 1080.129982
 Learning Rate: 0.000003

Learning rate adjusted from 0.000003 to 0.000002


Training Progress:  63%|██████▎   | 63/100 [1:10:21<41:12, 66.83s/it, train_loss=736.5861, val_loss=1070.5980]


Epoch 63/100:
 Train Loss: 736.586080 (Recon: 734.372604, KLD: 2213.476462)
 Validation Loss: 1070.597961
 Learning Rate: 0.000002


Training Progress:  64%|██████▍   | 64/100 [1:11:23<39:19, 65.54s/it, train_loss=739.7717, val_loss=1079.7443]


Epoch 64/100:
 Train Loss: 739.771659 (Recon: 737.560000, KLD: 2211.659763)
 Validation Loss: 1079.744338
 Learning Rate: 0.000002


Training Progress:  65%|██████▌   | 65/100 [1:12:26<37:48, 64.82s/it, train_loss=738.1664, val_loss=1076.1722]


Epoch 65/100:
 Train Loss: 738.166364 (Recon: 735.956535, KLD: 2209.829346)
 Validation Loss: 1076.172159
 Learning Rate: 0.000002


Training Progress:  66%|██████▌   | 66/100 [1:13:30<36:28, 64.36s/it, train_loss=740.1339, val_loss=1072.5933]


Epoch 66/100:
 Train Loss: 740.133915 (Recon: 737.926844, KLD: 2207.071979)
 Validation Loss: 1072.593333
 Learning Rate: 0.000002


Training Progress:  67%|██████▋   | 67/100 [1:14:32<35:02, 63.71s/it, train_loss=738.4122, val_loss=1067.7059]


Epoch 67/100:
 Train Loss: 738.412161 (Recon: 736.205497, KLD: 2206.663377)
 Validation Loss: 1067.705879
 Learning Rate: 0.000002


Training Progress:  68%|██████▊   | 68/100 [1:15:35<33:50, 63.44s/it, train_loss=740.9649, val_loss=1074.2909]


Epoch 68/100:
 Train Loss: 740.964942 (Recon: 738.761730, KLD: 2203.212947)
 Validation Loss: 1074.290882
 Learning Rate: 0.000002

Learning rate adjusted from 0.000002 to 0.000001


Training Progress:  69%|██████▉   | 69/100 [1:16:38<32:42, 63.31s/it, train_loss=734.9225, val_loss=1073.7082]


Epoch 69/100:
 Train Loss: 734.922488 (Recon: 732.719282, KLD: 2203.205683)
 Validation Loss: 1073.708150
 Learning Rate: 0.000001


Training Progress:  69%|██████▉   | 69/100 [1:17:40<32:42, 63.31s/it, train_loss=737.2982, val_loss=1074.3075]


Epoch 70/100:
 Train Loss: 737.298221 (Recon: 735.096414, KLD: 2201.808016)
 Validation Loss: 1074.307477
 Learning Rate: 0.000001
Model checkpoint saved to /kaggle/working/checkpoints/vae_epoch_70.pth
Sample images saved for epoch 70

Running evaluation at epoch 70...
For better results, install LPIPS with: pip install lpips
Inception model loaded for FID calculation



Evaluating corruption types:   0%|          | 0/6 [00:00<?, ?it/s][A
Evaluating mask:   0%|          | 0/6 [00:00<?, ?it/s]            [A
Evaluating mask:   0%|          | 0/6 [00:06<?, ?it/s, lpips=1.3884, psnr=24.13dB, ssim=0.9620][A
Evaluating mask:  17%|█▋        | 1/6 [00:06<00:31,  6.30s/it, lpips=1.3884, psnr=24.13dB, ssim=0.9620][A
Evaluating gaussian:  17%|█▋        | 1/6 [00:06<00:31,  6.30s/it, lpips=1.3884, psnr=24.13dB, ssim=0.9620][A
Evaluating gaussian:  17%|█▋        | 1/6 [00:12<00:31,  6.30s/it, lpips=1.3612, psnr=24.47dB, ssim=0.9649][A
Evaluating gaussian:  33%|███▎      | 2/6 [00:12<00:25,  6.48s/it, lpips=1.3612, psnr=24.47dB, ssim=0.9649][A
Evaluating salt_pepper:  33%|███▎      | 2/6 [00:12<00:25,  6.48s/it, lpips=1.3612, psnr=24.47dB, ssim=0.9649][A
Evaluating salt_pepper:  33%|███▎      | 2/6 [00:19<00:25,  6.48s/it, lpips=1.3964, psnr=24.43dB, ssim=0.9640][A
Evaluating salt_pepper:  50%|█████     | 3/6 [00:19<00:19,  6.45s/it, lpips=1.3964, psnr=24.


Validation Metrics:
 mask: PSNR = 24.13dB, SSIM = 0.9620, LPIPS = 1.3884, FID = 114.91, Inference Time = 0.36ms
 gaussian: PSNR = 24.47dB, SSIM = 0.9649, LPIPS = 1.3612, FID = 116.44, Inference Time = 0.27ms
 salt_pepper: PSNR = 24.43dB, SSIM = 0.9640, LPIPS = 1.3964, FID = 118.15, Inference Time = 0.24ms
 blur: PSNR = 24.34dB, SSIM = 0.9631, LPIPS = 1.6057, FID = 120.44, Inference Time = 0.25ms
 jpeg: PSNR = 24.32dB, SSIM = 0.9640, LPIPS = 1.3833, FID = 117.66, Inference Time = 0.24ms
 combined: PSNR = 23.78dB, SSIM = 0.9582, LPIPS = 1.6236, FID = 119.72, Inference Time = 0.30ms


Training Progress:  71%|███████   | 71/100 [1:19:25<34:39, 71.72s/it, train_loss=734.2165, val_loss=1078.0374]


Epoch 71/100:
 Train Loss: 734.216494 (Recon: 732.015749, KLD: 2200.745309)
 Validation Loss: 1078.037360
 Learning Rate: 0.000001


Training Progress:  72%|███████▏  | 72/100 [1:20:26<32:04, 68.75s/it, train_loss=736.7036, val_loss=1081.3611]


Epoch 72/100:
 Train Loss: 736.703596 (Recon: 734.505046, KLD: 2198.549995)
 Validation Loss: 1081.361102
 Learning Rate: 0.000001


Training Progress:  73%|███████▎  | 73/100 [1:21:28<29:56, 66.56s/it, train_loss=734.9414, val_loss=1081.1134]


Epoch 73/100:
 Train Loss: 734.941378 (Recon: 732.743737, KLD: 2197.640597)
 Validation Loss: 1081.113400
 Learning Rate: 0.000001


Training Progress:  74%|███████▍  | 74/100 [1:22:29<28:07, 64.89s/it, train_loss=734.4293, val_loss=1069.9745]


Epoch 74/100:
 Train Loss: 734.429314 (Recon: 732.232666, KLD: 2196.648457)
 Validation Loss: 1069.974507
 Learning Rate: 0.000001

Learning rate adjusted from 0.000001 to 0.000000


Training Progress:  75%|███████▌  | 75/100 [1:23:30<26:32, 63.69s/it, train_loss=732.9946, val_loss=1083.4639]


Epoch 75/100:
 Train Loss: 732.994600 (Recon: 730.798937, KLD: 2195.663213)
 Validation Loss: 1083.463950
 Learning Rate: 0.000000


Training Progress:  76%|███████▌  | 76/100 [1:24:32<25:16, 63.18s/it, train_loss=734.9359, val_loss=1075.1659]


Epoch 76/100:
 Train Loss: 734.935852 (Recon: 732.740506, KLD: 2195.346172)
 Validation Loss: 1075.165867
 Learning Rate: 0.000000


Training Progress:  77%|███████▋  | 77/100 [1:25:33<24:02, 62.71s/it, train_loss=736.8772, val_loss=1077.4536]


Epoch 77/100:
 Train Loss: 736.877168 (Recon: 734.683155, KLD: 2194.012734)
 Validation Loss: 1077.453588
 Learning Rate: 0.000000


Training Progress:  78%|███████▊  | 78/100 [1:26:35<22:54, 62.48s/it, train_loss=735.4997, val_loss=1080.0774]


Epoch 78/100:
 Train Loss: 735.499717 (Recon: 733.305229, KLD: 2194.488152)
 Validation Loss: 1080.077371
 Learning Rate: 0.000000


Training Progress:  79%|███████▉  | 79/100 [1:27:38<21:54, 62.58s/it, train_loss=735.9558, val_loss=1075.8058]


Epoch 79/100:
 Train Loss: 735.955815 (Recon: 733.762493, KLD: 2193.322083)
 Validation Loss: 1075.805760
 Learning Rate: 0.000000


Training Progress:  79%|███████▉  | 79/100 [1:28:40<21:54, 62.58s/it, train_loss=735.6809, val_loss=1080.8910]


Epoch 80/100:
 Train Loss: 735.680904 (Recon: 733.487659, KLD: 2193.244013)
 Validation Loss: 1080.891018
 Learning Rate: 0.000000
Model checkpoint saved to /kaggle/working/checkpoints/vae_epoch_80.pth
Sample images saved for epoch 80

Running evaluation at epoch 80...
For better results, install LPIPS with: pip install lpips
Inception model loaded for FID calculation



Evaluating corruption types:   0%|          | 0/6 [00:00<?, ?it/s][A
Evaluating mask:   0%|          | 0/6 [00:00<?, ?it/s]            [A
Evaluating mask:   0%|          | 0/6 [00:08<?, ?it/s, lpips=1.3472, psnr=23.87dB, ssim=0.9608][A
Evaluating mask:  17%|█▋        | 1/6 [00:08<00:42,  8.40s/it, lpips=1.3472, psnr=23.87dB, ssim=0.9608][A
Evaluating gaussian:  17%|█▋        | 1/6 [00:08<00:42,  8.40s/it, lpips=1.3472, psnr=23.87dB, ssim=0.9608][A
Evaluating gaussian:  17%|█▋        | 1/6 [00:14<00:42,  8.40s/it, lpips=1.3232, psnr=24.45dB, ssim=0.9647][A
Evaluating gaussian:  33%|███▎      | 2/6 [00:14<00:28,  7.13s/it, lpips=1.3232, psnr=24.45dB, ssim=0.9647][A
Evaluating salt_pepper:  33%|███▎      | 2/6 [00:14<00:28,  7.13s/it, lpips=1.3232, psnr=24.45dB, ssim=0.9647][A
Evaluating salt_pepper:  33%|███▎      | 2/6 [00:20<00:28,  7.13s/it, lpips=1.3473, psnr=24.40dB, ssim=0.9640][A
Evaluating salt_pepper:  50%|█████     | 3/6 [00:20<00:20,  6.76s/it, lpips=1.3473, psnr=24.


Validation Metrics:
 mask: PSNR = 23.87dB, SSIM = 0.9608, LPIPS = 1.3472, FID = 111.78, Inference Time = 0.34ms
 gaussian: PSNR = 24.45dB, SSIM = 0.9647, LPIPS = 1.3232, FID = 114.05, Inference Time = 0.30ms
 salt_pepper: PSNR = 24.40dB, SSIM = 0.9640, LPIPS = 1.3473, FID = 115.66, Inference Time = 0.23ms
 blur: PSNR = 24.37dB, SSIM = 0.9634, LPIPS = 1.5478, FID = 119.19, Inference Time = 0.23ms
 jpeg: PSNR = 24.29dB, SSIM = 0.9639, LPIPS = 1.3412, FID = 116.26, Inference Time = 0.24ms
 combined: PSNR = 24.18dB, SSIM = 0.9612, LPIPS = 1.4941, FID = 118.38, Inference Time = 0.24ms

Learning rate adjusted from 0.000000 to 0.000000


Training Progress:  81%|████████  | 81/100 [1:30:27<22:42, 71.69s/it, train_loss=733.3484, val_loss=1083.1689]


Epoch 81/100:
 Train Loss: 733.348382 (Recon: 731.155779, KLD: 2192.603260)
 Validation Loss: 1083.168926
 Learning Rate: 0.000000


Training Progress:  82%|████████▏ | 82/100 [1:31:30<20:43, 69.10s/it, train_loss=736.4258, val_loss=1077.1343]


Epoch 82/100:
 Train Loss: 736.425763 (Recon: 734.233804, KLD: 2191.958167)
 Validation Loss: 1077.134342
 Learning Rate: 0.000000


Training Progress:  83%|████████▎ | 83/100 [1:32:32<19:01, 67.15s/it, train_loss=734.7262, val_loss=1071.2673]


Epoch 83/100:
 Train Loss: 734.726210 (Recon: 732.534050, KLD: 2192.159798)
 Validation Loss: 1071.267336
 Learning Rate: 0.000000


Training Progress:  84%|████████▍ | 84/100 [1:33:35<17:33, 65.87s/it, train_loss=733.8777, val_loss=1077.6119]


Epoch 84/100:
 Train Loss: 733.877671 (Recon: 731.685073, KLD: 2192.599432)
 Validation Loss: 1077.611885
 Learning Rate: 0.000000


Training Progress:  85%|████████▌ | 85/100 [1:34:39<16:18, 65.24s/it, train_loss=736.2459, val_loss=1081.5715]


Epoch 85/100:
 Train Loss: 736.245925 (Recon: 734.054865, KLD: 2191.059469)
 Validation Loss: 1081.571520
 Learning Rate: 0.000000


Training Progress:  86%|████████▌ | 86/100 [1:35:42<15:03, 64.51s/it, train_loss=736.0862, val_loss=1072.7279]


Epoch 86/100:
 Train Loss: 736.086219 (Recon: 733.895326, KLD: 2190.893765)
 Validation Loss: 1072.727900
 Learning Rate: 0.000000

Learning rate adjusted from 0.000000 to 0.000000


Training Progress:  87%|████████▋ | 87/100 [1:36:45<13:53, 64.09s/it, train_loss=735.6711, val_loss=1080.0843]


Epoch 87/100:
 Train Loss: 735.671112 (Recon: 733.480292, KLD: 2190.819696)
 Validation Loss: 1080.084329
 Learning Rate: 0.000000


Training Progress:  88%|████████▊ | 88/100 [1:37:48<12:44, 63.70s/it, train_loss=733.3625, val_loss=1071.6924]


Epoch 88/100:
 Train Loss: 733.362480 (Recon: 731.171363, KLD: 2191.117211)
 Validation Loss: 1071.692382
 Learning Rate: 0.000000


Training Progress:  89%|████████▉ | 89/100 [1:38:51<11:37, 63.42s/it, train_loss=733.8557, val_loss=1087.1347]


Epoch 89/100:
 Train Loss: 733.855741 (Recon: 731.664512, KLD: 2191.229531)
 Validation Loss: 1087.134687
 Learning Rate: 0.000000


Training Progress:  89%|████████▉ | 89/100 [1:39:52<11:37, 63.42s/it, train_loss=735.8126, val_loss=1077.5064]


Epoch 90/100:
 Train Loss: 735.812601 (Recon: 733.623389, KLD: 2189.211017)
 Validation Loss: 1077.506389
 Learning Rate: 0.000000
Model checkpoint saved to /kaggle/working/checkpoints/vae_epoch_90.pth
Sample images saved for epoch 90

Running evaluation at epoch 90...
For better results, install LPIPS with: pip install lpips
Inception model loaded for FID calculation



Evaluating corruption types:   0%|          | 0/6 [00:00<?, ?it/s][A
Evaluating mask:   0%|          | 0/6 [00:00<?, ?it/s]            [A
Evaluating mask:   0%|          | 0/6 [00:08<?, ?it/s, lpips=1.4060, psnr=23.96dB, ssim=0.9615][A
Evaluating mask:  17%|█▋        | 1/6 [00:08<00:41,  8.38s/it, lpips=1.4060, psnr=23.96dB, ssim=0.9615][A
Evaluating gaussian:  17%|█▋        | 1/6 [00:08<00:41,  8.38s/it, lpips=1.4060, psnr=23.96dB, ssim=0.9615][A
Evaluating gaussian:  17%|█▋        | 1/6 [00:14<00:41,  8.38s/it, lpips=1.3477, psnr=24.47dB, ssim=0.9648][A
Evaluating gaussian:  33%|███▎      | 2/6 [00:14<00:28,  7.07s/it, lpips=1.3477, psnr=24.47dB, ssim=0.9648][A
Evaluating salt_pepper:  33%|███▎      | 2/6 [00:14<00:28,  7.07s/it, lpips=1.3477, psnr=24.47dB, ssim=0.9648][A
Evaluating salt_pepper:  33%|███▎      | 2/6 [00:20<00:28,  7.07s/it, lpips=1.3786, psnr=24.40dB, ssim=0.9637][A
Evaluating salt_pepper:  50%|█████     | 3/6 [00:20<00:20,  6.68s/it, lpips=1.3786, psnr=24.


Validation Metrics:
 mask: PSNR = 23.96dB, SSIM = 0.9615, LPIPS = 1.4060, FID = 114.35, Inference Time = 0.35ms
 gaussian: PSNR = 24.47dB, SSIM = 0.9648, LPIPS = 1.3477, FID = 115.13, Inference Time = 0.33ms
 salt_pepper: PSNR = 24.40dB, SSIM = 0.9637, LPIPS = 1.3786, FID = 115.85, Inference Time = 0.25ms
 blur: PSNR = 24.33dB, SSIM = 0.9629, LPIPS = 1.5830, FID = 120.29, Inference Time = 0.23ms
 jpeg: PSNR = 24.33dB, SSIM = 0.9638, LPIPS = 1.3650, FID = 117.27, Inference Time = 0.23ms
 combined: PSNR = 24.23dB, SSIM = 0.9623, LPIPS = 1.5622, FID = 119.57, Inference Time = 0.36ms


Training Progress:  91%|█████████ | 91/100 [1:41:37<10:43, 71.47s/it, train_loss=733.5858, val_loss=1077.4575]


Epoch 91/100:
 Train Loss: 733.585789 (Recon: 731.395100, KLD: 2190.688289)
 Validation Loss: 1077.457543
 Learning Rate: 0.000000


Training Progress:  92%|█████████▏| 92/100 [1:42:39<09:09, 68.75s/it, train_loss=735.1114, val_loss=1079.3673]


Epoch 92/100:
 Train Loss: 735.111386 (Recon: 732.921307, KLD: 2190.079028)
 Validation Loss: 1079.367345
 Learning Rate: 0.000000

Learning rate adjusted from 0.000000 to 0.000000


Training Progress:  93%|█████████▎| 93/100 [1:43:40<07:45, 66.51s/it, train_loss=735.9137, val_loss=1083.0168]


Epoch 93/100:
 Train Loss: 735.913732 (Recon: 733.724095, KLD: 2189.635968)
 Validation Loss: 1083.016845
 Learning Rate: 0.000000


Training Progress:  94%|█████████▍| 94/100 [1:44:42<06:31, 65.18s/it, train_loss=730.7899, val_loss=1075.2229]


Epoch 94/100:
 Train Loss: 730.789902 (Recon: 728.598984, KLD: 2190.917361)
 Validation Loss: 1075.222917
 Learning Rate: 0.000000


Training Progress:  95%|█████████▌| 95/100 [1:45:44<05:21, 64.21s/it, train_loss=732.7542, val_loss=1078.1843]


Epoch 95/100:
 Train Loss: 732.754175 (Recon: 730.563917, KLD: 2190.256867)
 Validation Loss: 1078.184287
 Learning Rate: 0.000000


Training Progress:  96%|█████████▌| 96/100 [1:46:48<04:15, 63.89s/it, train_loss=732.6911, val_loss=1074.7921]


Epoch 96/100:
 Train Loss: 732.691053 (Recon: 730.501127, KLD: 2189.925704)
 Validation Loss: 1074.792131
 Learning Rate: 0.000000


Training Progress:  97%|█████████▋| 97/100 [1:47:51<03:11, 63.81s/it, train_loss=731.7401, val_loss=1083.5732]


Epoch 97/100:
 Train Loss: 731.740122 (Recon: 729.549243, KLD: 2190.879344)
 Validation Loss: 1083.573236
 Learning Rate: 0.000000


Training Progress:  98%|█████████▊| 98/100 [1:48:55<02:07, 63.77s/it, train_loss=734.1329, val_loss=1085.2506]


Epoch 98/100:
 Train Loss: 734.132929 (Recon: 731.943395, KLD: 2189.533941)
 Validation Loss: 1085.250553
 Learning Rate: 0.000000

Learning rate adjusted from 0.000000 to 0.000000


Training Progress:  99%|█████████▉| 99/100 [1:49:58<01:03, 63.69s/it, train_loss=735.8425, val_loss=1082.1703]


Epoch 99/100:
 Train Loss: 735.842501 (Recon: 733.653220, KLD: 2189.281521)
 Validation Loss: 1082.170311
 Learning Rate: 0.000000


Training Progress:  99%|█████████▉| 99/100 [1:51:02<01:03, 63.69s/it, train_loss=733.7148, val_loss=1079.8387]


Epoch 100/100:
 Train Loss: 733.714842 (Recon: 731.524655, KLD: 2190.187880)
 Validation Loss: 1079.838712
 Learning Rate: 0.000000
Model checkpoint saved to /kaggle/working/checkpoints/vae_epoch_100.pth
Sample images saved for epoch 100

Running evaluation at epoch 100...
For better results, install LPIPS with: pip install lpips
Inception model loaded for FID calculation



Evaluating corruption types:   0%|          | 0/6 [00:00<?, ?it/s][A
Evaluating mask:   0%|          | 0/6 [00:00<?, ?it/s]            [A
Evaluating mask:   0%|          | 0/6 [00:08<?, ?it/s, lpips=1.3552, psnr=24.17dB, ssim=0.9627][A
Evaluating mask:  17%|█▋        | 1/6 [00:08<00:43,  8.67s/it, lpips=1.3552, psnr=24.17dB, ssim=0.9627][A
Evaluating gaussian:  17%|█▋        | 1/6 [00:08<00:43,  8.67s/it, lpips=1.3552, psnr=24.17dB, ssim=0.9627][A
Evaluating gaussian:  17%|█▋        | 1/6 [00:15<00:43,  8.67s/it, lpips=1.3444, psnr=24.45dB, ssim=0.9648][A
Evaluating gaussian:  33%|███▎      | 2/6 [00:15<00:29,  7.40s/it, lpips=1.3444, psnr=24.45dB, ssim=0.9648][A
Evaluating salt_pepper:  33%|███▎      | 2/6 [00:15<00:29,  7.40s/it, lpips=1.3444, psnr=24.45dB, ssim=0.9648][A
Evaluating salt_pepper:  33%|███▎      | 2/6 [00:21<00:29,  7.40s/it, lpips=1.3789, psnr=24.39dB, ssim=0.9638][A
Evaluating salt_pepper:  50%|█████     | 3/6 [00:21<00:20,  6.98s/it, lpips=1.3789, psnr=24.


Validation Metrics:
 mask: PSNR = 24.17dB, SSIM = 0.9627, LPIPS = 1.3552, FID = 113.22, Inference Time = 0.34ms
 gaussian: PSNR = 24.45dB, SSIM = 0.9648, LPIPS = 1.3444, FID = 114.78, Inference Time = 0.26ms
 salt_pepper: PSNR = 24.39dB, SSIM = 0.9638, LPIPS = 1.3789, FID = 115.82, Inference Time = 0.24ms
 blur: PSNR = 24.33dB, SSIM = 0.9632, LPIPS = 1.5770, FID = 119.22, Inference Time = 0.24ms
 jpeg: PSNR = 24.30dB, SSIM = 0.9639, LPIPS = 1.3605, FID = 116.62, Inference Time = 0.26ms
 combined: PSNR = 24.07dB, SSIM = 0.9622, LPIPS = 1.3700, FID = 115.89, Inference Time = 0.29ms
Training completed!


Evaluating final model...:  60%|██████    | 3/5 [1:51:48<1:42:05, 3062.96s/it]

Final model saved to /kaggle/working/vae_final_model.pth
For better results, install LPIPS with: pip install lpips
Inception model loaded for FID calculation



Evaluating corruption types:   0%|          | 0/6 [00:00<?, ?it/s][A
Evaluating mask:   0%|          | 0/6 [00:00<?, ?it/s]            [A
Evaluating mask:   0%|          | 0/6 [00:06<?, ?it/s, lpips=1.4007, psnr=24.07dB, ssim=0.9617][A
Evaluating mask:  17%|█▋        | 1/6 [00:06<00:33,  6.69s/it, lpips=1.4007, psnr=24.07dB, ssim=0.9617][A
Evaluating gaussian:  17%|█▋        | 1/6 [00:06<00:33,  6.69s/it, lpips=1.4007, psnr=24.07dB, ssim=0.9617][A
Evaluating gaussian:  17%|█▋        | 1/6 [00:13<00:33,  6.69s/it, lpips=1.3409, psnr=24.45dB, ssim=0.9649][A
Evaluating gaussian:  33%|███▎      | 2/6 [00:13<00:26,  6.65s/it, lpips=1.3409, psnr=24.45dB, ssim=0.9649][A
Evaluating salt_pepper:  33%|███▎      | 2/6 [00:13<00:26,  6.65s/it, lpips=1.3409, psnr=24.45dB, ssim=0.9649][A
Evaluating salt_pepper:  33%|███▎      | 2/6 [00:19<00:26,  6.65s/it, lpips=1.3851, psnr=24.40dB, ssim=0.9639][A
Evaluating salt_pepper:  50%|█████     | 3/6 [00:19<00:19,  6.62s/it, lpips=1.3851, psnr=24.


Final Metrics:
 mask: PSNR = 24.07dB, SSIM = 0.9617, LPIPS = 1.4006945230066776
 mask Inference Time: 0.36ms
 gaussian: PSNR = 24.45dB, SSIM = 0.9649, LPIPS = 1.34086924046278
 gaussian Inference Time: 0.28ms
 salt_pepper: PSNR = 24.40dB, SSIM = 0.9639, LPIPS = 1.385096449404955
 salt_pepper Inference Time: 0.26ms
 blur: PSNR = 24.33dB, SSIM = 0.9631, LPIPS = 1.576094999909401
 blur Inference Time: 0.25ms
 jpeg: PSNR = 24.30dB, SSIM = 0.9638, LPIPS = 1.3601593896746635
 jpeg Inference Time: 0.25ms
 combined: PSNR = 24.21dB, SSIM = 0.9617, LPIPS = 1.5392706990242004
 combined Inference Time: 0.24ms
Metrics saved to /kaggle/working/metrics/final_metrics.json
Final model saved to /kaggle/working/checkpoints/vae_final_model.pth

ClearVision Pipeline completed!





Final comparison image saved to /kaggle/working/final_results.png

Model Testing Results:
For better results, install LPIPS with: pip install lpips
Mask Corruption:
  PSNR: 24.13dB
  SSIM: 0.9735
  LPIPS: 1.4343
For better results, install LPIPS with: pip install lpips
Gaussian Corruption:
  PSNR: 24.42dB
  SSIM: 0.9676
  LPIPS: 1.2685
For better results, install LPIPS with: pip install lpips
Salt_pepper Corruption:
  PSNR: 24.63dB
  SSIM: 0.9596
  LPIPS: 1.6075
For better results, install LPIPS with: pip install lpips
Blur Corruption:
  PSNR: 24.33dB
  SSIM: 0.9628
  LPIPS: 2.1067
For better results, install LPIPS with: pip install lpips
Jpeg Corruption:
  PSNR: 24.03dB
  SSIM: 0.9590
  LPIPS: 1.4422
For better results, install LPIPS with: pip install lpips
Combined Corruption:
  PSNR: 22.78dB
  SSIM: 0.9537
  LPIPS: 1.9756

Test images saved to /kaggle/working/test_outputs

    ## ClearVision Image Restoration Complete
    
    The model has been trained to restore corrupted images u