# Emoji Generator: Complete Pipeline
## Training, Evaluation & Inference with Stable Diffusion

This notebook provides a complete end-to-end pipeline for training, evaluating, and using a Stable Diffusion model for emoji generation. Designed to work in Google Colab and Kaggle environments.

### Features:
- Complete training pipeline without MLflow
- Direct configuration (no YAML files)
- Evaluation metrics and visualization
- Inference and image generation
- Google Colab & Kaggle compatible


## 1. Environment Setup & Installation


In [None]:
# Install required packages
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
%pip install transformers diffusers accelerate
%pip install pillow matplotlib tqdm pandas numpy
%pip install ipywidgets

# For Colab: Mount Google Drive if needed
# from google.colab import drive
# drive.mount('/content/drive')

import sys
import os
import warnings
warnings.filterwarnings('ignore')

print("Environment setup complete!")


## 2. Configuration & Imports


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.amp import GradScaler, autocast
import torchvision.transforms as transforms

import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import math
import random
from typing import List, Optional, Tuple, Any, Dict
import os
from pathlib import Path

# Transformers and diffusers
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL

print("All imports successful!")


In [None]:
# =============================================================================
# CONFIGURATION - Replace YAML config files
# =============================================================================

# Model Configuration
MODEL_CONFIG = {
    "stable_diffusion": {
        "h_dim": 384,
        "n_head": 8,
        "time_dim": 1280,
        "num_train_timesteps": 1000,
        "beta_start": 0.00085,
        "beta_end": 0.012
    },
    "clip": {
        "low_cpu_mem_usage": True,
        "model_id": "openai/clip-vit-base-patch32",
        "max_length": 77
    },
    "vae": {
        "low_cpu_mem_usage": True,
        "model_id": "stabilityai/sd-vae-ft-mse",
        "scaling_factor": 0.18215
    }
}

# Training Configuration
TRAINING_CONFIG = {
    "epochs": 50,  # Reduced for notebook demo
    "batch_size": 4,  # Reduced for memory constraints
    "learning_rate": 1e-4,
    "eta_min": 1e-5,
    "gradient_accumulation_steps": 1,
    "mixed_precision": True,
    "weight_decay": 0.01,
    "save_every": 10,
    "log_every": 5
}

# Data Configuration
DATA_CONFIG = {
    "data_dirs": [
        "datasets/blobs_crawled_data",
        "datasets/pepe_crawled_data"
    ],
    "image_size": [32, 32],
    "latent_size": [4, 4],
    "train_split": 0.9,
    "val_split": 0.1,
    "num_workers": 2,
    "pin_memory": True,
    "persistent_workers": False  # Set to False for notebooks
}

# Inference Configuration
INFERENCE_CONFIG = {
    "num_inference_steps": 25,  # Reduced for faster inference
    "guidance_scale": 7.5,
    "seed": 42
}

# Device setup
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Set random seeds for reproducibility
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(seed)

set_seed(42)
print("Configuration complete!")


In [None]:
# =============================================================================
# SCHEDULER COMPONENT
# =============================================================================

def embed_timesteps(timesteps: torch.Tensor, embedding_dim: int = 320) -> torch.Tensor:
    """Embed timesteps for diffusion process"""
    half_dim = embedding_dim // 2
    freqs = torch.exp(
        -math.log(10000) * torch.arange(half_dim, dtype=torch.float32) / half_dim
    ).to(device=timesteps.device)
    args = timesteps[:, None].float() * freqs[None, :]
    return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)


class DDPMScheduler:
    """DDPM Scheduler for diffusion process"""
    def __init__(
        self,
        random_generator: torch.Generator,
        train_timesteps: int = 1000,
        beta_start: float = 0.00085,
        beta_end: float = 0.012,
    ):
        self.betas = (
            torch.linspace(
                beta_start**0.5, beta_end**0.5, train_timesteps, dtype=torch.float32
            ) ** 2
        )
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.one_val = torch.tensor(1.0)
        self.generator = random_generator
        self.total_train_timesteps = train_timesteps
        self.timesteps = torch.from_numpy(np.arange(0, train_timesteps)[::-1].copy())

    def set_steps(self, num_inference_steps: int = 50):
        self.num_inference_steps = num_inference_steps
        step_ratio = self.total_train_timesteps // num_inference_steps
        timesteps = (
            (np.arange(0, num_inference_steps) * step_ratio)
            .round()[::-1]
            .copy()
            .astype(np.int64)
        )
        self.timesteps = torch.from_numpy(timesteps)

    def _get_prev_timestep(self, timestep: int) -> int:
        return timestep - self.total_train_timesteps // self.num_inference_steps

    def _get_variance(self, timestep: int) -> torch.Tensor:
        prev_t = self._get_prev_timestep(timestep)
        alpha_cumprod_t = self.alphas_cumprod[timestep]
        alpha_cumprod_t_prev = (
            self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one_val
        )
        beta_t = 1 - alpha_cumprod_t / alpha_cumprod_t_prev
        variance = (1 - alpha_cumprod_t_prev) / (1 - alpha_cumprod_t) * beta_t
        return torch.clamp(variance, min=1e-20)

    def step(
        self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor
    ) -> torch.Tensor:
        t = timestep
        prev_t = self._get_prev_timestep(t)

        alpha_cumprod_t = self.alphas_cumprod[t]
        alpha_cumprod_t_prev = (
            self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one_val
        )
        beta_cumprod_t = 1 - alpha_cumprod_t
        beta_cumprod_t_prev = 1 - alpha_cumprod_t_prev
        alpha_t = alpha_cumprod_t / alpha_cumprod_t_prev
        beta_t = 1 - alpha_t

        # Predict original sample
        pred_original_sample = (
            latents - beta_cumprod_t**0.5 * model_output
        ) / alpha_cumprod_t**0.5

        # Compute coefficients
        pred_original_sample_coeff = (
            alpha_cumprod_t_prev**0.5 * beta_t
        ) / beta_cumprod_t
        current_sample_coeff = alpha_t**0.5 * beta_cumprod_t_prev / beta_cumprod_t

        # Compute predicted previous sample
        pred_prev_sample = (
            pred_original_sample_coeff * pred_original_sample
            + current_sample_coeff * latents
        )

        # Add noise
        variance = 0
        if t > 0:
            device = model_output.device
            noise = torch.randn(
                model_output.shape,
                generator=self.generator,
                device=device,
                dtype=model_output.dtype,
            )
            variance = (self._get_variance(t) ** 0.5) * noise

        return pred_prev_sample + variance

    def add_noise(
        self, original_samples: torch.Tensor, timesteps: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        alphas_cumprod = self.alphas_cumprod.to(
            device=original_samples.device, dtype=original_samples.dtype
        )
        timesteps = timesteps.to(original_samples.device)

        sqrt_alpha_cumprod = alphas_cumprod[timesteps] ** 0.5
        sqrt_alpha_cumprod = sqrt_alpha_cumprod.view(
            sqrt_alpha_cumprod.shape[0], *([1] * (original_samples.ndim - 1))
        )

        sqrt_one_minus_alpha_cumprod = (1 - alphas_cumprod[timesteps]) ** 0.5
        sqrt_one_minus_alpha_cumprod = sqrt_one_minus_alpha_cumprod.view(
            sqrt_one_minus_alpha_cumprod.shape[0], *([1] * (original_samples.ndim - 1))
        )

        noise = torch.randn(
            original_samples.shape,
            generator=self.generator,
            device=original_samples.device,
            dtype=original_samples.dtype,
        )

        noisy_samples = (
            sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_cumprod * noise
        )
        return noisy_samples, noise

print("Scheduler components defined!")


In [None]:
# =============================================================================
# ATTENTION COMPONENTS
# =============================================================================

class SelfAttention(nn.Module):
    """Self-attention mechanism"""
    def __init__(self, n_heads: int, d_embed: int, in_proj_bias: bool = True):
        super().__init__()
        self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
        self.out_proj = nn.Linear(d_embed, d_embed)
        self.n_heads = n_heads
        self.d_head = d_embed // n_heads

    def forward(self, x: torch.Tensor, causal_mask: bool = False) -> torch.Tensor:
        input_shape = x.shape
        batch_size, sequence_length, d_embed = input_shape
        interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)

        q, k, v = self.in_proj(x).chunk(3, dim=-1)
        q = q.view(interim_shape).transpose(1, 2)
        k = k.view(interim_shape).transpose(1, 2)
        v = v.view(interim_shape).transpose(1, 2)

        weight = q @ k.transpose(-1, -2)
        if causal_mask:
            mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
            weight.masked_fill_(mask, -torch.inf)

        weight /= math.sqrt(self.d_head)
        weight = F.softmax(weight, dim=-1)
        output = weight @ v
        output = output.transpose(1, 2).reshape(input_shape)
        return self.out_proj(output)


class CrossAttention(nn.Module):
    """Cross-attention for text conditioning"""
    def __init__(self, n_heads: int, d_embed: int, d_cross: int, in_proj_bias: bool = True):
        super().__init__()
        self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
        self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
        self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
        self.out_proj = nn.Linear(d_embed, d_embed)
        self.n_heads = n_heads
        self.d_head = d_embed // n_heads

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        input_shape = x.shape
        batch_size, sequence_length, d_embed = input_shape
        interim_shape = (batch_size, -1, self.n_heads, self.d_head)

        q = self.q_proj(x)
        k = self.k_proj(y)
        v = self.v_proj(y)

        q = q.view(batch_size, sequence_length, self.n_heads, self.d_head).transpose(1, 2)
        k = k.view(interim_shape).transpose(1, 2)
        v = v.view(interim_shape).transpose(1, 2)

        weight = q @ k.transpose(-1, -2)
        weight /= math.sqrt(self.d_head)
        weight = F.softmax(weight, dim=-1)
        output = weight @ v
        output = output.transpose(1, 2).reshape(input_shape)
        return self.out_proj(output)

print("Attention components defined!")


In [None]:
# =============================================================================
# UNET COMPONENTS
# =============================================================================

class TimeEmbedding(nn.Module):
    """Time embedding for diffusion timesteps"""
    def __init__(self, n_embd: int):
        super().__init__()
        self.proj1 = nn.Linear(n_embd, 4 * n_embd)
        self.proj2 = nn.Linear(4 * n_embd, 4 * n_embd)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.proj1(x)
        x = F.silu(x)
        x = self.proj2(x)
        return x


class UNETResidualBlock(nn.Module):
    """Residual block for UNET"""
    def __init__(self, in_channels: int, out_channels: int, time_dim: int = 1280):
        super().__init__()
        self.gn_feature = nn.GroupNorm(32, in_channels)
        self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.time_embedding_proj = nn.Linear(time_dim, out_channels)
        
        self.gn_merged = nn.GroupNorm(32, out_channels)
        self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        
        if in_channels == out_channels:
            self.residual_connection = nn.Identity()
        else:
            self.residual_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)

    def forward(self, input_feature: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor:
        residual = input_feature
        
        h = self.gn_feature(input_feature)
        h = F.silu(h)
        h = self.conv_feature(h)
        
        time_emb_processed = F.silu(time_emb)
        time_emb_projected = self.time_embedding_proj(time_emb_processed)
        time_emb_projected = time_emb_projected.unsqueeze(-1).unsqueeze(-1)
        
        merged_feature = h + time_emb_projected
        merged_feature = self.gn_merged(merged_feature)
        merged_feature = F.silu(merged_feature)
        merged_feature = self.conv_merged(merged_feature)
        
        return merged_feature + self.residual_connection(residual)


class UNETAttentionBlock(nn.Module):
    """Attention block for UNET"""
    def __init__(self, num_heads: int, head_dim: int, context_dim: int = 512):
        super().__init__()
        embed_dim = num_heads * head_dim
        
        self.gn_in = nn.GroupNorm(32, embed_dim, eps=1e-6)
        self.proj_in = nn.Conv2d(embed_dim, embed_dim, kernel_size=1, padding=0)
        
        self.ln_1 = nn.LayerNorm(embed_dim)
        self.attn_1 = SelfAttention(num_heads, embed_dim, in_proj_bias=False)
        self.ln_2 = nn.LayerNorm(embed_dim)
        self.attn_2 = CrossAttention(num_heads, embed_dim, context_dim, in_proj_bias=False)
        self.ln_3 = nn.LayerNorm(embed_dim)
        
        self.ffn_geglu = nn.Linear(embed_dim, 4 * embed_dim * 2)
        self.ffn_out = nn.Linear(4 * embed_dim, embed_dim)
        self.proj_out = nn.Conv2d(embed_dim, embed_dim, kernel_size=1, padding=0)

    def forward(self, input_tensor: torch.Tensor, context_tensor: torch.Tensor) -> torch.Tensor:
        skip_connection = input_tensor
        
        B, C, H, W = input_tensor.shape
        HW = H * W
        
        h = self.gn_in(input_tensor)
        h = self.proj_in(h)
        h = h.view(B, C, HW).transpose(-1, -2)
        
        # Self-attention
        attn1_skip = h
        h = self.ln_1(h)
        h = self.attn_1(h)
        h = h + attn1_skip
        
        # Cross-attention
        attn2_skip = h
        h = self.ln_2(h)
        h = self.attn_2(h, context_tensor)
        h = h + attn2_skip
        
        # Feed forward with GEGLU
        ffn_skip = h
        h = self.ln_3(h)
        h, gate = self.ffn_geglu(h).chunk(2, dim=-1)
        h = h * F.gelu(gate)
        h = self.ffn_out(h)
        h = h + ffn_skip
        
        h = h.transpose(-1, -2).view(B, C, H, W)
        return self.proj_out(h) + skip_connection

print("UNET components defined!")


In [None]:
# =============================================================================
# COMPLETE MODELS
# =============================================================================

class UNET(nn.Module):
    """Simplified UNET for emoji generation"""
    def __init__(self, h_dim: int = 384, n_head: int = 8):
        super().__init__()
        head_dim = h_dim // n_head
        
        # Input projection
        self.input_conv = nn.Conv2d(4, h_dim, kernel_size=3, padding=1)
        
        # Encoder layers
        self.encoder_blocks = nn.ModuleList([
            UNETResidualBlock(h_dim, h_dim),
            UNETAttentionBlock(n_head, head_dim),
            UNETResidualBlock(h_dim, h_dim),
        ])
        
        # Middle block
        self.middle_block = UNETAttentionBlock(n_head, head_dim)
        
        # Decoder layers  
        self.decoder_blocks = nn.ModuleList([
            UNETResidualBlock(h_dim, h_dim),
            UNETAttentionBlock(n_head, head_dim),
            UNETResidualBlock(h_dim, h_dim),
        ])

    def forward(self, latent: torch.Tensor, context: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor:
        h = self.input_conv(latent)
        
        # Encoder
        skip_connections = []
        for i, block in enumerate(self.encoder_blocks):
            if isinstance(block, UNETResidualBlock):
                h = block(h, time_emb)
            else:  # UNETAttentionBlock
                h = block(h, context)
            skip_connections.append(h)
        
        # Middle
        h = self.middle_block(h, context)
        
        # Decoder
        for i, block in enumerate(self.decoder_blocks):
            h = h + skip_connections.pop()
            if isinstance(block, UNETResidualBlock):
                h = block(h, time_emb)
            else:  # UNETAttentionBlock
                h = block(h, context)
        
        return h


class UNETOutputLayer(nn.Module):
    """Output layer for UNET"""
    def __init__(self, h_dim: int, output_channels: int):
        super().__init__()
        self.gn = nn.GroupNorm(32, h_dim)
        self.conv = nn.Conv2d(h_dim, output_channels, kernel_size=3, padding=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.gn(x)
        x = F.silu(x)
        return self.conv(x)


class StableDiffusion(nn.Module):
    """Complete Stable Diffusion model"""
    def __init__(self, h_dim: int = 384, n_head: int = 8):
        super().__init__()
        self.time_embedding = TimeEmbedding(320)
        self.unet = UNET(h_dim, n_head)
        self.unet_output = UNETOutputLayer(h_dim, 4)

    def forward(self, latent: torch.Tensor, context: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
        time_emb = embed_timesteps(timestep).to(latent.device)
        time_emb = self.time_embedding(time_emb)
        output = self.unet(latent, context, time_emb)
        return self.unet_output(output)


class CLIPTextEncoder(nn.Module):
    """CLIP text encoder for conditioning"""
    def __init__(self, model_id: str = "openai/clip-vit-base-patch32", device: str = "cuda"):
        super().__init__()
        self.tokenizer = CLIPTokenizer.from_pretrained(model_id, low_cpu_mem_usage=True)
        self.text_encoder = CLIPTextModel.from_pretrained(model_id, low_cpu_mem_usage=True)
        self.device = device
        
        # Freeze parameters
        for param in self.text_encoder.parameters():
            param.requires_grad = False
        
        self.text_encoder.eval()
        self.text_encoder.to(device)

    def forward(self, prompts: List[str]) -> torch.Tensor:
        inputs = self.tokenizer(
            prompts,
            padding="max_length",
            truncation=True,
            max_length=77,
            return_tensors="pt",
        )
        
        input_ids = inputs.input_ids.to(self.device)
        attention_mask = inputs.attention_mask.to(self.device)
        
        with torch.no_grad():
            text_encoder_output = self.text_encoder(
                input_ids=input_ids, attention_mask=attention_mask
            )
        
        return text_encoder_output.last_hidden_state


class VAEEncoder(nn.Module):
    """VAE encoder for latent space conversion"""
    def __init__(self, model_id: str = "stabilityai/sd-vae-ft-mse", scaling_factor: float = 0.18215):
        super().__init__()
        self.vae = AutoencoderKL.from_pretrained(model_id, low_cpu_mem_usage=True)
        self.scaling_factor = scaling_factor
        
        # Freeze VAE parameters
        self.vae.requires_grad_(False)
        self.vae.eval()

    def encode(self, images: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            latents = self.vae.encode(images).latent_dist.sample()
            return latents * self.scaling_factor

    def decode(self, latents: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            latents = latents / self.scaling_factor
            return self.vae.decode(latents).sample

print("All model components defined!")


In [None]:
# =============================================================================
# DATA LOADING
# =============================================================================

class EmojiDataset(Dataset):
    """Dataset for emoji images and text prompts"""
    def __init__(self, data_dirs: List[str], transform: Optional[transforms.Compose] = None):
        dataframes = []
        for data_dir in data_dirs:
            csv_file = os.path.join(data_dir, "metadata.csv")
            image_folder = os.path.join(data_dir, "images")
            
            if os.path.exists(csv_file):
                df = pd.read_csv(csv_file)
                df["image_path"] = df["file_name"].astype(str).str.replace("\\", "/")
                df["full_image_path"] = df["image_path"].apply(
                    lambda x: os.path.join(image_folder, x)
                )
                dataframes.append(df)
        
        if dataframes:
            self.dataframe = pd.concat(dataframes, ignore_index=True)
        else:
            # Create dummy data for demo purposes
            print("Warning: No data found, creating dummy dataset")
            self.dataframe = pd.DataFrame({
                'prompt': ['happy emoji', 'sad emoji', 'angry emoji', 'surprised emoji'] * 25,
                'full_image_path': ['dummy_path'] * 100
            })
        
        self.transform = transform
        self.prompts = self.dataframe["prompt"].tolist()
        self.image_paths = self.dataframe["full_image_path"].tolist()

    def __len__(self) -> int:
        return len(self.dataframe)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str]:
        image_path = self.image_paths[idx]
        prompt = self.prompts[idx].replace('"', "").replace("'", "")
        
        # Handle dummy data or missing images
        if image_path == 'dummy_path' or not os.path.exists(image_path):
            # Create a random RGB image for demo
            image = Image.fromarray(np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8))
        else:
            image = Image.open(image_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)
        
        return image, prompt


def get_transforms(image_size: Tuple[int, int] = (32, 32)) -> transforms.Compose:
    """Get image transforms"""
    return transforms.Compose([
        transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])


def create_data_loaders(data_config: Dict[str, Any], training_config: Dict[str, Any]) -> Tuple[DataLoader, DataLoader]:
    """Create train and validation data loaders"""
    transform = get_transforms(tuple(data_config["image_size"]))
    
    dataset = EmojiDataset(
        data_dirs=data_config["data_dirs"],
        transform=transform,
    )
    
    train_size = int(data_config["train_split"] * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(
        dataset,
        [train_size, val_size],
        generator=torch.Generator().manual_seed(42),
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=training_config["batch_size"],
        shuffle=True,
        num_workers=data_config["num_workers"],
        pin_memory=data_config["pin_memory"],
        persistent_workers=data_config["persistent_workers"],
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=training_config["batch_size"],
        shuffle=False,
        num_workers=data_config["num_workers"],
        pin_memory=data_config["pin_memory"],
        persistent_workers=data_config["persistent_workers"],
    )
    
    return train_loader, val_loader


# Create data loaders
print("Creating data loaders...")
train_loader, val_loader = create_data_loaders(DATA_CONFIG, TRAINING_CONFIG)
print(f"Train batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

# Test data loading
try:
    sample_batch = next(iter(train_loader))
    images, prompts = sample_batch
    print(f"Sample batch shape: {images.shape}")
    print(f"Sample prompts: {prompts[:2]}")
except Exception as e:
    print(f"Data loading test failed: {e}")

print("Data loading setup complete!")


In [None]:
# =============================================================================
# TRAINING PIPELINE
# =============================================================================

class StableDiffusionTrainer:
    """Trainer class for Stable Diffusion model"""
    def __init__(self, model_config: Dict[str, Any], training_config: Dict[str, Any], device: str = "cuda"):
        self.model_config = model_config
        self.training_config = training_config
        self.device = device
        
        # Initialize models
        self.diffusion_model = StableDiffusion(
            h_dim=model_config["stable_diffusion"]["h_dim"],
            n_head=model_config["stable_diffusion"]["n_head"],
        ).to(device)
        
        self.text_encoder = CLIPTextEncoder(
            model_id=model_config["clip"]["model_id"], 
            device=device
        )
        
        self.vae_encoder = VAEEncoder(
            model_id=model_config["vae"]["model_id"],
            scaling_factor=model_config["vae"]["scaling_factor"]
        ).to(device)
        
        # Initialize scheduler
        self.generator = torch.Generator(device=device)
        self.scheduler = DDPMScheduler(
            random_generator=self.generator,
            train_timesteps=model_config["stable_diffusion"]["num_train_timesteps"],
            beta_start=model_config["stable_diffusion"]["beta_start"],
            beta_end=model_config["stable_diffusion"]["beta_end"],
        )
        
        # Initialize optimizer and scheduler
        self.optimizer = torch.optim.AdamW(
            self.diffusion_model.parameters(),
            lr=training_config["learning_rate"],
            weight_decay=training_config["weight_decay"],
        )
        
        self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer,
            T_max=training_config["epochs"],
            eta_min=training_config["eta_min"],
        )
        
        # Mixed precision training
        self.scaler = GradScaler() if training_config["mixed_precision"] else None
        
        # Loss function
        self.criterion = nn.MSELoss()
        
        # Training state
        self.current_epoch = 0
        self.global_step = 0
        self.best_loss = float("inf")
        self.train_losses = []
        self.val_losses = []

    def train_epoch(self, train_loader: DataLoader) -> float:
        """Train for one epoch"""
        self.diffusion_model.train()
        epoch_loss = 0.0
        num_batches = len(train_loader)
        
        progress_bar = tqdm(
            train_loader,
            desc=f"Epoch {self.current_epoch + 1}/{self.training_config['epochs']}",
        )
        
        for batch_idx, (images, prompts) in enumerate(progress_bar):
            images = images.to(self.device)
            
            # Encode images to latent space
            with torch.no_grad():
                latents = self.vae_encoder.encode(images)
            
            # Sample random timesteps
            timesteps = torch.randint(
                0,
                self.scheduler.total_train_timesteps,
                (latents.shape[0],),
                device=self.device,
            )
            
            # Add noise to latents
            noisy_latents, noise = self.scheduler.add_noise(latents, timesteps)
            
            # Encode text
            with torch.no_grad():
                text_embeddings = self.text_encoder(prompts)
            
            # Forward pass with mixed precision
            if self.scaler:
                with autocast("cuda", dtype=torch.float16):
                    noise_pred = self.diffusion_model(
                        noisy_latents, text_embeddings, timesteps
                    )
                    loss = self.criterion(noise_pred, noise)
            else:
                noise_pred = self.diffusion_model(
                    noisy_latents, text_embeddings, timesteps
                )
                loss = self.criterion(noise_pred, noise)
            
            # Backward pass
            self.optimizer.zero_grad()
            
            if self.scaler:
                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                loss.backward()
                self.optimizer.step()
            
            # Update progress
            batch_loss = loss.item()
            epoch_loss += batch_loss
            self.global_step += 1
            
            # Update progress bar
            progress_bar.set_postfix({
                "loss": f"{batch_loss:.5f}",
                "lr": f"{self.optimizer.param_groups[0]['lr']:.6f}",
            })
        
        return epoch_loss / num_batches

    def validate(self, val_loader: DataLoader) -> float:
        """Validate the model"""
        self.diffusion_model.eval()
        val_loss = 0.0
        num_batches = len(val_loader)
        
        with torch.no_grad():
            for images, prompts in tqdm(val_loader, desc="Validation"):
                images = images.to(self.device)
                
                # Encode images to latent space
                latents = self.vae_encoder.encode(images)
                
                # Sample random timesteps
                timesteps = torch.randint(
                    0,
                    self.scheduler.total_train_timesteps,
                    (latents.shape[0],),
                    device=self.device,
                )
                
                # Add noise to latents
                noisy_latents, noise = self.scheduler.add_noise(latents, timesteps)
                
                # Encode text
                text_embeddings = self.text_encoder(prompts)
                
                # Forward pass
                noise_pred = self.diffusion_model(
                    noisy_latents, text_embeddings, timesteps
                )
                loss = self.criterion(noise_pred, noise)
                
                val_loss += loss.item()
        
        return val_loss / num_batches

    def save_checkpoint(self, filepath: str, is_best: bool = False):
        """Save model checkpoint"""
        checkpoint = {
            "epoch": self.current_epoch,
            "model_state_dict": self.diffusion_model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "lr_scheduler_state_dict": self.lr_scheduler.state_dict(),
            "best_loss": self.best_loss,
            "train_losses": self.train_losses,
            "val_losses": self.val_losses,
            "model_config": self.model_config,
            "training_config": self.training_config,
        }
        
        if self.scaler:
            checkpoint["scaler_state_dict"] = self.scaler.state_dict()
        
        os.makedirs(os.path.dirname(filepath), exist_ok=True)
        torch.save(checkpoint, filepath)
        
        if is_best:
            best_path = filepath.replace('.pt', '_best.pt')
            torch.save(checkpoint, best_path)

    def train(self, train_loader: DataLoader, val_loader: Optional[DataLoader] = None):
        """Complete training loop"""
        print(f"Starting training for {self.training_config['epochs']} epochs...")
        print(f"Device: {self.device}")
        print(f"Model parameters: {sum(p.numel() for p in self.diffusion_model.parameters() if p.requires_grad)}")
        
        for epoch in range(self.current_epoch, self.training_config["epochs"]):
            self.current_epoch = epoch
            
            # Train epoch
            train_loss = self.train_epoch(train_loader)
            self.train_losses.append(train_loss)
            
            # Validation
            if val_loader and epoch % 5 == 0:  # Validate every 5 epochs
                val_loss = self.validate(val_loader)
                self.val_losses.append(val_loss)
                
                # Save best model
                if val_loss < self.best_loss:
                    self.best_loss = val_loss
                    self.save_checkpoint(f"checkpoints/emoji_sd_epoch_{epoch}_best.pt", is_best=True)
                
                print(f"Epoch {epoch + 1}: Train Loss: {train_loss:.5f}, Val Loss: {val_loss:.5f}")
            else:
                print(f"Epoch {epoch + 1}: Train Loss: {train_loss:.5f}")
            
            # Save checkpoint
            if epoch % self.training_config["save_every"] == 0:
                self.save_checkpoint(f"checkpoints/emoji_sd_epoch_{epoch}.pt")
            
            # Update learning rate
            self.lr_scheduler.step()
        
        # Save final model
        self.save_checkpoint(f"checkpoints/emoji_sd_final.pt")
        print("Training completed!")

print("Training pipeline defined!")


In [None]:
# =============================================================================
# INFERENCE PIPELINE
# =============================================================================

def rescale_tensor(tensor: torch.Tensor, from_range: Tuple[float, float], to_range: Tuple[float, float], clamp: bool = False) -> torch.Tensor:
    """Rescale tensor from one range to another"""
    from_min, from_max = from_range
    to_min, to_max = to_range
    
    # Rescale
    tensor = (tensor - from_min) / (from_max - from_min)
    tensor = tensor * (to_max - to_min) + to_min
    
    if clamp:
        tensor = torch.clamp(tensor, min=to_min, max=to_max)
    
    return tensor


class EmojiGenerator:
    """Emoji generator using trained Stable Diffusion model"""
    def __init__(
        self,
        diffusion_model: StableDiffusion,
        text_encoder: CLIPTextEncoder,
        vae_encoder: VAEEncoder,
        scheduler: DDPMScheduler,
        device: str = "cuda",
    ):
        self.diffusion_model = diffusion_model.eval()
        self.text_encoder = text_encoder
        self.vae_encoder = vae_encoder
        self.scheduler = scheduler
        self.device = device

    @torch.no_grad()
    def generate(
        self,
        prompt: str,
        num_inference_steps: int = 25,
        guidance_scale: float = 7.5,
        height: int = 32,
        width: int = 32,
        seed: Optional[int] = None,
        batch_size: int = 1,
    ) -> List[Image.Image]:
        """Generate emoji images from text prompt"""
        
        # Set random seed
        generator = torch.Generator(device=self.device)
        if seed is not None:
            generator.manual_seed(seed)

        # Encode text
        prompts = [prompt] * batch_size
        text_embeddings = self.text_encoder(prompts)

        # Classifier-free guidance: create unconditional embeddings
        unconditional_prompts = [""] * batch_size
        uncond_embeddings = self.text_encoder(unconditional_prompts)

        # Concatenate for classifier-free guidance
        text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

        # Set scheduler timesteps
        self.scheduler.set_steps(num_inference_steps)

        # Create random latents
        latents_shape = (batch_size, 4, height // 8, width // 8)
        latents = torch.randn(latents_shape, generator=generator, device=self.device)

        # Denoising loop
        for t in tqdm(self.scheduler.timesteps, desc="Generating"):
            # Expand latents for classifier-free guidance
            latent_model_input = torch.cat([latents] * 2)

            # Predict noise
            noise_pred = self.diffusion_model(latent_model_input, text_embeddings, t)

            # Perform classifier-free guidance
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_text - noise_pred_uncond
            )

            # Compute previous noisy sample
            latents = self.scheduler.step(t, latents, noise_pred)

        # Decode latents to images
        images = self.vae_encoder.decode(latents)

        # Convert to PIL Images
        images = rescale_tensor(images, (-1, 1), (0, 255), clamp=True)
        images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)

        pil_images = [Image.fromarray(img) for img in images]
        return pil_images

    @classmethod
    def from_checkpoint(cls, checkpoint_path: str, device: str = "cuda") -> "EmojiGenerator":
        """Load generator from checkpoint"""
        checkpoint = torch.load(checkpoint_path, map_location=device)
        
        model_config = checkpoint["model_config"]
        
        # Load models
        diffusion_model = StableDiffusion(
            h_dim=model_config["stable_diffusion"]["h_dim"],
            n_head=model_config["stable_diffusion"]["n_head"],
        ).to(device)
        
        diffusion_model.load_state_dict(checkpoint["model_state_dict"])
        
        text_encoder = CLIPTextEncoder(
            model_id=model_config["clip"]["model_id"], 
            device=device
        )
        
        vae_encoder = VAEEncoder(
            model_id=model_config["vae"]["model_id"],
            scaling_factor=model_config["vae"]["scaling_factor"]
        ).to(device)
        
        # Create scheduler
        generator = torch.Generator(device=device)
        scheduler = DDPMScheduler(
            random_generator=generator,
            train_timesteps=model_config["stable_diffusion"]["num_train_timesteps"],
            beta_start=model_config["stable_diffusion"]["beta_start"],
            beta_end=model_config["stable_diffusion"]["beta_end"],
        )
        
        return cls(diffusion_model, text_encoder, vae_encoder, scheduler, device)


def plot_images(images: List[Image.Image], prompts: List[str] = None, cols: int = 4):
    """Plot generated images"""
    rows = (len(images) + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 2, rows * 2))
    
    if rows == 1:
        axes = [axes] if cols == 1 else axes
    elif cols == 1:
        axes = [[ax] for ax in axes]
    
    for i, img in enumerate(images):
        row, col = i // cols, i % cols
        ax = axes[row][col] if rows > 1 else axes[col]
        ax.imshow(img)
        ax.axis('off')
        
        if prompts and i < len(prompts):
            ax.set_title(prompts[i][:30], fontsize=8)
    
    # Hide empty subplots
    for i in range(len(images), rows * cols):
        row, col = i // cols, i % cols
        ax = axes[row][col] if rows > 1 else axes[col]
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

print("Inference pipeline defined!")


In [None]:
# =============================================================================
# TRAINING EXECUTION
# =============================================================================

# Initialize trainer
trainer = StableDiffusionTrainer(MODEL_CONFIG, TRAINING_CONFIG, DEVICE)

# Start training
print("Starting training...")
print("Note: This is a demonstration with reduced epochs and batch size.")
print("For full training, increase epochs to 300+ and batch_size based on your GPU memory.")

# Uncomment the next line to start training
# trainer.train(train_loader, val_loader)

print("Training setup complete! Uncomment the trainer.train() line to start training.")


In [None]:
# =============================================================================
# TRAINING MONITORING
# =============================================================================

def plot_training_curves(trainer: StableDiffusionTrainer):
    """Plot training and validation loss curves"""
    if not trainer.train_losses:
        print("No training data to plot yet.")
        return
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    # Training loss
    ax1.plot(trainer.train_losses, label='Training Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training Loss')
    ax1.legend()
    ax1.grid(True)
    
    # Validation loss (if available)
    if trainer.val_losses:
        ax2.plot(range(0, len(trainer.train_losses), 5), trainer.val_losses, label='Validation Loss', color='orange')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Loss')
        ax2.set_title('Validation Loss')
        ax2.legend()
        ax2.grid(True)
    else:
        ax2.text(0.5, 0.5, 'No validation data', transform=ax2.transAxes, ha='center', va='center')
        ax2.set_title('Validation Loss (Not Available)')
    
    plt.tight_layout()
    plt.show()


def monitor_training_progress(trainer: StableDiffusionTrainer):
    """Monitor current training progress"""
    print(f"Current Epoch: {trainer.current_epoch}")
    print(f"Global Step: {trainer.global_step}")
    print(f"Best Loss: {trainer.best_loss:.6f}")
    
    if trainer.train_losses:
        print(f"Latest Training Loss: {trainer.train_losses[-1]:.6f}")
    
    if trainer.val_losses:
        print(f"Latest Validation Loss: {trainer.val_losses[-1]:.6f}")


# Example usage (after training has started)
print("Training monitoring functions defined.")
print("Use plot_training_curves(trainer) and monitor_training_progress(trainer) to track progress.")


In [None]:
# =============================================================================
# MODEL EVALUATION
# =============================================================================

def evaluate_model(trainer: StableDiffusionTrainer, val_loader: DataLoader) -> Dict[str, float]:
    """Comprehensive model evaluation"""
    print("Evaluating model...")
    
    # Basic validation loss
    val_loss = trainer.validate(val_loader)
    
    metrics = {
        "validation_loss": val_loss,
        "best_loss": trainer.best_loss,
        "epochs_trained": trainer.current_epoch,
    }
    
    print(f"Validation Loss: {val_loss:.6f}")
    print(f"Best Loss: {trainer.best_loss:.6f}")
    print(f"Epochs Trained: {trainer.current_epoch}")
    
    return metrics


def generate_evaluation_samples(checkpoint_path: str = None, trainer: StableDiffusionTrainer = None):
    """Generate sample images for qualitative evaluation"""
    print("Generating evaluation samples...")
    
    # Create generator
    if checkpoint_path and os.path.exists(checkpoint_path):
        generator = EmojiGenerator.from_checkpoint(checkpoint_path, DEVICE)
    elif trainer:
        # Create generator from current trainer
        generator = EmojiGenerator(
            trainer.diffusion_model,
            trainer.text_encoder, 
            trainer.vae_encoder,
            trainer.scheduler,
            DEVICE
        )
    else:
        print("No trained model available for evaluation.")
        return
    
    # Sample prompts for evaluation
    evaluation_prompts = [
        "happy smiling emoji",
        "sad crying emoji", 
        "angry red face emoji",
        "surprised emoji with wide eyes",
        "laughing emoji with tears",
        "winking emoji",
        "heart eyes emoji",
        "thinking emoji"
    ]
    
    generated_images = []
    
    for prompt in tqdm(evaluation_prompts, desc="Generating samples"):
        try:
            images = generator.generate(
                prompt=prompt,
                num_inference_steps=INFERENCE_CONFIG["num_inference_steps"],
                guidance_scale=INFERENCE_CONFIG["guidance_scale"],
                seed=42  # Fixed seed for reproducible evaluation
            )
            generated_images.extend(images)
        except Exception as e:
            print(f"Error generating for prompt '{prompt}': {e}")
            # Create a placeholder image
            placeholder = Image.new('RGB', (32, 32), color='gray')
            generated_images.append(placeholder)
    
    # Plot generated samples
    if generated_images:
        print("Generated evaluation samples:")
        plot_images(generated_images, evaluation_prompts, cols=4)
        
        # Save samples
        os.makedirs("evaluation_samples", exist_ok=True)
        for i, (img, prompt) in enumerate(zip(generated_images, evaluation_prompts)):
            img.save(f"evaluation_samples/sample_{i:02d}_{prompt.replace(' ', '_')}.png")
        
        print("Evaluation samples saved to 'evaluation_samples/' directory.")
    else:
        print("No images were generated successfully.")


# Evaluation functions are ready
print("Evaluation functions defined.")
print("Use evaluate_model(trainer, val_loader) for quantitative evaluation.")
print("Use generate_evaluation_samples(trainer=trainer) for qualitative evaluation.")


In [None]:
# =============================================================================
# INTERACTIVE INFERENCE
# =============================================================================

def interactive_generation(checkpoint_path: str = None, trainer: StableDiffusionTrainer = None):
    """Interactive emoji generation function"""
    
    # Create generator
    if checkpoint_path and os.path.exists(checkpoint_path):
        print(f"Loading model from checkpoint: {checkpoint_path}")
        generator = EmojiGenerator.from_checkpoint(checkpoint_path, DEVICE)
    elif trainer:
        print("Using current trainer model...")
        generator = EmojiGenerator(
            trainer.diffusion_model,
            trainer.text_encoder,
            trainer.vae_encoder, 
            trainer.scheduler,
            DEVICE
        )
    else:
        print("No trained model available. Please train a model first or provide a checkpoint path.")
        return
    
    print("Interactive Emoji Generator Ready!")
    print("Enter text prompts to generate emojis. Type 'quit' to exit.")
    
    while True:
        try:
            prompt = input("\nEnter prompt: ").strip()
            
            if prompt.lower() in ['quit', 'exit', 'q']:
                print("Goodbye!")
                break
            
            if not prompt:
                print("Please enter a valid prompt.")
                continue
            
            print(f"Generating emoji for: '{prompt}'...")
            
            # Generate images
            images = generator.generate(
                prompt=prompt,
                num_inference_steps=INFERENCE_CONFIG["num_inference_steps"],
                guidance_scale=INFERENCE_CONFIG["guidance_scale"],
                seed=None,  # Random seed for variety
                batch_size=4  # Generate 4 variations
            )
            
            # Display results
            print("Generated emojis:")
            plot_images(images, [prompt] * len(images), cols=2)
            
        except KeyboardInterrupt:
            print("\nGeneration interrupted.")
            break
        except Exception as e:
            print(f"Error during generation: {e}")


def batch_generation(prompts: List[str], checkpoint_path: str = None, trainer: StableDiffusionTrainer = None):
    """Generate emojis for a batch of prompts"""
    
    # Create generator
    if checkpoint_path and os.path.exists(checkpoint_path):
        generator = EmojiGenerator.from_checkpoint(checkpoint_path, DEVICE)
    elif trainer:
        generator = EmojiGenerator(
            trainer.diffusion_model,
            trainer.text_encoder,
            trainer.vae_encoder,
            trainer.scheduler,
            DEVICE
        )
    else:
        print("No trained model available.")
        return []
    
    all_images = []
    all_prompts = []
    
    print(f"Generating emojis for {len(prompts)} prompts...")
    
    for prompt in tqdm(prompts, desc="Batch generation"):
        try:
            images = generator.generate(
                prompt=prompt,
                num_inference_steps=INFERENCE_CONFIG["num_inference_steps"],
                guidance_scale=INFERENCE_CONFIG["guidance_scale"],
                seed=42  # Fixed seed for consistency
            )
            all_images.extend(images)
            all_prompts.extend([prompt] * len(images))
        except Exception as e:
            print(f"Error generating for '{prompt}': {e}")
    
    if all_images:
        print("Batch generation complete!")
        plot_images(all_images, all_prompts, cols=4)
        
        # Save batch results
        os.makedirs("batch_generation", exist_ok=True)
        for i, (img, prompt) in enumerate(zip(all_images, all_prompts)):
            safe_prompt = prompt.replace(' ', '_').replace('/', '_')
            img.save(f"batch_generation/batch_{i:03d}_{safe_prompt}.png")
        
        print("Batch results saved to 'batch_generation/' directory.")
    
    return all_images


# Example usage
sample_prompts = [
    "happy emoji",
    "sad emoji", 
    "angry emoji",
    "laughing emoji",
    "heart eyes emoji"
]

print("Inference functions ready!")
print("Use interactive_generation() for interactive mode.")
print("Use batch_generation(prompts) for batch processing.")
print(f"Example: batch_generation({sample_prompts})")


In [None]:
# =============================================================================
# FINAL STATUS CHECK
# =============================================================================

print("🎉 Emoji Generator Pipeline Complete! 🎉")
print("=" * 50)
print("✅ All components loaded successfully")
print("✅ Data loaders created")
print("✅ Model architecture defined")
print("✅ Training pipeline ready")
print("✅ Evaluation functions ready") 
print("✅ Inference pipeline ready")
print("=" * 50)

print("\n📋 Quick Action Items:")
print("1. Uncomment 'trainer.train(train_loader, val_loader)' to start training")
print("2. Use 'plot_training_curves(trainer)' to monitor progress")
print("3. Run 'generate_evaluation_samples(trainer=trainer)' for qualitative evaluation")
print("4. Use 'interactive_generation(trainer=trainer)' for interactive inference")

print(f"\n🔧 Current Configuration:")
print(f"   - Device: {DEVICE}")
print(f"   - Model dimension: {MODEL_CONFIG['stable_diffusion']['h_dim']}")
print(f"   - Training epochs: {TRAINING_CONFIG['epochs']}")
print(f"   - Batch size: {TRAINING_CONFIG['batch_size']}")
print(f"   - Learning rate: {TRAINING_CONFIG['learning_rate']}")

print("\n🚀 Ready for Google Colab and Kaggle!")
print("💡 Tip: Increase epochs and batch_size for production training")
