# Unsupervised Raindrop Removal System with Day/Night Handling
This notebook implements and improves upon a paper for unsupervised raindrop removal.

## Library Explanations
- `torch`: PyTorch is an open-source machine learning library used for applications such as computer vision and natural language processing.
- `torch.nn`: A subpackage of PyTorch that provides modules and classes to build neural networks.
- `torch.optim`: A subpackage of PyTorch that contains optimization algorithms.
- `torch.utils.data`: A PyTorch package that provides utilities for data loading and processing.
- `torchvision`: A PyTorch package that provides tools for image processing, including pre-trained models and transformations.
- `lpips`: A library for computing the Learned Perceptual Image Patch Similarity metric, used for evaluating image quality.
- `numpy`: A library for numerical computations in Python.
- `skimage.metrics`: A module from the scikit-image library that provides functions to measure image quality, such as PSNR and SSIM.
- `os`: A standard library in Python for interacting with the operating system.
- `PIL`: The Python Imaging Library, used for opening, manipulating, and saving images.
- `tqdm`: A library for creating progress bars in Python.


In [None]:
from IPython.display import clear_output

In [None]:
%pip install lpips
#!pip install lpips (Kaggle)
clear_output()

In [None]:
"""
Unsupervised Raindrop Removal System with Day/Night Handling
Paper Implementation + Improvements
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms, models
from torchvision.utils import save_image
import lpips
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import os
from PIL import Image
from tqdm import 
import glob

# --------------------- CONFIG ---------------------
The following configuration parameters are used in the implementation:
- `device`: Specifies whether to use a GPU (if available) or CPU for training.
- `batch_size`: The number of samples processed before the model is updated.
- `num_epochs`: The number of complete passes through the training dataset.
- `lr`: The learning rate for the optimizer.
- `lambda_cycle`: The weight for the cycle consistency loss.
- `lambda_percep`: The weight for the perceptual loss.
- `lambda_blur`: The weight for the blur loss.
- `val_ratio`: The ratio of the dataset to be used for validation.


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 6
num_epochs = 50
lr = 0.0001
lambda_cycle = 10
lambda_percep = 5
lambda_blur = 2
val_ratio = 0.1

# --------------------- DATASET ---------------------


In [None]:
class MultiDomainRaindropDataset(Dataset):
    def __init__(self, root_dir, mode = 'train', test_split_ratio = 0.1):
        self.mode = mode
        self.clear_path = os.path.join(root_dir, 'Clear')
        self.blur_path = os.path.join(root_dir, 'Blur')
        self.drop_path = os.path.join(root_dir, 'Drop')
        
        self.clear_images = sorted(glob.glob(os.path.join(self.clear_path, '**', '*.*'), recursive=True))
        self.blur_images = sorted(glob.glob(os.path.join(self.blur_path, '**', '*.*'), recursive=True))
        self.drop_images = sorted(glob.glob(os.path.join(self.drop_path, '**', '*.*'), recursive=True))
        
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            #transforms.RandomHorizontalFlip() if mode == 'train' else transforms.Lambda(lambda x: x),
            transforms.ToTensor(),
            #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        super().__init__()
        
    def __len__(self):
        return max(len(self.clear_images), len(self.blur_images), len(self.drop_images))
    
    def __getitem__(self, idx):
        try:
            clear = self.transform(Image.open(self.clear_images[idx % len(self.clear_images)]).convert('RGB'))
            blur = self.transform(Image.open(self.blur_images[idx % len(self.blur_images)]).convert('RGB'))
            drop = self.transform(Image.open(self.drop_images[idx % len(self.drop_images)]).convert('RGB'))
        except Exception as e:
            print(f"Error loading image: {e}")
            return self[(idx + 1) % len(self)]
            
        return {'clear': clear, 'blur': blur, 'drop': drop}

# --------------------- MODELS ---------------------


In [None]:
class DomainAwareGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        # Shared encoder
        self.enc_conv1 = nn.Conv2d(3, 64, 4, 2, 1)
        self.enc_conv2 = nn.Conv2d(64, 128, 4, 2, 1)
        self.enc_conv3 = nn.Conv2d(128, 256, 4, 2, 1)
        self.enc_conv4 = nn.Conv2d(256, 512, 4, 2, 1)
        
        # Domain-specific adaptation
        self.domain_embedding = nn.Embedding(2, 512)  # 0: day, 1: night
        
        # Decoder
        self.dec_conv1 = nn.ConvTranspose2d(512, 256, 4, 2, 1)
        self.dec_conv2 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
        self.dec_conv3 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
        self.dec_conv4 = nn.ConvTranspose2d(64, 3, 4, 2, 1)
        
        self.attn = nn.Sequential(
            nn.Conv2d(512, 512, 1),
            nn.Sigmoid()
        )
        
        self.leaky_relu = nn.LeakyReLU(0.2)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, x, domain_label):
        # Encoder
        e1 = self.leaky_relu(self.enc_conv1(x))
        e2 = self.leaky_relu(self.enc_conv2(e1))
        e3 = self.leaky_relu(self.enc_conv3(e2))
        e4 = self.leaky_relu(self.enc_conv4(e3))
        
        # Domain conditioning
        domain_vec = self.domain_embedding(domain_label).unsqueeze(-1).unsqueeze(-1)
        domain_aware_feat = e4 * domain_vec
        
        # Attention mechanism
        attn_map = self.attn(domain_aware_feat)
        # print(f"domain_aware_feat.shape: {domain_aware_feat.shape}") 
        # print(f"attn_map.shape: {attn_map.shape}")  
        attended_feat = domain_aware_feat * attn_map

        # Decoder
        d1 = self.relu(self.dec_conv1(attended_feat) + e3)
        d2 = self.relu(self.dec_conv2(d1) + e2)
        d3 = self.relu(self.dec_conv3(d2) + e1)
        d4 = self.tanh(self.dec_conv4(d3))
        
        return d4

In [None]:
class MultiScaleDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 4, 2, 1)
        self.conv2 = nn.Conv2d(64, 128, 4, 2, 1)
        self.conv3 = nn.Conv2d(128, 256, 4, 2, 1)
        self.conv4 = nn.Conv2d(256, 512, 4, 1, 1)
        self.conv5 = nn.Conv2d(512, 1, 4, 1, 1)
        
        self.inst_norm = nn.InstanceNorm2d
        self.leaky_relu = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.leaky_relu(self.conv1(x))
        x = self.leaky_relu(self.inst_norm(128)(self.conv2(x)))
        x = self.leaky_relu(self.inst_norm(256)(self.conv3(x)))
        x = self.leaky_relu(self.inst_norm(512)(self.conv4(x)))
        x = self.conv5(x)
        return x

# --------------------- SYSTEMS ---------------------


In [None]:
class RaindropRemovalSystem:
    def __init__(self):
        # self.G_day = DomainAwareGenerator().to(device)
        # self.G_night = DomainAwareGenerator().to(device)
        self.G = DomainAwareGenerator().to(device) # Single generator for both domains
        self.D_clear = MultiScaleDiscriminator().to(device)
        self.D_blur = MultiScaleDiscriminator().to(device)
        
        # self.optim_G = optim.Adam(
        #     list(self.G_day.parameters()) + list(self.G_night.parameters()),
        #     lr=lr, betas=(0.5, 0.999)
        # )
        self.optim_G = optim.Adam(self.G.parameters(), lr=lr, betas=(0.5, 0.999))
        self.optim_D = optim.Adam(
            list(self.D_clear.parameters()) + list(self.D_blur.parameters()),
            lr=lr, betas=(0.5, 0.999)
        )
        
        
        self.vgg = models.vgg19(pretrained=True).features[:16].to(device).eval()
        self.lpips = lpips.LPIPS(net='vgg').to(device)
        
        self.criterion_gan = nn.MSELoss()
        self.criterion_cycle = nn.L1Loss()
        self.criterion_percep = nn.L1Loss()
        
        self.best_metrics = {'day': {'PSNR': 0, 'SSIM': 0}, 'night': {'PSNR': 0, 'SSIM': 0}}

    def compute_metrics(self, pred, target):
        pred_np = pred.permute(0, 2, 3, 1).cpu().detach().numpy()
        target_np = target.permute(0, 2, 3, 1).cpu().detach().numpy()
        
        psnr_val = np.mean([psnr(t, p, data_range=1.0) for p, t in zip(pred_np, target_np)])
        ssim_val = np.mean([ssim(t, p, channel_axis=-1, data_range=1.0) for p, t in zip(pred_np, target_np)])
        
        lpips_val = self.lpips(pred, target).mean().item()
        
        return {'PSNR': psnr_val, 'SSIM': ssim_val, 'LPIPS': lpips_val}
    
    def train_step(self, batch_day, batch_night):
        # Prepare data
        clear_day = batch_day['clear'].to(device)
        blur_day = batch_day['blur'].to(device)
        drop_day = batch_day['drop'].to(device)
        
        clear_night = batch_night['clear'].to(device)
        blur_night = batch_night['blur'].to(device)
        drop_night = batch_night['drop'].to(device)
        
        # Domain labels (0: day, 1: night)
        day_labels = torch.zeros(drop_day.size(0), dtype=torch.long).to(device)
        night_labels = torch.ones(drop_night.size(0), dtype=torch.long).to(device)
        
        # Train Generators
        self.optim_G.zero_grad()
        
        # Forward passes
        restored_day = self.G(drop_day, day_labels)
        restored_night = self.G(drop_night, night_labels)
        
        # Cycle consistency
        cycle_day = self.G(restored_day, day_labels)
        cycle_night = self.G(restored_night, night_labels)
        
        # Deblurring
        deblur_day = self.G(blur_day, day_labels)
        deblur_night = self.G(blur_night, night_labels)
        
        # Loss calculations
        # GAN Loss for clear
        g_gan_day = self.criterion_gan(self.D_clear(restored_day), torch.ones_like(self.D_clear(restored_day)))
        g_gan_night = self.criterion_gan(self.D_clear(restored_night), torch.ones_like(self.D_clear(restored_night)))
        
        # GAN Loss for deblur
        g_deblur_gan_day = self.criterion_gan(self.D_blur(deblur_day), torch.ones_like(self.D_blur(deblur_day)))
        g_deblur_gan_night = self.criterion_gan(self.D_blur(deblur_night), torch.ones_like(self.D_blur(deblur_night)))
        
        # Cycle Loss
        cycle_loss_day = self.criterion_cycle(cycle_day, drop_day)
        cycle_loss_night = self.criterion_cycle(cycle_night, drop_night)
        
        # Perceptual Loss
        percep_loss_day = self.criterion_percep(self.vgg(restored_day), self.vgg(clear_day))
        percep_loss_night = self.criterion_percep(self.vgg(restored_night), self.vgg(clear_night))
        
        # Deblur Loss
        deblur_loss_day = self.criterion_cycle(deblur_day, clear_day)
        deblur_loss_night = self.criterion_cycle(deblur_night, clear_night)
        
        # Total Generator Loss
        g_total = (
            g_gan_day + g_gan_night + g_deblur_gan_day + g_deblur_gan_night +
            lambda_cycle * (cycle_loss_day + cycle_loss_night) +
            lambda_percep * (percep_loss_day + percep_loss_night) +
            lambda_blur * (deblur_loss_day + deblur_loss_night)
        )
        
        g_total.backward()
        self.optim_G.step()
        
        # Train Discriminators
        self.optim_D.zero_grad()
        
        # Real data
        real_clear_day = self.D_clear(clear_day)
        real_clear_night = self.D_clear(clear_night)
        d_real_clear = (self.criterion_gan(real_clear_day, torch.ones_like(real_clear_day)) +
                       self.criterion_gan(real_clear_night, torch.ones_like(real_clear_night))) / 2
        
        # Fake data
        fake_clear_day = self.D_clear(restored_day.detach())
        fake_clear_night = self.D_clear(restored_night.detach())
        d_fake_clear = (self.criterion_gan(fake_clear_day, torch.zeros_like(fake_clear_day)) +
                       self.criterion_gan(fake_clear_night, torch.zeros_like(fake_clear_night))) / 2
        
        # Loss for D_clear
        d_clear_loss = 0.5 * (d_real_clear + d_fake_clear)
        
        # Loss for D_blur (deblur)
        real_blur_day = self.D_blur(blur_day)
        real_blur_night = self.D_blur(blur_night)
        d_real_blur = (self.criterion_gan(real_blur_day, torch.ones_like(real_blur_day)) +
                      self.criterion_gan(real_blur_night, torch.ones_like(real_blur_night))) / 2
        
        fake_deblur_day = self.D_blur(deblur_day.detach())
        fake_deblur_night = self.D_blur(deblur_night.detach())
        d_fake_blur = (self.criterion_gan(fake_deblur_day, torch.zeros_like(fake_deblur_day)) +
                      self.criterion_gan(fake_deblur_night, torch.zeros_like(fake_deblur_night))) / 2
        
        d_blur_loss = 0.5 * (d_real_blur + d_fake_blur)
        
        # Tổng Discriminator Loss
        d_total = d_clear_loss + d_blur_loss
        d_total.backward()
        self.optim_D.step()
        
        return {
            'g_total': g_total.item(),
            'd_total': d_total.item(),
            'cycle_day': cycle_loss_day.item(),
            'percep_day': percep_loss_day.item()
        }

    def validate(self, val_loader, domain):
        self.G.eval()
        
        total_metrics = {'PSNR': 0, 'SSIM': 0, 'LPIPS': 0}
        domain_label = torch.zeros(1, dtype=torch.long).to(device) if domain == 'day' else torch.ones(1, dtype=torch.long).to(device)
        
        with torch.no_grad():
            for batch in val_loader:
                inputs = batch['drop'].to(device)
                targets = batch['clear'].to(device)
                
                outputs = self.G(inputs, domain_label.expand(inputs.size(0)))
                
                metrics = self.compute_metrics(outputs, targets)
                for k in total_metrics:
                    total_metrics[k] += metrics[k]
                    
        for k in total_metrics:
            total_metrics[k] /= len(val_loader)
            
        return total_metrics

    def save_checkpoint(self, path, epoch, is_best=False):
        state = {
            'G': self.G.state_dict(),
            'D_clear': self.D_clear.state_dict(),
            'D_blur': self.D_blur.state_dict(),
            'optim_G': self.optim_G.state_dict(),
            'optim_D': self.optim_D.state_dict(),
            'epoch': epoch,
            'best_metrics': self.best_metrics
        }
        torch.save(state, path)
        if is_best:
            torch.save(state, "best_model.pth")

    def load_checkpoint(self, path):
        checkpoint = torch.load(path)
        self.G.load_state_dict(checkpoint['G'])
        self.D_clear.load_state_dict(checkpoint['D_clear'])
        self.D_blur.load_state_dict(checkpoint['D_blur'])
        self.optim_G.load_state_dict(checkpoint['optim_G'])
        self.optim_D.load_state_dict(checkpoint['optim_D'])
        self.best_metrics = checkpoint['best_metrics']
        return checkpoint['epoch']

# --------------------- TRAINING SETUP ---------------------


In [None]:
def main():
    # Prepare datasets
    def get_loaders(path):
        full_ds = MultiDomainRaindropDataset(path)
        train_size = int((1 - val_ratio) * len(full_ds))
        val_size = len(full_ds) - train_size
        train_ds, val_ds = random_split(full_ds, [train_size, val_size])
        return DataLoader(train_ds, batch_size, shuffle=True, num_workers=4, pin_memory=True), \
               DataLoader(val_ds, batch_size, shuffle=False, num_workers=2, pin_memory=True)

    day_train_loader, day_val_loader = get_loaders('/kaggle/input/raindrop-daynight/DayRainDrop_Train/DayRainDrop_Train')
    night_train_loader, night_val_loader = get_loaders('/kaggle/input/raindrop-daynight/NightRainDrop_Train/NightRainDrop_Train')

    # Initialize system
    system = RaindropRemovalSystem()
    start_epoch = 0
    
    # Main training loop
    for epoch in range(start_epoch, num_epochs):
        # Training
        system.G.train()
        
        progress = tqdm(zip(day_train_loader, night_train_loader), 
                        total=min(len(day_train_loader), len(night_train_loader)), desc=f"Epoch {epoch}")
        
        for batch_idx, (batch_day, batch_night) in enumerate(progress):
            losses = system.train_step(batch_day, batch_night)
            
            progress.set_description(
                f"Epoch {epoch} | G: {losses['g_total']:.3f} | D: {losses['d_total']:.3f} | "
                f"Cycle: {losses['cycle_day']:.3f}"
            )
        
        # Validation
        day_metrics = system.validate(day_val_loader, 'day')
        night_metrics = system.validate(night_val_loader, 'night')
        
        # Save best models
        if day_metrics['PSNR'] > system.best_metrics['day']['PSNR']:
            system.best_metrics['day'] = day_metrics
            system.save_checkpoint(f"best_day_model_epoch{epoch}.pth", epoch, is_best=True)
            
        if night_metrics['PSNR'] > system.best_metrics['night']['PSNR']:
            system.best_metrics['night'] = night_metrics
            system.save_checkpoint(f"best_night_model_epoch{epoch}.pth", epoch, is_best=True)
        
        # Print metrics
        print(f"\nValidation @ Epoch {epoch}:")
        print(f"[Day] PSNR: {day_metrics['PSNR']:.2f} | SSIM: {day_metrics['SSIM']:.4f} | LPIPS: {day_metrics['LPIPS']:.4f}")
        print(f"[Night] PSNR: {night_metrics['PSNR']:.2f} | SSIM: {night_metrics['SSIM']:.4f} | LPIPS: {night_metrics['LPIPS']:.4f}")
        
        # Save checkpoint
        if epoch % 10 == 0:
            system.save_checkpoint(f"checkpoint_epoch{epoch}.pth", epoch)

In [None]:
if __name__ == "__main__":
    main()