In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/stylized-images/stylized_images/gibli_stylized_images1/prithvi.png
/kaggle/input/stylized-images/stylized_images/gibli_stylized_images1/gourav.png
/kaggle/input/stylized-images/stylized_images/gibli_stylized_images1/houshik.png
/kaggle/input/stylized-images/stylized_images/gibli_stylized_images1/anukriti.png
/kaggle/input/stylized-images/stylized_images/gibli_stylized_images1/kenny.png
/kaggle/input/stylized-images/stylized_images/gibli_stylized_images1/aditya.png
/kaggle/input/stylized-images/stylized_images/gibli_stylized_images1/ishika.png
/kaggle/input/stylized-images/stylized_images/gibli_stylized_images1/akshat.png
/kaggle/input/stylized-images/stylized_images/gibli_stylized_images1/divyanshu.png
/kaggle/input/stylized-images/stylized_images/gibli_stylized_images1/kunal.png
/kaggle/input/stylized-images/stylized_images/gibli_stylized_images1/naveen.png
/kaggle/input/stylized-images/stylized_images/original_images1/aditya.jpeg
/kaggle/input/stylized-images/stylized_i

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, AutoencoderKL, DDPMScheduler
from diffusers import StableDiffusionImg2ImgPipeline
from transformers import CLIPTextModel, CLIPTokenizer
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from pathlib import Path
from PIL import Image
import random
import os
import gc
import logging
from tqdm import tqdm
from dataclasses import dataclass
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import torchvision.models as models
from typing import List, Dict, Union, Tuple


# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

@dataclass
class TrainingConfig:
    """Enhanced configuration with validation"""
    model_id: str = "runwayml/stable-diffusion-v1-5"
    image_size: int = 512
    batch_size: int = 2
    effective_batch_size: int = 8
    learning_rate: float = 1e-5
    num_epochs: int = 100  # Increased for better training
    save_every: int = 10
    mixed_precision: bool = True
    gradient_clip: float = 1.0
    lora_rank: int = 128  # Increased for more capacity
    lora_alpha: float = 128
    validation_split: float = 0.2
    accumulation_steps: int = 4
    use_enhanced_loss: bool = True
    perceptual_loss_weight: float = 0.1

2025-06-30 14:23:16.693447: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751293397.142831      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751293397.254861      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:

class LoRALinear(nn.Module):
    """Fixed LoRA implementation with proper dtype handling"""
    
    def __init__(self, original_layer: nn.Linear, rank: int = 64, alpha: float = 64):
        super().__init__()
        self.original_layer = original_layer
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank
        
        # Freeze original layer
        for param in self.original_layer.parameters():
            param.requires_grad = False
        
        # Initialize with same dtype and device as original layer
        device = original_layer.weight.device
        dtype = original_layer.weight.dtype
        
        self.lora_A = nn.Parameter(
            torch.randn(rank, original_layer.in_features, device=device, dtype=dtype) * 0.01
        )
        self.lora_B = nn.Parameter(
            torch.zeros(original_layer.out_features, rank, device=device, dtype=dtype)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        original_out = self.original_layer(x)
        
        # Ensure dtype consistency
        if x.dtype != self.lora_A.dtype:
            x_lora = x.to(self.lora_A.dtype)
        else:
            x_lora = x
        
        lora_out = F.linear(F.linear(x_lora, self.lora_A), self.lora_B) * self.scaling
        
        # Ensure output dtype matches original
        if lora_out.dtype != original_out.dtype:
            lora_out = lora_out.to(original_out.dtype)
        
        return original_out + lora_out
    
    def get_lora_state_dict(self) -> Dict[str, torch.Tensor]:
        return {
            'lora_A': self.lora_A,
            'lora_B': self.lora_B,
            'scaling': torch.tensor(self.scaling, dtype=self.lora_A.dtype)
        }


In [4]:
class EnhancedStyleLoss(nn.Module):
    """Enhanced loss function combining diffusion loss with perceptual loss"""
    
    def __init__(self, device="cuda", perceptual_weight=0.1):
        super().__init__()
        # Load VGG for perceptual loss
        vgg = models.vgg19(weights='VGG19_Weights.DEFAULT').features.to(device).eval()
        for param in vgg.parameters():
            param.requires_grad = False
        
        self.vgg = vgg
        self.vgg_layers = [2, 7, 12, 21]  # relu1_2, relu2_2, relu3_3, relu4_2
        self.mse_loss = nn.MSELoss()
        self.perceptual_weight = perceptual_weight
        
    def get_vgg_features(self, x):
        """Extract VGG features from multiple layers"""
        features = []
        for i, layer in enumerate(self.vgg):
            x = layer(x)
            if i in self.vgg_layers:
                features.append(x)
        return features
    
    def compute_perceptual_loss(self, pred_image, target_image):
        """Compute perceptual loss using VGG features"""
        # Normalize images to VGG expected range [0,1]
        pred_norm = (pred_image + 1) / 2
        target_norm = (target_image + 1) / 2
        
        pred_features = self.get_vgg_features(pred_norm)
        target_features = self.get_vgg_features(target_norm)
        
        perceptual_loss = 0
        for pred_feat, target_feat in zip(pred_features, target_features):
            perceptual_loss += self.mse_loss(pred_feat, target_feat)
        
        return perceptual_loss
    
    def forward(self, noise_pred, noise_target, pred_image=None, target_image=None):
        """Combined loss function"""
        # Standard diffusion loss
        diffusion_loss = self.mse_loss(noise_pred, noise_target)
        
        # Add perceptual loss if images are provided
        if pred_image is not None and target_image is not None:
            perceptual_loss = self.compute_perceptual_loss(pred_image, target_image)
            total_loss = diffusion_loss + self.perceptual_weight * perceptual_loss
            return total_loss, diffusion_loss, perceptual_loss
        else:
            return diffusion_loss, diffusion_loss, 0.0

In [5]:
class StyleTransferDataset(Dataset):
    """Robust dataset implementation with validation and error handling"""
    
    def __init__(
        self,
        original_dir: str,
        styled_dir: str,
        style_name: str,
        image_size: int = 512,
        augment: bool = True,
        validate_pairs: bool = True,
    ):
        self.original_dir = Path(original_dir)
        self.styled_dir = Path(styled_dir)
        self.image_size = image_size
        self.style_name = style_name
        
        # Validate directories
        self._validate_directories()
        
        # Load and validate image pairs
        self.image_pairs = self._load_and_validate_pairs(validate_pairs)
        
        # Setup transforms
        self.transform = self._setup_transforms(augment)
        
        logger.info(f"Loaded {len(self.image_pairs)} valid image pairs for style: {self.style_name}")
    
    def _validate_directories(self):
        """Validate input directories exist and contain images"""
        if not self.original_dir.exists():
            raise FileNotFoundError(f"Original directory not found: {self.original_dir}")
        if not self.styled_dir.exists():
            raise FileNotFoundError(f"Styled directory not found: {self.styled_dir}")
        
        # Check for image files
        orig_files = list(self.original_dir.glob('*.[jp][pn]g')) + list(self.original_dir.glob('*.[JP][PN]G'))
        style_files = list(self.styled_dir.glob('*.[jp][pn]g')) + list(self.styled_dir.glob('*.[JP][PN]G'))
        
        if not orig_files:
            raise ValueError(f"No image files found in {self.original_dir}")
        if not style_files:
            raise ValueError(f"No image files found in {self.styled_dir}")
    
    def _load_and_validate_pairs(self, validate: bool) -> List[Dict]:
        """Load and validate image pairs with error handling"""
        pairs = []
        
        # Get all image files
        extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.JPG', '*.JPEG', '*.PNG', '*.BMP']
        orig_files = []
        style_files = []
        
        for ext in extensions:
            orig_files.extend(self.original_dir.glob(ext))
            style_files.extend(self.styled_dir.glob(ext))
        
        # Sort for consistent pairing
        orig_files.sort()
        style_files.sort()
        
        # Create and validate pairs
        min_len = min(len(orig_files), len(style_files))
        
        for i in range(min_len):
            orig_path = orig_files[i]
            style_path = style_files[i]
            
            if validate and not self._validate_image_pair(orig_path, style_path):
                logger.warning(f"Skipping invalid pair: {orig_path.name} - {style_path.name}")
                continue
            
            pairs.append({
                'original_path': orig_path,
                'styled_path': style_path,
                'style_name': self.style_name
            })
        
        if not pairs:
            raise ValueError("No valid image pairs found")
        
        return pairs
    
    def _validate_image_pair(self, orig_path: Path, style_path: Path) -> bool:
        """Validate that both images can be loaded and have reasonable dimensions"""
        try:
            with Image.open(orig_path) as orig_img:
                with Image.open(style_path) as style_img:
                    # Check if images are valid and have reasonable size
                    if orig_img.size[0] < 64 or orig_img.size[1] < 64:
                        return False
                    if style_img.size[0] < 64 or style_img.size[1] < 64:
                        return False
                    return True
        except Exception as e:
            logger.warning(f"Error validating pair {orig_path.name}: {e}")
            return False
    
    def _setup_transforms(self, augment: bool) -> transforms.Compose:
        """Setup enhanced image transforms"""
        transform_list = [
            transforms.Resize((self.image_size, self.image_size), interpolation=transforms.InterpolationMode.LANCZOS),
        ]
        
        if augment:
            transform_list.extend([
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomRotation(degrees=5),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
                transforms.RandomAdjustSharpness(sharpness_factor=1.5, p=0.3),
            ])
        
        transform_list.extend([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        
        return transforms.Compose(transform_list)
    
    def __len__(self) -> int:
        return len(self.image_pairs)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, str]:
        """Get a single data sample with error handling"""
        try:
            pair = self.image_pairs[idx]
            
            # Load images
            with Image.open(pair['original_path']) as orig_img:
                original_image = orig_img.convert('RGB').copy()
            
            with Image.open(pair['styled_path']) as style_img:
                styled_image = style_img.convert('RGB').copy()
            
            # Apply transforms
            original_tensor = self.transform(original_image)
            styled_tensor = self.transform(styled_image)
            
            return original_tensor, styled_tensor, pair['style_name']
            
        except Exception as e:
            logger.error(f"Error loading sample {idx}: {e}")
            # Return a random valid sample instead
            return self.__getitem__(random.randint(0, len(self) - 1))

In [6]:
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from transformers import CLIPImageProcessor

In [7]:
class DiffusionStyleTransferPipeline:
    """Production-ready diffusion pipeline with memory management"""
    
    def __init__(
        self,
        config: TrainingConfig,
        device: str = "cuda",
        torch_dtype: torch.dtype = torch.float16,
    ):
        self.config = config
        self.device = device
        self.torch_dtype = torch_dtype
        self.lora_layers = {}
        
        logger.info("Initializing Stable Diffusion components...")
        
        # Load components
        self._load_components()
        
        # Setup enhanced loss if enabled
        if config.use_enhanced_loss:
            self.enhanced_loss = EnhancedStyleLoss(device=device, perceptual_weight=config.perceptual_loss_weight)
        
        # Setup inference pipeline (lazy loading)
        self.inference_pipeline = None
        
        logger.info(f"Pipeline initialized successfully")
    
    def _load_components(self):
        """Load SD components with proper memory management"""
        try:
            # Load components
            self.vae = AutoencoderKL.from_pretrained(
                self.config.model_id, 
                subfolder="vae", 
                torch_dtype=self.torch_dtype
            )
            self.text_encoder = CLIPTextModel.from_pretrained(
                self.config.model_id, 
                subfolder="text_encoder", 
                torch_dtype=self.torch_dtype
            )
            self.tokenizer = CLIPTokenizer.from_pretrained(
                self.config.model_id, 
                subfolder="tokenizer"
            )
            self.unet = UNet2DConditionModel.from_pretrained(
                self.config.model_id, 
                subfolder="unet", 
                torch_dtype=self.torch_dtype
            )
            self.scheduler = DDPMScheduler.from_pretrained(
                self.config.model_id, 
                subfolder="scheduler"
            )
            
            # Move to device
            self.vae.to(self.device)
            self.unet.to(self.device)
            
            # Keep text encoder on CPU initially to save memory
            if self.device == "cuda":
                self.text_encoder.to("cpu")
            else:
                self.text_encoder.to(self.device)
            
            # Freeze base models
            self.vae.requires_grad_(False)
            self.text_encoder.requires_grad_(False)
            self.unet.requires_grad_(False)
            
            # Add LoRA to UNet
            self._add_lora_to_unet()
            
        except Exception as e:
            logger.error(f"Error loading components: {e}")
            raise
    
    def _add_lora_to_unet(self):
        """Add LoRA layers with proper module replacement"""
        lora_count = 0
        
        for name, module in self.unet.named_modules():
            if isinstance(module, nn.Linear) and any(target in name for target in ['attn', 'to_q', 'to_k', 'to_v', 'to_out']):
                # Create LoRA layer
                lora_layer = LoRALinear(
                    module, 
                    rank=self.config.lora_rank, 
                    alpha=self.config.lora_alpha
                )
                
                # Ensure proper device and dtype
                lora_layer.to(device=self.device, dtype=self.torch_dtype)
                
                # Replace module
                parent_name = '.'.join(name.split('.')[:-1])
                child_name = name.split('.')[-1]
                
                if parent_name:
                    parent_module = dict(self.unet.named_modules())[parent_name]
                    setattr(parent_module, child_name, lora_layer)
                else:
                    setattr(self.unet, child_name, lora_layer)
                
                self.lora_layers[name] = lora_layer
                lora_count += 1
        
        logger.info(f"Added {lora_count} LoRA layers to UNet")
    
    def encode_text(self, prompts: List[str]) -> torch.Tensor:
        """Encode text prompts with memory management"""
        # Move text encoder to device temporarily
        self.text_encoder.to(self.device)
        
        try:
            text_inputs = self.tokenizer(
                prompts,
                padding="max_length",
                max_length=self.tokenizer.model_max_length,
                truncation=True,
                return_tensors="pt",
            )
            
            with torch.no_grad():
                text_embeddings = self.text_encoder(text_inputs.input_ids.to(self.device))[0]
            
            return text_embeddings
            
        finally:
            # Move back to CPU to save memory
            if self.device == "cuda":
                self.text_encoder.to("cpu")
    
    def create_enhanced_style_prompts(self, style_names: List[str]) -> List[str]:
        """Generate enhanced style transfer prompts"""
        ghibli_templates = [
            "studio ghibli style, anime artwork, soft watercolor painting, hand-drawn animation, whimsical, detailed backgrounds",
            "miyazaki style illustration, ghibli anime, pastel colors, dreamy atmosphere, nature-inspired, magical realism",
            "ghibli movie style, traditional animation, soft lighting, organic shapes, peaceful scenery, artistic masterpiece",
            "spirited away style, howl's moving castle aesthetic, anime art, vibrant yet soft colors, detailed character design"
        ]
        
        general_templates = [
            "{} style artwork, high quality, detailed, masterpiece",
            "beautiful {} style painting, artistic, vibrant colors",
            "{} art style, professional illustration, stunning",
            "amazing {} style image, creative, high resolution",
        ]
        
        prompts = []
        for style_name in style_names:
            if "ghibli" in style_name.lower():
                template = random.choice(ghibli_templates)
                prompts.append(template)
            else:
                template = random.choice(general_templates)
                prompts.append(template.format(style_name))
        
        return prompts
    
    def training_step(
        self,
        original_images: torch.Tensor,
        styled_images: torch.Tensor,
        style_prompts: List[str],
    ) -> Dict[str, torch.Tensor]:
        """Enhanced training step with optional perceptual loss"""
        try:
            batch_size = original_images.shape[0]
            
            # Encode images to latent space
            with torch.no_grad():
                styled_latents = self.vae.encode(styled_images).latent_dist.sample()
                styled_latents *= self.vae.config.scaling_factor
            
            # Sample noise and timesteps
            noise = torch.randn_like(styled_latents)
            timesteps = torch.randint(
                0, self.scheduler.config.num_train_timesteps, 
                (batch_size,), device=self.device
            ).long()
            
            # Add noise to styled images
            noisy_latents = self.scheduler.add_noise(styled_latents, noise, timesteps)
            
            # Encode text prompts
            text_embeddings = self.encode_text(style_prompts)
            
            # Predict noise
            noise_pred = self.unet(
                noisy_latents,
                timesteps,
                encoder_hidden_states=text_embeddings,
            ).sample
            
            # Calculate loss
            if self.config.use_enhanced_loss and hasattr(self, 'enhanced_loss'):
                total_loss, diffusion_loss, perceptual_loss = self.enhanced_loss(
                    noise_pred, noise, styled_images, styled_images
                )
                return {
                    "loss": total_loss,
                    "diffusion_loss": diffusion_loss,
                    "perceptual_loss": perceptual_loss,
                    "noise_pred": noise_pred,
                    "noise": noise,
                }
            else:
                loss = F.mse_loss(noise_pred, noise, reduction="mean")
                return {
                    "loss": loss,
                    "diffusion_loss": loss,
                    "perceptual_loss": 0.0,
                    "noise_pred": noise_pred,
                    "noise": noise,
                }
            
        except torch.cuda.OutOfMemoryError:
            logger.error("CUDA out of memory during training step")
            torch.cuda.empty_cache()
            raise
        except Exception as e:
            logger.error(f"Error in training step: {e}")
            raise
    
    def setup_inference_pipeline(self):
        """Setup inference pipeline with proper safety checker handling"""
        if self.inference_pipeline is None:
            try:
                # Move text encoder back to device for inference
                self.text_encoder.to(self.device)
                
                self.inference_pipeline = StableDiffusionImg2ImgPipeline(
                    vae=self.vae,
                    text_encoder=self.text_encoder,
                    tokenizer=self.tokenizer,
                    unet=self.unet,
                    scheduler=self.scheduler,
                    safety_checker=None,
                    requires_safety_checker=False,  # Fixed safety checker warning
                    feature_extractor=None,
                )
                self.inference_pipeline.to(self.device)
                logger.info("Inference pipeline setup complete")
                
            except Exception as e:
                logger.error(f"Error setting up inference pipeline: {e}")
                raise
    
    def stylize_image(
        self,
        image: Union[str, Image.Image],
        prompt: str,
        strength: float = 0.8,  # Higher strength for better style transfer
        guidance_scale: float = 12.0,  # Higher guidance for better prompt adherence
        num_inference_steps: int = 75,  # More steps for better quality
        negative_prompt: str = "blurry, low quality, distorted, ugly, bad anatomy"
    ) -> Image.Image:
        """Apply learned style to an image with enhanced parameters"""
        try:
            self.setup_inference_pipeline()
            
            if isinstance(image, str):
                with Image.open(image) as img:
                    image = img.convert('RGB').copy()
            
            # Resize image
            image = image.resize((self.config.image_size, self.config.image_size), Image.LANCZOS)
            
            # Generate styled image
            with torch.no_grad():
                result = self.inference_pipeline(
                    prompt=prompt,
                    image=image,
                    strength=strength,
                    guidance_scale=guidance_scale,
                    num_inference_steps=num_inference_steps,
                    negative_prompt=negative_prompt,
                )
            
            return result.images[0]
            
        except Exception as e:
            logger.error(f"Error during image stylization: {e}")
            raise

In [9]:
class StyleTransferTrainer:
    """Production-ready trainer with all improvements"""
    
    def __init__(self, pipeline: DiffusionStyleTransferPipeline, config: TrainingConfig):
        self.pipeline = pipeline
        self.config = config
        
        # Setup optimizer for LoRA parameters only
        self._setup_optimizer()
        
        # Setup mixed precision with updated API
        self.scaler = torch.amp.GradScaler('cuda') if config.mixed_precision and torch.cuda.is_available() else None
        
        # Training metrics
        self.train_losses = []
        self.val_losses = []
        
        logger.info(f"Trainer initialized with {len(self._get_lora_params())} LoRA parameters")
    
    def _get_lora_params(self) -> List[torch.nn.Parameter]:
        """Get all LoRA parameters"""
        lora_params = []
        for lora_layer in self.pipeline.lora_layers.values():
            lora_params.extend(lora_layer.parameters())
        return lora_params
    
    def _setup_optimizer(self):
        """Setup optimizer with learning rate scheduling"""
        lora_params = self._get_lora_params()
        
        self.optimizer = optim.AdamW(
            lora_params,
            lr=self.config.learning_rate,
            weight_decay=0.01,
            betas=(0.9, 0.999)
        )
        
        # Learning rate scheduler
        self.scheduler = CosineAnnealingLR(
            self.optimizer,
            T_max=self.config.num_epochs,
            eta_min=self.config.learning_rate * 0.1
        )
    
    def train_epoch(self, dataloader: DataLoader, epoch: int) -> float:
        """Train for one epoch with gradient accumulation and enhanced loss logging"""
        total_loss = 0
        total_diffusion_loss = 0
        total_perceptual_loss = 0
        num_batches = len(dataloader)
        
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch}")
        
        self.optimizer.zero_grad()
        
        for batch_idx, (original_images, styled_images, batch_style_names) in enumerate(progress_bar):
            try:
                # Move to device
                original_images = original_images.to(self.pipeline.device)
                styled_images = styled_images.to(self.pipeline.device)
                
                # Create enhanced prompts
                prompts = self.pipeline.create_enhanced_style_prompts(batch_style_names)
                
                # Training step with mixed precision (updated API)
                if self.scaler:
                    with torch.amp.autocast('cuda'):
                        outputs = self.pipeline.training_step(
                            original_images=original_images,
                            styled_images=styled_images,
                            style_prompts=prompts,
                        )
                        loss = outputs["loss"] / self.config.accumulation_steps
                    
                    self.scaler.scale(loss).backward()
                else:
                    outputs = self.pipeline.training_step(
                        original_images=original_images,
                        styled_images=styled_images,
                        style_prompts=prompts,
                    )
                    loss = outputs["loss"] / self.config.accumulation_steps
                    loss.backward()
                
                total_loss += loss.item() * self.config.accumulation_steps
                total_diffusion_loss += outputs.get("diffusion_loss", 0)
                total_perceptual_loss += outputs.get("perceptual_loss", 0)
                
                # Gradient accumulation step
                if (batch_idx + 1) % self.config.accumulation_steps == 0:
                    if self.scaler:
                        self.scaler.unscale_(self.optimizer)
                        torch.nn.utils.clip_grad_norm_(
                            self._get_lora_params(),
                            self.config.gradient_clip
                        )
                        self.scaler.step(self.optimizer)
                        self.scaler.update()
                    else:
                        torch.nn.utils.clip_grad_norm_(
                            self._get_lora_params(),
                            self.config.gradient_clip
                        )
                        self.optimizer.step()
                    
                    self.optimizer.zero_grad()
                
                # Update progress bar with detailed loss info
                progress_bar.set_postfix({
                    'Total': f'{loss.item() * self.config.accumulation_steps:.4f}',
                    'Diffusion': f'{outputs.get("diffusion_loss", 0):.4f}',
                    'Perceptual': f'{outputs.get("perceptual_loss", 0):.4f}',
                    'LR': f'{self.optimizer.param_groups[0]["lr"]:.2e}'
                })
                 
            except torch.cuda.OutOfMemoryError:
                logger.error("CUDA OOM during training - clearing cache")
                torch.cuda.empty_cache()
                self.optimizer.zero_grad()
                continue
            except Exception as e:
                logger.error(f"Error in batch {batch_idx}: {e}")
                continue
        
        avg_loss = total_loss / num_batches
        self.train_losses.append(avg_loss)
        
        # Step scheduler
        self.scheduler.step()
        
        # Log detailed loss information
        logger.info(f"Epoch {epoch} - Total: {avg_loss:.4f}, "
                    f"Diffusion: {total_diffusion_loss/num_batches:.4f}, "
                    f"Perceptual: {total_perceptual_loss/num_batches:.4f}")
        
        return avg_loss
    
    def validate(self, val_dataloader: DataLoader) -> float:
        """Validation loop"""
        total_val_loss = 0
        num_batches = len(val_dataloader)
        
        with torch.no_grad():
            for original_images, styled_images, batch_style_names in val_dataloader:
                try:
                    original_images = original_images.to(self.pipeline.device)
                    styled_images = styled_images.to(self.pipeline.device)
                    
                    prompts = self.pipeline.create_enhanced_style_prompts(batch_style_names)
                    
                    outputs = self.pipeline.training_step(
                        original_images=original_images,
                        styled_images=styled_images,
                        style_prompts=prompts,
                    )
                    
                    total_val_loss += outputs["loss"].item()
                    
                except Exception as e:
                    logger.warning(f"Error in validation batch: {e}")
                    continue
        
        avg_val_loss = total_val_loss / num_batches if num_batches > 0 else float('inf')
        self.val_losses.append(avg_val_loss)
        
        return avg_val_loss
    
    def save_model(self, path: str, epoch: int = None):
        """Save LoRA weights and training state"""
        try:
            os.makedirs(os.path.dirname(path), exist_ok=True)
            
            # Collect LoRA state dicts
            lora_state_dict = {}
            for name, lora_layer in self.pipeline.lora_layers.items():
                lora_state_dict[name] = lora_layer.get_lora_state_dict()
            
            # Save checkpoint
            checkpoint = {
                'lora_state_dict': lora_state_dict,
                'optimizer_state_dict': self.optimizer.state_dict(),
                'scheduler_state_dict': self.scheduler.state_dict(),
                'config': self.config,
                'epoch': epoch,
                'train_losses': self.train_losses,
                'val_losses': self.val_losses,
            }
            
            torch.save(checkpoint, path)
            logger.info(f"Model saved to {path}")
            
        except Exception as e:
            logger.error(f"Error saving model: {e}")
            raise
    
    def load_model(self, path: str):
        """Load LoRA weights and training state"""
        try:
            checkpoint = torch.load(path, map_location=self.pipeline.device)
            
            # Load LoRA weights
            for name, lora_layer in self.pipeline.lora_layers.items():
                if name in checkpoint['lora_state_dict']:
                    state_dict = checkpoint['lora_state_dict'][name]
                    lora_layer.lora_A.data = state_dict['lora_A']
                    lora_layer.lora_B.data = state_dict['lora_B']
                    lora_layer.scaling = state_dict['scaling'].item()
            
            # Load optimizer and scheduler
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            
            # Load training history
            self.train_losses = checkpoint.get('train_losses', [])
            self.val_losses = checkpoint.get('val_losses', [])
            
            logger.info(f"Model loaded from {path}")
            
        except Exception as e:
            logger.error(f"Error loading model: {e}")
            raise

def create_dataloaders(
    original_dir: str, 
    styled_dir: str, 
    style_name: str, 
    config: TrainingConfig
) -> Tuple[DataLoader, DataLoader]:
    """Create train and validation dataloaders"""
    
    # Create full dataset
    full_dataset = StyleTransferDataset(
        original_dir=original_dir,
        styled_dir=styled_dir,
        style_name=style_name,
        image_size=config.image_size,
        augment=True,
        validate_pairs=True,
    )
    
    # Split dataset
    train_size = int((1 - config.validation_split) * len(full_dataset))
    val_size = len(full_dataset) - train_size
    
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, [train_size, val_size]
    )
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        drop_last=True,
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        drop_last=False,
    )
    
    return train_loader, val_loader

def enhanced_test_inference(pipeline, style_name, original_dir, output_dir, num_test_images=3):
    """Enhanced test inference with multiple images and better handling"""
    logger.info("Testing inference...")
    
    # Find test images
    test_files = list(Path(original_dir).glob('*.[jp][pn]g'))
    test_files.extend(list(Path(original_dir).glob('*.[JP][PN]G')))
    
    if not test_files:
        logger.warning("No test images found in original directory")
        return
    
    # Test with multiple images (or just the first few)
    num_tests = min(num_test_images, len(test_files))
    
    for i in range(num_tests):
        test_image_path = test_files[i]  # Fixed: use index instead of entire list
        logger.info(f"Testing with image {i+1}/{num_tests}: {test_image_path.name}")
        
        try:
            styled_result = pipeline.stylize_image(
                image=str(test_image_path),
                prompt=f"studio ghibli style, anime artwork, soft watercolor painting, detailed, masterpiece",
                strength=0.8,
                guidance_scale=12.0,
                num_inference_steps=75
            )
            
            # Save with descriptive filename
            test_output_path = f"{output_dir}/test_result_{i+1}_{test_image_path.stem}.png"
            styled_result.save(test_output_path)
            logger.info(f"Test result {i+1} saved to {test_output_path}")
            
        except Exception as e:
            logger.error(f"Error processing test image {test_image_path.name}: {e}")
            continue
    
    logger.info(f"Test inference completed. Results saved in {output_dir}")


In [11]:
def main():
    """Main training function with all improvements"""
    # Enhanced configuration
    config = TrainingConfig(
        batch_size=2,
        effective_batch_size=8,
        num_epochs=25,  # Increased for better training
        learning_rate=5e-5,  # Slightly higher learning rate
        save_every=10,
        lora_rank=128,  # Higher rank for more capacity
        lora_alpha=128,
        use_enhanced_loss=True,
        perceptual_loss_weight=0.1
    )
    
    # Setup directories - CUSTOMIZE THESE PATHS
    original_dir = "/kaggle/input/stylized-images/stylized_images/original_images1"  # Path to original images
    styled_dir = "/kaggle/input/stylized-images/stylized_images/gibli_stylized_images1"      # Path to styled images
    style_name = "ghibli"                          
    output_dir = "outputs"
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Validate directories
    if not os.path.exists(original_dir) or not os.path.exists(styled_dir):
        logger.error("Please update the directory paths in the main() function!")
        logger.error(f"Looking for:")
        logger.error(f"  Original images: {original_dir}")
        logger.error(f"  Styled images: {styled_dir}")
        return
    
    # Setup device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch_dtype = torch.float16 if device == "cuda" else torch.float32
    logger.info(f"Using device: {device} with dtype: {torch_dtype}")
    
    try:
        # Initialize pipeline and trainer
        pipeline = DiffusionStyleTransferPipeline(
            config=config,
            device=device,
            torch_dtype=torch_dtype
        )
        
        trainer = StyleTransferTrainer(pipeline, config)
        
        # Create dataloaders
        train_loader, val_loader = create_dataloaders(
            original_dir, styled_dir, style_name, config
        )
        logger.info(f"Created dataloaders: {len(train_loader)} train, {len(val_loader)} val batches")
        
        # Training loop
        logger.info("Starting training...")
        best_val_loss = float('inf')
        
        for epoch in range(config.num_epochs):
            # Train epoch
            avg_train_loss = trainer.train_epoch(train_loader, epoch)
            
            # Validate
            avg_val_loss = trainer.validate(val_loader)
            
            logger.info(
                f"Epoch {epoch}: Train Loss = {avg_train_loss:.4f}, "
                f"Val Loss = {avg_val_loss:.4f}, "
                f"LR = {trainer.optimizer.param_groups[0]['lr']:.2e}"
            )
            
            # Save best model
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                best_path = f"{output_dir}/best_model.pt"
                trainer.save_model(best_path, epoch)
            
            # Save periodic checkpoints
            if (epoch + 1) % config.save_every == 0:
                checkpoint_path = f"{output_dir}/checkpoint_epoch_{epoch+1}.pt"
                trainer.save_model(checkpoint_path, epoch)
        
        # Save final model
        final_path = f"{output_dir}/final_model.pt"
        trainer.save_model(final_path, config.num_epochs)
        
        logger.info("Training completed successfully!")
        
        # Enhanced test inference
        enhanced_test_inference(pipeline, style_name, original_dir, output_dir, num_test_images=3)
        
    except Exception as e:
        logger.error(f"Training failed with error: {e}")
        raise

if __name__ == "__main__":
    main()


Epoch 0: 100%|██████████| 3/3 [00:04<00:00,  1.38s/it, Total=0.0622, Diffusion=0.0622, Perceptual=0.0000, LR=5.00e-05]
Epoch 1: 100%|██████████| 3/3 [00:04<00:00,  1.39s/it, Total=0.0195, Diffusion=0.0195, Perceptual=0.0000, LR=4.98e-05]
Epoch 2: 100%|██████████| 3/3 [00:04<00:00,  1.41s/it, Total=0.0192, Diffusion=0.0192, Perceptual=0.0000, LR=4.93e-05]
Epoch 3: 100%|██████████| 3/3 [00:04<00:00,  1.38s/it, Total=0.0680, Diffusion=0.0680, Perceptual=0.0000, LR=4.84e-05]
Epoch 4: 100%|██████████| 3/3 [00:04<00:00,  1.47s/it, Total=0.1955, Diffusion=0.1955, Perceptual=0.0000, LR=4.72e-05]
Epoch 5: 100%|██████████| 3/3 [00:04<00:00,  1.50s/it, Total=0.0808, Diffusion=0.0808, Perceptual=0.0000, LR=4.57e-05]
Epoch 6: 100%|██████████| 3/3 [00:04<00:00,  1.40s/it, Total=0.1037, Diffusion=0.1037, Perceptual=0.0000, LR=4.39e-05]
Epoch 7: 100%|██████████| 3/3 [00:04<00:00,  1.39s/it, Total=0.0880, Diffusion=0.0880, Perceptual=0.0000, LR=4.18e-05]
Epoch 8: 100%|██████████| 3/3 [00:04<00:00,  1.3

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]