# Fine Tuning


URL dataset

https://www.kaggle.com/datasets/ipythonx/wikiart-gangogh-creating-art-gan

In [None]:
from datasets import load_dataset
import torch.nn.functional as F

dataset = load_dataset("huggan/wikiart", split="train")

print(dataset[0]["text"])  

In [None]:
# Debes convertir las imágenes a tensores y normalizarlas
from torchvision import transforms

preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

def preprocess_images(examples):
    examples["image"] = [preprocess(image.convert("RGB")) for image in examples["image"]]
    return examples

dataset = dataset.map(preprocess_images, batched=True)

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")

def tokenize_captions(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=77)

dataset = dataset.map(tokenize_captions, batched=True)
dataset.set_format("torch", columns=["input_ids", "image"])

dataset = dataset.train_test_split(test_size=0.1)
train_dataset = dataset["train"]
val_dataset = dataset["test"]

## Fine tuning SDXL 

In [None]:
from transformers import CLIPTextModel
import torch
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, DDPMScheduler
from diffusers.optimization import get_cosine_schedule_with_warmup
from diffusers import AutoencoderKL
from accelerate import Accelerator
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import os
import torch.nn.functional as F

MODEL_NAME = "stabilityai/stable-diffusion-xl-base-1.0"
DATASET_NAME = "huggan/wikiart"
OUTPUT_DIR = "sdxl-wikiart"
BATCH_SIZE = 1
GRAD_ACCUM_STEPS = 4
LEARNING_RATE = 1e-5
NUM_EPOCHS = 3  
MIXED_PRECISION = "fp16"

# Cargar y preparar datasets
dataset = load_dataset(DATASET_NAME, split="train")
dataset = dataset.train_test_split(test_size=0.1)
train_dataset = dataset["train"]
val_dataset = dataset["test"]

# Crear DataLoaders
def collate_fn(examples):
    return {
        "input_ids": torch.stack([example["input_ids"] for example in examples]),
        "image": torch.stack([example["image"] for example in examples])
    }

train_dataloader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True,
    collate_fn=collate_fn
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn
)

# Cargar componentes SDXL
unet = UNet2DConditionModel.from_pretrained(MODEL_NAME, subfolder="unet")
noise_scheduler = DDPMScheduler.from_pretrained(MODEL_NAME, subfolder="scheduler")
vae = AutoencoderKL.from_pretrained(MODEL_NAME, subfolder="vae")
text_encoder = CLIPTextModel.from_pretrained(MODEL_NAME, subfolder="text_encoder")

# Optimizer
optimizer = torch.optim.AdamW(unet.parameters(), lr=LEARNING_RATE)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=100,
    num_training_steps=NUM_EPOCHS * len(train_dataset) // (BATCH_SIZE * GRAD_ACCUM_STEPS),
)

# Preparar con Accelerator
accelerator = Accelerator(
    gradient_accumulation_steps=GRAD_ACCUM_STEPS,
    mixed_precision=MIXED_PRECISION,
)

train_dataloader, val_dataloader, unet, vae, text_encoder, optimizer, lr_scheduler = accelerator.prepare(
    train_dataloader, val_dataloader, unet, vae, text_encoder, optimizer, lr_scheduler
)

# Entrenamiento
for epoch in range(NUM_EPOCHS):
    # Fase de entrenamiento
    unet.train()
    total_train_loss = 0
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch} [Train]")
    
    for batch in progress_bar:
        with accelerator.accumulate(unet):
            # Obtener hidden states del texto
            encoder_hidden_states = text_encoder(batch["input_ids"])[0]
            
            # Convertir imágenes a latents
            latents = vae.encode(batch["image"]).latent_dist.sample() * 0.18215
            
            # Añadir ruido
            noise = torch.randn_like(latents)
            timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (latents.shape[0],)).long()
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
            
            # Predecir ruido
            noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
            
            # Calcular pérdida
            loss = F.mse_loss(noise_pred, noise)
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            
            total_train_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

    # Fase de validación
    unet.eval()
    total_val_loss = 0
    val_progress = tqdm(val_dataloader, desc=f"Epoch {epoch} [Val]")
    
    for batch in val_progress:
        with torch.no_grad():
            encoder_hidden_states = text_encoder(batch["input_ids"])[0]
            latents = vae.encode(batch["image"]).latent_dist.sample() * 0.18215
            noise = torch.randn_like(latents)
            timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (latents.shape[0],)).long()
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
            noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
            val_loss = F.mse_loss(noise_pred, noise)
            
            total_val_loss += val_loss.item()
            val_progress.set_postfix(val_loss=val_loss.item())

    # Guardar checkpoint
    if epoch % 1 == 0:
        accelerator.save_state(f"{OUTPUT_DIR}/checkpoint-{epoch}")
    
    # Mostrar métricas
    avg_train_loss = total_train_loss / len(train_dataloader)
    avg_val_loss = total_val_loss / len(val_dataloader)
    print(f"\nEpoch {epoch} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}\n")

# Guardar modelo final
accelerator.wait_for_everyone()
unet = accelerator.unwrap_model(unet)
unet.save_pretrained(OUTPUT_DIR, safe_serialization=True)
print(f"✅ Modelo final guardado en: {OUTPUT_DIR}")

## Pruebas

In [None]:
from diffusers import StableDiffusionXLPipeline

pipe = StableDiffusionXLPipeline.from_pretrained(
    OUTPUT_DIR,
    torch_dtype=torch.float16,
).to("cuda")

prompt = "A painting in Van Gogh style with vibrant colors"
image = pipe(prompt, num_inference_steps=30).images[0]
image.save("van_gogh_style.png")