<a href="https://colab.research.google.com/github/carinunez/ControlNet/blob/main/Paper_ControlNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install diffusers transformers accelerate opencv-python torch torchvision datasets

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [29]:
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe = pipe.to("cuda")

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [3]:
import cv2
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset, Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
from diffusers.schedulers import DDIMScheduler
from types import MethodType
import os
from IPython.display import display
import copy

#Clases

In [23]:
class ConditionEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=4, stride=2, padding=1),  # 256→128
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1),  # 128→64
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # 64→32
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # 32→16
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.encoder(x)  # [B, 128, 32, 32]
        # Upsample para coincidir con tamaño de input a ControlBlock (64×64)
        x = F.interpolate(x, size=(64, 64), mode='bilinear', align_corners=False)
        return x

In [24]:
class ZeroConv(nn.Conv2d):
    def reset_parameters(self):
        nn.init.zeros_(self.weight)
        if self.bias is not None:
            nn.init.zeros_(self.bias)

class ControlBlock(nn.Module):
    def __init__(self, base_block, in_channels):
        super().__init__()
        self.locked_block = base_block.eval()
        for p in self.locked_block.parameters():
            p.requires_grad = False

        self.trainable_block = copy.deepcopy(base_block)
        self.zero_conv1 = ZeroConv(in_channels, 320, kernel_size=1)
        self.zero_conv2 = ZeroConv(320, 320, kernel_size=1)

    def forward(self, x, condition):
        zc1 = self.zero_conv1(condition)
        assert zc1.shape == x.shape, f"Shape mismatch: {zc1.shape} vs {x.shape}"

        x_cond = x + self.zero_conv1(condition)
        y_trainable = self.trainable_block(x_cond)
        y_final = self.locked_block(x) + self.zero_conv2(y_trainable)
        return y_final

In [25]:
class ControlNetWrapper(nn.Module):
    def __init__(self, unet, control_blocks, condition_encoder):
        super().__init__()
        self.unet = unet.eval()  # U-Net congelada
        self.control_blocks = control_blocks  # Lista de ControlBlocks
        self.condition_encoder = condition_encoder.eval()  # Congelada en inferencia

        # congelar todo
        for p in self.unet.parameters():
            p.requires_grad = False
        for cb in self.control_blocks:
            for p in cb.parameters():
                p.requires_grad = False
        for p in self.condition_encoder.parameters():
            p.requires_grad = False

    def forward(self, sample, timestep, encoder_hidden_states, condition=None, **kwargs):
        # sample: el latente z_t (ruido)
        # condition: imagen tipo Canny como tensor (B x 3 x 256 x 256)

        if condition is None:
            raise ValueError("Falta la condición visual en el forward")

        # Codificar la condición visual
        cond_feat = self.condition_encoder(condition)

        # Aplicar los bloques de ControlNet en secuencia
        x = sample
        for cb in self.control_blocks:
            x = cb(x, cond_feat)

        # Pasar a la U-Net congelada
        return self.unet(x, timestep, encoder_hidden_states=encoder_hidden_states, **kwargs)

In [26]:
# DataLoader
class ControlNetDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        img = transform(item["image"])
        cond = transform(item["condition"])
        caption = item["caption"]
        return img, cond, caption

In [None]:
#from huggingface_hub import login

In [None]:
#login(token="TU_TOKEN_AQUI")

# Datos
Datos en https://github.com/odegeasslbc/FastGAN-pytorch?tab=readme-ov-file.

In [8]:
#dataset = load_dataset("lambdalabs/pokemon-blip-captions", split="train")
#dataset = load_dataset("lambdalabs/pokemon", split="train")

FileNotFoundError: Couldn't find a dataset script at /content/lambdalabs/pokemon/pokemon.py or any data file in the same directory. Couldn't find 'lambdalabs/pokemon' on the Hugging Face Hub either: FileNotFoundError: Dataset 'lambdalabs/pokemon' doesn't exist on the Hub. If the repo is private or gated, make sure to log in with `huggingface-cli login`.

In [8]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [9]:
image_dir = "/content/drive/MyDrive/Modelos Generativos Profundos/img_pkmn"

image_paths = sorted([
    os.path.join(image_dir, f) for f in os.listdir(image_dir)
    if f.lower().endswith((".png", ".jpg", ".jpeg"))
])

examples = [{"image": Image.open(p).convert("RGB"), "text": "a colorful pokemon"} for p in image_paths]

dataset = Dataset.from_list(examples)

In [10]:
def add_caption(example):
    return {"text": "a cute colorful pokemon creature"}

dataset = dataset.map(add_caption)

Map:   0%|          | 0/833 [00:00<?, ? examples/s]

In [11]:
# Canny
def apply_canny(image: Image.Image) -> Image.Image:
    img = np.array(image.convert("RGB"))
    edges = cv2.Canny(img, 100, 200)
    edges = np.stack([edges] * 3, axis=-1)  # convertir a 3 canales
    return Image.fromarray(edges)

In [12]:
# Preprocesamiento
def preprocess(example):
    image = example["image"].resize((512, 512))
    canny = apply_canny(image)
    return {
        "image": image,
        "caption": example["text"],
        "condition": canny
    }

In [13]:
dataset = dataset.map(preprocess)

Map:   0%|          | 0/833 [00:00<?, ? examples/s]

#Red

In [14]:
# Transforms
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor()
])

# Dataloader
train_dataloader = DataLoader(ControlNetDataset(dataset), batch_size=1, shuffle=True)

In [15]:
# Parche a stable diffusion pipeline
def patched_call(self, prompt=None, condition=None, **kwargs):
    if "condition" not in kwargs:
        kwargs["condition"] = condition
    return self.__class__.original_call(self, prompt=prompt, **kwargs)

In [16]:
scheduler = DDIMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler")

vae = pipe.vae
text_encoder = pipe.text_encoder
tokenizer = pipe.tokenizer
unet = pipe.unet.eval()  # congelado

# Freeze U-Net
for p in unet.parameters():
    p.requires_grad = False

In [17]:
# Parchear stable diffussion para darle la condición
StableDiffusionPipeline.original_call = StableDiffusionPipeline.__call__
StableDiffusionPipeline.__call__ = MethodType(patched_call, pipe)

In [30]:
# Encoder de condiciones minimal
condition_encoder = ConditionEncoder().to("cuda")

# Bloque 1 (igual que ya tienes)
control_block1 = ControlBlock(
    base_block=pipe.unet.down_blocks[0].resnets[0],
    in_channels=128
).to("cuda")

# Bloque 2 (segundo resnet del mismo down_block)
control_block2 = ControlBlock(
    base_block=pipe.unet.down_blocks[0].resnets[1],
    in_channels=128
).to("cuda")

In [31]:
wrapper = ControlNetWrapper(
    unet=unet,
    control_blocks=[control_block1, control_block2],
    condition_encoder=condition_encoder
)

# Reemplaza la unet del pipeline
pipe.unet = wrapper

In [44]:
torch.cuda.empty_cache()

In [45]:
# Optimizador
optimizer = torch.optim.Adam(
    list(control_block1.parameters()) + list(control_block2.parameters()),
    lr=1e-4
)

# Entrenamiento simple
# Steps
step = 0
for epoch in range(3):
    for imgs, conds, captions in tqdm(train_dataloader):
        imgs, conds = imgs.to("cuda"), conds.to("cuda")

        # 1. Codifica imagen a latente limpio
        latents = vae.encode(imgs).latent_dist.sample() * 0.18215

        # 2. Samplea paso de ruido
        noise = torch.randn_like(latents)
        timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (latents.shape[0],), device=latents.device).long()

        # 3. Agrega ruido
        noisy_latents = scheduler.add_noise(latents, noise, timesteps)

        # 4. Codifica texto (prompt)
        text_inputs = tokenizer(captions, padding="max_length", max_length=77, return_tensors="pt").to("cuda")
        text_embeddings = text_encoder(text_inputs.input_ids)[0]

        # 5. Codifica condición visual
        cond_feat = condition_encoder(conds)

        # 6. Pasa por ControlNet (tus bloques)
        x = pipe.unet.conv_in(noisy_latents)
        print("Shape 2",cond_feat.shape)
        print("Shape x 2",x.shape)
        x = control_block1(x, cond_feat)
        x = control_block2(x, cond_feat)

        # 7. Predice el ruido con U-Net congelado
        noise_pred = unet(
            x, timesteps, encoder_hidden_states=text_embeddings
        ).sample

        # 8. Loss entre ruido predicho y ruido real
        loss = torch.nn.functional.mse_loss(noise_pred, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"[Epoch {epoch}] Step {step} - Loss: {loss.item()}")

        # Visualización cada 100 pasos
        if step % 250 == 0:
            with torch.no_grad():
                pipe.unet = wrapper  # Asegura que pipe use el wrapper
                img_out = pipe(prompt=captions[0], condition=conds[0].unsqueeze(0)).images[0]
                display(img_out)

        del latents, x, cond_feat, noise_pred
        torch.cuda.empty_cache()
        step += 1  # ← avanzar contador

  0%|          | 0/833 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 128.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 126.12 MiB is free. Process 17477 has 14.62 GiB memory in use. Of the allocated memory 14.25 GiB is allocated by PyTorch, and 250.94 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
torch.save({
    "encoder": condition_encoder.state_dict(),
    "cb1": control_block1.state_dict(),
    "cb2": control_block2.state_dict()
}, "controlnet_minimal.pth")

#Inferencia

In [None]:
pipe.enable_attention_slicing()

example = dataset[0]
cond = transform(example["condition"]).unsqueeze(0).to("cuda")
caption = example["caption"]

# Ejecutar como siempre, pero ahora pasas la condición
with torch.no_grad():
    image = pipe(prompt=caption, condition=cond).images[0]

image.show()
