In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageFile, ImageDraw, ImageFont
import os
import glob
from tqdm import tqdm
import json
from datetime import datetime
import warnings
import traceback
import logging
from pathlib import Path
import random
import cv2
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import seaborn as sns
from scipy import ndimage
import pandas as pd

try:
    from skimage.metrics import structural_similarity as ssim
except ImportError:
    import subprocess
    subprocess.check_call(["pip", "install", "scikit-image"])
    from skimage.metrics import structural_similarity as ssim

ImageFile.LOAD_TRUNCATED_IMAGES = True
warnings.filterwarnings('ignore')
plt.switch_backend('Agg')

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class EnhancedLoss(nn.Module):
   
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.l1_loss = nn.L1Loss()
        self.mse_loss = nn.MSELoss()
    
    def generator_loss(self, generated, target, disc_output):
        # L1 reconstruction loss
        l1 = self.l1_loss(generated, target)
        
        # Adversarial loss
        adv = self.mse_loss(disc_output, torch.ones_like(disc_output))
        
        # Total loss
        total = self.config.lambda_l1 * l1 + self.config.lambda_adversarial * adv
        
        return {
            'total': total,
            'l1': l1,
            'adversarial': adv
        }
    
    def discriminator_loss(self, real_output, fake_output):
        real_loss = self.mse_loss(real_output, torch.ones_like(real_output))
        fake_loss = self.mse_loss(fake_output, torch.zeros_like(fake_output))
        return (real_loss + fake_loss) * 0.5

class EnhancedTrainer:
    """Enhanced trainer with validation support"""
    
    def __init__(self, config):
        self.config = config
        
        # Set random seeds
        torch.manual_seed(config.random_seed)
        np.random.seed(config.random_seed)
        random.seed(config.random_seed)
        
        # Initialize models
        self.generator = EnhancedGenerator(
            config.input_nc, config.output_nc, config.ngf
        ).to(config.device)
        
        self.discriminator = EnhancedDiscriminator(
            config.input_nc + config.output_nc, config.ndf
        ).to(config.device)
        
        # Initialize optimizers
        self.g_optimizer = torch.optim.Adam(
            self.generator.parameters(),
            lr=config.learning_rate_g,
            betas=(config.beta1, config.beta2)
        )
        
        self.d_optimizer = torch.optim.Adam(
            self.discriminator.parameters(),
            lr=config.learning_rate_d,
            betas=(config.beta1, config.beta2)
        )
        
        # Initialize loss and metrics
        self.criterion = EnhancedLoss(config)
        self.metrics = EnhancedMetrics(config.device)
        
        # Training history
        self.history = {
            'train_g_loss': [],
            'train_d_loss': [],
            'train_psnr': [],
            'train_mae': [],
            'val_g_loss': [],
            'val_d_loss': [],
            'val_psnr': [],
            'val_mae': [],
            'val_ssim': []
        }
        
        self.best_val_psnr = 0
        self.best_epoch = 0
        
        logger.info("Enhanced trainer initialized successfully")
    
    def train_epoch(self, train_loader, epoch):
        """Train for one epoch"""
        self.generator.train()
        self.discriminator.train()
        
        epoch_losses = {'g_loss': 0, 'd_loss': 0, 'psnr': 0, 'mae': 0}
        num_batches = 0
        
        pbar = tqdm(train_loader, desc=f'Training Epoch {epoch}')
        
        for batch_idx, batch in enumerate(pbar):
            if batch is None:
                continue
            
            try:
                input_img = batch['input'].to(self.config.device)
                target_img = batch['target'].to(self.config.device)
                
                # Train Discriminator
                self.d_optimizer.zero_grad()
                
                with torch.no_grad():
                    fake_img = self.generator(input_img)
                
                real_pred = self.discriminator(input_img, target_img)
                fake_pred = self.discriminator(input_img, fake_img.detach())
                
                d_loss = self.criterion.discriminator_loss(real_pred, fake_pred)
                d_loss.backward()
                self.d_optimizer.step()
                
                # Train Generator
                self.g_optimizer.zero_grad()
                
                fake_img = self.generator(input_img)
                fake_pred = self.discriminator(input_img, fake_img)
                
                g_losses = self.criterion.generator_loss(fake_img, target_img, fake_pred)
                g_loss = g_losses['total']
                
                g_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.generator.parameters(), 1.0)
                self.g_optimizer.step()
                
                # Calculate metrics
                with torch.no_grad():
                    psnr = self.metrics.calculate_psnr(fake_img, target_img).item()
                    mae = F.l1_loss(fake_img, target_img).item()
                
                # Update epoch totals
                epoch_losses['g_loss'] += g_loss.item()
                epoch_losses['d_loss'] += d_loss.item()
                epoch_losses['psnr'] += psnr
                epoch_losses['mae'] += mae
                num_batches += 1
                
                # Update progress bar
                pbar.set_postfix({
                    'G_Loss': f'{g_loss.item():.4f}',
                    'D_Loss': f'{d_loss.item():.4f}',
                    'PSNR': f'{psnr:.2f}'
                })
                
            except Exception as e:
                logger.warning(f"Error in batch {batch_idx}: {e}")
                continue
        
        if num_batches == 0:
            raise RuntimeError("No valid batches processed!")
        
        # Calculate averages
        for key in epoch_losses:
            epoch_losses[key] /= num_batches
        
        return epoch_losses
    
    def validate_epoch(self, val_loader, epoch):
        """Validate for one epoch"""
        self.generator.eval()
        self.discriminator.eval()
        
        val_losses = {'g_loss': 0, 'd_loss': 0, 'psnr': 0, 'mae': 0, 'ssim': 0}
        num_batches = 0
        
        with torch.no_grad():
            pbar = tqdm(val_loader, desc=f'Validation Epoch {epoch}')
            
            for batch_idx, batch in enumerate(pbar):
                if batch is None:
                    continue
                
                try:
                    input_img = batch['input'].to(self.config.device)
                    target_img = batch['target'].to(self.config.device)
                    
                    # Generate fake image
                    fake_img = self.generator(input_img)
                    
                    # Calculate discriminator loss
                    real_pred = self.discriminator(input_img, target_img)
                    fake_pred = self.discriminator(input_img, fake_img)
                    d_loss = self.criterion.discriminator_loss(real_pred, fake_pred)
                    
                    # Calculate generator loss
                    g_losses = self.criterion.generator_loss(fake_img, target_img, fake_pred)
                    g_loss = g_losses['total']
                    
                    # Calculate comprehensive metrics
                    metrics = self.metrics.evaluate_comprehensive(fake_img, target_img)
                    
                    # Update totals
                    val_losses['g_loss'] += g_loss.item()
                    val_losses['d_loss'] += d_loss.item()
                    val_losses['psnr'] += metrics['psnr']
                    val_losses['mae'] += metrics['mae']
                    val_losses['ssim'] += metrics['ssim']
                    num_batches += 1
                    
                    # Update progress bar
                    pbar.set_postfix({
                        'Val_PSNR': f'{metrics["psnr"]:.2f}',
                        'Val_SSIM': f'{metrics["ssim"]:.3f}'
                    })
                    
                except Exception as e:
                    logger.warning(f"Error in validation batch {batch_idx}: {e}")
                    continue
        
        if num_batches == 0:
            logger.warning("No valid validation batches processed!")
            return None
        
        # Calculate averages
        for key in val_losses:
            val_losses[key] /= num_batches
        
        return val_losses
    
    def update_history(self, train_results, val_results):
        """Update training history"""
        # Training metrics
        self.history['train_g_loss'].append(train_results['g_loss'])
        self.history['train_d_loss'].append(train_results['d_loss'])
        self.history['train_psnr'].append(train_results['psnr'])
        self.history['train_mae'].append(train_results['mae'])
        
        # Validation metrics
        if val_results:
            self.history['val_g_loss'].append(val_results['g_loss'])
            self.history['val_d_loss'].append(val_results['d_loss'])
            self.history['val_psnr'].append(val_results['psnr'])
            self.history['val_mae'].append(val_results['mae'])
            self.history['val_ssim'].append(val_results['ssim'])
        else:
            # Append None or last value if validation failed
            for key in ['val_g_loss', 'val_d_loss', 'val_psnr', 'val_mae', 'val_ssim']:
                if self.history[key]:
                    self.history[key].append(self.history[key][-1])
                else:
                    self.history[key].append(0)
    
    def save_checkpoint(self, epoch, is_best=False):
        """Save model checkpoint"""
        try:
            checkpoint = {
                'epoch': epoch,
                'generator_state_dict': self.generator.state_dict(),
                'discriminator_state_dict': self.discriminator.state_dict(),
                'g_optimizer_state_dict': self.g_optimizer.state_dict(),
                'd_optimizer_state_dict': self.d_optimizer.state_dict(),
                'history': self.history,
                'best_val_psnr': self.best_val_psnr,
                'best_epoch': self.best_epoch
            }
            
            # Save checkpoint
            checkpoint_path = os.path.join(self.config.experiment_dir, 'checkpoints', f'checkpoint_epoch_{epoch}.pth')
            torch.save(checkpoint, checkpoint_path)
            
            if is_best:
                best_path = os.path.join(self.config.experiment_dir, 'checkpoints', 'best_model.pth')
                torch.save(checkpoint, best_path)
                logger.info(f"💾 Best model saved at epoch {epoch} (PSNR: {self.best_val_psnr:.2f})")
                
        except Exception as e:
            logger.error(f"Error saving checkpoint: {e}")
    
    def generate_samples(self, dataloader, epoch, num_samples=8, split_name='train'):
        """Generate sample images"""
        self.generator.eval()
        
        try:
            with torch.no_grad():
                batch = next(iter(dataloader))
                if batch is None:
                    return
                
                input_img = batch['input'][:num_samples].to(self.config.device)
                target_img = batch['target'][:num_samples].to(self.config.device)
                
                generated_img = self.generator(input_img)
                
                # Save comparison images
                for i in range(min(num_samples, input_img.size(0))):
                    # Convert to [0, 1] range
                    input_np = (input_img[i].cpu() + 1) / 2
                    generated_np = (generated_img[i].cpu() + 1) / 2
                    target_np = (target_img[i].cpu() + 1) / 2
                    
                    # Create side-by-side comparison
                    comparison = torch.cat([input_np, generated_np, target_np], dim=2)
                    
                    # Save image
                    save_path = os.path.join(
                        self.config.experiment_dir, 
                        'generated_samples', 
                        f'{split_name}_epoch_{epoch}_sample_{i}.png'
                    )
                    transforms.ToPILImage()(comparison).save(save_path)
                
                logger.info(f"📸 {split_name.capitalize()} sample images saved for epoch {epoch}")
                
        except Exception as e:
            logger.warning(f"Error generating {split_name} samples: {e}")
        
        self.generator.train()
    
    def plot_training_progress(self):
        """Plot comprehensive training progress"""
        try:
            fig, axes = plt.subplots(2, 3, figsize=(18, 10))
            epochs = range(1, len(self.history['train_g_loss']) + 1)
            
            # Loss curves
            axes[0, 0].plot(epochs, self.history['train_g_loss'], 'b-', label='Train Generator', alpha=0.7)
            axes[0, 0].plot(epochs, self.history['val_g_loss'], 'r-', label='Val Generator', alpha=0.7)
            axes[0, 0].set_title('Generator Loss')
            axes[0, 0].set_xlabel('Epoch')
            axes[0, 0].set_ylabel('Loss')
            axes[0, 0].legend()
            axes[0, 0].grid(True, alpha=0.3)
            
            axes[0, 1].plot(epochs, self.history['train_d_loss'], 'b-', label='Train Discriminator', alpha=0.7)
            axes[0, 1].plot(epochs, self.history['val_d_loss'], 'r-', label='Val Discriminator', alpha=0.7)
            axes[0, 1].set_title('Discriminator Loss')
            axes[0, 1].set_xlabel('Epoch')
            axes[0, 1].set_ylabel('Loss')
            axes[0, 1].legend()
            axes[0, 1].grid(True, alpha=0.3)
            
            # PSNR
            axes[0, 2].plot(epochs, self.history['train_psnr'], 'b-', label='Train PSNR', alpha=0.7)
            axes[0, 2].plot(epochs, self.history['val_psnr'], 'r-', label='Val PSNR', alpha=0.7)
            axes[0, 2].set_title('PSNR Over Time')
            axes[0, 2].set_xlabel('Epoch')
            axes[0, 2].set_ylabel('PSNR (dB)')
            axes[0, 2].legend()
            axes[0, 2].grid(True, alpha=0.3)
            
            # MAE
            axes[1, 0].plot(epochs, self.history['train_mae'], 'b-', label='Train MAE', alpha=0.7)
            axes[1, 0].plot(epochs, self.history['val_mae'], 'r-', label='Val MAE', alpha=0.7)
            axes[1, 0].set_title('MAE Over Time')
            axes[1, 0].set_xlabel('Epoch')
            axes[1, 0].set_ylabel('MAE')
            axes[1, 0].legend()
            axes[1, 0].grid(True, alpha=0.3)
            
            # SSIM
            axes[1, 1].plot(epochs, self.history['val_ssim'], 'g-', label='Val SSIM', alpha=0.7)
            axes[1, 1].set_title('SSIM Over Time')
            axes[1, 1].set_xlabel('Epoch')
            axes[1, 1].set_ylabel('SSIM')
            axes[1, 1].legend()
            axes[1, 1].grid(True, alpha=0.3)
            
            # Summary statistics
            if epochs:
                final_train_psnr = self.history['train_psnr'][-1] if self.history['train_psnr'] else 0
                final_val_psnr = self.history['val_psnr'][-1] if self.history['val_psnr'] else 0
                final_val_ssim = self.history['val_ssim'][-1] if self.history['val_ssim'] else 0
                
                axes[1, 2].text(0.1, 0.8, f'Best Val PSNR: {self.best_val_psnr:.2f} dB (Epoch {self.best_epoch})', 
                               transform=axes[1, 2].transAxes, fontsize=12, weight='bold')
                axes[1, 2].text(0.1, 0.7, f'Final Train PSNR: {final_train_psnr:.2f} dB', 
                               transform=axes[1, 2].transAxes, fontsize=10)
                axes[1, 2].text(0.1, 0.6, f'Final Val PSNR: {final_val_psnr:.2f} dB', 
                               transform=axes[1, 2].transAxes, fontsize=10)
                axes[1, 2].text(0.1, 0.5, f'Final Val SSIM: {final_val_ssim:.3f}', 
                               transform=axes[1, 2].transAxes, fontsize=10)
                axes[1, 2].text(0.1, 0.4, f'Total Epochs: {len(epochs)}', 
                               transform=axes[1, 2].transAxes, fontsize=10)
            
            axes[1, 2].set_title('Training Summary')
            axes[1, 2].axis('off')
            
            plt.tight_layout()
            progress_path = os.path.join(self.config.experiment_dir, 'training_progress', 'training_curves.png')
            plt.savefig(progress_path, dpi=300, bbox_inches='tight')
            plt.close()
            
            logger.info(f" Training progress plot saved")
            
        except Exception as e:
            logger.error(f"Error plotting progress: {e}")

class TestEvaluator:
    """Comprehensive test evaluation with detailed analysis"""
    
    def __init__(self, config, model_path):
        self.config = config
        self.model_path = model_path
        self.device = config.device
        
        # Load trained model
        self.generator = self._load_model()
        
        # Initialize metrics
        self.metrics_calculator = EnhancedMetrics(self.device)
        
        # Results storage
        self.test_results = {
            'metrics': [],
            'filenames': [],
            'individual_results': []
        }
        
        logger.info("Test evaluator initialized")
    
    def _load_model(self):
        """Load the trained generator model"""
        try:
            checkpoint = torch.load(self.model_path, map_location=self.device)
            generator = EnhancedGenerator(
                self.config.input_nc, 
                self.config.output_nc, 
                self.config.ngf
            ).to(self.device)
            generator.load_state_dict(checkpoint['generator_state_dict'])
            generator.eval()
            logger.info(f" Model loaded from {self.model_path}")
            return generator
        except Exception as e:
            logger.error(f" Failed to load model: {e}")
            raise
    
    def evaluate_test_dataset(self, test_loader):
        """Evaluate test dataset comprehensively"""
        
        print("🧪 " + "="*60)
        print("🧪 COMPREHENSIVE TEST EVALUATION")
        print("🧪 " + "="*60)
        
        # Create output directories
        test_output_dir = os.path.join(self.config.experiment_dir, 'test_results')
        self._create_output_directories(test_output_dir)
        
        # Process test images
        self._process_test_dataset(test_loader, test_output_dir)
        
        # Calculate aggregate metrics
        aggregate_metrics = self._calculate_aggregate_metrics()
        
        # Generate visualizations
        self._generate_comprehensive_visualizations(test_output_dir, aggregate_metrics)
        
        # Save detailed results
        self._save_comprehensive_results(test_output_dir, aggregate_metrics)
        
        print(f"\n🎉 Test evaluation completed!")
        print(f"📁 Results saved to: {test_output_dir}")
        
        return aggregate_metrics
    
    def _create_output_directories(self, base_dir):
        """Create output directories for test results"""
        subdirs = [
            'generated_images', 'comparisons', 'error_maps',
            'metrics_analysis', 'quality_assessment', 'detailed_reports'
        ]
        
        for subdir in subdirs:
            os.makedirs(os.path.join(base_dir, subdir), exist_ok=True)
    
    def _process_test_dataset(self, test_loader, output_dir):
        """Process test dataset and generate comprehensive results"""
        
        print(f"\n📊 Processing test dataset...")
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(tqdm(test_loader, desc="Processing test batches")):
                if batch is None:
                    continue
                
                try:
                    input_imgs = batch['input'].to(self.config.device)
                    target_imgs = batch['target'].to(self.config.device)
                    filenames = batch['filename']
                    
                    # Generate outputs
                    generated_imgs = self.generator(input_imgs)
                    
                    # Process each image in the batch
                    for i in range(input_imgs.size(0)):
                        self._process_single_test_image(
                            input_imgs[i:i+1], generated_imgs[i:i+1], target_imgs[i:i+1],
                            filenames[i], output_dir, batch_idx * self.config.batch_size + i
                        )
                        
                except Exception as e:
                    logger.warning(f" Error processing batch {batch_idx}: {e}")
                    continue
    
    def _process_single_test_image(self, input_img, generated_img, target_img, 
                                 filename, output_dir, img_idx):
        """Process a single test image comprehensively"""
        
        # Calculate comprehensive metrics
        metrics = self.metrics_calculator.evaluate_comprehensive(generated_img, target_img)
        
        # Store results
        self.test_results['metrics'].append(metrics)
        self.test_results['filenames'].append(filename)
        self.test_results['individual_results'].append({
            'filename': filename,
            'index': img_idx,
            'metrics': metrics
        })
        
        # Save visual results
        self._save_visual_results(input_img, generated_img, target_img, filename, output_dir, metrics)
        
        # Create error map
        self._create_detailed_error_map(generated_img, target_img, filename, output_dir)
    
    def _save_visual_results(self, input_img, generated_img, target_img, filename, output_dir, metrics):
        """Save visual comparison results"""
        
        # Convert tensors to PIL images
        input_pil = self._tensor_to_pil(input_img.squeeze(0))
        generated_pil = self._tensor_to_pil(generated_img.squeeze(0))
        target_pil = self._tensor_to_pil(target_img.squeeze(0))
        
        # Save individual generated image
        generated_path = os.path.join(output_dir, 'generated_images', f'{filename}_generated.png')
        generated_pil.save(generated_path)
        
        # Create detailed comparison with metrics
        comparison_width = input_pil.width * 3
        comparison_height = input_pil.height + 80  # Extra space for text
        comparison = Image.new('RGB', (comparison_width, comparison_height), (255, 255, 255))
        
        # Paste images
        comparison.paste(input_pil, (0, 80))
        comparison.paste(generated_pil, (input_pil.width, 80))
        comparison.paste(target_pil, (input_pil.width * 2, 80))
        
        # Add labels and metrics
        draw = ImageDraw.Draw(comparison)
        try:
            font = ImageFont.truetype("arial.ttf", 16)
            small_font = ImageFont.truetype("arial.ttf", 12)
        except:
            font = ImageFont.load_default()
            small_font = font
        
        # Labels
        draw.text((10, 10), "Input", fill=(0, 0, 0), font=font)
        draw.text((input_pil.width + 10, 10), "Generated", fill=(0, 0, 0), font=font)
        draw.text((input_pil.width * 2 + 10, 10), "Target", fill=(0, 0, 0), font=font)
        
        # Metrics
        metrics_text = f"PSNR: {metrics['psnr']:.2f}dB | SSIM: {metrics['ssim']:.3f} | MAE: {metrics['mae']:.4f}"
        draw.text((10, 40), metrics_text, fill=(0, 0, 0), font=small_font)
        
        # Quality assessment
        quality = self._assess_quality(metrics['psnr'])
        quality_color = self._get_quality_color(quality)
        draw.text((10, 60), f"Quality: {quality}", fill=quality_color, font=small_font)
        
        # Save comparison
        comparison_path = os.path.join(output_dir, 'comparisons', f'{filename}_detailed_comparison.png')
        comparison.save(comparison_path)
    
    def _tensor_to_pil(self, tensor):
        """Convert tensor to PIL image"""
        tensor = (tensor + 1) / 2  # Denormalize from [-1, 1] to [0, 1]
        tensor = torch.clamp(tensor, 0, 1)
        return transforms.ToPILImage()(tensor)
    
    def _assess_quality(self, psnr):
        """Assess image quality based on PSNR"""
        if psnr >= 35:
            return "Excellent"
        elif psnr >= 30:
            return "Very Good"
        elif psnr >= 25:
            return "Good"
        elif psnr >= 20:
            return "Fair"
        else:
            return "Poor"
    
    def _get_quality_color(self, quality):
        """Get color for quality assessment"""
        colors = {
            "Excellent": (0, 128, 0),
            "Very Good": (50, 205, 50),
            "Good": (255, 165, 0),
            "Fair": (255, 140, 0),
            "Poor": (255, 0, 0)
        }
        return colors.get(quality, (0, 0, 0))
    
    def _create_detailed_error_map(self, generated, target, filename, output_dir):
        """Create detailed error visualization"""
        # Calculate pixel-wise error
        error = torch.abs(generated - target)
        error_mean = torch.mean(error, dim=1, keepdim=True)
        error_np = error_mean.squeeze().detach().cpu().numpy()
        
        # Create comprehensive error visualization
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        
        # Error heatmap
        im1 = axes[0, 0].imshow(error_np, cmap='hot', interpolation='nearest')
        axes[0, 0].set_title('Pixel-wise Error (MAE)')
        axes[0, 0].axis('off')
        plt.colorbar(im1, ax=axes[0, 0])
        
        # Error histogram
        axes[0, 1].hist(error_np.flatten(), bins=50, alpha=0.7, color='red')
        axes[0, 1].set_title('Error Distribution')
        axes[0, 1].set_xlabel('Error Value')
        axes[0, 1].set_ylabel('Frequency')
        axes[0, 1].grid(True, alpha=0.3)
        
        # Thresholded error (high error regions)
        threshold = np.percentile(error_np, 90)
        high_error = error_np > threshold
        axes[1, 0].imshow(high_error, cmap='Reds', interpolation='nearest')
        axes[1, 0].set_title(f'High Error Regions (>{threshold:.3f})')
        axes[1, 0].axis('off')
        
        # Error statistics
        error_stats = {
            'Mean Error': np.mean(error_np),
            'Std Error': np.std(error_np),
            'Max Error': np.max(error_np),
            'Min Error': np.min(error_np),
            '95th Percentile': np.percentile(error_np, 95)
        }
        
        stats_text = '\n'.join([f'{k}: {v:.4f}' for k, v in error_stats.items()])
        axes[1, 1].text(0.1, 0.5, stats_text, transform=axes[1, 1].transAxes, 
                        fontsize=12, verticalalignment='center')
        axes[1, 1].set_title('Error Statistics')
        axes[1, 1].axis('off')
        
        plt.suptitle(f'Error Analysis: {filename}', fontsize=14)
        plt.tight_layout()
        
        error_path = os.path.join(output_dir, 'error_maps', f'{filename}_error_analysis.png')
        plt.savefig(error_path, dpi=200, bbox_inches='tight')
        plt.close()
    
    def _calculate_aggregate_metrics(self):
        """Calculate comprehensive aggregate metrics"""
        if not self.test_results['metrics']:
            return {}
        
        # Collect all metric values
        all_metrics = {}
        for metric_dict in self.test_results['metrics']:
            for key, value in metric_dict.items():
                if key not in all_metrics:
                    all_metrics[key] = []
                all_metrics[key].append(value)
        
        # Calculate comprehensive statistics
        aggregate = {}
        for metric_name, values in all_metrics.items():
            if values:
                aggregate[metric_name] = {
                    'mean': np.mean(values),
                    'std': np.std(values),
                    'min': np.min(values),
                    'max': np.max(values),
                    'median': np.median(values),
                    'q25': np.percentile(values, 25),
                    'q75': np.percentile(values, 75),
                    'values': values
                }
        
        return aggregate
    
    def _generate_comprehensive_visualizations(self, output_dir, aggregate_metrics):       #Generate comprehensive test visualizations

        
        print("\n Generating comprehensive visualizations...")
        
        # 1. Metrics dashboard
        self._create_metrics_dashboard(output_dir, aggregate_metrics)
        
        # 2. Quality distribution analysis
        self._create_quality_distribution(output_dir, aggregate_metrics)
        
        # 3. Performance correlation analysis
        self._create_correlation_analysis(output_dir, aggregate_metrics)
        
        # 4. Best and worst examples
        self._create_examples_showcase(output_dir, aggregate_metrics)
        
        # 5. Detailed statistical analysis
        self._create_statistical_analysis(output_dir, aggregate_metrics)
    
    def _create_metrics_dashboard(self, output_dir, aggregate_metrics):
        """Create comprehensive metrics dashboard"""
        
        fig = plt.figure(figsize=(20, 12))
        gs = fig.add_gridspec(3, 4, hspace=0.3, wspace=0.3)
        
        # Summary statistics table
        ax1 = fig.add_subplot(gs[0, :2])
        self._plot_summary_table(ax1, aggregate_metrics)
        
        # PSNR distribution
        if 'psnr' in aggregate_metrics:
            ax2 = fig.add_subplot(gs[0, 2:])
            psnr_values = aggregate_metrics['psnr']['values']
            ax2.hist(psnr_values, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
            ax2.axvline(aggregate_metrics['psnr']['mean'], color='red', linestyle='--', 
                       label=f"Mean: {aggregate_metrics['psnr']['mean']:.2f}")
            ax2.axvline(aggregate_metrics['psnr']['median'], color='green', linestyle='--', 
                       label=f"Median: {aggregate_metrics['psnr']['median']:.2f}")
            ax2.set_title('PSNR Distribution')
            ax2.set_xlabel('PSNR (dB)')
            ax2.set_ylabel('Frequency')
            ax2.legend()
            ax2.grid(True, alpha=0.3)
        
        # SSIM vs PSNR scatter
        if 'psnr' in aggregate_metrics and 'ssim' in aggregate_metrics:
            ax3 = fig.add_subplot(gs[1, :2])
            psnr_vals = aggregate_metrics['psnr']['values']
            ssim_vals = aggregate_metrics['ssim']['values']
            scatter = ax3.scatter(psnr_vals, ssim_vals, alpha=0.6, c=psnr_vals, cmap='viridis')
            ax3.set_xlabel('PSNR (dB)')
            ax3.set_ylabel('SSIM')
            ax3.set_title('PSNR vs SSIM Correlation')
            ax3.grid(True, alpha=0.3)
            plt.colorbar(scatter, ax=ax3, label='PSNR (dB)')
        
        # Quality pie chart
        if 'psnr' in aggregate_metrics:
            ax4 = fig.add_subplot(gs[1, 2:])
            self._plot_quality_pie_chart(ax4, aggregate_metrics['psnr']['values'])
        
        # Box plots for all metrics
        ax5 = fig.add_subplot(gs[2, :])
        self._plot_metrics_boxplots(ax5, aggregate_metrics)
        
        plt.suptitle('Test Evaluation Dashboard', fontsize=20, y=0.98)
        plt.savefig(os.path.join(output_dir, 'test_evaluation_dashboard.png'), 
                   dpi=300, bbox_inches='tight')
        plt.close()
    
    def _plot_summary_table(self, ax, aggregate_metrics):
        """Plot summary statistics table"""
        table_data = []
        
        for metric_name, metric_data in aggregate_metrics.items():
            if isinstance(metric_data, dict) and 'mean' in metric_data:
                table_data.append([
                    metric_name.upper(),
                    f"{metric_data['mean']:.4f}",
                    f"{metric_data['std']:.4f}",
                    f"{metric_data['median']:.4f}",
                    f"{metric_data['min']:.4f}",
                    f"{metric_data['max']:.4f}"
                ])
        
        if table_data:
            headers = ['Metric', 'Mean', 'Std', 'Median', 'Min', 'Max']
            table = ax.table(cellText=table_data, colLabels=headers, 
                           cellLoc='center', loc='center')
            table.auto_set_font_size(False)
            table.set_fontsize(9)
            table.scale(1.2, 1.5)
            
            # Style the table
            for i in range(len(headers)):
                table[(0, i)].set_facecolor('#4CAF50')
                table[(0, i)].set_text_props(weight='bold', color='white')
        
        ax.set_title('Summary Statistics', fontsize=14, weight='bold')
        ax.axis('off')
    
    def _plot_quality_pie_chart(self, ax, psnr_values):
        """Plot quality distribution pie chart"""
        excellent = sum(1 for x in psnr_values if x >= 35)
        very_good = sum(1 for x in psnr_values if 30 <= x < 35)
        good = sum(1 for x in psnr_values if 25 <= x < 30)
        fair = sum(1 for x in psnr_values if 20 <= x < 25)
        poor = sum(1 for x in psnr_values if x < 20)
        
        sizes = [excellent, very_good, good, fair, poor]
        labels = ['Excellent\n(≥35dB)', 'Very Good\n(30-35dB)', 'Good\n(25-30dB)', 
                 'Fair\n(20-25dB)', 'Poor\n(<20dB)']
        colors = ['darkgreen', 'green', 'orange', 'gold', 'red']
        
        # Only show non-zero categories
        non_zero_sizes = [(size, label, color) for size, label, color in zip(sizes, labels, colors) if size > 0]
        if non_zero_sizes:
            sizes, labels, colors = zip(*non_zero_sizes)
            
            wedges, texts, autotexts = ax.pie(sizes, labels=labels, colors=colors, 
                                            autopct='%1.1f%%', startangle=90)
            
            # Enhance text
            for autotext in autotexts:
                autotext.set_color('white')
                autotext.set_weight('bold')
        
        ax.set_title('Quality Distribution by PSNR', fontsize=12, weight='bold')
    
    def _plot_metrics_boxplots(self, ax, aggregate_metrics):
        """Plot box plots for all metrics"""
        metrics_to_plot = ['psnr', 'ssim', 'mae', 'mse', 'lpips', 'edge_similarity']
        available_metrics = [m for m in metrics_to_plot if m in aggregate_metrics]
        
        if available_metrics:
            data_to_plot = []
            labels = []
            
            for metric in available_metrics:
                data_to_plot.append(aggregate_metrics[metric]['values'])
                labels.append(metric.upper())
            
            bp = ax.boxplot(data_to_plot, labels=labels, patch_artist=True)
            
            # Color the boxes
            colors = plt.cm.Set3(np.linspace(0, 1, len(bp['boxes'])))
            for patch, color in zip(bp['boxes'], colors):
                patch.set_facecolor(color)
                patch.set_alpha(0.7)
        
        ax.set_title('Metrics Distribution (Box Plots)', fontsize=12, weight='bold')
        ax.grid(True, alpha=0.3)
    
    def _create_quality_distribution(self, output_dir, aggregate_metrics):
        """Create detailed quality distribution analysis"""
        
        if 'psnr' not in aggregate_metrics:
            return
        
        psnr_values = aggregate_metrics['psnr']['values']
        
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        # PSNR histogram with quality zones
        axes[0, 0].hist(psnr_values, bins=30, alpha=0.7, color='skyblue', edgecolor='black')
        
        # Add quality zone backgrounds
        axes[0, 0].axvspan(35, max(psnr_values), alpha=0.2, color='green', label='Excellent')
        axes[0, 0].axvspan(30, 35, alpha=0.2, color='lightgreen', label='Very Good')
        axes[0, 0].axvspan(25, 30, alpha=0.2, color='orange', label='Good')
        axes[0, 0].axvspan(20, 25, alpha=0.2, color='gold', label='Fair')
        axes[0, 0].axvspan(min(psnr_values), 20, alpha=0.2, color='red', label='Poor')
        
        axes[0, 0].set_title('PSNR Distribution with Quality Zones')
        axes[0, 0].set_xlabel('PSNR (dB)')
        axes[0, 0].set_ylabel('Frequency')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)
        
        # Cumulative distribution
        sorted_psnr = np.sort(psnr_values)
        y_vals = np.arange(1, len(sorted_psnr) + 1) / len(sorted_psnr)
        axes[0, 1].plot(sorted_psnr, y_vals, linewidth=2)
        axes[0, 1].set_title('Cumulative PSNR Distribution')
        axes[0, 1].set_xlabel('PSNR (dB)')
        axes[0, 1].set_ylabel('Cumulative Probability')
        axes[0, 1].grid(True, alpha=0.3)
        
        # Quality summary bar chart
        quality_counts = {
            'Excellent': sum(1 for x in psnr_values if x >= 35),
            'Very Good': sum(1 for x in psnr_values if 30 <= x < 35),
            'Good': sum(1 for x in psnr_values if 25 <= x < 30),
            'Fair': sum(1 for x in psnr_values if 20 <= x < 25),
            'Poor': sum(1 for x in psnr_values if x < 20)
        }
        
        qualities = list(quality_counts.keys())
        counts = list(quality_counts.values())
        colors = ['darkgreen', 'green', 'orange', 'gold', 'red']
        
        bars = axes[1, 0].bar(qualities, counts, color=colors, alpha=0.7)
        axes[1, 0].set_title('Quality Distribution Count')
        axes[1, 0].set_ylabel('Number of Images')
        
        # Add count labels on bars
        for bar, count in zip(bars, counts):
            if count > 0:
                axes[1, 0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
                               str(count), ha='center', va='bottom', fontweight='bold')
        
        # Performance percentiles
        percentiles = [10, 25, 50, 75, 90, 95, 99]
        psnr_percentiles = [np.percentile(psnr_values, p) for p in percentiles]
        
        axes[1, 1].plot(percentiles, psnr_percentiles, 'o-', linewidth=2, markersize=8)
        axes[1, 1].set_title('PSNR Percentiles')
        axes[1, 1].set_xlabel('Percentile')
        axes[1, 1].set_ylabel('PSNR (dB)')
        axes[1, 1].grid(True, alpha=0.3)
        
        # Add percentile labels
        for p, psnr in zip(percentiles, psnr_percentiles):
            axes[1, 1].annotate(f'{psnr:.1f}', (p, psnr), textcoords="offset points", 
                               xytext=(0,10), ha='center')
        
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, 'quality_assessment', 'quality_distribution_analysis.png'), 
                   dpi=300, bbox_inches='tight')
        plt.close()
    
    def _create_correlation_analysis(self, output_dir, aggregate_metrics):
        """Create correlation analysis between metrics"""
        
        # Prepare data for correlation analysis
        metrics_data = {}
        for metric_name, metric_data in aggregate_metrics.items():
            if isinstance(metric_data, dict) and 'values' in metric_data:
                metrics_data[metric_name] = metric_data['values']
        
        if len(metrics_data) < 2:
            return
        
        # Create correlation matrix
        df = pd.DataFrame(metrics_data)
        correlation_matrix = df.corr()
        
        fig, axes = plt.subplots(1, 2, figsize=(15, 6))
        
        # Correlation heatmap
        sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', center=0,
                   square=True, linewidths=0.5, ax=axes[0])
        axes[0].set_title('Metrics Correlation Matrix', fontsize=14, weight='bold')
        
        # Pairwise scatter plots for key metrics
        if 'psnr' in metrics_data and 'ssim' in metrics_data:
            axes[1].scatter(metrics_data['psnr'], metrics_data['ssim'], alpha=0.6, s=50)
            
            # Add trend line
            z = np.polyfit(metrics_data['psnr'], metrics_data['ssim'], 1)
            p = np.poly1d(z)
            axes[1].plot(metrics_data['psnr'], p(metrics_data['psnr']), "r--", alpha=0.8)
            
            # Calculate correlation coefficient
            corr_coef = np.corrcoef(metrics_data['psnr'], metrics_data['ssim'])[0, 1]
            axes[1].text(0.05, 0.95, f'Correlation: {corr_coef:.3f}', 
                        transform=axes[1].transAxes, fontsize=12, 
                        bbox=dict(boxstyle="round,pad=0.3", facecolor="yellow", alpha=0.7))
            
            axes[1].set_xlabel('PSNR (dB)')
            axes[1].set_ylabel('SSIM')
            axes[1].set_title('PSNR vs SSIM Correlation', fontsize=14, weight='bold')
            axes[1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, 'metrics_analysis', 'correlation_analysis.png'), 
                   dpi=300, bbox_inches='tight')
        plt.close()
    
    def _create_examples_showcase(self, output_dir, aggregate_metrics):
        """Create showcase of best and worst examples"""
        
        if 'psnr' not in aggregate_metrics:
            return
        
        psnr_values = aggregate_metrics['psnr']['values']
        
        # Find best and worst examples
        best_indices = np.argsort(psnr_values)[-5:][::-1]  # Top 5
        worst_indices = np.argsort(psnr_values)[:5]  # Bottom 5
        
        # Create showcase for best examples
        self._create_examples_grid(best_indices, "Best", output_dir, "📈 Top Performing Examples")
        
        # Create showcase for worst examples
        self._create_examples_grid(worst_indices, "Worst", output_dir, "📉 Lowest Performing Examples")
    
    def _create_examples_grid(self, indices, category, output_dir, title):
        """Create a grid showcasing examples"""
        
        fig, axes = plt.subplots(len(indices), 4, figsize=(16, 4*len(indices)))
        if len(indices) == 1:
            axes = axes.reshape(1, -1)
        
        for i, idx in enumerate(indices):
            filename = self.test_results['filenames'][idx]
            metrics = self.test_results['metrics'][idx]
            
            # Load the comparison image if it exists
            comparison_path = os.path.join(output_dir, 'comparisons', f'{filename}_detailed_comparison.png')
            if os.path.exists(comparison_path):
                comparison_img = plt.imread(comparison_path)
                
                # Split the comparison image (assuming 3 panels + metrics area)
                h, w = comparison_img.shape[:2]
                panel_width = w // 3
                
                input_img = comparison_img[80:, :panel_width]  # Skip text area
                generated_img = comparison_img[80:, panel_width:2*panel_width]
                target_img = comparison_img[80:, 2*panel_width:3*panel_width]
                
                # Plot images
                axes[i, 0].imshow(input_img)
                axes[i, 0].set_title('Input', fontsize=12)
                axes[i, 0].axis('off')
                
                axes[i, 1].imshow(generated_img)
                axes[i, 1].set_title('Generated', fontsize=12)
                axes[i, 1].axis('off')
                
                axes[i, 2].imshow(target_img)
                axes[i, 2].set_title('Target', fontsize=12)
                axes[i, 2].axis('off')
                
                # Metrics text
                metrics_text = f"Filename: {filename}\n\n"
                metrics_text += f"PSNR: {metrics['psnr']:.2f} dB\n"
                metrics_text += f"SSIM: {metrics['ssim']:.3f}\n"
                metrics_text += f"MAE: {metrics['mae']:.4f}\n"
                metrics_text += f"MSE: {metrics['mse']:.4f}\n"
                metrics_text += f"LPIPS: {metrics['lpips']:.4f}\n"
                metrics_text += f"Edge Sim: {metrics['edge_similarity']:.3f}"
                
                axes[i, 3].text(0.1, 0.9, metrics_text, transform=axes[i, 3].transAxes, 
                               fontsize=10, verticalalignment='top',
                               bbox=dict(boxstyle="round,pad=0.5", facecolor="lightblue", alpha=0.8))
                axes[i, 3].set_title('Metrics', fontsize=12)
                axes[i, 3].axis('off')
        
        plt.suptitle(title, fontsize=16, weight='bold', y=0.98)
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, 'quality_assessment', f'{category.lower()}_examples_showcase.png'), 
                   dpi=300, bbox_inches='tight')
        plt.close()
    
    def _create_statistical_analysis(self, output_dir, aggregate_metrics):
        """Create detailed statistical analysis"""
        
        # Create statistical summary report
        fig, axes = plt.subplots(2, 2, figsize=(16, 12))
        
        # 1. Metrics comparison violin plot
        if len(aggregate_metrics) > 1:
            ax1 = axes[0, 0]
            metrics_to_plot = ['psnr', 'ssim', 'mae']
            available_metrics = [m for m in metrics_to_plot if m in aggregate_metrics]
            
            if available_metrics:
                data_normalized = []
                labels = []
                
                for metric in available_metrics:
                    values = aggregate_metrics[metric]['values']
                    # Normalize to 0-1 for comparison
                    if metric == 'psnr':
                        normalized = np.array(values) / 40.0  # Assume 40dB is excellent
                    elif metric == 'ssim':
                        normalized = np.array(values)  # Already 0-1
                    elif metric == 'mae':
                        normalized = 1 - np.array(values)  # Invert (lower is better)
                    else:
                        normalized = np.array(values)
                    
                    data_normalized.append(normalized)
                    labels.append(metric.upper())
                
                parts = ax1.violinplot(data_normalized, positions=range(len(labels)), showmeans=True)
                ax1.set_xticks(range(len(labels)))
                ax1.set_xticklabels(labels)
                ax1.set_title('Normalized Metrics Distribution', fontsize=12, weight='bold')
                ax1.set_ylabel('Normalized Score (0-1)')
                ax1.grid(True, alpha=0.3)
        
        # 2. Performance ranking
        ax2 = axes[0, 1]
        if 'psnr' in aggregate_metrics:
            psnr_values = aggregate_metrics['psnr']['values']
            filenames = self.test_results['filenames']
            
            # Get top 10 and bottom 10
            sorted_indices = np.argsort(psnr_values)
            top_10_idx = sorted_indices[-10:][::-1]
            bottom_10_idx = sorted_indices[:10]
            
            # Plot ranking
            y_pos_top = np.arange(len(top_10_idx))
            y_pos_bottom = np.arange(len(bottom_10_idx)) - len(bottom_10_idx) - 1
            
            ax2.barh(y_pos_top, [psnr_values[i] for i in top_10_idx], 
                    color='green', alpha=0.7, label='Top 10')
            ax2.barh(y_pos_bottom, [psnr_values[i] for i in bottom_10_idx], 
                    color='red', alpha=0.7, label='Bottom 10')
            
            # Add labels
            top_labels = [filenames[i][:15] + '...' if len(filenames[i]) > 15 else filenames[i] 
                         for i in top_10_idx]
            bottom_labels = [filenames[i][:15] + '...' if len(filenames[i]) > 15 else filenames[i] 
                           for i in bottom_10_idx]
            
            ax2.set_yticks(list(y_pos_top) + list(y_pos_bottom))
            ax2.set_yticklabels(top_labels + bottom_labels, fontsize=8)
            ax2.set_xlabel('PSNR (dB)')
            ax2.set_title('Performance Ranking (Top/Bottom 10)', fontsize=12, weight='bold')
            ax2.legend()
            ax2.grid(True, alpha=0.3)
        
        # 3. Outlier analysis
        ax3 = axes[1, 0]
        if 'psnr' in aggregate_metrics:
            psnr_values = aggregate_metrics['psnr']['values']
            
            # Calculate IQR and outliers
            q1 = np.percentile(psnr_values, 25)
            q3 = np.percentile(psnr_values, 75)
            iqr = q3 - q1
            lower_bound = q1 - 1.5 * iqr
            upper_bound = q3 + 1.5 * iqr
            
            outliers_low = [x for x in psnr_values if x < lower_bound]
            outliers_high = [x for x in psnr_values if x > upper_bound]
            
            # Plot box plot with outliers highlighted
            bp = ax3.boxplot(psnr_values, patch_artist=True)
            bp['boxes'][0].set_facecolor('lightblue')
            
            # Highlight outliers
            if outliers_low:
                ax3.scatter([1] * len(outliers_low), outliers_low, color='red', s=50, alpha=0.7, label='Low Outliers')
            if outliers_high:
                ax3.scatter([1] * len(outliers_high), outliers_high, color='green', s=50, alpha=0.7, label='High Outliers')
            
            ax3.set_title('Outlier Analysis (PSNR)', fontsize=12, weight='bold')
            ax3.set_ylabel('PSNR (dB)')
            if outliers_low or outliers_high:
                ax3.legend()
            ax3.grid(True, alpha=0.3)
        
        # 4. Summary statistics table
        ax4 = axes[1, 1]
        
        # Prepare comprehensive statistics
        stats_data = []
        for metric_name, metric_data in aggregate_metrics.items():
            if isinstance(metric_data, dict) and 'mean' in metric_data:
                stats_data.append([
                    metric_name.upper(),
                    f"{metric_data['mean']:.4f}",
                    f"{metric_data['std']:.4f}",
                    f"{metric_data['q25']:.4f}",
                    f"{metric_data['median']:.4f}",
                    f"{metric_data['q75']:.4f}"
                ])
        
        if stats_data:
            headers = ['Metric', 'Mean', 'Std', 'Q25', 'Median', 'Q75']
            table = ax4.table(cellText=stats_data, colLabels=headers, 
                             cellLoc='center', loc='center')
            table.auto_set_font_size(False)
            table.set_fontsize(10)
            table.scale(1.2, 1.8)
            
            # Style the table
            for i in range(len(headers)):
                table[(0, i)].set_facecolor('#2E8B57')
                table[(0, i)].set_text_props(weight='bold', color='white')
        
        ax4.set_title('Detailed Statistics', fontsize=12, weight='bold')
        ax4.axis('off')
        
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, 'metrics_analysis', 'statistical_analysis.png'), 
                   dpi=300, bbox_inches='tight')
        plt.close()
    
    def _save_comprehensive_results(self, output_dir, aggregate_metrics):
        """Save comprehensive test results"""
        
        # Save detailed JSON results
        results_summary = {
            'evaluation_summary': {
                'total_images': len(self.test_results['metrics']),
                'model_path': self.model_path,
                'evaluation_date': datetime.now().isoformat(),
                'configuration': {
                    'image_size': self.config.image_size,
                    'device': str(self.config.device),
                    'data_root': self.config.data_root
                }
            },
            'aggregate_metrics': aggregate_metrics,
            'individual_results': self.test_results['individual_results']
        }
        
        # Save to JSON
        results_path = os.path.join(output_dir, 'detailed_reports', 'comprehensive_test_results.json')
        with open(results_path, 'w') as f:
            json.dump(results_summary, f, indent=2, default=str)
        
        # Save to CSV for analysis
        if self.test_results['individual_results']:
            df_data = []
            for result in self.test_results['individual_results']:
                row = {
                    'filename': result['filename'],
                    'index': result['index']
                }
                row.update(result['metrics'])
                df_data.append(row)
            
            df = pd.DataFrame(df_data)
            csv_path = os.path.join(output_dir, 'detailed_reports', 'test_metrics_detailed.csv')
            df.to_csv(csv_path, index=False)
        
        # Create comprehensive markdown report
        self._create_comprehensive_report(output_dir, aggregate_metrics, results_summary)
        
        logger.info(f"📊 Comprehensive test results saved to {output_dir}")
    
    def _create_comprehensive_report(self, output_dir, aggregate_metrics, results_summary):
        """Create comprehensive markdown report"""
        
        report_path = os.path.join(output_dir, 'COMPREHENSIVE_TEST_REPORT.md')
        
        with open(report_path, 'w') as f:
            f.write("# 🧪 Comprehensive Pix2Pix Test Evaluation Report\n\n")
            f.write(f"**Generated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
            f.write(f"**Model:** `{self.model_path}`\n\n")
            f.write(f"**Dataset:** {self.config.data_root}\n\n")
            f.write(f"**Total Test Images:** {results_summary['evaluation_summary']['total_images']}\n\n")
            
            f.write("---\n\n")
            f.write("##  Executive Summary\n\n")
            
            # Key findings
            if 'psnr' in aggregate_metrics:
                psnr_data = aggregate_metrics['psnr']
                f.write(f"- **Average PSNR:** {psnr_data['mean']:.2f} dB (± {psnr_data['std']:.2f})\n")
                f.write(f"- **PSNR Range:** {psnr_data['min']:.2f} - {psnr_data['max']:.2f} dB\n")
            
            if 'ssim' in aggregate_metrics:
                ssim_data = aggregate_metrics['ssim']
                f.write(f"- **Average SSIM:** {ssim_data['mean']:.3f} (± {ssim_data['std']:.3f})\n")
            
            # Quality assessment
            if 'psnr' in aggregate_metrics:
                psnr_values = aggregate_metrics['psnr']['values']
                excellent = sum(1 for x in psnr_values if x >= 35)
                very_good = sum(1 for x in psnr_values if 30 <= x < 35)
                good = sum(1 for x in psnr_values if 25 <= x < 30)
                fair = sum(1 for x in psnr_values if 20 <= x < 25)
                poor = sum(1 for x in psnr_values if x < 20)
                total = len(psnr_values)
                
                f.write(f"\n### 🎯 Quality Distribution\n\n")
                f.write(f"| Quality Level | Count | Percentage |\n")
                f.write(f"|---------------|-------|------------|\n")
                f.write(f"| Excellent (≥35dB) | {excellent} | {excellent/total*100:.1f}% |\n")
                f.write(f"| Very Good (30-35dB) | {very_good} | {very_good/total*100:.1f}% |\n")
                f.write(f"| Good (25-30dB) | {good} | {good/total*100:.1f}% |\n")
                f.write(f"| Fair (20-25dB) | {fair} | {fair/total*100:.1f}% |\n")
                f.write(f"| Poor (<20dB) | {poor} | {poor/total*100:.1f}% |\n")
            
            f.write(f"\n---\n\n")
            f.write("## 📈 Detailed Metrics Analysis\n\n")
            
            # Comprehensive metrics table
            f.write("### Summary Statistics\n\n")
            f.write("| Metric | Mean | Std | Min | Max | Median | Q25 | Q75 |\n")
            f.write("|--------|------|-----|-----|-----|--------|-----|-----|\n")
            
            for metric_name, metric_data in aggregate_metrics.items():
                if isinstance(metric_data, dict) and 'mean' in metric_data:
                    f.write(f"| {metric_name.upper()} | {metric_data['mean']:.4f} | "
                           f"{metric_data['std']:.4f} | {metric_data['min']:.4f} | "
                           f"{metric_data['max']:.4f} | {metric_data['median']:.4f} | "
                           f"{metric_data['q25']:.4f} | {metric_data['q75']:.4f} |\n")
            
            f.write(f"\n### 🔍 Key Insights\n\n")
            
            # Generate insights based on metrics
            if 'psnr' in aggregate_metrics and 'ssim' in aggregate_metrics:
                psnr_mean = aggregate_metrics['psnr']['mean']
                ssim_mean = aggregate_metrics['ssim']['mean']
                
                if psnr_mean >= 30:
                    f.write("✅ **Excellent overall performance** - High PSNR indicates very good reconstruction quality.\n\n")
                elif psnr_mean >= 25:
                    f.write("✅ **Good performance** - PSNR values show satisfactory reconstruction quality.\n\n")
                else:
                    f.write("⚠️ **Room for improvement** - PSNR suggests reconstruction quality could be enhanced.\n\n")
                
                if ssim_mean >= 0.8:
                    f.write("✅ **Strong structural similarity** - Generated images preserve structural information well.\n\n")
                elif ssim_mean >= 0.6:
                    f.write("✅ **Moderate structural preservation** - Generated images maintain reasonable structural fidelity.\n\n")
                else:
                    f.write("⚠️ **Structural improvements needed** - Consider enhancing structural preservation in the model.\n\n")
            
            f.write("---\n\n")
            f.write("## 🏆 Performance Examples\n\n")
            
            # Best examples
            if 'psnr' in aggregate_metrics:
                psnr_values = aggregate_metrics['psnr']['values']
                best_indices = np.argsort(psnr_values)[-5:][::-1]
                worst_indices = np.argsort(psnr_values)[:5]
                
                f.write("### 🥇 Top Performing Images\n\n")
                for i, idx in enumerate(best_indices, 1):
                    filename = self.test_results['filenames'][idx]
                    psnr_val = psnr_values[idx]
                    ssim_val = aggregate_metrics['ssim']['values'][idx] if 'ssim' in aggregate_metrics else "N/A"
                    f.write(f"{i}. **{filename}** - PSNR: {psnr_val:.2f}dB, SSIM: {ssim_val:.3f if ssim_val != 'N/A' else 'N/A'}\n")
                
                f.write("\n### 📉 Areas for Improvement\n\n")
                for i, idx in enumerate(worst_indices, 1):
                    filename = self.test_results['filenames'][idx]
                    psnr_val = psnr_values[idx]
                    ssim_val = aggregate_metrics['ssim']['values'][idx] if 'ssim' in aggregate_metrics else "N/A"
                    f.write(f"{i}. **{filename}** - PSNR: {psnr_val:.2f}dB, SSIM: {ssim_val:.3f if ssim_val != 'N/A' else 'N/A'}\n")
            
            f.write(f"\n---\n\n")
            f.write("## 📁 Generated Files\n\n")
            f.write("This evaluation generated the following analysis files:\n\n")
            f.write("### 🖼️ Visual Results\n")
            f.write("- `generated_images/` - Individual generated images\n")
            f.write("- `comparisons/` - Side-by-side input/generated/target comparisons\n")
            f.write("- `error_maps/` - Pixel-wise error analysis visualizations\n\n")
            
            f.write("### 📊 Analysis Reports\n")
            f.write("- `test_evaluation_dashboard.png` - Comprehensive metrics dashboard\n")
            f.write("- `quality_assessment/quality_distribution_analysis.png` - Quality distribution analysis\n")
            f.write("- `metrics_analysis/correlation_analysis.png` - Metrics correlation analysis\n")
            f.write("- `metrics_analysis/statistical_analysis.png` - Detailed statistical analysis\n")
            f.write("- `quality_assessment/best_examples_showcase.png` - Top performing examples\n")
            f.write("- `quality_assessment/worst_examples_showcase.png` - Examples needing improvement\n\n")
            
            f.write("### 📋 Data Files\n")
            f.write("- `detailed_reports/comprehensive_test_results.json` - Complete results in JSON format\n")
            f.write("- `detailed_reports/test_metrics_detailed.csv` - Metrics data for further analysis\n")
            f.write("- `COMPREHENSIVE_TEST_REPORT.md` - This detailed report\n\n")
            
            f.write("---\n\n")
            f.write("## 🎯 Recommendations\n\n")
            
            # Generate recommendations based on results
            if 'psnr' in aggregate_metrics:
                psnr_std = aggregate_metrics['psnr']['std']
                psnr_mean = aggregate_metrics['psnr']['mean']
                
                if psnr_std > 5:
                    f.write("1. **Consistency Improvement**: High PSNR variance suggests inconsistent performance across images. Consider data augmentation or model regularization.\n\n")
                
                if psnr_mean < 25:
                    f.write("2. **Quality Enhancement**: Consider increasing model capacity, adjusting loss functions, or extending training duration.\n\n")
                
                if 'mae' in aggregate_metrics and aggregate_metrics['mae']['mean'] > 0.1:
                    f.write("3. **Reconstruction Accuracy**: High MAE suggests room for improvement in pixel-level accuracy. Consider adjusting the L1 loss weight.\n\n")
            
            f.write("4. **Focus Areas**: Pay special attention to images in the 'Areas for Improvement' section for targeted enhancements.\n\n")
            f.write("5. **Validation**: Use the top-performing examples as benchmarks for future model iterations.\n\n")
            
            f.write("---\n\n")
            f.write("*Report generated by Enhanced Pix2Pix Evaluation System*\n")

def train_pix2pix_complete_pipeline():  #MAIN TRAINING FUNCTION WITH COMPREHENSIVE PIPELINE
    """Complete training pipeline with train/validation/test splits"""
    
    print("🚀 " + "="*70)
    print("🚀 ENHANCED PIX2PIX - COMPLETE PIPELINE WITH TRAIN/VAL/TEST")
    print("🚀 " + "="*70)
    
    try:
        # Initialize configuration
        config = Config()
        
        # Data splitting and preparation
        logger.info("📊 Preparing data splits...")
        data_splitter = DataSplitter(config)
        train_pairs, val_pairs, test_pairs = data_splitter.find_and_split_data()
        
        # Create datasets
        train_dataset = Pix2PixDataset(train_pairs, config, 'train')
        val_dataset = Pix2PixDataset(val_pairs, config, 'val')
        test_dataset = Pix2PixDataset(test_pairs, config, 'test')
        
        # Create dataloaders
        train_loader = DataLoader(
            train_dataset, batch_size=config.batch_size, shuffle=True,
            num_workers=config.num_workers, drop_last=True
        )
        
        val_loader = DataLoader(
            val_dataset, batch_size=config.batch_size, shuffle=False,
            num_workers=config.num_workers, drop_last=False
        )
        
        test_loader = DataLoader(
            test_dataset, batch_size=config.batch_size, shuffle=False,
            num_workers=config.num_workers, drop_last=False
        )
        
        logger.info(f"📊 Dataloaders created:")
        logger.info(f"   📚 Training: {len(train_loader)} batches")
        logger.info(f"   🔍 Validation: {len(val_loader)} batches")
        logger.info(f"   🧪 Testing: {len(test_loader)} batches")
        
        # Initialize trainer
        logger.info(" Initializing enhanced trainer...")
        trainer = EnhancedTrainer(config)
        
        # Training loop with validation
        logger.info("🏋️ Starting training with validation...")
        
        for epoch in range(1, config.num_epochs + 1):
            try:
                print(f"\n📅 Epoch {epoch}/{config.num_epochs}")
                
                # Training phase
                train_results = trainer.train_epoch(train_loader, epoch)
                
                # Validation phase
                val_results = trainer.validate_epoch(val_loader, epoch)
                
                # Update history
                trainer.update_history(train_results, val_results)
                
                # Log results
                logger.info(f"📊 Epoch {epoch} Results:")
                logger.info(f"   🏋️ Train - G_Loss: {train_results['g_loss']:.4f}, D_Loss: {train_results['d_loss']:.4f}, PSNR: {train_results['psnr']:.2f}dB")
                if val_results:
                    logger.info(f"   🔍 Val   - G_Loss: {val_results['g_loss']:.4f}, D_Loss: {val_results['d_loss']:.4f}, PSNR: {val_results['psnr']:.2f}dB, SSIM: {val_results['ssim']:.3f}")
                
                # Check for best model
                is_best = False
                if val_results and val_results['psnr'] > trainer.best_val_psnr:
                    trainer.best_val_psnr = val_results['psnr']
                    trainer.best_epoch = epoch
                    is_best = True
                
                # Save samples and checkpoints
                if epoch % config.save_frequency == 0:
                    trainer.generate_samples(train_loader, epoch, split_name='train')
                    if val_loader:
                        trainer.generate_samples(val_loader, epoch, split_name='val')
                    trainer.plot_training_progress()
                
                if epoch % config.save_frequency == 0 or is_best:
                    trainer.save_checkpoint(epoch, is_best)
                
            except Exception as e:
                logger.error(f" Error in epoch {epoch}: {e}")
                if epoch > 10:
                    continue
                else:
                    raise
        
        # Training completed
        print("\n🎉 " + "="*70)
        print("🎉 TRAINING COMPLETED SUCCESSFULLY!")
        print("🎉 " + "="*70)
        
        # Final training progress
        trainer.plot_training_progress()
        
        # Comprehensive test evaluation
        print("\n🧪 " + "="*70)
        print("🧪 STARTING COMPREHENSIVE TEST EVALUATION")
        print("🧪 " + "="*70)
        
        # Load best model for testing
        best_model_path = os.path.join(config.experiment_dir, 'checkpoints', 'best_model.pth')
        if os.path.exists(best_model_path):
            test_evaluator = TestEvaluator(config, best_model_path)
            test_results = test_evaluator.evaluate_test_dataset(test_loader)
            
            print("\n📊 TEST EVALUATION SUMMARY:")
            if test_results:
                for metric, data in test_results.items():
                    if isinstance(data, dict) and 'mean' in data:
                        print(f"   {metric.upper()}: {data['mean']:.4f} ± {data['std']:.4f}")
        else:
            logger.warning("⚠️ Best model not found, skipping test evaluation")
        
        # Save final comprehensive summary
        final_summary = {
            'training_summary': {
                'total_epochs': config.num_epochs,
                'best_val_psnr': trainer.best_val_psnr,
                'best_epoch': trainer.best_epoch,
                'final_train_psnr': trainer.history['train_psnr'][-1] if trainer.history['train_psnr'] else 0,
                'final_val_psnr': trainer.history['val_psnr'][-1] if trainer.history['val_psnr'] else 0
            },
            'data_split_summary': {
                'total_images': len(train_pairs) + len(val_pairs) + len(test_pairs),
                'train_images': len(train_pairs),
                'val_images': len(val_pairs),
                'test_images': len(test_pairs)
            },
            'test_results': test_results if 'test_results' in locals() else None,
            'configuration': {
                'data_root': config.data_root,
                'image_size': config.image_size,
                'batch_size': config.batch_size,
                'learning_rate': config.learning_rate_g,
                'train_ratio': config.train_ratio,
                'val_ratio': config.val_ratio,
                'test_ratio': config.test_ratio
            }
        }
        
        summary_path = os.path.join(config.experiment_dir, 'final_results', 'complete_pipeline_summary.json')
        with open(summary_path, 'w') as f:
            json.dump(final_summary, f, indent=2, default=str)
        
        print(f"\n FINAL PIPELINE RESULTS:")
        print(f"    Best Validation PSNR: {trainer.best_val_psnr:.2f} dB (Epoch {trainer.best_epoch})")
        print(f"    Training Images: {len(train_pairs)}")
        print(f"    Validation Images: {len(val_pairs)}")
        print(f"    Test Images: {len(test_pairs)}")
        print(f"    Results saved to: {config.experiment_dir}")
        
        print(f"\n COMPLETE PIPELINE OUTPUTS:")
        print(f"    Training samples: {config.experiment_dir}/generated_samples/")
        print(f"    Model checkpoints: {config.experiment_dir}/checkpoints/")
        print(f"    Training progress: {config.experiment_dir}/training_progress/")
        print(f"    Data splits info: {config.experiment_dir}/data_splits/")
        print(f"    Validation results: {config.experiment_dir}/validation_results/")
        print(f"    Test evaluation: {config.experiment_dir}/test_results/")
        print(f"    Final summary: {config.experiment_dir}/final_results/")
        
        return config.experiment_dir, final_summary
        
    except Exception as e:
        logger.error(f"❌ Pipeline failed: {e}")
        logger.error(f"💡 Traceback: {traceback.format_exc()}")
        raise

def main_complete(): #MAIN EXECUTION WITH ENHANCED FEATURES
    """Main execution function with complete pipeline"""
    
    print("🎮 " + "="*80)
    print("🎮 ENHANCED PIX2PIX - COMPLETE TRAIN/VAL/TEST PIPELINE")
    print("🎮 " + "="*80)
    
    print(f"""
📋 ENHANCED FEATURES:

🔄 **Automatic Data Splitting:**
   - Intelligent train/validation/test splits ({TRAIN_RATIO*100:.0f}%/{VAL_RATIO*100:.0f}%/{TEST_RATIO*100:.0f}%)
   - Reproducible splits with fixed random seed
   - Comprehensive data validation

📊 **Advanced Training:**
   - Training with validation monitoring
   - Early stopping based on validation PSNR
   - Comprehensive loss tracking (Generator + Discriminator)
   - Advanced data augmentation for training set

📈 **Real-time Monitoring:**
   - Training vs Validation curves
   - Multiple metrics tracking (PSNR, SSIM, MAE, MSE)
   - Best model checkpointing
   - Progress visualization

🧪 **Comprehensive Test Evaluation:**
   - Detailed metrics analysis (6+ metrics)
   - Quality distribution assessment
   - Error map visualizations
   - Statistical analysis and outlier detection
   - Performance correlation analysis
   - Best/worst examples showcase

📁 **Professional Reporting:**
   - Comprehensive markdown reports
   - JSON/CSV data exports
   - Interactive visualizations
   - Executive summaries with insights

🎯 **Data Organization:**
   Data structure: {DATA_ROOT}/
   ├── input/  (or existing folder names)
   ├── target/ (or existing folder names)
   
   Results: {OUTPUT_FOLDER}/experiment_[timestamp]/
   ├── checkpoints/
   ├── data_splits/
   ├── training_progress/
   ├── validation_results/
   ├── test_results/
   └── final_results/
""")
    
    try:
        # Validate data paths
        if not os.path.exists(DATA_ROOT):
            print(f"⚠️  DATA ROOT NOT FOUND: {DATA_ROOT}")
            print("\n🔧 SETUP INSTRUCTIONS:")
            print(f"1. Create data root: {DATA_ROOT}")
            print(f"2. Add subfolders: input/ and target/ (or use existing folder structure)")
            print("3. Place matching image pairs in both folders")
            print("4. Run this script again")
            print("\n💡 The script will auto-detect existing folder structures")
            return
        
        # Start complete pipeline
        print("\n🚀 Starting complete pipeline...")
        experiment_dir, summary = train_pix2pix_complete_pipeline()
        
        print("\n🎉 " + "="*80)
        print("🎉 COMPLETE PIPELINE FINISHED SUCCESSFULLY!")
        print("🎉 " + "="*80)
        
        # Final summary
        if summary and 'test_results' in summary and summary['test_results']:
            print("\n📊 FINAL PERFORMANCE SUMMARY:")
            test_results = summary['test_results']
            for metric, data in test_results.items():
                if isinstance(data, dict) and 'mean' in data:
                    print(f"   {metric.upper()}: {data['mean']:.4f}")
        
        print(f"\n📁 Check {experiment_dir} for complete results!")
        
    except KeyboardInterrupt:
        print("\n⏹️  Pipeline stopped by user")
    except Exception as e:
        print(f"\n❌ Error: {e}")
        print("\n🔧 TROUBLESHOOTING:")
        print("1. Check DATA_ROOT path exists and contains image folders")
        print("2. Ensure matching image names in input and target folders")
        print("3. Try reducing BATCH_SIZE if memory issues")
        print("4. Set DEVICE = 'cpu' for GPU issues")
        print("5. Check file permissions for OUTPUT_FOLDER")

def create_demo_data(): #UTILITY FUNCTIONS
    """Create demonstration data with proper folder structure"""
    
    print("🎨 Creating demonstration data...")
    
    demo_root = "demo_data"
    input_dir = os.path.join(demo_root, "input")
    target_dir = os.path.join(demo_root, "target")
    
    os.makedirs(input_dir, exist_ok=True)
    os.makedirs(target_dir, exist_ok=True)
    
    # Create diverse sample data
    for i in range(100):
        # Create input (sketch-like with various styles)
        input_array = np.ones((256, 256, 3), dtype=np.uint8) * 255
        
        # Different sketch styles
        if i % 3 == 0:  # Line drawings
            for _ in range(np.random.randint(10, 20)):
                y = np.random.randint(20, 236)
                x_start = np.random.randint(20, 200)
                x_end = x_start + np.random.randint(20, 80)
                thickness = np.random.randint(1, 3)
                input_array[y-thickness:y+thickness, x_start:x_end] = 0
        
        elif i % 3 == 1:  # Geometric shapes
            for _ in range(np.random.randint(3, 8)):
                center_x = np.random.randint(50, 206)
                center_y = np.random.randint(50, 206)
                radius = np.random.randint(10, 30)
                
                y, x = np.ogrid[:256, :256]
                mask = (x - center_x)**2 + (y - center_y)**2 <= radius**2
                input_array[mask] = 0
        
        else:  # Mixed patterns
            for _ in range(np.random.randint(15, 25)):
                y = np.random.randint(10, 246)
                x = np.random.randint(10, 246)
                size = np.random.randint(2, 8)
                input_array[y:y+size, x:x+size] = 0
        
        # Create corresponding target (colored version)
        target_array = np.random.randint(80, 220, (256, 256, 3), dtype=np.uint8)
        
        # Add some coherent color regions
        for _ in range(np.random.randint(2, 5)):
            color = np.random.randint(50, 255, 3)
            y_start = np.random.randint(0, 128)
            x_start = np.random.randint(0, 128)
            y_end = y_start + np.random.randint(64, 128)
            x_end = x_start + np.random.randint(64, 128)
            target_array[y_start:y_end, x_start:x_end] = color
        
        # Preserve sketch lines in target
        mask = (input_array < 128).any(axis=2)
        target_array[mask] = input_array[mask]
        
        # Save images
        input_img = Image.fromarray(input_array)
        target_img = Image.fromarray(target_array)
        
        filename = f"demo_{i:03d}.png"
        input_img.save(os.path.join(input_dir, filename))
        target_img.save(os.path.join(target_dir, filename))
    
    print(f"✅ Created 100 demo image pairs in {demo_root}/")
    print(f"📁 Structure:")
    print(f"   {demo_root}/")
    print(f"   ├── input/")
    print(f"   └── target/")
    
    return demo_root

# Entry point
if __name__ == "__main__":
    # Uncomment to create demo data
    # demo_path = create_demo_data()
    # print(f"Demo data created at: {demo_path}")
    # print("Update DATA_ROOT to point to this demo data and run again!")
    
    # Run complete pipeline
    main_complete()


DATA_ROOT = "G:/Jafar/Luna16/"                    # 🔴 CHANGE THIS: Root folder with input/ and target/ subfolders
OUTPUT_FOLDER = "G:/Jafar/Luna16/Generated/"      # 🔴 CHANGE THIS: Where to save results

IMAGE_SIZE = 256                     # Image resolution (256 or 512)
BATCH_SIZE = 16                      # Start with 4, increase if you have good GPU
NUM_EPOCHS = 300                     # Number of training epochs
LEARNING_RATE = 0.0002               # Learning rate
DEVICE = 'auto'                      # 'auto', 'cuda', or 'cpu'

TRAIN_RATIO = 0.7                    # 70% for training
VAL_RATIO = 0.15                     # 15% for validation
TEST_RATIO = 0.15                    # 15% for testing

class Config:  #ENHANCED CONFIGURATION CLASS
    """Configuration class for all settings"""
    
    def __init__(self):
        # Data paths
        self.data_root = DATA_ROOT
        self.input_folder = os.path.join(DATA_ROOT, "input") if os.path.exists(os.path.join(DATA_ROOT, "input")) else os.path.join(DATA_ROOT, "16")
        self.target_folder = os.path.join(DATA_ROOT, "target") if os.path.exists(os.path.join(DATA_ROOT, "target")) else os.path.join(DATA_ROOT, "Original")
        self.output_folder = OUTPUT_FOLDER
        
        # Data split ratios
        self.train_ratio = TRAIN_RATIO
        self.val_ratio = VAL_RATIO
        self.test_ratio = TEST_RATIO
        
        # Training settings
        self.image_size = IMAGE_SIZE
        self.batch_size = BATCH_SIZE
        self.num_epochs = NUM_EPOCHS
        self.learning_rate_g = LEARNING_RATE
        self.learning_rate_d = LEARNING_RATE
        self.device = self._setup_device(DEVICE)
        
        # Model settings
        self.input_nc = 3
        self.output_nc = 3
        self.ngf = 64
        self.ndf = 64
        
        # Training parameters
        self.beta1 = 0.5
        self.beta2 = 0.999
        self.lambda_l1 = 100.0
        self.lambda_adversarial = 1.0
        
        # System settings
        self.num_workers = 0
        self.save_frequency = 10
        self.random_seed = 42
        
        # Create output directory
        self.experiment_dir = self._setup_experiment_dir()
        
        # Validate paths
        self._validate_paths()
        
        logger.info(f"📁 Input folder: {self.input_folder}")
        logger.info(f"📁 Target folder: {self.target_folder}")
        logger.info(f"📁 Results will be saved to: {self.experiment_dir}")
    
    def _setup_device(self, device):
        """Setup device with error handling"""
        if device == 'auto':
            if torch.cuda.is_available():
                device = 'cuda'
                gpu_name = torch.cuda.get_device_name(0)
                logger.info(f"🎮 Using GPU: {gpu_name}")
            else:
                device = 'cpu'
                logger.info("💻 Using CPU (GPU not available)")
        
        try:
            test_tensor = torch.randn(1, 3, 64, 64).to(device)
            logger.info(f"✅ Device '{device}' ready")
            return device
        except Exception as e:
            logger.warning(f"⚠️ Device '{device}' failed, using CPU: {e}")
            return 'cpu'
    
    def _validate_paths(self):
        """Validate and create necessary paths"""
        if not os.path.exists(self.input_folder):
            logger.warning(f"⚠️ Input folder not found: {self.input_folder}")
            os.makedirs(self.input_folder, exist_ok=True)
            logger.info(f"📁 Created input folder: {self.input_folder}")
        
        if not os.path.exists(self.target_folder):
            logger.warning(f"⚠️ Target folder not found: {self.target_folder}")
            os.makedirs(self.target_folder, exist_ok=True)
            logger.info(f"📁 Created target folder: {self.target_folder}")
    
    def _setup_experiment_dir(self):
        """Create timestamped experiment directory"""
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        experiment_dir = os.path.join(self.output_folder, f"experiment_{timestamp}")
        
        # Create subdirectories
        subdirs = [
            'checkpoints', 'generated_samples', 'training_progress', 
            'data_splits', 'validation_results', 'test_results', 'final_results'
        ]
        for subdir in subdirs:
            os.makedirs(os.path.join(experiment_dir, subdir), exist_ok=True)
        
        # Save config
        config_path = os.path.join(experiment_dir, 'config.json')
        config_dict = {k: v for k, v in self.__dict__.items() if not k.startswith('_')}
        with open(config_path, 'w') as f:
            json.dump(config_dict, f, indent=2, default=str)
        
        return experiment_dir

class DataSplitter:   #DATA SPLITTING AND MANAGEMENT
    """Handle data splitting and management"""
    
    def __init__(self, config):
        self.config = config
        self.image_pairs = []
        self.train_pairs = []
        self.val_pairs = []
        self.test_pairs = []
        
    def find_and_split_data(self):
        """Find image pairs and split into train/val/test"""
        
        logger.info("🔍 Scanning for image pairs...")
        self._find_image_pairs()
        
        if len(self.image_pairs) == 0:
            logger.warning("⚠️ No valid image pairs found! Creating sample data...")
            self._create_sample_data()
            self._find_image_pairs()
        
        logger.info(f"📊 Found {len(self.image_pairs)} total image pairs")
        
        # Split data
        self._split_data()
        
        # Save split information
        self._save_split_info()
        
        logger.info(f"✅ Data split completed:")
        logger.info(f"   📚 Training: {len(self.train_pairs)} pairs")
        logger.info(f"   🔍 Validation: {len(self.val_pairs)} pairs")
        logger.info(f"   🧪 Testing: {len(self.test_pairs)} pairs")
        
        return self.train_pairs, self.val_pairs, self.test_pairs
    
    def _find_image_pairs(self):
        """Find matching input-target image pairs"""
        extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff']
        
        # Get input files
        input_files = []
        for ext in extensions:
            input_files.extend(glob.glob(os.path.join(self.config.input_folder, ext)))
            input_files.extend(glob.glob(os.path.join(self.config.input_folder, ext.upper())))
        
        logger.info(f"📸 Found {len(input_files)} input images")
        
        # Find matching targets and validate
        valid_pairs = []
        for input_path in tqdm(input_files, desc="Validating image pairs"):
            input_name = Path(input_path).stem
            
            # Find matching target
            target_path = None
            for ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
                for target_ext in [ext, ext.upper()]:
                    candidate = os.path.join(self.config.target_folder, input_name + target_ext)
                    if os.path.exists(candidate):
                        target_path = candidate
                        break
                if target_path:
                    break
            
            if target_path and self._validate_image_pair(input_path, target_path):
                valid_pairs.append((input_path, target_path))
        
        self.image_pairs = valid_pairs
    
    def _validate_image_pair(self, input_path, target_path):
        """Validate that an image pair can be loaded"""
        try:
            input_img = Image.open(input_path).convert('RGB')
            target_img = Image.open(target_path).convert('RGB')
            
            # Test resizing
            input_img.resize((64, 64), Image.LANCZOS)
            target_img.resize((64, 64), Image.LANCZOS)
            
            return True
        except Exception:
            return False
    
    def _split_data(self):
        """Split data into train/validation/test sets"""
        
        # Set random seed for reproducible splits
        random.seed(self.config.random_seed)
        np.random.seed(self.config.random_seed)
        
        # Shuffle the pairs
        shuffled_pairs = self.image_pairs.copy()
        random.shuffle(shuffled_pairs)
        
        # Calculate split indices
        total_pairs = len(shuffled_pairs)
        train_size = int(total_pairs * self.config.train_ratio)
        val_size = int(total_pairs * self.config.val_ratio)
        
        # Split the data
        self.train_pairs = shuffled_pairs[:train_size]
        self.val_pairs = shuffled_pairs[train_size:train_size + val_size]
        self.test_pairs = shuffled_pairs[train_size + val_size:]
        
        # Ensure we have data in each split
        if len(self.test_pairs) == 0 and len(self.val_pairs) > 1:
            # Move one from val to test
            self.test_pairs.append(self.val_pairs.pop())
        
        if len(self.val_pairs) == 0 and len(self.train_pairs) > 1:
            # Move one from train to val
            self.val_pairs.append(self.train_pairs.pop())
    
    def _save_split_info(self):
        """Save information about the data splits"""
        split_info = {
            'total_pairs': len(self.image_pairs),
            'train_pairs': len(self.train_pairs),
            'val_pairs': len(self.val_pairs),
            'test_pairs': len(self.test_pairs),
            'train_ratio': self.config.train_ratio,
            'val_ratio': self.config.val_ratio,
            'test_ratio': self.config.test_ratio,
            'random_seed': self.config.random_seed,
            'split_timestamp': datetime.now().isoformat(),
            'train_files': [(str(i), str(t)) for i, t in self.train_pairs],
            'val_files': [(str(i), str(t)) for i, t in self.val_pairs],
            'test_files': [(str(i), str(t)) for i, t in self.test_pairs]
        }
        
        split_path = os.path.join(self.config.experiment_dir, 'data_splits', 'split_info.json')
        with open(split_path, 'w') as f:
            json.dump(split_info, f, indent=2)
        
        # Create CSV files for easy viewing
        splits = [
            ('train', self.train_pairs),
            ('val', self.val_pairs),
            ('test', self.test_pairs)
        ]
        
        for split_name, pairs in splits:
            df_data = []
            for input_path, target_path in pairs:
                df_data.append({
                    'input_path': input_path,
                    'target_path': target_path,
                    'filename': Path(input_path).name
                })
            
            if df_data:
                df = pd.DataFrame(df_data)
                csv_path = os.path.join(self.config.experiment_dir, 'data_splits', f'{split_name}_files.csv')
                df.to_csv(csv_path, index=False)
    
    def _create_sample_data(self):
        """Create sample data if no valid data found"""
        logger.info("🎨 Creating sample data for demonstration...")
        
        num_samples = 50
        for i in range(num_samples):
            # Create input image (sketch-like)
            input_array = np.ones((256, 256, 3), dtype=np.uint8) * 255
            for _ in range(15):
                y = np.random.randint(20, 236)
                x_start = np.random.randint(20, 200)
                x_end = x_start + np.random.randint(20, 60)
                thickness = np.random.randint(1, 4)
                input_array[y-thickness:y+thickness, x_start:x_end] = 0
            
            # Create target image (colored)
            target_array = np.random.randint(50, 200, (256, 256, 3), dtype=np.uint8)
            mask = (input_array < 128).any(axis=2)
            target_array[mask] = input_array[mask]
            
            # Save images
            input_img = Image.fromarray(input_array)
            target_img = Image.fromarray(target_array)
            
            filename = f"sample_{i:03d}.png"
            input_img.save(os.path.join(self.config.input_folder, filename))
            target_img.save(os.path.join(self.config.target_folder, filename))
        
        logger.info(f"✅ Created {num_samples} sample image pairs")

class Pix2PixDataset(Dataset):  ENHANCED DATASET WITH SPLIT SUPPORT
    """Dataset class supporting train/val/test splits"""
    
    def __init__(self, image_pairs, config, split_type='train'):
        self.image_pairs = image_pairs
        self.config = config
        self.split_type = split_type
        
        # Data augmentation for training
        if split_type == 'train':
            self.transform = transforms.Compose([
                transforms.Resize((config.image_size, config.image_size), Image.LANCZOS),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
                transforms.ToTensor(),
                transforms.Lambda(lambda x: x * 2.0 - 1.0)  # Normalize to [-1, 1]
            ])
        else:
            # No augmentation for validation/test
            self.transform = transforms.Compose([
                transforms.Resize((config.image_size, config.image_size), Image.LANCZOS),
                transforms.ToTensor(),
                transforms.Lambda(lambda x: x * 2.0 - 1.0)  # Normalize to [-1, 1]
            ])
    
    def __len__(self):
        return len(self.image_pairs)
    
    def __getitem__(self, idx):
        try:
            input_path, target_path = self.image_pairs[idx]
            
            # Load images
            input_img = Image.open(input_path).convert('RGB')
            target_img = Image.open(target_path).convert('RGB')
            
            # Apply same random transformations to both images
            if self.split_type == 'train':
                # Apply same random seed for synchronized augmentation
                seed = np.random.randint(2147483647)
                
                random.seed(seed)
                np.random.seed(seed)
                torch.manual_seed(seed)
                input_tensor = self.transform(input_img)
                
                random.seed(seed)
                np.random.seed(seed)
                torch.manual_seed(seed)
                target_tensor = self.transform(target_img)
            else:
                input_tensor = self.transform(input_img)
                target_tensor = self.transform(target_img)
            
            return {
                'input': input_tensor,
                'target': target_tensor,
                'input_path': input_path,
                'target_path': target_path,
                'filename': Path(input_path).stem
            }
            
        except Exception as e:
            logger.warning(f"Error loading item {idx}: {e}")
            # Return dummy data
            dummy = torch.zeros(3, self.config.image_size, self.config.image_size)
            return {
                'input': dummy,
                'target': dummy,
                'input_path': '',
                'target_path': '',
                'filename': 'error'
            }


class EnhancedGenerator(nn.Module):  #ENHANCED MODELS (UNCHANGED FROM ORIGINAL)
    """Enhanced U-Net Generator"""
    
    def __init__(self, input_nc=3, output_nc=3, ngf=64):
        super().__init__()
        
        # Encoder layers
        self.e1 = self._make_layer(input_nc, ngf, normalize=False)
        self.e2 = self._make_layer(ngf, ngf * 2)
        self.e3 = self._make_layer(ngf * 2, ngf * 4)
        self.e4 = self._make_layer(ngf * 4, ngf * 8)
        self.e5 = self._make_layer(ngf * 8, ngf * 8)
        self.e6 = self._make_layer(ngf * 8, ngf * 8)
        self.e7 = self._make_layer(ngf * 8, ngf * 8)
        self.e8 = self._make_layer(ngf * 8, ngf * 8, normalize=False)
        
        # Decoder layers
        self.d1 = self._make_up_layer(ngf * 8, ngf * 8, dropout=True)
        self.d2 = self._make_up_layer(ngf * 16, ngf * 8, dropout=True)
        self.d3 = self._make_up_layer(ngf * 16, ngf * 8, dropout=True)
        self.d4 = self._make_up_layer(ngf * 16, ngf * 8)
        self.d5 = self._make_up_layer(ngf * 16, ngf * 4)
        self.d6 = self._make_up_layer(ngf * 8, ngf * 2)
        self.d7 = self._make_up_layer(ngf * 4, ngf)
        
        self.final = nn.Sequential(
            nn.ConvTranspose2d(ngf * 2, output_nc, 4, 2, 1),
            nn.Tanh()
        )
        
        self._initialize_weights()
    
    def _make_layer(self, in_channels, out_channels, normalize=True):
        layers = [nn.Conv2d(in_channels, out_channels, 4, 2, 1)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2, True))
        return nn.Sequential(*layers)
    
    def _make_up_layer(self, in_channels, out_channels, dropout=False):
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True)
        ]
        if dropout:
            layers.append(nn.Dropout2d(0.5))
        return nn.Sequential(*layers)
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
                nn.init.normal_(m.weight, 0.0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight, 1.0, 0.02)
                nn.init.zeros_(m.bias)
    
    def forward(self, x):
        # Encoder
        e1 = self.e1(x)
        e2 = self.e2(e1)
        e3 = self.e3(e2)
        e4 = self.e4(e3)
        e5 = self.e5(e4)
        e6 = self.e6(e5)
        e7 = self.e7(e6)
        e8 = self.e8(e7)
        
        # Decoder with skip connections
        d1 = self.d1(e8)
        d2 = self.d2(torch.cat([d1, e7], 1))
        d3 = self.d3(torch.cat([d2, e6], 1))
        d4 = self.d4(torch.cat([d3, e5], 1))
        d5 = self.d5(torch.cat([d4, e4], 1))
        d6 = self.d6(torch.cat([d5, e3], 1))
        d7 = self.d7(torch.cat([d6, e2], 1))
        
        output = self.final(torch.cat([d7, e1], 1))
        return output

class EnhancedDiscriminator(nn.Module):
    """Enhanced PatchGAN Discriminator"""
    
    def __init__(self, input_nc=6, ndf=64):
        super().__init__()
        
        self.model = nn.Sequential(
            nn.Conv2d(input_nc, ndf, 4, 2, 1),
            nn.LeakyReLU(0.2, True),
            
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, True),
            
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, True),
            
            nn.Conv2d(ndf * 4, ndf * 8, 4, 1, 1),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, True),
            
            nn.Conv2d(ndf * 8, 1, 4, 1, 1)
        )
        
        self._initialize_weights()
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, 0.0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight, 1.0, 0.02)
                nn.init.zeros_(m.bias)
    
    def forward(self, input_img, target_img):
        x = torch.cat([input_img, target_img], dim=1)
        return self.model(x)

class EnhancedMetrics:  #COMPREHENSIVE METRICS CALCULATOR
    """Comprehensive metrics calculator for image quality assessment"""
    
    def __init__(self, device):
        self.device = device
    
    def calculate_psnr(self, img1, img2, max_val=1.0):
        """Calculate Peak Signal-to-Noise Ratio"""
        mse = torch.mean((img1 - img2) ** 2)
        if mse == 0:
            return float('inf')
        return 20 * torch.log10(max_val / torch.sqrt(mse))
    
    def calculate_ssim(self, img1, img2):
        """Calculate Structural Similarity Index"""
        img1_np = img1.detach().cpu().numpy()
        img2_np = img2.detach().cpu().numpy()
        
        ssim_values = []
        for i in range(img1_np.shape[0]):
            im1 = np.transpose(img1_np[i], (1, 2, 0))
            im2 = np.transpose(img2_np[i], (1, 2, 0))
            
            # Ensure values are in [0, 1] range
            im1 = np.clip((im1 + 1) / 2, 0, 1)
            im2 = np.clip((im2 + 1) / 2, 0, 1)
            
            # Calculate SSIM
            ssim_val = ssim(im1, im2, multichannel=True, data_range=1.0, channel_axis=2)
            ssim_values.append(ssim_val)
        
        return np.mean(ssim_values)
    
    def calculate_lpips(self, img1, img2):
        """Calculate simplified LPIPS-like perceptual metric"""
        img1_gray = torch.mean(img1, dim=1, keepdim=True)
        img2_gray = torch.mean(img2, dim=1, keepdim=True)
        
        # Sobel filters
        sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], 
                              dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(self.device)
        sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], 
                              dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(self.device)
        
        # Calculate gradients
        grad1_x = F.conv2d(img1_gray, sobel_x, padding=1)
        grad1_y = F.conv2d(img1_gray, sobel_y, padding=1)
        grad2_x = F.conv2d(img2_gray, sobel_x, padding=1)
        grad2_y = F.conv2d(img2_gray, sobel_y, padding=1)
        
        # Perceptual difference
        diff = torch.mean(torch.abs(grad1_x - grad2_x) + torch.abs(grad1_y - grad2_y))
        return diff.item()
    
    def evaluate_comprehensive(self, generated, target):
        """Calculate all metrics for a batch"""
        with torch.no_grad():
            metrics = {}
            
            # Basic metrics
            metrics['psnr'] = self.calculate_psnr(generated, target).item()
            metrics['mae'] = F.l1_loss(generated, target).item()
            metrics['mse'] = F.mse_loss(generated, target).item()
            
            # Advanced metrics
            metrics['ssim'] = self.calculate_ssim(generated, target)
            metrics['lpips'] = self.calculate_lpips(generated, target)
            
            # Edge preservation metric
            generated_edges = self._detect_edges(generated)
            target_edges = self._detect_edges(target)
            metrics['edge_similarity'] = F.cosine_similarity(
                generated_edges.flatten(1), target_edges.flatten(1)
            ).mean().item()
            
            return metrics
    
    def _detect_edges(self, images):
        """Detect edges using Sobel operator"""
        gray = torch.mean(images, dim=1, keepdim=True)
        
        sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], 
                              dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(self.device)
        sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], 
                              dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(self.device)
        
        grad_x = F.conv2d(gray, sobel_x, padding=1)
        grad_y = F.conv2d(gray, sobel_y, padding=1)
        
        edges = torch.sqrt(grad_x**2 + grad_y**2)
        return edges
