In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import os
import numpy as np

MODEL_PATH = "G:/Sparse2/Results/Chest/10/checkpoints/best_model.pth"  # Path to your best_model.pth
INPUT_IMAGE_PATH = "G:/Sparse2/Results/Chest/10/checkpoints/1/LIDC-IDRI-0004_1-001.png"  # Input image to process
OUTPUT_PATH = "G:/Sparse2/Results/Chest/10/checkpoints/1/LIDC-IDRI-0004_1-001.png"  # Where to save result

# Model configuration (should match training config)
IMAGE_SIZE = 256
DEVICE = 'auto'  # 'auto', 'cuda', or 'cpu'

class EnhancedGenerator(nn.Module):
    """Enhanced U-Net Generator - Same as training version"""
    
    def __init__(self, input_nc=3, output_nc=3, ngf=64):
        super().__init__()
        
        # Encoder layers
        self.e1 = self._make_layer(input_nc, ngf, normalize=False)
        self.e2 = self._make_layer(ngf, ngf * 2)
        self.e3 = self._make_layer(ngf * 2, ngf * 4)
        self.e4 = self._make_layer(ngf * 4, ngf * 8)
        self.e5 = self._make_layer(ngf * 8, ngf * 8)
        self.e6 = self._make_layer(ngf * 8, ngf * 8)
        self.e7 = self._make_layer(ngf * 8, ngf * 8)
        self.e8 = self._make_layer(ngf * 8, ngf * 8, normalize=False)
        
        # Decoder layers
        self.d1 = self._make_up_layer(ngf * 8, ngf * 8, dropout=True)
        self.d2 = self._make_up_layer(ngf * 16, ngf * 8, dropout=True)
        self.d3 = self._make_up_layer(ngf * 16, ngf * 8, dropout=True)
        self.d4 = self._make_up_layer(ngf * 16, ngf * 8)
        self.d5 = self._make_up_layer(ngf * 16, ngf * 4)
        self.d6 = self._make_up_layer(ngf * 8, ngf * 2)
        self.d7 = self._make_up_layer(ngf * 4, ngf)
        
        self.final = nn.Sequential(
            nn.ConvTranspose2d(ngf * 2, output_nc, 4, 2, 1),
            nn.Tanh()
        )
    
    def _make_layer(self, in_channels, out_channels, normalize=True):
        layers = [nn.Conv2d(in_channels, out_channels, 4, 2, 1)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2, True))
        return nn.Sequential(*layers)
    
    def _make_up_layer(self, in_channels, out_channels, dropout=False):
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True)
        ]
        if dropout:
            layers.append(nn.Dropout2d(0.5))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        # Encoder
        e1 = self.e1(x)
        e2 = self.e2(e1)
        e3 = self.e3(e2)
        e4 = self.e4(e3)
        e5 = self.e5(e4)
        e6 = self.e6(e5)
        e7 = self.e7(e6)
        e8 = self.e8(e7)
        
        # Decoder with skip connections
        d1 = self.d1(e8)
        d2 = self.d2(torch.cat([d1, e7], 1))
        d3 = self.d3(torch.cat([d2, e6], 1))
        d4 = self.d4(torch.cat([d3, e5], 1))
        d5 = self.d5(torch.cat([d4, e4], 1))
        d6 = self.d6(torch.cat([d5, e3], 1))
        d7 = self.d7(torch.cat([d6, e2], 1))
        
        output = self.final(torch.cat([d7, e1], 1))
        return output

class Pix2PixGenerator: # IMAGE GENERATOR CLASS
    """Simple interface for generating images with trained Pix2Pix model"""
    
    def __init__(self, model_path, device='auto'):
        self.device = self._setup_device(device)
        self.model = self._load_model(model_path)
        self.transform = self._setup_transforms()
        
        print(f" Pix2Pix Generator ready!")
        print(f" Device: {self.device}")
        print(f" Model: {model_path}")
    
    def _setup_device(self, device):
        """Setup computation device"""
        if device == 'auto':
            if torch.cuda.is_available():
                device = 'cuda'
                print(f" Using GPU: {torch.cuda.get_device_name(0)}")
            else:
                device = 'cpu'
                print(" Using CPU")
        
        try:
            # Test device
            test_tensor = torch.randn(1, 3, 64, 64).to(device)
            return device
        except Exception as e:
            print(f" Device '{device}' failed, using CPU: {e}")
            return 'cpu'
    
    def _load_model(self, model_path):
        """Load the trained generator model"""
        try:
            # Load checkpoint
            checkpoint = torch.load(model_path, map_location=self.device)
            
            # Create generator
            generator = EnhancedGenerator(input_nc=3, output_nc=3, ngf=64)
            generator.load_state_dict(checkpoint['generator_state_dict'])
            generator.to(self.device)
            generator.eval()
            
            print(f" Model loaded successfully!")
            if 'best_val_psnr' in checkpoint:
                print(f" Best validation PSNR: {checkpoint['best_val_psnr']:.2f} dB")
            
            return generator
            
        except Exception as e:
            print(f" Error loading model: {e}")
            raise
    
    def _setup_transforms(self):
        """Setup image preprocessing transforms"""
        return transforms.Compose([
            transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), Image.LANCZOS),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * 2.0 - 1.0)  # Normalize to [-1, 1]
        ])
    
    def generate_image(self, input_path, output_path, save_comparison=True):
        """
        Generate image from input
        
        Args:
            input_path: Path to input image
            output_path: Path to save generated image
            save_comparison: Whether to save side-by-side comparison
        """
        try:
            print(f"\n🖼️ Processing: {input_path}")
            
            # Load and preprocess input image
            input_img = Image.open(input_path).convert('RGB')
            original_size = input_img.size
            
            # Transform for model
            input_tensor = self.transform(input_img).unsqueeze(0).to(self.device)
            
            # Generate image
            with torch.no_grad():
                generated_tensor = self.model(input_tensor)
            
            # Convert back to PIL image
            generated_img = self._tensor_to_pil(generated_tensor.squeeze(0))
            
            # Resize back to original size if needed
            if original_size != (IMAGE_SIZE, IMAGE_SIZE):
                generated_img = generated_img.resize(original_size, Image.LANCZOS)
                input_img = input_img.resize(original_size, Image.LANCZOS)
            
            # Save generated image
            generated_img.save(output_path)
            print(f"✅ Generated image saved: {output_path}")
            
            # Save comparison if requested
            if save_comparison:
                comparison_path = output_path.replace('.', '_comparison.')
                self._save_comparison(input_img, generated_img, comparison_path)
                print(f"📊 Comparison saved: {comparison_path}")
            
            return generated_img
            
        except Exception as e:
            print(f"❌ Error generating image: {e}")
            raise
    
    def _tensor_to_pil(self, tensor):
        """Convert tensor to PIL image"""
        # Denormalize from [-1, 1] to [0, 1]
        tensor = (tensor + 1) / 2
        tensor = torch.clamp(tensor, 0, 1)
        
        # Convert to PIL
        return transforms.ToPILImage()(tensor)
    
    def _save_comparison(self, input_img, generated_img, output_path):
        """Save side-by-side comparison"""
        # Create comparison image
        width, height = input_img.size
        comparison = Image.new('RGB', (width * 2, height), (255, 255, 255))
        
        # Paste images
        comparison.paste(input_img, (0, 0))
        comparison.paste(generated_img, (width, 0))
        
        # Save
        comparison.save(output_path)
    
    def generate_batch(self, input_folder, output_folder, file_extensions=None):
        """
        Generate images for all files in a folder
        
        Args:
            input_folder: Folder containing input images
            output_folder: Folder to save generated images
            file_extensions: List of extensions to process (default: common image types)
        """
        if file_extensions is None:
            file_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
        
        # Create output folder
        os.makedirs(output_folder, exist_ok=True)
        
        # Find all image files
        image_files = []
        for ext in file_extensions:
            for file_ext in [ext.lower(), ext.upper()]:
                pattern = os.path.join(input_folder, f"*{file_ext}")
                import glob
                image_files.extend(glob.glob(pattern))
        
        if not image_files:
            print(f"⚠️ No image files found in {input_folder}")
            return
        
        print(f"\n📁 Processing {len(image_files)} images from {input_folder}")
        
        # Process each image
        for i, input_path in enumerate(image_files, 1):
            try:
                filename = os.path.basename(input_path)
                name, ext = os.path.splitext(filename)
                output_path = os.path.join(output_folder, f"{name}_generated{ext}")
                
                print(f"\n[{i}/{len(image_files)}] Processing: {filename}")
                self.generate_image(input_path, output_path)
                
            except Exception as e:
                print(f"❌ Failed to process {filename}: {e}")
                continue
        
        print(f"\n🎉 Batch processing completed!")
        print(f"📁 Results saved to: {output_folder}")

def generate_single_image():   # MAIN EXECUTION FUNCTIONS
    """Generate a single image"""
    print("="*60)
    print("PIX2PIX IMAGE GENERATOR - SINGLE IMAGE")
    print("="*60)
    
    # Validate paths
    if not os.path.exists(MODEL_PATH):
        print(f" Model not found: {MODEL_PATH}")
        print(" Please update MODEL_PATH in the script")
        return
    
    if not os.path.exists(INPUT_IMAGE_PATH):
        print(f" Input image not found: {INPUT_IMAGE_PATH}")
        print(" Please update INPUT_IMAGE_PATH in the script")
        return
    
    # Create output directory
    os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)
    
    try:
        # Initialize generator
        generator = Pix2PixGenerator(MODEL_PATH, DEVICE)
        
        # Generate image
        generated_img = generator.generate_image(INPUT_IMAGE_PATH, OUTPUT_PATH)
        
        print(f"\n SUCCESS!")
        print(f" Input: {INPUT_IMAGE_PATH}")
        print(f" Output: {OUTPUT_PATH}")
        print(f" Comparison: {OUTPUT_PATH.replace('.', '_comparison.')}")
        
    except Exception as e:
        print(f" Generation failed: {e}")

def generate_batch_images():
    """Generate images for a folder"""
    print( "="*60)
    print(" PIX2PIX IMAGE GENERATOR - BATCH PROCESSING")
    print( "="*60)
    
    # Get input folder from input image path
    input_folder = os.path.dirname(INPUT_IMAGE_PATH)
    output_folder = os.path.join(os.path.dirname(OUTPUT_PATH), "batch_generated")
    
    print(f" Input folder: {input_folder}")
    print(f" Output folder: {output_folder}")
    
    # Validate paths
    if not os.path.exists(MODEL_PATH):
        print(f" Model not found: {MODEL_PATH}")
        return
    
    if not os.path.exists(input_folder):
        print(f" Input folder not found: {input_folder}")
        return
    
    try:
        # Initialize generator
        generator = Pix2PixGenerator(MODEL_PATH, DEVICE)
        
        # Generate batch
        generator.generate_batch(input_folder, output_folder)
        
    except Exception as e:
        print(f" Batch generation failed: {e}")

def interactive_mode():
    """Interactive mode for generating images"""
    print( "="*60)
    print(" PIX2PIX GENERATOR - INTERACTIVE MODE")
    print( + "="*60)
    
    # Validate model
    if not os.path.exists(MODEL_PATH):
        print(f" Model not found: {MODEL_PATH}")
        print(" Please update MODEL_PATH in the script")
        return
    
    try:
        # Initialize generator once
        generator = Pix2PixGenerator(MODEL_PATH, DEVICE)
        
        while True:
            print(f"\n" + "="*50)
            print(" Choose an option:")
            print("1. Generate single image")
            print("2. Process folder")
            print("3. Exit")
            
            choice = input("\nEnter your choice (1-3): ").strip()
            
            if choice == '1':
                input_path = input("📸 Enter input image path: ").strip()
                if not os.path.exists(input_path):
                    print(f" File not found: {input_path}")
                    continue
                
                output_path = input("💾 Enter output path: ").strip()
                os.makedirs(os.path.dirname(output_path), exist_ok=True)
                
                try:
                    generator.generate_image(input_path, output_path)
                    print(f" Generated: {output_path}")
                except Exception as e:
                    print(f" Error: {e}")
            
            elif choice == '2':
                input_folder = input("📁 Enter input folder path: ").strip()
                if not os.path.exists(input_folder):
                    print(f" Folder not found: {input_folder}")
                    continue
                
                output_folder = input(" Enter output folder path: ").strip()
                
                try:
                    generator.generate_batch(input_folder, output_folder)
                except Exception as e:
                    print(f" Error: {e}")
            
            elif choice == '3':
                print(" Goodbye!")
                break
            
            else:
                print(" Invalid choice. Please enter 1, 2, or 3.")
    
    except Exception as e:
        print(f" Failed to initialize generator: {e}")

def main():
    """Main execution function"""
    print("🚀 " + "="*70)
    print("🚀 PIX2PIX IMAGE GENERATOR")
    print("🚀 " + "="*70)
    
    print(f"""
 CURRENT CONFIGURATION:
    Model: {MODEL_PATH}
    Input: {INPUT_IMAGE_PATH}
    Output: {OUTPUT_PATH}
    Device: {DEVICE}
    Image Size: {IMAGE_SIZE}x{IMAGE_SIZE}

 AVAILABLE MODES:
   1. Single Image Generation
   2. Batch Processing
   3. Interactive Mode
""")
    
    while True:
        print("\n" + "="*50)
        print(" Choose a mode:")
        print("1. Generate single image (using config above)")
        print("2. Batch process folder")
        print("3. Interactive mode")
        print("4. Exit")
        
        choice = input("\nEnter your choice (1-4): ").strip()
        
        if choice == '1':
            generate_single_image()
        elif choice == '2':
            generate_batch_images()
        elif choice == '3':
            interactive_mode()
        elif choice == '4':
            print(" Goodbye!")
            break
        else:
            print(" Invalid choice. Please enter 1, 2, 3, or 4.")

if __name__ == "__main__":
    main()

🚀 PIX2PIX IMAGE GENERATOR

📋 CURRENT CONFIGURATION:
   🎯 Model: G:/Sparse2/Results/Chest/10/checkpoints/best_model.pth
   📸 Input: G:/Sparse2/Results/Chest/10/checkpoints/1/LIDC-IDRI-0004_1-001.png
   💾 Output: G:/Sparse2/Results/Chest/10/checkpoints/1/LIDC-IDRI-0004_1-001.png
   🖥️ Device: auto
   📏 Image Size: 256x256

🎮 AVAILABLE MODES:
   1. Single Image Generation
   2. Batch Processing
   3. Interactive Mode


🎯 Choose a mode:
1. Generate single image (using config above)
2. Batch process folder
3. Interactive mode
4. Exit



Enter your choice (1-4):  1


🎨 PIX2PIX IMAGE GENERATOR - SINGLE IMAGE
💻 Using CPU


  checkpoint = torch.load(model_path, map_location=self.device)


📦 Model loaded successfully!
🏆 Best validation PSNR: 22.54 dB
✅ Pix2Pix Generator ready!
🎮 Device: cpu
📁 Model: G:/Sparse2/Results/Chest/10/checkpoints/best_model.pth

🖼️ Processing: G:/Sparse2/Results/Chest/10/checkpoints/1/LIDC-IDRI-0004_1-001.png
✅ Generated image saved: G:/Sparse2/Results/Chest/10/checkpoints/1/LIDC-IDRI-0004_1-001.png
📊 Comparison saved: G:/Sparse2/Results/Chest/10/checkpoints/1/LIDC-IDRI-0004_1-001_comparison.png

🎉 SUCCESS!
📸 Input: G:/Sparse2/Results/Chest/10/checkpoints/1/LIDC-IDRI-0004_1-001.png
🎨 Output: G:/Sparse2/Results/Chest/10/checkpoints/1/LIDC-IDRI-0004_1-001.png
📊 Comparison: G:/Sparse2/Results/Chest/10/checkpoints/1/LIDC-IDRI-0004_1-001_comparison.png

🎯 Choose a mode:
1. Generate single image (using config above)
2. Batch process folder
3. Interactive mode
4. Exit
