In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import numpy as np
import cv2
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
import os
import random
from typing import List, Tuple, Dict
import logging
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import warnings
import math
import time
import platform

# Platform-specific configuration for Windows compatibility
if platform.system() == 'Windows':
    try:
        torch.multiprocessing.set_start_method('spawn', force=True)
    except RuntimeError:
        pass
    os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
    DATALOADER_NUM_WORKERS = 0
else:
    DATALOADER_NUM_WORKERS = 2

warnings.filterwarnings("ignore", category=UserWarning)

class Config:
    # Training parameters
    BATCH_SIZE = 2
    LEARNING_RATE_G = 0.0002
    LEARNING_RATE_D = 0.0001
    EPOCHS = 700
    
    # Image parameters
    IMG_HEIGHT = 512
    IMG_WIDTH = 512
    CHANNELS = 1
    
    # GAN parameters
    LATENT_DIM = 128
    FEATURE_DIM = 64
    
    # Loss weights
    LAMBDA_ADVERSARIAL = 1.0
    LAMBDA_L1 = 100.0
    LAMBDA_PERCEPTUAL = 20.0
    LAMBDA_SSIM = 10.0
    LAMBDA_EDGE = 15.0
    LAMBDA_TV = 2.0
    LAMBDA_TEXT_STRUCTURE = 30.0
    
    # Quality targets
    TARGET_SSIM = 0.85
    TARGET_PSNR = 30.0
    
    # Paths
    REAL_IMAGES_DIR = "Dataset/Ancient palm leaf documents/Jathakam/Original/dummy_sample"
    OUTPUT_DIR = "Enhanced_GAN_Output2"
    MODELS_DIR = "Enhanced_GAN_Models2"
    GENERATED_IMAGES_DIR = "Generated_Images3"
    
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Training stability parameters
    GRADIENT_PENALTY_LAMBDA = 10.0
    SPECTRAL_NORM = False
    
    # Document generation parameters
    MIN_FONT_SIZE = 14
    MAX_FONT_SIZE = 20
    LINE_SPACING = 25
    PARAGRAPH_SPACING = 35
    CHAR_DETAIL_LEVEL = 5

config = Config()

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# ==================== ENHANCED LOSS FUNCTIONS ====================
class TextStructureLoss(nn.Module):
    """Loss function for text structure preservation."""
    
    def __init__(self):
        super(TextStructureLoss, self).__init__()
        
        # Fixed: Proper 4D tensor shape [out_channels, in_channels, height, width]
        horizontal_kernel = torch.tensor([[[[-1, -1, -1],
                                            [ 2,  2,  2],
                                            [-1, -1, -1]]]], dtype=torch.float32)
        
        vertical_kernel = torch.tensor([[[[-1, 2, -1],
                                          [-1, 2, -1],
                                          [-1, 2, -1]]]], dtype=torch.float32)
        
        self.register_buffer('h_kernel', horizontal_kernel)
        self.register_buffer('v_kernel', vertical_kernel)
    
    def forward(self, pred, target):
        pred_h = F.conv2d(pred, self.h_kernel, padding=1)
        target_h = F.conv2d(target, self.h_kernel, padding=1)
        
        pred_v = F.conv2d(pred, self.v_kernel, padding=1)
        target_v = F.conv2d(target, self.v_kernel, padding=1)
        
        h_loss = F.l1_loss(pred_h, target_h)
        v_loss = F.l1_loss(pred_v, target_v)
        
        return h_loss + v_loss

class EnhancedPerceptualLoss(nn.Module):
    """Enhanced perceptual loss for text features."""
    
    def __init__(self):
        super(EnhancedPerceptualLoss, self).__init__()
    
    def forward(self, pred, target):
        scales = [1, 2, 4]
        total_loss = 0.0
        
        for i, scale in enumerate(scales):
            if scale > 1:
                pred_scaled = F.avg_pool2d(pred, scale)
                target_scaled = F.avg_pool2d(target, scale)
            else:
                pred_scaled = pred
                target_scaled = target
            
            weight = 1.0 / (scale * 0.5)
            total_loss += weight * F.mse_loss(pred_scaled, target_scaled)
        
        return total_loss

class EnhancedEdgeLoss(nn.Module):
    """Enhanced edge loss for character boundaries."""
    
    def __init__(self):
        super(EnhancedEdgeLoss, self).__init__()
        
        sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32)
        sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32)
        laplacian = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=torch.float32)
        
        self.register_buffer('sobel_x', sobel_x.view(1, 1, 3, 3))
        self.register_buffer('sobel_y', sobel_y.view(1, 1, 3, 3))
        self.register_buffer('laplacian', laplacian.view(1, 1, 3, 3))
    
    def get_edges(self, x):
        edge_x = F.conv2d(x, self.sobel_x, padding=1)
        edge_y = F.conv2d(x, self.sobel_y, padding=1)
        edge_lap = F.conv2d(x, self.laplacian, padding=1)
        
        sobel_edges = torch.sqrt(edge_x**2 + edge_y**2 + 1e-8)
        return sobel_edges + 0.5 * torch.abs(edge_lap)
    
    def forward(self, pred, target):
        pred_edges = self.get_edges(pred)
        target_edges = self.get_edges(target)
        return F.l1_loss(pred_edges, target_edges)

class SimpleSSIMLoss(nn.Module):
    """Simple SSIM loss for structural similarity."""
    
    def __init__(self, window_size=11, size_average=True):
        super(SimpleSSIMLoss, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        
        sigma = 1.5
        gauss = torch.Tensor([math.exp(-(x - window_size//2)**2 / float(2*sigma**2)) 
                             for x in range(window_size)])
        gauss = gauss / gauss.sum()
        
        window = gauss.unsqueeze(1) @ gauss.unsqueeze(0)
        window = window.unsqueeze(0).unsqueeze(0)
        
        self.register_buffer('window', window)
    
    def forward(self, pred, target):
        C1 = 0.01 ** 2
        C2 = 0.03 ** 2
        
        mu1 = F.conv2d(pred, self.window, padding=self.window_size//2, groups=self.channel)
        mu2 = F.conv2d(target, self.window, padding=self.window_size//2, groups=self.channel)
        
        mu1_sq = mu1.pow(2)
        mu2_sq = mu2.pow(2)
        mu1_mu2 = mu1 * mu2
        
        sigma1_sq = F.conv2d(pred * pred, self.window, padding=self.window_size//2, groups=self.channel) - mu1_sq
        sigma2_sq = F.conv2d(target * target, self.window, padding=self.window_size//2, groups=self.channel) - mu2_sq
        sigma12 = F.conv2d(pred * target, self.window, padding=self.window_size//2, groups=self.channel) - mu1_mu2
        
        ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \
                   ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
        
        if self.size_average:
            return 1 - ssim_map.mean()
        else:
            return 1 - ssim_map.mean(1).mean(1).mean(1)

# ==================== GENERATOR ARCHITECTURE ====================
class SelfAttention(nn.Module):
    """Self-attention mechanism."""
    
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.in_dim = in_dim
        
        self.query_conv = nn.Conv2d(in_dim, in_dim // 8, 1)
        self.key_conv = nn.Conv2d(in_dim, in_dim // 8, 1)
        self.value_conv = nn.Conv2d(in_dim, in_dim, 1)
        self.gamma = nn.Parameter(torch.zeros(1))
        
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, x):
        batch_size, C, H, W = x.size()
        
        proj_query = self.query_conv(x).view(batch_size, -1, H * W).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batch_size, -1, H * W)
        proj_value = self.value_conv(x).view(batch_size, -1, H * W)
        
        attention = torch.bmm(proj_query, proj_key)
        attention = self.softmax(attention)
        
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batch_size, C, H, W)
        
        return self.gamma * out + x

class EnhancedGenerator(nn.Module):
    """Enhanced generator with attention mechanism."""
    
    def __init__(self):
        super(EnhancedGenerator, self).__init__()
        
        self.fc = nn.Sequential(
            nn.Linear(config.LATENT_DIM, 1024 * 8 * 8),
            nn.BatchNorm1d(1024 * 8 * 8),
            nn.ReLU(True)
        )
        
        self.layer1 = self._make_layer(1024, 512, use_attention=True)
        self.layer2 = self._make_layer(512, 256, use_attention=True)
        self.layer3 = self._make_layer(256, 128, use_attention=False)
        self.layer4 = self._make_layer(128, 64, use_attention=False)
        self.layer5 = self._make_layer(64, 32, use_attention=False)
        self.layer6 = self._make_layer(32, 16, use_attention=False)
        
        self.text_refine = nn.Sequential(
            nn.Conv2d(16, 8, 3, 1, 1),
            nn.BatchNorm2d(8),
            nn.ReLU(inplace=True),
            nn.Conv2d(8, 4, 3, 1, 1),
            nn.BatchNorm2d(4),
            nn.ReLU(inplace=True)
        )
        
        self.final = nn.Sequential(
            nn.Conv2d(4, config.CHANNELS, 3, 1, 1),
            nn.Tanh()
        )
        
        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
            nn.init.normal_(m.weight, 0.0, 0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.normal_(m.weight, 1.0, 0.02)
            nn.init.constant_(m.bias, 0)
    
    def _make_layer(self, in_channels, out_channels, use_attention=False):
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ]
        
        if use_attention:
            layers.append(SelfAttention(out_channels))
        
        layers.extend([
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ])
        
        return nn.Sequential(*layers)
    
    def forward(self, z):
        batch_size = z.size(0)
        
        x = self.fc(z)
        x = x.view(batch_size, 1024, 8, 8)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.layer6(x)
        
        x = self.text_refine(x)
        
        return self.final(x)

class EnhancedDiscriminator(nn.Module):
    """Enhanced discriminator."""
    
    def __init__(self):
        super(EnhancedDiscriminator, self).__init__()
        
        self.main = nn.Sequential(
            nn.Conv2d(config.CHANNELS, 32, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(32, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(512, 1024, 4, 2, 1, bias=False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(1024, 1, 8, 1, 0, bias=False),
            nn.Sigmoid()
        )
        
        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, nn.Conv2d):
            nn.init.normal_(m.weight, 0.0, 0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.normal_(m.weight, 1.0, 0.02)
            nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        return self.main(x).view(-1, 1).squeeze(1)

# ==================== DATASET ====================
class ReadableDocumentDataset(Dataset):
    """Enhanced dataset with readable text."""
    
    def __init__(self, images_dir, create_samples=True):
        self.images_dir = os.path.abspath(images_dir)
        self.fonts = self._get_available_fonts()
        
        if create_samples:
            self._create_readable_documents()
        
        self.image_paths = self._get_image_paths()
        
        if len(self.image_paths) == 0:
            self._create_font_fallback()
            self.image_paths = self._get_image_paths()
        
        self.transform = transforms.Compose([
            transforms.Resize((config.IMG_HEIGHT, config.IMG_WIDTH)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])
        
        logger.info(f"Dataset loaded with {len(self.image_paths)} images")
    
    def _get_available_fonts(self):
        """Get available fonts."""
        fonts = []
        
        font_paths = [
            "C:/Windows/Fonts/arial.ttf",
            "C:/Windows/Fonts/times.ttf",
            "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
            "/System/Library/Fonts/Arial.ttf",
        ]
        
        for font_path in font_paths:
            if os.path.exists(font_path):
                try:
                    for size in [12, 14, 16, 18, 20]:
                        fonts.append(ImageFont.truetype(font_path, size))
                    break
                except:
                    continue
        
        if not fonts:
            try:
                for size in [12, 14, 16, 18, 20]:
                    fonts.append(ImageFont.load_default())
            except:
                fonts = [None] * 5
        
        logger.info(f"Found {len([f for f in fonts if f is not None])} fonts")
        return fonts
    
    def _get_image_paths(self):
        if not os.path.exists(self.images_dir):
            return []
        
        image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'}
        paths = []
        
        for file in os.listdir(self.images_dir):
            if any(file.lower().endswith(ext) for ext in image_extensions):
                paths.append(os.path.join(self.images_dir, file))
        
        return sorted(paths)
    
    def _create_readable_documents(self):
        """Create readable documents."""
        logger.info("Creating readable document samples...")
        os.makedirs(self.images_dir, exist_ok=True)
        
        document_contents = [
            {
                "title": "ARTIFICIAL INTELLIGENCE IN MODERN COMPUTING",
                "sections": [
                    {
                        "heading": "1. Introduction",
                        "content": [
                            "Artificial Intelligence has revolutionized the way we approach complex problems",
                            "in computer science. Machine learning algorithms have become increasingly",
                            "sophisticated, enabling computers to learn from data and make predictions",
                            "with remarkable accuracy. Deep learning, a subset of machine learning,",
                            "has shown particular promise in areas such as image recognition, natural",
                            "language processing, and autonomous systems."
                        ]
                    },
                    {
                        "heading": "2. Methodology",
                        "content": [
                            "Our research employs a comprehensive approach to evaluate AI performance",
                            "across multiple domains. We collected datasets from various sources and",
                            "implemented state-of-the-art neural network architectures. The training",
                            "process involved careful hyperparameter tuning and cross-validation to",
                            "ensure robust results. Performance metrics included accuracy, precision,",
                            "recall, and F1-score for classification tasks."
                        ]
                    }
                ]
            },
            {
                "title": "DATA SCIENCE AND STATISTICAL ANALYSIS",
                "sections": [
                    {
                        "heading": "Abstract",
                        "content": [
                            "This paper presents a comprehensive analysis of statistical methods used",
                            "in modern data science applications. We examine the effectiveness of",
                            "various techniques including regression analysis, hypothesis testing,",
                            "and Bayesian inference. The results demonstrate significant improvements",
                            "in prediction accuracy when proper statistical foundations are applied",
                            "to machine learning models."
                        ]
                    }
                ]
            }
        ]
        
        for i in range(200):
            img = Image.new('RGB', (config.IMG_WIDTH, config.IMG_HEIGHT), color=(250, 250, 250))
            draw = ImageDraw.Draw(img)
            
            doc_content = random.choice(document_contents)
            
            left_margin = 40
            right_margin = config.IMG_WIDTH - 40
            top_margin = 50
            y_pos = top_margin
            
            # Draw title
            title_font = self._get_font(18)
            if title_font:
                try:
                    title_bbox = draw.textbbox((0, 0), doc_content["title"], font=title_font)
                    title_width = title_bbox[2] - title_bbox[0]
                except:
                    title_width = len(doc_content["title"]) * 10
                
                title_x = (config.IMG_WIDTH - title_width) // 2
                draw.text((title_x, y_pos), doc_content["title"], fill=(0, 0, 0), font=title_font)
                y_pos += 50
                
                draw.line([(left_margin, y_pos), (right_margin, y_pos)], fill=(0, 0, 0), width=2)
                y_pos += 30
            
            # Draw sections
            for section in doc_content["sections"]:
                if y_pos > config.IMG_HEIGHT - 100:
                    break
                
                heading_font = self._get_font(16)
                if heading_font:
                    draw.text((left_margin, y_pos), section["heading"], fill=(0, 0, 0), font=heading_font)
                    y_pos += 35
                
                content_font = self._get_font(14)
                if content_font:
                    for line in section["content"]:
                        words = line.split()
                        current_line = ""
                        
                        for word in words:
                            test_line = current_line + (" " if current_line else "") + word
                            
                            try:
                                bbox = draw.textbbox((0, 0), test_line, font=content_font)
                                line_width = bbox[2] - bbox[0]
                            except:
                                line_width = len(test_line) * 8
                            
                            if line_width > (right_margin - left_margin):
                                if current_line:
                                    draw.text((left_margin, y_pos), current_line, 
                                            fill=(0, 0, 0), font=content_font)
                                    y_pos += config.LINE_SPACING
                                    current_line = word
                            else:
                                current_line = test_line
                        
                        if current_line:
                            draw.text((left_margin, y_pos), current_line, 
                                    fill=(0, 0, 0), font=content_font)
                            y_pos += config.LINE_SPACING
                    
                    y_pos += 20
            
            # Convert to grayscale
            img_gray = img.convert('L')
            
            # Add slight noise
            img_array = np.array(img_gray)
            noise = np.random.normal(0, 0.5, img_array.shape)
            img_array = np.clip(img_array.astype(float) + noise, 0, 255).astype(np.uint8)
            
            # Fixed: Remove deprecated mode parameter
            final_img = Image.fromarray(img_array)
            final_img.save(os.path.join(self.images_dir, f'readable_doc_{i:04d}.png'))
        
        logger.info("Completed creating readable documents")
    
    def _get_font(self, size):
        """Get font for size."""
        if not self.fonts or all(f is None for f in self.fonts):
            return None
        
        size_map = {12: 0, 14: 1, 16: 2, 18: 3, 20: 4}
        font_index = size_map.get(size, 1)
        
        if font_index < len(self.fonts):
            return self.fonts[font_index]
        return self.fonts[0] if self.fonts[0] is not None else None
    
    def _create_font_fallback(self):
        """Create fallback documents."""
        logger.info("Creating fallback documents...")
        os.makedirs(self.images_dir, exist_ok=True)
        
        for i in range(100):
            img = np.ones((config.IMG_HEIGHT, config.IMG_WIDTH), dtype=np.uint8) * 248
            
            y = 80
            while y < config.IMG_HEIGHT - 80:
                if random.random() > 0.1:
                    x = 60
                    while x < config.IMG_WIDTH - 100:
                        word_length = random.randint(30, 120)
                        char_height = random.randint(14, 18)
                        
                        img[y:y+char_height, x:x+word_length] = random.randint(20, 40)
                        
                        x += word_length + random.randint(10, 20)
                        
                        if x >= config.IMG_WIDTH - 150:
                            break
                
                y += random.randint(20, 28)
                
                if random.random() > 0.85:
                    y += random.randint(20, 35)
            
            Image.fromarray(img, mode='L').save(
                os.path.join(self.images_dir, f'fallback_doc_{i:04d}.png')
            )
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx % len(self.image_paths)]
        
        try:
            with Image.open(img_path) as image:
                image = image.convert('L')
                return self.transform(image)
        except Exception as e:
            logger.warning(f"Error loading {img_path}: {e}")
            fallback = np.ones((config.IMG_HEIGHT, config.IMG_WIDTH), dtype=np.uint8) * 245
            fallback_pil = Image.fromarray(fallback, mode='L')
            return self.transform(fallback_pil)

# ==================== TRAINER ====================
class EnhancedGANTrainer:
    """Enhanced GAN trainer."""
    
    def __init__(self):
        self.generator = EnhancedGenerator().to(config.DEVICE)
        self.discriminator = EnhancedDiscriminator().to(config.DEVICE)
        
        g_params = sum(p.numel() for p in self.generator.parameters() if p.requires_grad)
        d_params = sum(p.numel() for p in self.discriminator.parameters() if p.requires_grad)
        logger.info(f"Generator parameters: {g_params:,}")
        logger.info(f"Discriminator parameters: {d_params:,}")
        
        self.optimizer_G = optim.Adam(
            self.generator.parameters(), 
            lr=config.LEARNING_RATE_G, 
            betas=(0.5, 0.999),
            weight_decay=1e-5
        )
        self.optimizer_D = optim.Adam(
            self.discriminator.parameters(), 
            lr=config.LEARNING_RATE_D, 
            betas=(0.5, 0.999),
            weight_decay=1e-5
        )
        
        self.adversarial_loss = nn.BCELoss()
        self.l1_loss = nn.L1Loss()
        self.l2_loss = nn.MSELoss()
        self.perceptual_loss = EnhancedPerceptualLoss().to(config.DEVICE)
        self.ssim_loss = SimpleSSIMLoss().to(config.DEVICE)
        self.edge_loss = EnhancedEdgeLoss().to(config.DEVICE)
        self.text_structure_loss = TextStructureLoss().to(config.DEVICE)
        
        self.history = {
            'g_loss': [], 'd_loss': [], 'ssim_scores': [], 'psnr_scores': [],
            'g_adv_loss': [], 'g_content_loss': []
        }
        
        # Fixed: Removed verbose parameter
        self.scheduler_G = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer_G, mode='min', factor=0.5, patience=50
        )
        self.scheduler_D = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer_D, mode='min', factor=0.5, patience=50
        )
        
        os.makedirs(config.OUTPUT_DIR, exist_ok=True)
        os.makedirs(config.MODELS_DIR, exist_ok=True)
        os.makedirs(config.GENERATED_IMAGES_DIR, exist_ok=True)
        
        logger.info(f"Trainer initialized on {config.DEVICE}")
    
    def train_discriminator(self, real_images):
        """Train discriminator."""
        self.optimizer_D.zero_grad()
        
        batch_size = real_images.size(0)
        
        real_labels = torch.ones(batch_size, device=config.DEVICE) * 0.9
        fake_labels = torch.zeros(batch_size, device=config.DEVICE) * 0.1
        
        real_output = self.discriminator(real_images)
        real_loss = self.adversarial_loss(real_output, real_labels)
        
        noise = torch.randn(batch_size, config.LATENT_DIM, device=config.DEVICE)
        with torch.no_grad():
            fake_images = self.generator(noise)
        
        fake_output = self.discriminator(fake_images.detach())
        fake_loss = self.adversarial_loss(fake_output, fake_labels)
        
        gradient_penalty = self.compute_gradient_penalty(real_images, fake_images)
        
        d_loss = (real_loss + fake_loss) / 2 + config.GRADIENT_PENALTY_LAMBDA * gradient_penalty
        d_loss.backward()
        
        torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), 1.0)
        
        self.optimizer_D.step()
        
        return d_loss.item(), real_output.mean().item(), fake_output.mean().item()
    
    def compute_gradient_penalty(self, real_images, fake_images):
        """Compute gradient penalty."""
        batch_size = real_images.size(0)
        alpha = torch.rand(batch_size, 1, 1, 1, device=config.DEVICE)
        alpha = alpha.expand_as(real_images)
        
        interpolated = alpha * real_images + (1 - alpha) * fake_images
        interpolated.requires_grad_(True)
        
        d_interpolated = self.discriminator(interpolated)
        
        gradients = torch.autograd.grad(
            outputs=d_interpolated,
            inputs=interpolated,
            grad_outputs=torch.ones_like(d_interpolated),
            create_graph=True,
            retain_graph=True,
            only_inputs=True
        )[0]
        
        gradient_norm = gradients.view(batch_size, -1).norm(2, dim=1)
        gradient_penalty = ((gradient_norm - 1) ** 2).mean()
        
        return gradient_penalty
    
    def train_generator(self, real_images):
        """Train generator."""
        self.optimizer_G.zero_grad()
        
        batch_size = real_images.size(0)
        noise = torch.randn(batch_size, config.LATENT_DIM, device=config.DEVICE)
        fake_images = self.generator(noise)
        
        fake_output = self.discriminator(fake_images)
        real_labels = torch.ones(batch_size, device=config.DEVICE)
        adversarial_loss = self.adversarial_loss(fake_output, real_labels)
        
        l1_loss = self.l1_loss(fake_images, real_images)
        l2_loss = self.l2_loss(fake_images, real_images)
        perceptual_loss = self.perceptual_loss(fake_images, real_images)
        ssim_loss = self.ssim_loss(fake_images, real_images)
        edge_loss = self.edge_loss(fake_images, real_images)
        text_loss = self.text_structure_loss(fake_images, real_images)
        
        tv_loss = self.total_variation_loss(fake_images)
        
        content_loss = (
            config.LAMBDA_L1 * l1_loss +
            config.LAMBDA_L1 * 0.5 * l2_loss +
            config.LAMBDA_PERCEPTUAL * perceptual_loss +
            config.LAMBDA_SSIM * ssim_loss +
            config.LAMBDA_EDGE * edge_loss +
            config.LAMBDA_TEXT_STRUCTURE * text_loss +
            config.LAMBDA_TV * tv_loss
        )
        
        g_loss = config.LAMBDA_ADVERSARIAL * adversarial_loss + content_loss
        
        g_loss.backward()
        
        torch.nn.utils.clip_grad_norm_(self.generator.parameters(), 1.0)
        
        self.optimizer_G.step()
        
        with torch.no_grad():
            fake_np = fake_images.cpu().numpy()
            real_np = real_images.cpu().numpy()
            
            ssim_scores = []
            psnr_scores = []
            
            for i in range(min(4, batch_size)):
                fake_img = ((fake_np[i, 0] + 1.0) / 2.0 * 255).astype(np.uint8)
                real_img = ((real_np[i, 0] + 1.0) / 2.0 * 255).astype(np.uint8)
                
                try:
                    ssim_score = ssim(real_img, fake_img, data_range=255)
                    psnr_score = psnr(real_img, fake_img, data_range=255)
                    ssim_scores.append(ssim_score)
                    psnr_scores.append(psnr_score)
                except:
                    pass
            
            avg_ssim = np.mean(ssim_scores) if ssim_scores else 0.0
            avg_psnr = np.mean(psnr_scores) if psnr_scores else 0.0
        
        return {
            'g_loss': g_loss.item(),
            'adversarial_loss': adversarial_loss.item(),
            'content_loss': content_loss.item(),
            'ssim': avg_ssim,
            'psnr': avg_psnr,
            'fake_images': fake_images
        }
    
    def total_variation_loss(self, images):
        """Total variation loss."""
        tv_h = torch.mean(torch.abs(images[:, :, 1:, :] - images[:, :, :-1, :]))
        tv_w = torch.mean(torch.abs(images[:, :, :, 1:] - images[:, :, :, :-1]))
        return tv_h + tv_w
    
    def save_training_progress(self, epoch, real_images, fake_images, num_samples=8):
        """Save training progress."""
        progress_folder = os.path.join(config.OUTPUT_DIR, f"epoch_{epoch:04d}")
        os.makedirs(progress_folder, exist_ok=True)
        
        self.generator.eval()
        with torch.no_grad():
            real_denorm = (real_images + 1.0) / 2.0
            fake_denorm = (fake_images + 1.0) / 2.0
            
            fig, axes = plt.subplots(2, num_samples, figsize=(num_samples * 3, 6))
            fig.suptitle(f'Training Progress - Epoch {epoch}', fontsize=16)
            
            for i in range(min(num_samples, real_images.size(0))):
                real_img = real_denorm[i].cpu().squeeze().numpy()
                axes[0, i].imshow(real_img, cmap='gray')
                axes[0, i].set_title('Real', fontsize=10)
                axes[0, i].axis('off')
                
                fake_img = fake_denorm[i].cpu().squeeze().numpy()
                axes[1, i].imshow(fake_img, cmap='gray')
                axes[1, i].set_title('Generated', fontsize=10)
                axes[1, i].axis('off')
            
            plt.tight_layout()
            plt.savefig(os.path.join(progress_folder, 'comparison.png'), 
                       dpi=150, bbox_inches='tight')
            plt.close()
            
            for i in range(min(num_samples, fake_images.size(0))):
                img_array = fake_denorm[i].cpu().squeeze().numpy()
                img_array = (img_array * 255).astype(np.uint8)
                
                Image.fromarray(img_array, mode='L').save(
                    os.path.join(progress_folder, f'generated_{i+1:02d}.png')
                )
        
        self.generator.train()
        logger.info(f"Saved progress for epoch {epoch}")
    
    def save_loss_plots(self):
        """Save loss plots."""
        if len(self.history['g_loss']) < 2:
            return
        
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
        
        epochs = range(1, len(self.history['g_loss']) + 1)
        
        ax1.plot(epochs, self.history['g_loss'], 'b-', label='Generator', linewidth=2)
        ax1.plot(epochs, self.history['d_loss'], 'r-', label='Discriminator', linewidth=2)
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.set_title('Training Losses')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        if len(self.history['ssim_scores']) > 0:
            ax2.plot(epochs, self.history['ssim_scores'], 'g-', label='SSIM', linewidth=2)
            ax2.axhline(y=config.TARGET_SSIM, color='g', linestyle='--', alpha=0.7, label='Target')
            ax2.set_xlabel('Epoch')
            ax2.set_ylabel('SSIM Score')
            ax2.set_title('SSIM Quality')
            ax2.legend()
            ax2.grid(True, alpha=0.3)
        
        if len(self.history['psnr_scores']) > 0:
            ax3.plot(epochs, self.history['psnr_scores'], 'm-', label='PSNR', linewidth=2)
            ax3.axhline(y=config.TARGET_PSNR, color='m', linestyle='--', alpha=0.7, label='Target')
            ax3.set_xlabel('Epoch')
            ax3.set_ylabel('PSNR (dB)')
            ax3.set_title('PSNR Quality')
            ax3.legend()
            ax3.grid(True, alpha=0.3)
        
        if len(self.history['g_adv_loss']) > 0:
            ax4.plot(epochs, self.history['g_adv_loss'], 'c-', label='Adversarial', linewidth=2)
            ax4.plot(epochs, self.history['g_content_loss'], 'y-', label='Content', linewidth=2)
            ax4.set_xlabel('Epoch')
            ax4.set_ylabel('Loss Component')
            ax4.set_title('Generator Components')
            ax4.legend()
            ax4.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(os.path.join(config.OUTPUT_DIR, 'training_metrics.png'), 
                   dpi=150, bbox_inches='tight')
        plt.close()
    
    def train(self, dataloader, start_epoch=0):
        """Training loop."""
        logger.info("Starting training...")
        
        best_ssim = 0.0
        
        for epoch in range(start_epoch, config.EPOCHS):
            epoch_start_time = time.time()
            
            d_losses = []
            g_losses = []
            g_adv_losses = []
            g_content_losses = []
            ssim_scores = []
            psnr_scores = []
            
            real_scores = []
            fake_scores = []
            
            for batch_idx, real_images in enumerate(dataloader):
                real_images = real_images.to(config.DEVICE)
                
                d_loss, real_score, fake_score = self.train_discriminator(real_images)
                d_losses.append(d_loss)
                real_scores.append(real_score)
                fake_scores.append(fake_score)
                
                g_results = self.train_generator(real_images)
                g_losses.append(g_results['g_loss'])
                g_adv_losses.append(g_results['adversarial_loss'])
                g_content_losses.append(g_results['content_loss'])
                ssim_scores.append(g_results['ssim'])
                psnr_scores.append(g_results['psnr'])
                
                if batch_idx % 10 == 0:
                    logger.info(
                        f'Epoch [{epoch+1:3d}/{config.EPOCHS}] '
                        f'Batch [{batch_idx:3d}/{len(dataloader)}] '
                        f'D_Loss: {d_loss:.4f} G_Loss: {g_results["g_loss"]:.4f} '
                        f'SSIM: {g_results["ssim"]:.3f} PSNR: {g_results["psnr"]:.2f}'
                    )
                
                if (epoch == 0 and batch_idx % 50 == 0) or (epoch % 25 == 0 and batch_idx == 0):
                    self.save_training_progress(epoch + 1, real_images, g_results['fake_images'])
            
            epoch_d_loss = np.mean(d_losses)
            epoch_g_loss = np.mean(g_losses)
            epoch_ssim = np.mean([s for s in ssim_scores if s > 0])
            epoch_psnr = np.mean([p for p in psnr_scores if p > 0])
            
            self.history['d_loss'].append(epoch_d_loss)
            self.history['g_loss'].append(epoch_g_loss)
            self.history['g_adv_loss'].append(np.mean(g_adv_losses))
            self.history['g_content_loss'].append(np.mean(g_content_losses))
            self.history['ssim_scores'].append(epoch_ssim)
            self.history['psnr_scores'].append(epoch_psnr)
            
            self.scheduler_G.step(epoch_g_loss)
            self.scheduler_D.step(epoch_d_loss)
            
            epoch_time = time.time() - epoch_start_time
            
            logger.info(
                f'Epoch [{epoch+1:3d}/{config.EPOCHS}] completed in {epoch_time:.2f}s - '
                f'D_Loss: {epoch_d_loss:.4f} G_Loss: {epoch_g_loss:.4f} '
                f'SSIM: {epoch_ssim:.3f} PSNR: {epoch_psnr:.2f}'
            )
            
            if (epoch + 1) % 5 == 0:
                self.save_generated_images_only(epoch + 1)
            
            if (epoch + 1) % 10 == 0:
                self.save_loss_plots()
            
            if epoch_ssim > best_ssim:
                best_ssim = epoch_ssim
                
                torch.save({
                    'epoch': epoch,
                    'generator_state_dict': self.generator.state_dict(),
                    'discriminator_state_dict': self.discriminator.state_dict(),
                    'optimizer_G_state_dict': self.optimizer_G.state_dict(),
                    'optimizer_D_state_dict': self.optimizer_D.state_dict(),
                    'history': self.history,
                    'best_ssim': best_ssim
                }, os.path.join(config.MODELS_DIR, 'best_model.pth'))
                
                logger.info(f"New best SSIM: {best_ssim:.4f} - Model saved!")
            
            if (epoch + 1) % 25 == 0:
                torch.save({
                    'epoch': epoch,
                    'generator_state_dict': self.generator.state_dict(),
                    'discriminator_state_dict': self.discriminator.state_dict(),
                    'history': self.history
                }, os.path.join(config.MODELS_DIR, f'checkpoint_epoch_{epoch+1:04d}.pth'))
        
        logger.info(f"Training completed! Best SSIM: {best_ssim:.4f}")
        self.save_loss_plots()
    
    def save_generated_images_only(self, epoch, num_samples=16):
        """Save generated images."""
        epoch_folder = os.path.join(config.GENERATED_IMAGES_DIR, f"epoch_{epoch:04d}")
        os.makedirs(epoch_folder, exist_ok=True)
        
        self.generator.eval()
        with torch.no_grad():
            noise = torch.randn(num_samples, config.LATENT_DIM, device=config.DEVICE)
            fake_images = self.generator(noise)
            
            fake_denorm = (fake_images + 1.0) / 2.0
            
            for i in range(fake_images.size(0)):
                img_array = fake_denorm[i].cpu().squeeze().numpy()
                img_array = (img_array * 255).astype(np.uint8)
                
                # Sharpening filter
                kernel = np.array([[-1,-1,-1], [-1, 9,-1], [-1,-1,-1]])
                img_array = cv2.filter2D(img_array, -1, kernel)
                img_array = np.clip(img_array, 0, 255).astype(np.uint8)
                
                Image.fromarray(img_array, 'L').save(
                    os.path.join(epoch_folder, f'generated_{i+1:02d}.png')
                )
            
            logger.info(f"Saved {num_samples} images to {epoch_folder}")
        
        self.generator.train()

# ==================== MAIN ====================
def main():
    """Main function."""
    logger.info("Starting Enhanced Document GAN")
    logger.info(f"Resolution: {config.IMG_WIDTH}x{config.IMG_HEIGHT}")
    logger.info(f"Device: {config.DEVICE}")
    
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)
        torch.cuda.manual_seed_all(42)
    
    logger.info("Creating dataset...")
    dataset = ReadableDocumentDataset(config.REAL_IMAGES_DIR, create_samples=True)
    
    dataloader = DataLoader(
        dataset, 
        batch_size=config.BATCH_SIZE, 
        shuffle=True, 
        num_workers=DATALOADER_NUM_WORKERS,
        drop_last=True,
        pin_memory=True if config.DEVICE.type == 'cuda' else False
    )
    
    logger.info(f"Dataset: {len(dataset)} images, Batches: {len(dataloader)}")
    
    try:
        test_batch = next(iter(dataloader))
        logger.info(f"Batch shape: {test_batch.shape}")
        logger.info(f"Value range: [{test_batch.min():.3f}, {test_batch.max():.3f}]")
    except Exception as e:
        logger.error(f"Error testing dataloader: {e}")
        return
    
    try:
        trainer = EnhancedGANTrainer()
        logger.info("Trainer initialized successfully")
    except Exception as e:
        logger.error(f"Error initializing trainer: {e}")
        import traceback
        traceback.print_exc()
        return
    
    try:
        logger.info("Starting training...")
        trainer.train(dataloader)
        
        logger.info("Training completed!")
        logger.info(f"Models: {config.MODELS_DIR}")
        logger.info(f"Images: {config.GENERATED_IMAGES_DIR}")
        
    except KeyboardInterrupt:
        logger.info("Training interrupted")
        torch.save({
            'generator_state_dict': trainer.generator.state_dict(),
            'discriminator_state_dict': trainer.discriminator.state_dict(),
            'history': trainer.history
        }, os.path.join(config.MODELS_DIR, 'interrupted_checkpoint.pth'))
        logger.info("Checkpoint saved")
        
    except Exception as e:
        logger.error(f"Training failed: {e}")
        import traceback
        traceback.print_exc()
        
        try:
            torch.save({
                'generator_state_dict': trainer.generator.state_dict(),
                'discriminator_state_dict': trainer.discriminator.state_dict(),
                'history': trainer.history
            }, os.path.join(config.MODELS_DIR, 'error_checkpoint.pth'))
            logger.info("Error checkpoint saved")
        except:
            logger.error("Could not save checkpoint")

if __name__ == "__main__":
    main()

2025-10-18 14:17:17,713 - INFO - Starting Enhanced Document GAN
2025-10-18 14:17:17,714 - INFO - Resolution: 512x512
2025-10-18 14:17:17,714 - INFO - Device: cpu
2025-10-18 14:17:17,717 - INFO - Creating dataset...
2025-10-18 14:17:17,718 - INFO - Found 5 fonts
2025-10-18 14:17:17,718 - INFO - Creating readable document samples...
2025-10-18 14:17:57,475 - INFO - Completed creating readable documents
2025-10-18 14:17:57,480 - INFO - Dataset loaded with 449 images
2025-10-18 14:17:57,480 - INFO - Dataset: 449 images, Batches: 56
2025-10-18 14:17:57,551 - INFO - Batch shape: torch.Size([8, 1, 512, 512])
2025-10-18 14:17:57,551 - INFO - Value range: [-1.000, 0.984]
2025-10-18 14:17:57,869 - INFO - Generator parameters: 23,328,363
2025-10-18 14:17:57,869 - INFO - Discriminator parameters: 11,243,904
2025-10-18 14:17:57,869 - INFO - Trainer initialized on cpu
2025-10-18 14:17:57,878 - INFO - Trainer initialized successfully
2025-10-18 14:17:57,878 - INFO - Starting training...
2025-10-18 14

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import os
from pathlib import Path
import seaborn as sns

# Set style for better-looking plots
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (15, 10)
plt.rcParams['font.size'] = 10

class GANMetricsPlotter:
    """Plot comprehensive metrics for GAN training analysis."""
    
    def __init__(self, models_dir="Enhanced_GAN_Models", output_dir="Enhanced_GAN_Output"):
        self.models_dir = models_dir
        self.output_dir = output_dir
        self.history = None
        
    def load_checkpoint(self, checkpoint_name='best_model.pth'):
        """Load training history from checkpoint."""
        checkpoint_path = os.path.join(self.models_dir, checkpoint_name)
        
        if not os.path.exists(checkpoint_path):
            print(f"Checkpoint not found: {checkpoint_path}")
            print("Available checkpoints:")
            if os.path.exists(self.models_dir):
                for file in os.listdir(self.models_dir):
                    if file.endswith('.pth'):
                        print(f"  - {file}")
            return False
        
        try:
            checkpoint = torch.load(checkpoint_path, map_location='cpu')
            self.history = checkpoint.get('history', {})
            
            print(f"Loaded checkpoint: {checkpoint_name}")
            print(f"Epoch: {checkpoint.get('epoch', 'N/A')}")
            print(f"Best SSIM: {checkpoint.get('best_ssim', 'N/A')}")
            print(f"Available metrics: {list(self.history.keys())}")
            
            return True
        except Exception as e:
            print(f"Error loading checkpoint: {e}")
            return False
    
    def plot_all_metrics(self, save_name='comprehensive_metrics.png'):
        """Create comprehensive visualization of all training metrics."""
        if not self.history:
            print("No history data loaded. Please load a checkpoint first.")
            return
        
        # Create figure with subplots
        fig = plt.figure(figsize=(20, 12))
        gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)
        
        epochs = range(1, len(self.history.get('g_loss', [])) + 1)
        
        # 1. Generator and Discriminator Loss
        ax1 = fig.add_subplot(gs[0, 0])
        if 'g_loss' in self.history and 'd_loss' in self.history:
            ax1.plot(epochs, self.history['g_loss'], 'b-', label='Generator Loss', linewidth=2, alpha=0.8)
            ax1.plot(epochs, self.history['d_loss'], 'r-', label='Discriminator Loss', linewidth=2, alpha=0.8)
            ax1.set_xlabel('Epoch')
            ax1.set_ylabel('Loss')
            ax1.set_title('Training Losses', fontweight='bold')
            ax1.legend()
            ax1.grid(True, alpha=0.3)
        
        # 2. SSIM Score with Target Line
        ax2 = fig.add_subplot(gs[0, 1])
        if 'ssim_scores' in self.history and len(self.history['ssim_scores']) > 0:
            ax2.plot(epochs, self.history['ssim_scores'], 'g-', label='SSIM Score', linewidth=2)
            ax2.axhline(y=0.85, color='g', linestyle='--', alpha=0.7, label='Target SSIM (0.85)')
            ax2.fill_between(epochs, self.history['ssim_scores'], alpha=0.3, color='g')
            ax2.set_xlabel('Epoch')
            ax2.set_ylabel('SSIM Score')
            ax2.set_title('Structural Similarity Index (SSIM)', fontweight='bold')
            ax2.legend()
            ax2.grid(True, alpha=0.3)
            ax2.set_ylim([0, 1])
        
        # 3. PSNR Score with Target Line
        ax3 = fig.add_subplot(gs[0, 2])
        if 'psnr_scores' in self.history and len(self.history['psnr_scores']) > 0:
            ax3.plot(epochs, self.history['psnr_scores'], 'm-', label='PSNR Score', linewidth=2)
            ax3.axhline(y=30.0, color='m', linestyle='--', alpha=0.7, label='Target PSNR (30 dB)')
            ax3.fill_between(epochs, self.history['psnr_scores'], alpha=0.3, color='m')
            ax3.set_xlabel('Epoch')
            ax3.set_ylabel('PSNR (dB)')
            ax3.set_title('Peak Signal-to-Noise Ratio (PSNR)', fontweight='bold')
            ax3.legend()
            ax3.grid(True, alpha=0.3)
        
        # 4. Generator Loss Components
        ax4 = fig.add_subplot(gs[1, 0])
        if 'g_adv_loss' in self.history and 'g_content_loss' in self.history:
            ax4.plot(epochs, self.history['g_adv_loss'], 'c-', label='Adversarial Loss', linewidth=2, alpha=0.8)
            ax4.plot(epochs, self.history['g_content_loss'], 'y-', label='Content Loss', linewidth=2, alpha=0.8)
            ax4.set_xlabel('Epoch')
            ax4.set_ylabel('Loss Value')
            ax4.set_title('Generator Loss Components', fontweight='bold')
            ax4.legend()
            ax4.grid(True, alpha=0.3)
            ax4.set_yscale('log')  # Log scale for better visualization
        
        # 5. Loss Ratio (G/D Balance)
        ax5 = fig.add_subplot(gs[1, 1])
        if 'g_loss' in self.history and 'd_loss' in self.history:
            loss_ratio = np.array(self.history['g_loss']) / (np.array(self.history['d_loss']) + 1e-8)
            ax5.plot(epochs, loss_ratio, 'purple', linewidth=2, alpha=0.8)
            ax5.axhline(y=1.0, color='gray', linestyle='--', alpha=0.7, label='Perfect Balance')
            ax5.set_xlabel('Epoch')
            ax5.set_ylabel('G_Loss / D_Loss Ratio')
            ax5.set_title('Training Balance (Generator/Discriminator)', fontweight='bold')
            ax5.legend()
            ax5.grid(True, alpha=0.3)
        
        # 6. Quality Improvement Over Time
        ax6 = fig.add_subplot(gs[1, 2])
        if 'ssim_scores' in self.history and 'psnr_scores' in self.history:
            # Normalize both metrics to 0-1 range for comparison
            ssim_norm = np.array(self.history['ssim_scores'])
            psnr_norm = np.array(self.history['psnr_scores']) / 50.0  # Normalize PSNR
            
            ax6.plot(epochs, ssim_norm, 'g-', label='SSIM (normalized)', linewidth=2, alpha=0.8)
            ax6.plot(epochs, psnr_norm, 'm-', label='PSNR (normalized)', linewidth=2, alpha=0.8)
            ax6.set_xlabel('Epoch')
            ax6.set_ylabel('Normalized Score')
            ax6.set_title('Quality Metrics Comparison', fontweight='bold')
            ax6.legend()
            ax6.grid(True, alpha=0.3)
        
        # 7. Loss Convergence Analysis
        ax7 = fig.add_subplot(gs[2, 0])
        if 'g_loss' in self.history:
            # Moving average for trend analysis
            window = min(10, len(self.history['g_loss']) // 10)
            if window > 1:
                g_loss_ma = np.convolve(self.history['g_loss'], 
                                       np.ones(window)/window, mode='valid')
                d_loss_ma = np.convolve(self.history['d_loss'], 
                                       np.ones(window)/window, mode='valid')
                
                ax7.plot(range(1, len(g_loss_ma) + 1), g_loss_ma, 
                        'b-', label=f'Generator (MA-{window})', linewidth=2)
                ax7.plot(range(1, len(d_loss_ma) + 1), d_loss_ma, 
                        'r-', label=f'Discriminator (MA-{window})', linewidth=2)
                ax7.set_xlabel('Epoch')
                ax7.set_ylabel('Loss (Moving Average)')
                ax7.set_title('Loss Convergence Trend', fontweight='bold')
                ax7.legend()
                ax7.grid(True, alpha=0.3)
        
        # 8. Quality Score Distribution
        ax8 = fig.add_subplot(gs[2, 1])
        if 'ssim_scores' in self.history and len(self.history['ssim_scores']) > 0:
            ax8.hist(self.history['ssim_scores'], bins=30, alpha=0.7, color='g', edgecolor='black')
            ax8.axvline(np.mean(self.history['ssim_scores']), color='darkgreen', 
                       linestyle='--', linewidth=2, label=f'Mean: {np.mean(self.history["ssim_scores"]):.3f}')
            ax8.set_xlabel('SSIM Score')
            ax8.set_ylabel('Frequency')
            ax8.set_title('SSIM Score Distribution', fontweight='bold')
            ax8.legend()
            ax8.grid(True, alpha=0.3)
        
        # 9. Training Summary Statistics
        ax9 = fig.add_subplot(gs[2, 2])
        ax9.axis('off')
        
        # Calculate summary statistics
        summary_text = "Training Summary Statistics\n" + "="*40 + "\n\n"
        
        if 'g_loss' in self.history:
            summary_text += f"Generator Loss:\n"
            summary_text += f"  Final: {self.history['g_loss'][-1]:.4f}\n"
            summary_text += f"  Mean: {np.mean(self.history['g_loss']):.4f}\n"
            summary_text += f"  Min: {np.min(self.history['g_loss']):.4f}\n\n"
        
        if 'd_loss' in self.history:
            summary_text += f"Discriminator Loss:\n"
            summary_text += f"  Final: {self.history['d_loss'][-1]:.4f}\n"
            summary_text += f"  Mean: {np.mean(self.history['d_loss']):.4f}\n"
            summary_text += f"  Min: {np.min(self.history['d_loss']):.4f}\n\n"
        
        if 'ssim_scores' in self.history and len(self.history['ssim_scores']) > 0:
            summary_text += f"SSIM Score:\n"
            summary_text += f"  Final: {self.history['ssim_scores'][-1]:.4f}\n"
            summary_text += f"  Mean: {np.mean(self.history['ssim_scores']):.4f}\n"
            summary_text += f"  Best: {np.max(self.history['ssim_scores']):.4f}\n\n"
        
        if 'psnr_scores' in self.history and len(self.history['psnr_scores']) > 0:
            summary_text += f"PSNR Score:\n"
            summary_text += f"  Final: {self.history['psnr_scores'][-1]:.2f} dB\n"
            summary_text += f"  Mean: {np.mean(self.history['psnr_scores']):.2f} dB\n"
            summary_text += f"  Best: {np.max(self.history['psnr_scores']):.2f} dB\n"
        
        ax9.text(0.1, 0.9, summary_text, transform=ax9.transAxes, 
                fontfamily='monospace', fontsize=10, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
        
        # Overall title
        fig.suptitle('Enhanced Document GAN - Training Metrics Analysis', 
                    fontsize=16, fontweight='bold')
        
        # Save figure
        save_path = os.path.join(self.output_dir, save_name)
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Comprehensive metrics plot saved to: {save_path}")
        plt.close()
    
    def plot_accuracy_analysis(self, save_name='accuracy_analysis.png'):
        """Plot detailed accuracy/quality analysis."""
        if not self.history:
            print("No history data loaded.")
            return
        
        fig, axes = plt.subplots(2, 2, figsize=(16, 10))
        epochs = range(1, len(self.history.get('ssim_scores', [])) + 1)
        
        # 1. Quality Score Evolution
        ax1 = axes[0, 0]
        if 'ssim_scores' in self.history:
            ax1.plot(epochs, self.history['ssim_scores'], 'g-', linewidth=2, label='SSIM')
            # Add trend line
            z = np.polyfit(range(len(self.history['ssim_scores'])), 
                          self.history['ssim_scores'], 1)
            p = np.poly1d(z)
            ax1.plot(epochs, p(range(len(self.history['ssim_scores']))), 
                    "g--", alpha=0.5, label='Trend')
            ax1.set_xlabel('Epoch')
            ax1.set_ylabel('SSIM Score')
            ax1.set_title('SSIM Accuracy Evolution', fontweight='bold')
            ax1.legend()
            ax1.grid(True, alpha=0.3)
        
        # 2. PSNR Evolution
        ax2 = axes[0, 1]
        if 'psnr_scores' in self.history:
            ax2.plot(epochs, self.history['psnr_scores'], 'm-', linewidth=2, label='PSNR')
            z = np.polyfit(range(len(self.history['psnr_scores'])), 
                          self.history['psnr_scores'], 1)
            p = np.poly1d(z)
            ax2.plot(epochs, p(range(len(self.history['psnr_scores']))), 
                    "m--", alpha=0.5, label='Trend')
            ax2.set_xlabel('Epoch')
            ax2.set_ylabel('PSNR (dB)')
            ax2.set_title('PSNR Accuracy Evolution', fontweight='bold')
            ax2.legend()
            ax2.grid(True, alpha=0.3)
        
        # 3. Combined Quality Score
        ax3 = axes[1, 0]
        if 'ssim_scores' in self.history and 'psnr_scores' in self.history:
            # Normalize and combine
            ssim_norm = (np.array(self.history['ssim_scores']) - 0.5) / 0.5
            psnr_norm = (np.array(self.history['psnr_scores']) - 20) / 30
            combined = (ssim_norm + psnr_norm) / 2
            
            ax3.plot(epochs, combined, 'b-', linewidth=2, label='Combined Quality Score')
            ax3.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5, label='Good Quality Threshold')
            ax3.fill_between(epochs, combined, alpha=0.3)
            ax3.set_xlabel('Epoch')
            ax3.set_ylabel('Normalized Combined Score')
            ax3.set_title('Overall Quality Score (SSIM + PSNR)', fontweight='bold')
            ax3.legend()
            ax3.grid(True, alpha=0.3)
        
        # 4. Improvement Rate
        ax4 = axes[1, 1]
        if 'ssim_scores' in self.history and len(self.history['ssim_scores']) > 1:
            improvement = np.diff(self.history['ssim_scores'])
            ax4.plot(range(2, len(epochs) + 1), improvement, 'orange', linewidth=2)
            ax4.axhline(y=0, color='red', linestyle='-', alpha=0.5)
            ax4.set_xlabel('Epoch')
            ax4.set_ylabel('SSIM Improvement')
            ax4.set_title('Quality Improvement Rate per Epoch', fontweight='bold')
            ax4.grid(True, alpha=0.3)
        
        plt.tight_layout()
        save_path = os.path.join(self.output_dir, save_name)
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Accuracy analysis plot saved to: {save_path}")
        plt.close()
    
    def print_detailed_statistics(self):
        """Print detailed training statistics."""
        if not self.history:
            print("No history data loaded.")
            return
        
        print("\n" + "="*60)
        print("DETAILED TRAINING STATISTICS")
        print("="*60 + "\n")
        
        metrics = {
            'Generator Loss': 'g_loss',
            'Discriminator Loss': 'd_loss',
            'SSIM Score': 'ssim_scores',
            'PSNR Score': 'psnr_scores',
            'Adversarial Loss': 'g_adv_loss',
            'Content Loss': 'g_content_loss'
        }
        
        for name, key in metrics.items():
            if key in self.history and len(self.history[key]) > 0:
                values = np.array(self.history[key])
                print(f"{name}:")
                print(f"  Initial:  {values[0]:.4f}")
                print(f"  Final:    {values[-1]:.4f}")
                print(f"  Best:     {np.max(values) if 'score' in key.lower() else np.min(values):.4f}")
                print(f"  Mean:     {np.mean(values):.4f}")
                print(f"  Std Dev:  {np.std(values):.4f}")
                
                # Calculate improvement
                if len(values) > 1:
                    if 'score' in key.lower():
                        improvement = ((values[-1] - values[0]) / values[0]) * 100
                    else:
                        improvement = ((values[0] - values[-1]) / values[0]) * 100
                    print(f"  Improvement: {improvement:+.2f}%")
                print()


def main():
    """Main function to generate all plots."""
    print("GAN Training Metrics Plotter")
    print("="*60)
    
    # Initialize plotter
    plotter = GANMetricsPlotter()
    
    # Try to load the best model
    if plotter.load_checkpoint('best_model.pth'):
        print("\nGenerating plots...")
        
        # Generate comprehensive metrics plot
        plotter.plot_all_metrics('comprehensive_metrics.png')
        
        # Generate accuracy analysis
        plotter.plot_accuracy_analysis('accuracy_analysis.png')
        
        # Print detailed statistics
        plotter.print_detailed_statistics()
        
        print("\n" + "="*60)
        print("All plots generated successfully!")
        print("="*60)
    else:
        print("\nFailed to load checkpoint. Please ensure training has been completed.")


if __name__ == "__main__":
    main()