# Fine-tuning Stable Diffusion v1-4 con Old Book Illustrations

Adaptacion del notebook del docente (`2.finetuning_stable_diffusion.ipynb`)
para el dataset `gigant/oldbookillustrations`.

**3 cambios respecto al codigo del docente:**
1. `example["image"]` -> `example["1600px"]` (columna de imagen)
2. `example["text"]` -> `example["info_alt"]` (columna de caption)
3. `Resize((512,512))` -> `Resize(512) + CenterCrop(512)` (imagenes no cuadradas)

In [None]:
# Importamos las librerias necesarias:
from diffusers import StableDiffusionPipeline, DDPMScheduler
from diffusers import UNet2DConditionModel, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
from datasets import load_dataset
from accelerate import Accelerator
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
from tqdm import tqdm

In [None]:
# Vamos a utilizar GPU si esta disponible, o CPU:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
pretrained_model_name = "CompVis/stable-diffusion-v1-4"

# Cargamos un modelo pre-entrenado Stable Diffusion:
pipe = StableDiffusionPipeline.from_pretrained(
    pretrained_model_name,
).to(device)

In [None]:
# Generamos una imagen ANTES del fine-tuning:
prompt = "an illustration of a ship sailing through a stormy sea"
image_before = pipe(prompt).images[0]

image_before.save("../generated/before_finetuning.png")
image_before

In [None]:
# Cargando el dataset Old Book Illustrations:
dataset_name = "gigant/oldbookillustrations"

dataset = load_dataset(dataset_name, split="train")

# Comprobamos el tamano de las imagenes del dataset:
size = dataset[0]["1600px"].size
print(f"Tamano de las imagenes del dataset: {size}")
print(f"Numero de muestras: {len(dataset)}")

In [None]:
# Definimos las transformaciones necesarias para el dataset:
# CAMBIO 3: Resize(512) + CenterCrop(512) en vez de Resize((512,512))
# Las imagenes de Old Book Illustrations no son cuadradas,
# asi que usamos Resize + CenterCrop para preservar proporciones.
resolution = 512
image_transforms = transforms.Compose([
    transforms.Resize(resolution),                          # redimensionar lado menor a 512
    transforms.CenterCrop(resolution),                      # recortar cuadrado central
    transforms.ToTensor(),                                   # convertir a tensor
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),  # normalizacion
])

In [None]:
# Aplicamos la transformacion a una imagen aleatoria del dataset:
original_image = dataset[0]["1600px"]
transformed_image = image_transforms(original_image.convert("RGB"))
transformed_pil_image = transforms.ToPILImage()(transformed_image)

print("Imagen transformada:")
transformed_pil_image

In [None]:
# Antes de cargar todas las componentes, liberamos la pipeline:
del pipe
if device == "cuda":
    torch.cuda.empty_cache()

In [None]:
# Tokenizador:
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name, subfolder='tokenizer')

# Scheduler:
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name, subfolder="scheduler")

# Text Encoder (CLIP):
text_encoder = CLIPTextModel.from_pretrained(
    pretrained_model_name,
    subfolder="text_encoder",
).to(device)

# VAE: Autoencoder:
vae = AutoencoderKL.from_pretrained(
    pretrained_model_name,
    subfolder="vae",
).to(device)

# La UNet:
unet = UNet2DConditionModel.from_pretrained(
    pretrained_model_name,
    subfolder="unet",
).to(device)

In [None]:
# Creamos un Dataset wrapper para la hora del entrenamiento:
batch_size = 6

class Text2ImageDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        example = self.dataset[idx]
        # CAMBIO 1: "1600px" en vez de "image"
        image = image_transforms(example["1600px"].convert("RGB"))
        # CAMBIO 2: "info_alt" en vez de "text"
        token = tokenizer(example["info_alt"], padding="max_length", truncation=True, max_length=tokenizer.model_max_length, return_tensors="pt")
        return {
            "pixel_values": image,
            "input_ids": token.input_ids.squeeze(0),
            "attention_mask": token.attention_mask.squeeze(0),
        }

train_dataset = Text2ImageDataset(dataset)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [None]:
# Congelamos los pesos del VAE y del Text Encoder, ya que solo queremos finetunear la UNet:
vae.eval()
text_encoder.eval()

for param in vae.parameters():
    param.requires_grad = False
for param in text_encoder.parameters():
    param.requires_grad = False

In [None]:
# Optimizador:
learning_rate = 1e-5
optimizer = torch.optim.AdamW(unet.parameters(), lr=learning_rate)

# Acelerador:
accelerator = Accelerator()
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
print(accelerator.device)

In [None]:
# Training loop:
num_epochs = 2

os.makedirs("../outputs/checkpoints", exist_ok=True)

for epoch in range(num_epochs):
    epoch_losses = []
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}")

    for batch in progress_bar:

        # Se pasan los pixeles al espacio latente con el encoder del VAE:
        with torch.no_grad():
            latents = vae.encode(batch["pixel_values"].to(accelerator.device)).latent_dist.sample()
            latents = latents * 0.18215

        # Proceso de difusion hacia delante:
        # 1. Creamos ruido aleatorio
        noise = torch.randn_like(latents)
        # 2. Cogemos un timestep aleatorio:
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=latents.device).long()
        # 3. Anadimos ruido al vector del espacio latente:
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        # Codificamos el texto:
        encoder_hidden_states = text_encoder(batch["input_ids"].to(accelerator.device))[0]

        # Con el vector con ruido, el timestep, y el vector de texto, hacemos la prediccion de ruido:
        noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

        # Calculamos el error y actualizamos los parametros:
        loss = torch.nn.functional.mse_loss(noise_pred, noise)
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()

        epoch_losses.append(loss.item())
        progress_bar.set_postfix(loss=loss.item())

    # Loss promedio por epoch
    avg_loss = sum(epoch_losses) / len(epoch_losses)
    print(f"Epoch {epoch + 1}/{num_epochs} - Loss promedio: {avg_loss:.6f}")

    # Checkpoint cada epoch
    checkpoint_path = f"../outputs/checkpoints/checkpoint-epoch-{epoch + 1}"
    unet.save_pretrained(checkpoint_path)
    print(f"Checkpoint guardado en {checkpoint_path}")

In [None]:
# Guardamos el modelo finetuneado:
output_dir = "../outputs/finetuned-model"
os.makedirs(output_dir, exist_ok=True)

unet.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

print(f"Modelo guardado en {output_dir}")

## Inferencia con el modelo fine-tuneado

Cargamos la UNet fine-tuneada y generamos una imagen con el mismo prompt.

In [None]:
# Cargamos la UNet finetuneada:
finetuned_unet = UNet2DConditionModel.from_pretrained("../outputs/finetuned-model")
finetuned_unet.to(device)

print('Modelo finetuneado cargado correctamente!')

In [None]:
from diffusers import StableDiffusionPipeline

# Cargamos el modelo pre-entrenado pero sustituyendo la UNet por la nuestra:
pipe = StableDiffusionPipeline.from_pretrained(
    pretrained_model_name,
    unet=finetuned_unet,
).to(device)

In [None]:
# Generamos una imagen con el modelo finetuneado (mismo prompt que antes):
prompt = "an illustration of a ship sailing through a stormy sea"
image_after = pipe(prompt).images[0]

image_after.save("../generated/after_finetuning.png")
image_after

In [None]:
# Comparacion side-by-side
import matplotlib.pyplot as plt

image_before = Image.open("../generated/before_finetuning.png")

fig, axes = plt.subplots(1, 2, figsize=(14, 7))

axes[0].imshow(image_before)
axes[0].set_title('ANTES del fine-tuning', fontsize=14)
axes[0].axis('off')

axes[1].imshow(image_after)
axes[1].set_title('DESPUES del fine-tuning', fontsize=14)
axes[1].axis('off')

plt.suptitle(f'Prompt: "{prompt}"', fontsize=12, style='italic')
plt.tight_layout()
plt.savefig("../generated/comparison.png", dpi=150, bbox_inches='tight')
plt.show()

print("Comparacion guardada en ../generated/comparison.png")