In [None]:
import shutil
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import os
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import itertools

trainset_dir = "/kaggle/input/syrahq/SyRa-HQ/trainset"
testset_dir = "/kaggle/input/syrahq/SyRa-HQ/testset"

target_train_A = 'derain/trainA'
target_train_B = 'derain/trainB'
target_test_A = 'derain/testA'
target_test_B = 'derain/testB'

max_train = 500
max_test = 50

clear_name = 'Clear.jpg'
rain_name = 'rain_0.jpg'

os.makedirs(target_train_A, exist_ok=True)
os.makedirs(target_train_B, exist_ok=True)
os.makedirs(target_test_A, exist_ok=True)
os.makedirs(target_test_B, exist_ok=True)

for i, folder_name in enumerate(os.listdir(trainset_dir)[:max_train*2]):
    if i < max_train:
        image_path = os.path.join(trainset_dir, folder_name, clear_name)
        shutil.copy(image_path, os.path.join(target_train_A, str(i)+'.jpg'))
    else:
        image_path = os.path.join(trainset_dir, folder_name, rain_name)
        shutil.copy(image_path, os.path.join(target_train_B, str(i)+'.jpg'))

for i, folder_name in enumerate(os.listdir(testset_dir)[:max_test*2]):
    if i < max_test:
        image_path = os.path.join(testset_dir, folder_name, clear_name)
        shutil.copy(image_path, os.path.join(target_test_A, str(i)+'.jpg'))
    else:
        image_path = os.path.join(testset_dir, folder_name, rain_name)
        shutil.copy(image_path, os.path.join(target_test_B, str(i)+'.jpg'))


In [None]:
# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Dataset class
class ImageDataset(Dataset):
    def __init__(self, root_A, root_B, transform=None, mode='train'):
        self.transform = transform
        self.files_A = sorted([os.path.join(root_A, f) for f in os.listdir(root_A)])
        self.files_B = sorted([os.path.join(root_B, f) for f in os.listdir(root_B)])
        
    def __getitem__(self, index):
        img_A = Image.open(self.files_A[index % len(self.files_A)]).convert('RGB')
        img_B = Image.open(self.files_B[index % len(self.files_B)]).convert('RGB')
        
        if self.transform:
            img_A = self.transform(img_A)
            img_B = self.transform(img_B)
            
        return {'A': img_A, 'B': img_B}
    
    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

# Generator Network (U-Net with skip connections)
class Generator(nn.Module):
    def __init__(self, input_nc=3, output_nc=3, ngf=64):
        super(Generator, self).__init__()
        
        # Encoder
        self.encoder1 = nn.Sequential(
            nn.Conv2d(input_nc, ngf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.encoder2 = nn.Sequential(
            nn.Conv2d(ngf, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.encoder3 = nn.Sequential(
            nn.Conv2d(ngf * 2, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.encoder4 = nn.Sequential(
            nn.Conv2d(ngf * 4, ngf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(ngf * 8, ngf * 8, 4, 2, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(ngf * 8, ngf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(inplace=True)
        )
        
        # Decoder
        self.decoder4 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 16, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(inplace=True)
        )
        self.decoder3 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 8, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(inplace=True)
        )
        self.decoder2 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 4, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(inplace=True)
        )
        self.decoder1 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 2, output_nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )
        
    def forward(self, x):
        # Encoder
        e1 = self.encoder1(x)
        e2 = self.encoder2(e1)
        e3 = self.encoder3(e2)
        e4 = self.encoder4(e3)
        
        # Bottleneck
        b = self.bottleneck(e4)
        
        # Decoder with skip connections
        d4 = self.decoder4(torch.cat([b, e4], 1))
        d3 = self.decoder3(torch.cat([d4, e3], 1))
        d2 = self.decoder2(torch.cat([d3, e2], 1))
        d1 = self.decoder1(torch.cat([d2, e1], 1))
        
        return d1

# PatchGAN Discriminator Network
class Discriminator(nn.Module):
    def __init__(self, input_nc=3, ndf=64, norm_type='batchnorm'):
        super(Discriminator, self).__init__()
        
        # First layer - no normalization
        self.down1 = nn.Sequential(
            nn.Conv2d(input_nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # Second layer
        if norm_type.lower() == 'batchnorm':
            self.down2 = nn.Sequential(
                nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 2),
                nn.LeakyReLU(0.2, inplace=True)
            )
        else:  # instancenorm
            self.down2 = nn.Sequential(
                nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
                nn.InstanceNorm2d(ndf * 2),
                nn.LeakyReLU(0.2, inplace=True)
            )
        
        # Third layer
        if norm_type.lower() == 'batchnorm':
            self.down3 = nn.Sequential(
                nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 4),
                nn.LeakyReLU(0.2, inplace=True)
            )
        else:  # instancenorm
            self.down3 = nn.Sequential(
                nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
                nn.InstanceNorm2d(ndf * 4),
                nn.LeakyReLU(0.2, inplace=True)
            )
        
        # Fourth layer with zero padding
        self.zero_pad1 = nn.ZeroPad2d(1)
        self.conv = nn.Conv2d(ndf * 4, ndf * 8, 4, 1, 0, bias=False)
        
        if norm_type.lower() == 'batchnorm':
            self.norm1 = nn.BatchNorm2d(ndf * 8)
        else:  # instancenorm
            self.norm1 = nn.InstanceNorm2d(ndf * 8)
        
        self.leaky_relu = nn.LeakyReLU(0.2, inplace=True)
        
        # Final layer
        self.zero_pad2 = nn.ZeroPad2d(1)
        self.last = nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=True)
        
        # Initialize weights
        self.apply(self._init_weights)
        
    def _init_weights(self, m):
        if isinstance(m, nn.Conv2d):
            nn.init.normal_(m.weight, 0.0, 0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d)):
            nn.init.normal_(m.weight, 1.0, 0.02)
            nn.init.constant_(m.bias, 0)
        
    def forward(self, x):
        # Downsample layers
        x = self.down1(x)        # (bs, 64, 128, 128)
        x = self.down2(x)        # (bs, 128, 64, 64)
        x = self.down3(x)        # (bs, 256, 32, 32)
        
        # Fourth layer with padding
        x = self.zero_pad1(x)    # (bs, 256, 34, 34)
        x = self.conv(x)         # (bs, 512, 31, 31)
        x = self.norm1(x)
        x = self.leaky_relu(x)
        
        # Final layer
        x = self.zero_pad2(x)    # (bs, 512, 33, 33)
        x = self.last(x)         # (bs, 1, 30, 30)
        
        return x  # Return patch predictions (30x30 patches)

# Replay Buffer for storing generated images
class ReplayBuffer:
    def __init__(self, max_size=50):
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if np.random.uniform(0, 1) > 0.5:
                    i = np.random.randint(0, self.max_size - 1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return torch.cat(to_return)

# Training function
def train_cyclegan():
    # Hyperparameters
    IMG_SIZE = 256
    BATCH_SIZE = 1
    LEARNING_RATE = 0.0002
    NUM_EPOCHS = 20
    LAMBDA_CYCLE = 10.0
    LAMBDA_IDENTITY = 0.5
    
    # Data transforms
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    # Create datasets and dataloaders
    train_dataset = ImageDataset('derain/trainA', 'derain/trainB', transform)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    
    test_dataset = ImageDataset('derain/testA', 'derain/testB', transform)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    
    # Initialize networks
    G_AB = Generator().to(device)  # Clear to Rain
    G_BA = Generator().to(device)  # Rain to Clear
    D_A = Discriminator(norm_type='batchnorm').to(device)  # Discriminator for Clear images
    D_B = Discriminator(norm_type='batchnorm').to(device)  # Discriminator for Rain images
    
    # Initialize weights for generators only (discriminators handle their own)
    def weights_init(m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
            nn.init.normal_(m.weight, 0.0, 0.02)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.normal_(m.weight, 1.0, 0.02)
            nn.init.constant_(m.bias, 0)
    
    G_AB.apply(weights_init)
    G_BA.apply(weights_init)
    
    # Loss functions
    criterion_GAN = nn.MSELoss()  # LSGAN loss for PatchGAN
    criterion_cycle = nn.L1Loss()
    criterion_identity = nn.L1Loss()
    
    # Optimizers
    optimizer_G = optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), 
                           lr=LEARNING_RATE, betas=(0.5, 0.999))
    optimizer_D_A = optim.Adam(D_A.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    optimizer_D_B = optim.Adam(D_B.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    
    # Replay buffers
    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()
    
    # Training loop - track all losses
    G_losses = []
    D_A_losses = []
    D_B_losses = []
    
    # Individual loss tracking
    identity_losses = []
    gan_AB_losses = []
    gan_BA_losses = []
    cycle_A_losses = []
    cycle_B_losses = []
    
    print("Starting training...")
    for epoch in range(NUM_EPOCHS):
        epoch_G_loss = 0
        epoch_D_A_loss = 0
        epoch_D_B_loss = 0
        
        # Individual loss accumulators
        epoch_identity_loss = 0
        epoch_gan_AB_loss = 0
        epoch_gan_BA_loss = 0
        epoch_cycle_A_loss = 0
        epoch_cycle_B_loss = 0
        
        for i, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")):
            real_A = batch['A'].to(device)
            real_B = batch['B'].to(device)
            
            batch_size = real_A.size(0)
            # PatchGAN outputs 30x30 patches for 256x256 input
            real_label = torch.ones(batch_size, 1, 30, 30).to(device)
            fake_label = torch.zeros(batch_size, 1, 30, 30).to(device)
            
            # ============ Train Generators ============
            optimizer_G.zero_grad()
            
            # Identity loss
            loss_id_A = criterion_identity(G_BA(real_A), real_A)
            loss_id_B = criterion_identity(G_AB(real_B), real_B)
            loss_identity = (loss_id_A + loss_id_B) / 2
            
            # GAN loss
            fake_B = G_AB(real_A)
            pred_fake_B = D_B(fake_B)
            loss_GAN_AB = criterion_GAN(pred_fake_B, real_label)
            
            fake_A = G_BA(real_B)
            pred_fake_A = D_A(fake_A)
            loss_GAN_BA = criterion_GAN(pred_fake_A, real_label)
            
            # Cycle loss
            recovered_A = G_BA(fake_B)
            loss_cycle_A = criterion_cycle(recovered_A, real_A)
            
            recovered_B = G_AB(fake_A)
            loss_cycle_B = criterion_cycle(recovered_B, real_B)
            
            # Total generator loss
            loss_G = loss_GAN_AB + loss_GAN_BA + LAMBDA_CYCLE * (loss_cycle_A + loss_cycle_B) + LAMBDA_IDENTITY * loss_identity
            loss_G.backward()
            optimizer_G.step()
            
            # ============ Train Discriminator A ============
            optimizer_D_A.zero_grad()
            
            pred_real_A = D_A(real_A)
            loss_D_real_A = criterion_GAN(pred_real_A, real_label)
            
            fake_A_buffered = fake_A_buffer.push_and_pop(fake_A)
            pred_fake_A = D_A(fake_A_buffered.detach())
            loss_D_fake_A = criterion_GAN(pred_fake_A, fake_label)
            
            loss_D_A = (loss_D_real_A + loss_D_fake_A) / 2
            loss_D_A.backward()
            optimizer_D_A.step()
            
            # ============ Train Discriminator B ============
            optimizer_D_B.zero_grad()
            
            pred_real_B = D_B(real_B)
            loss_D_real_B = criterion_GAN(pred_real_B, real_label)
            
            fake_B_buffered = fake_B_buffer.push_and_pop(fake_B)
            pred_fake_B = D_B(fake_B_buffered.detach())
            loss_D_fake_B = criterion_GAN(pred_fake_B, fake_label)
            
            loss_D_B = (loss_D_real_B + loss_D_fake_B) / 2
            loss_D_B.backward()
            optimizer_D_B.step()
            
            # Accumulate losses
            epoch_G_loss += loss_G.item()
            epoch_D_A_loss += loss_D_A.item()
            epoch_D_B_loss += loss_D_B.item()
            
            # Accumulate individual losses
            epoch_identity_loss += (LAMBDA_IDENTITY * loss_identity.item())
            epoch_gan_AB_loss += loss_GAN_AB.item()
            epoch_gan_BA_loss += loss_GAN_BA.item()
            epoch_cycle_A_loss += (LAMBDA_CYCLE * loss_cycle_A.item())
            epoch_cycle_B_loss += (LAMBDA_CYCLE * loss_cycle_B.item())
        
        # Store epoch losses
        G_losses.append(epoch_G_loss / len(train_loader))
        D_A_losses.append(epoch_D_A_loss / len(train_loader))
        D_B_losses.append(epoch_D_B_loss / len(train_loader))
        
        # Store individual losses
        identity_losses.append(epoch_identity_loss / len(train_loader))
        gan_AB_losses.append(epoch_gan_AB_loss / len(train_loader))
        gan_BA_losses.append(epoch_gan_BA_loss / len(train_loader))
        cycle_A_losses.append(epoch_cycle_A_loss / len(train_loader))
        cycle_B_losses.append(epoch_cycle_B_loss / len(train_loader))
        
        # Print losses with more detail
        if (epoch + 1) % 1 == 0:  # Print every 5 epochs
            print(f"Epoch [{epoch+1}/{NUM_EPOCHS}]:")
            print(f"  Total G_loss: {G_losses[-1]:.4f}")
            print(f"    ├─ Identity: {identity_losses[-1]:.4f}")
            print(f"    ├─ GAN_AB: {gan_AB_losses[-1]:.4f}")
            print(f"    ├─ GAN_BA: {gan_BA_losses[-1]:.4f}")
            print(f"    ├─ Cycle_A: {cycle_A_losses[-1]:.4f}")
            print(f"    └─ Cycle_B: {cycle_B_losses[-1]:.4f}")
            print(f"  D_A_loss: {D_A_losses[-1]:.4f}")
            print(f"  D_B_loss: {D_B_losses[-1]:.4f}")
            print("-" * 50)
        
        # Save sample images every 10 epochs
        if (epoch + 1) % 10 == 0:
            save_sample_images(G_AB, G_BA, test_loader, epoch + 1, device)
    
    # Save models
    torch.save(G_AB.state_dict(), 'G_AB.pth')
    torch.save(G_BA.state_dict(), 'G_BA.pth')
    torch.save(D_A.state_dict(), 'D_A.pth')
    torch.save(D_B.state_dict(), 'D_B.pth')
    
    # Plot comprehensive training losses
    plt.figure(figsize=(20, 12))
    
    # Main losses plot
    plt.subplot(2, 3, 1)
    plt.plot(G_losses, label='Generator Loss', linewidth=2)
    plt.plot(D_A_losses, label='Discriminator A Loss', linewidth=2)
    plt.plot(D_B_losses, label='Discriminator B Loss', linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Main Training Losses')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # GAN losses
    plt.subplot(2, 3, 2)
    plt.plot(gan_AB_losses, label='GAN AB (Clear→Rain)', linewidth=2, color='blue')
    plt.plot(gan_BA_losses, label='GAN BA (Rain→Clear)', linewidth=2, color='red')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Adversarial Losses')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Cycle losses
    plt.subplot(2, 3, 3)
    plt.plot(cycle_A_losses, label='Cycle A (Clear→Rain→Clear)', linewidth=2, color='green')
    plt.plot(cycle_B_losses, label='Cycle B (Rain→Clear→Rain)', linewidth=2, color='orange')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Cycle Consistency Losses')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Identity loss
    plt.subplot(2, 3, 4)
    plt.plot(identity_losses, label='Identity Loss', linewidth=2, color='purple')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Identity Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Generator components breakdown
    plt.subplot(2, 3, 5)
    plt.plot(gan_AB_losses, label='GAN AB', alpha=0.7)
    plt.plot(gan_BA_losses, label='GAN BA', alpha=0.7)
    plt.plot(cycle_A_losses, label='Cycle A (×10)', alpha=0.7)
    plt.plot(cycle_B_losses, label='Cycle B (×10)', alpha=0.7)
    plt.plot(identity_losses, label='Identity (×0.5)', alpha=0.7)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Generator Loss Components')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Total loss comparison
    plt.subplot(2, 3, 6)
    total_discriminator = [(d_a + d_b) / 2 for d_a, d_b in zip(D_A_losses, D_B_losses)]
    plt.plot(G_losses, label='Total Generator', linewidth=2, color='blue')
    plt.plot(total_discriminator, label='Average Discriminator', linewidth=2, color='red')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Generator vs Discriminator')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.suptitle('CycleGAN Training Losses - Comprehensive View', fontsize=16)
    plt.tight_layout()
    plt.savefig('comprehensive_training_losses.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Print final loss summary
    print("\n" + "="*60)
    print("FINAL TRAINING SUMMARY")
    print("="*60)
    print(f"Total Generator Loss: {G_losses[-1]:.4f}")
    print(f"├─ Identity Loss: {identity_losses[-1]:.4f} ({identity_losses[-1]/G_losses[-1]*100:.1f}%)")
    print(f"├─ GAN AB Loss: {gan_AB_losses[-1]:.4f} ({gan_AB_losses[-1]/G_losses[-1]*100:.1f}%)")
    print(f"├─ GAN BA Loss: {gan_BA_losses[-1]:.4f} ({gan_BA_losses[-1]/G_losses[-1]*100:.1f}%)")
    print(f"├─ Cycle A Loss: {cycle_A_losses[-1]:.4f} ({cycle_A_losses[-1]/G_losses[-1]*100:.1f}%)")
    print(f"└─ Cycle B Loss: {cycle_B_losses[-1]:.4f} ({cycle_B_losses[-1]/G_losses[-1]*100:.1f}%)")
    print(f"Discriminator A Loss: {D_A_losses[-1]:.4f}")
    print(f"Discriminator B Loss: {D_B_losses[-1]:.4f}")
    print("="*60)
    
    return G_AB, G_BA

# Function to save sample images
def save_sample_images(G_AB, G_BA, test_loader, epoch, device):
    G_AB.eval()
    G_BA.eval()
    
    with torch.no_grad():
        batch = next(iter(test_loader))
        real_A = batch['A'].to(device)
        real_B = batch['B'].to(device)
        
        fake_B = G_AB(real_A)
        fake_A = G_BA(real_B)
        
        # Convert tensors to numpy for plotting
        def tensor_to_image(tensor):
            image = tensor.cpu().clone()
            image = image.squeeze(0)
            image = image * 0.5 + 0.5  # Denormalize
            image = transforms.ToPILImage()(image)
            return image
        
        # Create subplot
        fig, axes = plt.subplots(2, 4, figsize=(16, 8))
        
        images = [real_A[0], fake_B[0], real_B[0], fake_A[0]]
        titles = ['Real Clear', 'Generated Rain', 'Real Rain', 'Generated Clear']
        
        for i in range(4):
            # Top row
            axes[0, i].imshow(tensor_to_image(images[i]))
            axes[0, i].set_title(titles[i])
            axes[0, i].axis('off')
            
            # Bottom row - cycle consistency
            if i < 2:
                recovered = G_BA(fake_B)
                axes[1, i].imshow(tensor_to_image(recovered[0]))
                axes[1, i].set_title('Recovered Clear')
            else:
                recovered = G_AB(fake_A)
                axes[1, i].imshow(tensor_to_image(recovered[0]))
                axes[1, i].set_title('Recovered Rain')
            axes[1, i].axis('off')
        
        plt.suptitle(f'CycleGAN Results - Epoch {epoch}')
        plt.tight_layout()
        plt.savefig(f'sample_epoch_{epoch}.png', dpi=150, bbox_inches='tight')
        plt.show()
    
    G_AB.train()
    G_BA.train()

# Inference function
def inference_and_visualization(G_AB_path='G_AB.pth', G_BA_path='G_BA.pth'):
    """Load trained models and generate sample images"""
    
    # Load models
    G_AB = Generator().to(device)
    G_BA = Generator().to(device)
    
    G_AB.load_state_dict(torch.load(G_AB_path, map_location=device))
    G_BA.load_state_dict(torch.load(G_BA_path, map_location=device))
    
    G_AB.eval()
    G_BA.eval()
    
    # Load test data
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    test_dataset = ImageDataset('derain/testA', 'derain/testB', transform)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)
    
    # Generate and display results
    with torch.no_grad():
        plt.figure(figsize=(20, 12))
        
        for i in range(6):  # Show 6 examples
            batch = next(iter(test_loader))
            real_A = batch['A'].to(device)
            real_B = batch['B'].to(device)
            
            fake_B = G_AB(real_A)  # Clear to Rain
            fake_A = G_BA(real_B)  # Rain to Clear
            
            # Cycle consistency
            recovered_A = G_BA(fake_B)
            recovered_B = G_AB(fake_A)
            
            def tensor_to_numpy(tensor):
                image = tensor.cpu().squeeze(0)
                image = image * 0.5 + 0.5
                return image.permute(1, 2, 0).numpy()
            
            # Plot results
            images = [
                tensor_to_numpy(real_A[0]),
                tensor_to_numpy(fake_B[0]),
                tensor_to_numpy(recovered_A[0]),
                tensor_to_numpy(real_B[0]),
                tensor_to_numpy(fake_A[0]),
                tensor_to_numpy(recovered_B[0])
            ]
            
            titles = ['Real Clear', 'Generated Rain', 'Recovered Clear', 
                     'Real Rain', 'Generated Clear', 'Recovered Rain']
            
            for j in range(6):
                plt.subplot(6, 6, i * 6 + j + 1)
                plt.imshow(np.clip(images[j], 0, 1))
                if i == 0:
                    plt.title(titles[j])
                plt.axis('off')
        
        plt.suptitle('CycleGAN Inference Results: Clear ↔ Rain Translation', fontsize=16)
        plt.tight_layout()
        plt.savefig('inference_results.png', dpi=300, bbox_inches='tight')
        plt.show()

# Main execution
if __name__ == "__main__":
    print("Starting CycleGAN training for Clear/Rain translation...")
    
    # Train the model
    G_AB, G_BA = train_cyclegan()
    
    print("\nTraining completed! Running inference...")
    
    # Run inference
    inference_and_visualization()
    
    print("All done! Check the generated images and training plots.")