<a href="https://colab.research.google.com/github/carinunez/ControlNet/blob/main/ControlNet_try2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ControlNet

Integrantes:
- xx
- xx
- xx
-xx

Arquitectura neuronal utilizada para agregar controles/restricciones especiales a los modelos de difusión de texto a imágenes preentrenados. En este caso, la arquitectura se aplico a Stable Diffusion, creando una copia entrenable de la UNet en el encoder que se une al modelo original utilizando zero-convolutions.


Escoger qué condiciones se probarán-> pose, profundidad,


In [None]:
import torch
import torch.nn as nn
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import StableDiffusionPipeline, DDPMScheduler, UNet2DConditionModel
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset, Dataset as HFDataset
from PIL import Image
import torchvision.transforms as T


In [None]:
device = "cuda" if torch.cuda.is_available() else None
device

In [None]:
model_id= "runwayml/stable-diffusion-v1-5"

# 1. Carga de los datos


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import cv2
import torchvision.transforms as T


# 1.2 Generación de Condiciones con Canny

In [None]:
image_dir = "/content/drive/MyDrive/img_pkmn"
canny_dir  ="/content/drive/MyDrive/canny"

if not os.path.isdir(canny_dir):
    os.makedirs(canny_dir)

def apply_canny_path(image: Image.Image, im_path) -> Image.Image:

    img = np.array(image.convert("RGB"))
    edges = cv2.Canny(img, 100, 200)
    edges = np.stack([edges] * 3, axis=-1)  # convertir a 3 canales
    edges = Image.fromarray(edges)
    edges.save(os.path.join(canny_dir, f'canny_{im_path}'))
    return edges

def load_images(folder, img_path):
    return Image.open(os.path.join(folder, img_path)).convert('RGB')

data = []
for im_path in os.listdir(image_dir):
  img = load_images(image_dir, im_path)
  # canny = apply_canny_path(img, im_path)
  canny = load_images(canny_dir, f'canny_{im_path}')
  prompt = "a colorful pokemon"
  data.append({"image": img, 'canny':canny, "text": f"a colorful pokemon_{im_path}"})

dataset_hf = HFDataset.from_list(data)


In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 12))
ax1.imshow(data[0]['image']);
ax1.set_title('Original');
ax2.imshow(data[0]['canny']);
ax2.set_title('Canny');


In [None]:
class ControlDataset(Dataset):
    def __init__(self, dataset, size=512):
        self.dataset = dataset
        self.image_trans = T.Compose([
            T.Resize((size, size)),
            T.ToTensor(),
            T.Normalize([0.5]*3, [0.5]*3)  # RGB
        ])

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
      sample = self.dataset[idx]

      img = self.image_trans(sample['image'])
      ctrl = self.image_trans(sample['canny'])
      prompt = sample['text']
      return img, ctrl, prompt


In [None]:
# Dataloader
# dataset = Dataset.from_list(data[:16])
train_dataloader = DataLoader(ControlDataset(dataset_hf), batch_size=2, shuffle=True)

Se agrega un acelerador para hacer una prueba

In [None]:
from diffusers import UNet2DConditionModel, AutoencoderKL, DDIMScheduler
from transformers import CLIPTextModel
import torch

# Carga el UNet preentrenado
unet = UNet2DConditionModel.from_pretrained(
    model_id, subfolder="unet"
).to(device)

# Carga el VAE preentrenado
vae = AutoencoderKL.from_pretrained(
    model_id, subfolder="vae"
).to(device)

# Carga el text encoder CLIP
text_encoder = CLIPTextModel.from_pretrained(
    "openai/clip-vit-large-patch14"
).to(device)

# Scheduler
scheduler = DDIMScheduler.from_pretrained(
    model_id, subfolder="scheduler"
)

# 2. Tokenizador

Importamos el tokenizador de Stable Diffusion para evitar problemas/inconsistencias con el tamaño de los embeddings del texto c/r a la unet

In [None]:
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder='tokenizer')
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder='text_encoder')

def get_text_embedding(prompt):
  inputs = tokenizer(prompt, padding='max_length', trunctation=True,
                     max_length=77, return_tensors='pt')

  with torch.no_grad():
    return text_encoder(**inputs.to('cuda')).last_hidden_state

# 2. Bloques de ControlNet

El bloque de ControlNet toma:
- x: imagen ruidosa
- t: paso de difusión
- encoder_hidden_states: embeddings del texto
- control_im: imagen utilizada para condicionar

Se generan activaciones residuales de control_im con las zeroConv2d, luego inyecto estas activaciones a los bloques originales de Stable Diffusion
Finalmente, se entrena manteniendo congelada la capa original de UNet

In [None]:
# Congelar los pesos de la UNet, VAE y Text Encoder
unet.requires_grad_(False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)

In [None]:
# scheduler
# noise_sch = DDPMScheduler(num_train_timesteps=1000)
noise_sch  =scheduler

In [None]:
# Forma de inyectar la condicion a StableDifussion
class zeroConv2d(nn.Conv2d):
    def __init__(self, in_c, out_c, kernel_size=1, stride=1, padding=0):
        super().__init__(in_c, out_c, kernel_size, stride, padding)
        nn.init.zeros_(self.weight)
        if self.bias is not None:
            nn.init.zeros_(self.bias)

# 4.1 Entrenamiento pipeline

In [None]:
# Forma de inyectar la condicion a StableDifussion
class zeroConv2d(nn.Conv2d):
    def __init__(self, in_c, out_c, kernel_size=1, stride=1, padding=0):
        super().__init__(in_c, out_c, kernel_size, stride, padding)
        nn.init.zeros_(self.weight)
        if self.bias is not None:
            nn.init.zeros_(self.bias)

class MYControlNet(nn.Module):
    def __init__(self, unet:UNet2DConditionModel, condition_image_channels=3):
        super().__init__()
        # ControlNet no necesita la UNet internamente para su forward pass principal
        # Solo necesita su configuración para construir sus propias capas.
        unet_config = unet.config
        block_channels = unet_config.block_out_channels

        # Bloques de ControlNet que reflejan los down_blocks de la UNet
        self.conv_in = zeroConv2d(condition_image_channels,
                                  unet_config.block_out_channels[0],
                                  kernel_size=3, padding=1)

        self.input_blocks = nn.ModuleList()
        in_ch = block_channels[0]

        for out_ch in unet_config.block_out_channels[1:]:
            self.input_blocks.append(
                nn.Sequential(
                    zeroConv2d(in_ch, out_ch, kernel_size=3, padding=1),
                    nn.AvgPool2d(2)  # Reduce resolución a la mitad
                )
            )
            in_ch = out_ch

        # Bloque central/intermedio de ControlNet
        self.middle_block = nn.Sequential(
                                zeroConv2d(in_ch, in_ch, kernel_size=3, padding=1),
                                nn.SiLU(),
                                zeroConv2d(in_ch, in_ch, kernel_size=3, padding=1))


    def forward(self, sample, timestep=None, encoder_hidden_states=None,
                controlnet_cond=None, **kwargs):
       # ControlNet solo toma la imagen de control
        # Generar las activaciones residuales de ControlNet
        controlnet_down_res = []

        # Primera activación (después de `conv_in`)
        h = self.conv_in(controlnet_cond)
        controlnet_down_res.append(h)

        # Iterar a través de los bloques de entrada de ControlNet
        for block in self.input_blocks:
            h = block(h)
            controlnet_down_res.append(h)

        # Activación del bloque central
        controlnet_mid_res = self.middle_block(h) # La última 'h' es la entrada al mid_block

        return controlnet_down_res, controlnet_mid_res

In [None]:
num_epochs = 10
from tqdm.auto import tqdm

# Congelo la UNet
model = MYControlNet(unet).to('cuda')
model.train()

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)


In [None]:
unet.to(device)
unet.eval()

vae.to(device)
vae.eval()

text_encoder.to(device)
text_encoder.eval()
for param in text_encoder.parameters():
    param.requires_grad = False

optimizer = torch.optim.AdamW(list(model.parameters()) + list(unet.parameters()),
                              lr=1e-4)


for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    for batch_idx, (img_batch, contrl_batch, prompt_batch) in enumerate(tqdm(train_dataloader)):
        optimizer.zero_grad()

        # Mover datos a la GPU
        target_latents = img_batch.to(device)

        with torch.no_grad():
          latent_dist = vae.encode(target_latents).latent_dist
          target_latents = latent_dist.sample()*vae.config.scaling_factor

        control_images = contrl_batch.to(device)
        control_images = torch.nn.functional.interpolate(control_images,
                                                         size=(512,512),
                                                         mode="bilinear",
                                                         align_corners=False)
        text_prompts = prompt_batch

        # 1. Codificar el texto de condición
        text_inputs = tokenizer(
                          text_prompts,
                          padding="max_length",
                          truncation=True,
                          max_length=tokenizer.model_max_length,
                          return_tensors="pt",
                          ).input_ids.to(device)
        encoder_hidden_states = text_encoder(text_inputs)[0]

        # 2. Generar un timestep aleatorio
        timesteps = torch.randint(0, noise_sch.num_train_timesteps, (target_latents.shape[0],), device=device).long()

        # 3. Añadir ruido a los latents objetivos
        noise = torch.randn_like(target_latents)
        noisy_latents = noise_sch.add_noise(target_latents, noise, timesteps)

        # Las activaciones de ControlNet:
        controlnet_down_res, controlnet_mid_res = model(
                                sample=noisy_latents,  # ignored inside MYControlNet
                                timestep=timesteps,    # ignored inside MYControlNet
                                encoder_hidden_states=encoder_hidden_states,  # ignored inside MYControlNet
                                controlnet_cond=control_images)

        # controlnet_down_block_res_samples =controlnet_outputs.down_block_res_samples
        # controlnet_mid_block_res_sample = controlnet_outputs.mid_block_res_samples
        # controlnet_down_block_res_samples= list(reversed(controlnet_down_block_res_samples))

        for i, r in enumerate(controlnet_down_block_res_samples):
            print(f"  Residual {i}: {r.shape}")
        print(f"  Mid residual: {controlnet_mid_block_res_sample.shape}")


        model_pred = unet(
            noisy_latents,
            timesteps,
            encoder_hidden_states=encoder_hidden_states,
            down_block_additional_residuals=controlnet_down_res, # Lista de tensores
            mid_block_additional_residual=controlnet_mid_res,     # Tensor
            return_dict=False)[0] # Obtener el tensor de salida, no el dict

        # 6. Calcular la pérdida
        # Se compara la predicción del ruido con el ruido real.
        loss = torch.nn.functional.mse_loss(model_pred, noise)

        # 7. Retropropagación
        loss.backward()
        optimizer.step()

    if (batch_idx + 1) % 10 == 0:
      print(f"Step {batch_idx+1}, Loss: {loss.item():.4f}")


        # Checkpoints
        # --- Save model every 10 epochs ---
    if (epoch + 1) % 10 == 0:
      checkpoint_path = f"./checkpoints/epoch_{epoch+1}.pth"
      torch.save(model.state_dict(), checkpoint_path)
      print(f"Model saved to {checkpoint_path}")
    print("Entrenamiento completado.")

# Guardar tu ControlNet
torch.save(model.state_dict(), "my_controlnet_weights.pth")

# 4. Entrenamiento

In [None]:
from accelerate import Accelerator

accelerator = Accelerator(
    mixed_precision="fp16",  # o bf16
    log_with="tensorboard",  # o wandb
    project_dir="./logs")

optimizer = torch.optim.AdamW(
                  filter(lambda p: p.requires_grad, model.parameters()),
                  lr=1e-4)

# usado para un aprendizaje más suave
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=1000)

model, optimizer, dataloader, lr_scheduler = accelerator.prepare(model, optimizer,
                                                                 train_dataloader, lr_scheduler)
