In [1]:
from diffusers import StableDiffusionPipeline
import torch

# Load pretrained LDM
pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16)
vae = pipe.vae.to("cuda")
unet = pipe.unet.to("cuda")  # We’ll fine-tune this
scheduler = pipe.scheduler

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [None]:
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from diffusers import StableDiffusionImg2ImgPipeline, DDIMScheduler, AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import os

# Hyperparameters
lr = 1e-5
num_epochs = 15
denoise_strength = 0.7
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# IndianFacesDataset (unchanged)
class IndianFacesDataset(Dataset):
    def __init__(self, img_dir, transform=None, race_filter=3):
        self.img_dir = img_dir
        self.transform = transform
        
        if not os.path.exists(img_dir):
            raise FileNotFoundError(f"Folder {img_dir} does not exist!")
        
        self.img_names = [f for f in os.listdir(img_dir) if f.endswith('jpg.chip.jpg')]
        self.metadata = []
        for img_name in self.img_names:
            parts = img_name.split("_")
            try:
                age, gender, race = int(parts[0]), int(parts[1]), int(parts[2])
                if race == race_filter:
                    self.metadata.append({
                        "filename": img_name, 
                        "age": age,
                        "gender": gender
                    })
            except (IndexError, ValueError) as e:
                print(f"Skipping {img_name}: Invalid format ({e})")
        
        print(f"Found {len(self.metadata)} Indian images (race=3)")

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        img_info = self.metadata[idx]
        img_path = os.path.join(self.img_dir, img_info["filename"])
        image = Image.open(img_path).convert("RGB")
        
        age = torch.tensor(img_info["age"], dtype=torch.float32) / 116.0
        gender = torch.tensor(img_info["gender"], dtype=torch.long)
        
        if self.transform:
            image = self.transform(image)
            
        return {
            "image": image,
            "age": age,
            "gender": gender,
            "filename": img_info["filename"]
        }

# Modified FaceAgingDataset
class FaceAgingDataset(IndianFacesDataset):
    def __getitem__(self, idx):
        data = super().__getitem__(idx)
        
        gender = data["gender"].item()
        age_years = int(data["age"].item() * 116)
        gender_str = "male" if gender == 0 else "female"
        prompt = f"A high-quality photo of an Indian {gender_str}, age {age_years} years"
        
        return {
            "image": data["image"],
            "age": data["age"],
            "gender": data["gender"],
            "prompt": prompt
        }

# AgeAdapter (unchanged)
class AgeAdapter(nn.Module):
    def __init__(self):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(1, 256),
            nn.SiLU(),
            nn.Linear(256, 768),
            nn.LayerNorm(768)
        )
        
    def forward(self, age):
        return self.proj(age.unsqueeze(-1))

# Training Loop
def train_face_aging(num_epochs=15, lr=1e-5, batch_size=2, image_size=512):
    tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder").to(device)
    vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").to(device)
    unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet").to(device)
    age_adapter = AgeAdapter().to(device)

    text_encoder.requires_grad_(False)
    vae.requires_grad_(False)

    optimizer = AdamW(
        list(unet.parameters()) + list(age_adapter.parameters()),
        lr=lr,
        weight_decay=1e-4
    )

    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])
    dataset = FaceAgingDataset("indian_images", transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    noise_scheduler = DDIMScheduler(
        num_train_timesteps=1000,
        beta_start=0.0001,
        beta_end=0.02,
        beta_schedule="linear"
    )

    for epoch in range(num_epochs):
        for step, batch in enumerate(dataloader):
            images = batch["image"].to(device)
            ages = batch["age"].to(device)
            genders = batch["gender"].to(device)
            prompts = batch["prompt"]
            
            text_input = tokenizer(
                prompts,
                padding="max_length",
                max_length=tokenizer.model_max_length,
                truncation=True,
                return_tensors="pt"
            )
            text_input = {k: v.to(device) for k, v in text_input.items()}
            with torch.no_grad():
                text_embeddings = text_encoder(text_input["input_ids"])[0]
            
            gender_embeddings = age_adapter(genders.float().unsqueeze(1)).squeeze(1)
            age_emb = age_adapter(ages.unsqueeze(1))
            combined_embeddings = text_embeddings + age_emb + gender_embeddings.unsqueeze(1)
            
            with torch.no_grad():
                latents = vae.encode(images).latent_dist.sample() * 0.18215
            
            noise = torch.randn_like(latents)
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps,
                (latents.shape[0],), device=device
            ).long()
            
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
            noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=combined_embeddings).sample
            loss = F.mse_loss(noise_pred, noise)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()
            
            if step % 10 == 0:
                print(f"Epoch {epoch}, Step {step}, Loss: {loss.item():.4f}")

    unet.save_pretrained("./face_aging_unet")
    torch.save(age_adapter.state_dict(), "./age_adapter.pth")
    print("Training complete!")

# Face Aging Inference Pipeline
class FaceAgingPipeline(StableDiffusionImg2ImgPipeline):
    def __init__(self, *args, age_adapter_path="./age_adapter.pth", **kwargs):
        super().__init__(*args, **kwargs)
        self.age_adapter = AgeAdapter().to(self.device)
        self.age_adapter.load_state_dict(torch.load(age_adapter_path))
        
    def __call__(self, image, target_age, strength=0.7, guidance_scale=7.5):
        prompt = f"A high-quality photo of an Indian person, age {target_age} years"
        
        text_input = self.tokenizer(
            [prompt],
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt"
        )
        text_input = {k: v.to(self.device) for k, v in text_input.items()}
        with torch.no_grad():
            text_embeddings = self.text_encoder(text_input["input_ids"])[0]
        
        age_tensor = torch.tensor([[target_age / 116.0]], device=self.device)
        age_emb = self.age_adapter(age_tensor)
        combined_embeddings = text_embeddings + age_emb
        
        return super().__call__(
            prompt_embeds=combined_embeddings,
            image=image,
            strength=strength,
            guidance_scale=guidance_scale
        )

# Usage Example
def perform_face_aging(input_image_path, target_age):
    pipe = FaceAgingPipeline.from_pretrained(
        "face_aging_unet",
        safety_checker=None,
        requires_safety_checker=False,
        age_adapter_path="./age_adapter.pth"
    ).to(device)
    
    input_image = Image.open(input_image_path).convert("RGB").resize((512, 512))
    
    result = pipe(
        image=input_image,
        target_age=target_age,
        strength=0.75,
        guidance_scale=8.0
    )
    
    return result.images[0]

# Example Usage
if __name__ == "__main__":
    # Training
    train_face_aging()
    
    # Inference
    aged_image = perform_face_aging("input.jpg", target_age=50)
    aged_image.save("output.jpg")

Found 1467 Indian images (race=3)
Epoch 0, Step 0, Loss: 0.0047
Epoch 0, Step 10, Loss: 0.0111
Epoch 0, Step 20, Loss: 0.0158
Epoch 0, Step 30, Loss: 0.0052
Epoch 0, Step 40, Loss: 0.0948
Epoch 0, Step 50, Loss: 0.0800
Epoch 0, Step 60, Loss: 0.0894
Epoch 0, Step 70, Loss: 0.3219
Epoch 0, Step 80, Loss: 0.1315
Epoch 0, Step 90, Loss: 0.5133
Epoch 0, Step 100, Loss: 0.0843
Epoch 0, Step 110, Loss: 0.1255
Epoch 0, Step 120, Loss: 0.0316
Epoch 0, Step 130, Loss: 0.0084
Epoch 0, Step 140, Loss: 0.1827
Epoch 0, Step 150, Loss: 0.0555
Epoch 0, Step 160, Loss: 0.0889
Epoch 0, Step 170, Loss: 0.0068
Epoch 0, Step 180, Loss: 0.0102
Epoch 0, Step 190, Loss: 0.0003
Epoch 0, Step 200, Loss: 0.1004
Epoch 0, Step 210, Loss: 0.0168
Epoch 0, Step 220, Loss: 0.0163
Epoch 0, Step 230, Loss: 0.1067
Epoch 0, Step 240, Loss: 0.4050
Epoch 0, Step 250, Loss: 0.0048
Epoch 0, Step 260, Loss: 0.0005
Epoch 0, Step 270, Loss: 0.3861
Epoch 0, Step 280, Loss: 0.0116
Epoch 0, Step 290, Loss: 0.4804
Epoch 0, Step 300