In [14]:
%load_ext autoreload
%autoreload 2
%cd ~/dev/imaginedriving/

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
/home/gasparyanartur/dev/imaginedriving


  bkms = self.shell.db.get('bookmarks', {})
  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [15]:
from typing import Optional, Dict, Any, Union, Tuple
from diffusers.models.controlnet import ControlNetOutput, ControlNetModel
import torch
from src.data import read_image
from diffusers import ControlNetModel, StableDiffusionControlNetImg2ImgPipeline, UNet2DConditionModel
from diffusers.loaders.lora import LoraLoaderMixin
from src.control_lora import ControlLoRAModel
import torch

In [42]:



class PeftCompatibleControlNet(ControlNetModel):
    def forward(
            self,
            sample: torch.FloatTensor,
            timestep: Union[torch.Tensor, float, int],
            encoder_hidden_states: torch.Tensor,
            controlnet_cond: torch.FloatTensor,
            conditioning_scale: float = 1.0,
            class_labels: Optional[torch.Tensor] = None,
            timestep_cond: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            guess_mode: bool = False,
            return_dict: bool = True,
        ) -> Union[ControlNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
            """
            The [`ControlNetModel`] forward method.

            Args:
                sample (`torch.FloatTensor`):
                    The noisy input tensor.
                timestep (`Union[torch.Tensor, float, int]`):
                    The number of timesteps to denoise an input.
                encoder_hidden_states (`torch.Tensor`):
                    The encoder hidden states.
                controlnet_cond (`torch.FloatTensor`):
                    The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
                conditioning_scale (`float`, defaults to `1.0`):
                    The scale factor for ControlNet outputs.
                class_labels (`torch.Tensor`, *optional*, defaults to `None`):
                    Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
                timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
                    Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
                    timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
                    embeddings.
                attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
                    An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
                    is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
                    negative values to the attention scores corresponding to "discard" tokens.
                added_cond_kwargs (`dict`):
                    Additional conditions for the Stable Diffusion XL UNet.
                cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
                    A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
                guess_mode (`bool`, defaults to `False`):
                    In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
                    you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
                return_dict (`bool`, defaults to `True`):
                    Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.

            Returns:
                [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
                    If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
                    returned where the first element is the sample tensor.
            """
            # check channel order
            channel_order = self.config.controlnet_conditioning_channel_order

            if channel_order == "rgb":
                # in rgb order by default
                ...
            elif channel_order == "bgr":
                controlnet_cond = torch.flip(controlnet_cond, dims=[1])
            else:
                raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")

            # prepare attention_mask
            if attention_mask is not None:
                attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
                attention_mask = attention_mask.unsqueeze(1)

            # 1. time
            timesteps = timestep
            if not torch.is_tensor(timesteps):
                # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
                # This would be a good case for the `match` statement (Python 3.10+)
                is_mps = sample.device.type == "mps"
                if isinstance(timestep, float):
                    dtype = torch.float32 if is_mps else torch.float64
                else:
                    dtype = torch.int32 if is_mps else torch.int64
                timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
            elif len(timesteps.shape) == 0:
                timesteps = timesteps[None].to(sample.device)

            # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
            timesteps = timesteps.expand(sample.shape[0])

            t_emb = self.time_proj(timesteps)

            # timesteps does not contain any weights and will always return f32 tensors
            # but time_embedding might actually be running in fp16. so we need to cast here.
            # there might be better ways to encapsulate this.
            t_emb = t_emb.to(dtype=sample.dtype)

            emb = self.time_embedding(t_emb, timestep_cond)
            aug_emb = None

            if self.class_embedding is not None:
                if class_labels is None:
                    raise ValueError("class_labels should be provided when num_class_embeds > 0")

                if self.config.class_embed_type == "timestep":
                    class_labels = self.time_proj(class_labels)

                class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
                emb = emb + class_emb

            if self.config.addition_embed_type is not None:
                if self.config.addition_embed_type == "text":
                    aug_emb = self.add_embedding(encoder_hidden_states)

                elif self.config.addition_embed_type == "text_time":
                    if "text_embeds" not in added_cond_kwargs:
                        raise ValueError(
                            f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
                        )
                    text_embeds = added_cond_kwargs.get("text_embeds")
                    if "time_ids" not in added_cond_kwargs:
                        raise ValueError(
                            f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
                        )
                    time_ids = added_cond_kwargs.get("time_ids")
                    time_embeds = self.add_time_proj(time_ids.flatten())
                    time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))

                    add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
                    add_embeds = add_embeds.to(emb.dtype)
                    aug_emb = self.add_embedding(add_embeds)

            emb = emb + aug_emb if aug_emb is not None else emb

            # 2. pre-process
            sample = self.conv_in(sample)

            controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
            sample = sample + controlnet_cond

            # 3. down
            down_block_res_samples = (sample,)
            for downsample_block in self.down_blocks:
                if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
                    sample, res_samples = downsample_block(
                        hidden_states=sample,
                        temb=emb,
                        encoder_hidden_states=encoder_hidden_states,
                        attention_mask=attention_mask,
                        cross_attention_kwargs=cross_attention_kwargs,
                    )
                else:
                    sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

                down_block_res_samples += res_samples

            # 4. mid
            if self.mid_block is not None:
                if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
                    sample = self.mid_block(
                        sample,
                        emb,
                        encoder_hidden_states=encoder_hidden_states,
                        attention_mask=attention_mask,
                        cross_attention_kwargs=cross_attention_kwargs,
                    )
                else:
                    sample = self.mid_block(sample, emb)

            # 5. Control net blocks

            controlnet_down_block_res_samples = ()

            #controlnet_down_blocks = next((b for a, b in self.controlnet_down_blocks.named_children() if a == "modules_to_save"))["default"]
            controlnet_down_blocks = next((b for a, b in self.controlnet_down_blocks.named_children() if a == "original_module"))
            for down_block_res_sample, controlnet_block in zip(down_block_res_samples, controlnet_down_blocks):
                down_block_res_sample = controlnet_block(down_block_res_sample)
                controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)

            down_block_res_samples = controlnet_down_block_res_samples

            mid_block_res_sample = self.controlnet_mid_block(sample)

            # 6. scaling
            if guess_mode and not self.config.global_pool_conditions:
                scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device)  # 0.1 to 1.0
                scales = scales * conditioning_scale
                down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
                mid_block_res_sample = mid_block_res_sample * scales[-1]  # last one
            else:
                down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
                mid_block_res_sample = mid_block_res_sample * conditioning_scale

            if self.config.global_pool_conditions:
                down_block_res_samples = [
                    torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
                ]
                mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)

            if not return_dict:
                return (down_block_res_samples, mid_block_res_sample)

            return ControlNetOutput(
                down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
            )

In [17]:

model_id = "stabilityai/stable-diffusion-2-1"
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=torch.float16)
controlnet = PeftCompatibleControlNet.from_unet(unet)
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
    model_id, unet=unet, controlnet=controlnet, torch_dtype=torch.float16
)



Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

In [18]:
pipe.unet

UNet2DConditionModel(
  (conv_in): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=320, out_features=1280, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
  )
  (down_blocks): ModuleList(
    (0): CrossAttnDownBlock2D(
      (attentions): ModuleList(
        (0-1): 2 x Transformer2DModel(
          (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
          (proj_in): Linear(in_features=320, out_features=320, bias=True)
          (transformer_blocks): ModuleList(
            (0): BasicTransformerBlock(
              (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
              (attn1): Attention(
                (to_q): Linear(in_features=320, out_features=320, bias=False)
                (to_k): Linear(in_features=320, out_features=320, bias=False)
                (to_v): Linear(in_features=320, out_f

In [19]:
pipe.controlnet

PeftCompatibleControlNet(
  (conv_in): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=320, out_features=1280, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
  )
  (controlnet_cond_embedding): ControlNetConditioningEmbedding(
    (conv_in): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (blocks): ModuleList(
      (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): Conv2d(32, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (4): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): Conv2d(96, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    )
    (conv_out): Conv2d(256,

In [20]:
from typing import Union, Iterable
import re

def get_matching(model, patterns: Iterable[Union[re.Pattern, str]] = (".*",)):
    for i, pattern in enumerate(patterns):
        if isinstance(pattern, str):
            patterns[i] = re.compile(pattern)

    li = []
    for name, mod in model.named_modules():
        for pattern in patterns:
            if pattern.match(name):
                li.append((name, mod))
    return li


get_matching(pipe.unet, [r".*\.to_[qkv]", r".*\.to_out.0", r""])

[('',
  UNet2DConditionModel(
    (conv_in): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (time_proj): Timesteps()
    (time_embedding): TimestepEmbedding(
      (linear_1): Linear(in_features=320, out_features=1280, bias=True)
      (act): SiLU()
      (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
    )
    (down_blocks): ModuleList(
      (0): CrossAttnDownBlock2D(
        (attentions): ModuleList(
          (0-1): 2 x Transformer2DModel(
            (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
            (proj_in): Linear(in_features=320, out_features=320, bias=True)
            (transformer_blocks): ModuleList(
              (0): BasicTransformerBlock(
                (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
                (attn1): Attention(
                  (to_q): Linear(in_features=320, out_features=320, bias=False)
                  (to_k): Linear(in_features=320, out_features=320, bias=False)
       

In [21]:
unet_target_ranks = {
    "downblocks": {"attn": 4, "resnet": 4, "ff": 8, "proj": 8},
    "midblocks": {"attn": 8, "resnet": 8, "ff": 16, "proj": 16},
    "upblocks": {"attn": 8, "resnet": 8, "ff": 16, "proj": 16},
}

controlnet_target_ranks = {
    "downblocks": {"attn": 8, "resnet": 8, "ff": 16, "proj": 16},
    "midblocks": {"attn": 8, "resnet": 8, "ff": 16, "proj": 16},
}


def parse_target_ranks(target_ranks, prefix=r""):
    parsed_targets = {}

    for name, item in target_ranks.items():
        if not item:
            continue

        match name:
            case "":
                continue

            case "downblocks":
                assert isinstance(item, dict)
                parsed_targets.update(
                    parse_target_ranks(item, rf"{prefix}.*down_blocks")
                )

            case "midblocks":
                assert isinstance(item, dict)
                parsed_targets.update(
                    parse_target_ranks(item, rf"{prefix}.*mid_blocks")
                )

            case "upblocks":
                assert isinstance(item, dict)
                parsed_targets.update(
                    parse_target_ranks(item, rf"{prefix}.*up_blocks")
                )

            case "attn":
                assert isinstance(item, int)
                parsed_targets[f"{prefix}.*attn.*to_[kvq]"] = item
                parsed_targets[ rf"{prefix}.*attn.*to_out\.0"] = item


            case "resnet":
                assert isinstance(item, int)
                parsed_targets[rf"{prefix}.*resnets.*conv\d*"] = item
                parsed_targets[rf"{prefix}.*resnets.*time_emb_proj"] = item

            case "ff":
                assert isinstance(item, int)
                parsed_targets[rf"{prefix}.*ff\.net\.0\.proj"] = item
                parsed_targets[rf"{prefix}.*ff\.net\.2"] = item

            case "proj":
                assert isinstance(item, int)
                parsed_targets[rf"{prefix}.*attentions.*proj_in"] = item
                parsed_targets[rf"{prefix}.*attentions.*proj_out"] = item

            case "_":
                raise NotImplementedError(f"Unrecognized target: {name}")

    return parsed_targets


unet_ranks = parse_target_ranks(unet_target_ranks)
controlnet_ranks = parse_target_ranks(controlnet_target_ranks)

parsed_ranks = {"unet": unet_ranks, "controlnet": controlnet_ranks}
parsed_ranks

{'unet': {'.*down_blocks.*attn.*to_[kvq]': 4,
  '.*down_blocks.*attn.*to_out\\.0': 4,
  '.*down_blocks.*resnets.*conv\\d*': 4,
  '.*down_blocks.*resnets.*time_emb_proj': 4,
  '.*down_blocks.*ff\\.net\\.0\\.proj': 8,
  '.*down_blocks.*ff\\.net\\.2': 8,
  '.*down_blocks.*attentions.*proj_in': 8,
  '.*down_blocks.*attentions.*proj_out': 8,
  '.*mid_blocks.*attn.*to_[kvq]': 8,
  '.*mid_blocks.*attn.*to_out\\.0': 8,
  '.*mid_blocks.*resnets.*conv\\d*': 8,
  '.*mid_blocks.*resnets.*time_emb_proj': 8,
  '.*mid_blocks.*ff\\.net\\.0\\.proj': 16,
  '.*mid_blocks.*ff\\.net\\.2': 16,
  '.*mid_blocks.*attentions.*proj_in': 16,
  '.*mid_blocks.*attentions.*proj_out': 16,
  '.*up_blocks.*attn.*to_[kvq]': 8,
  '.*up_blocks.*attn.*to_out\\.0': 8,
  '.*up_blocks.*resnets.*conv\\d*': 8,
  '.*up_blocks.*resnets.*time_emb_proj': 8,
  '.*up_blocks.*ff\\.net\\.0\\.proj': 16,
  '.*up_blocks.*ff\\.net\\.2': 16,
  '.*up_blocks.*attentions.*proj_in': 16,
  '.*up_blocks.*attentions.*proj_out': 16},
 'controlnet':

In [22]:
list(unet_ranks.keys())

['.*down_blocks.*attn.*to_[kvq]',
 '.*down_blocks.*attn.*to_out\\.0',
 '.*down_blocks.*resnets.*conv\\d*',
 '.*down_blocks.*resnets.*time_emb_proj',
 '.*down_blocks.*ff\\.net\\.0\\.proj',
 '.*down_blocks.*ff\\.net\\.2',
 '.*down_blocks.*attentions.*proj_in',
 '.*down_blocks.*attentions.*proj_out',
 '.*mid_blocks.*attn.*to_[kvq]',
 '.*mid_blocks.*attn.*to_out\\.0',
 '.*mid_blocks.*resnets.*conv\\d*',
 '.*mid_blocks.*resnets.*time_emb_proj',
 '.*mid_blocks.*ff\\.net\\.0\\.proj',
 '.*mid_blocks.*ff\\.net\\.2',
 '.*mid_blocks.*attentions.*proj_in',
 '.*mid_blocks.*attentions.*proj_out',
 '.*up_blocks.*attn.*to_[kvq]',
 '.*up_blocks.*attn.*to_out\\.0',
 '.*up_blocks.*resnets.*conv\\d*',
 '.*up_blocks.*resnets.*time_emb_proj',
 '.*up_blocks.*ff\\.net\\.0\\.proj',
 '.*up_blocks.*ff\\.net\\.2',
 '.*up_blocks.*attentions.*proj_in',
 '.*up_blocks.*attentions.*proj_out']

In [23]:
from peft import get_peft_model, LoraConfig, inject_adapter_in_model


peft_unet_conf = LoraConfig(
    r=8,
    init_lora_weights="gaussian",
    target_modules="|".join(unet_ranks.keys()),
    rank_pattern=unet_ranks
)

peft_controlnet_conf = LoraConfig(
    r=8,
    init_lora_weights="gaussian",
    target_modules="|".join(controlnet_ranks.keys()),
    rank_pattern=controlnet_ranks,
    modules_to_save=["controlnet_down_blocks", "controlnet_mid_block", "controlnet_cond_embedding", ]
)

## Peft Interface

In [24]:
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=torch.float16)
peft_unet = get_peft_model(unet, peft_unet_conf)
peft_unet.print_trainable_parameters()
peft_unet

trainable params: 6,965,760 || all params: 872,876,484 || trainable%: 0.7980235609142611


PeftModel(
  (base_model): LoraModel(
    (model): UNet2DConditionModel(
      (conv_in): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (time_proj): Timesteps()
      (time_embedding): TimestepEmbedding(
        (linear_1): Linear(in_features=320, out_features=1280, bias=True)
        (act): SiLU()
        (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (down_blocks): ModuleList(
        (0): CrossAttnDownBlock2D(
          (attentions): ModuleList(
            (0-1): 2 x Transformer2DModel(
              (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
              (proj_in): Linear(
                in_features=320, out_features=320, bias=True
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=320, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
        

In [25]:
peft_controlnet = ControlNetModel.from_unet(peft_unet)
peft_controlnet = get_peft_model(controlnet, peft_controlnet_conf)
peft_controlnet.print_trainable_parameters()
peft_controlnet

RuntimeError: Error(s) in loading state_dict for ModuleList:
	Unexpected key(s) in state_dict: "0.attentions.0.proj_in.lora_A.default.weight", "0.attentions.0.proj_in.lora_B.default.weight", "0.attentions.0.transformer_blocks.0.attn1.to_q.lora_A.default.weight", "0.attentions.0.transformer_blocks.0.attn1.to_q.lora_B.default.weight", "0.attentions.0.transformer_blocks.0.attn1.to_k.lora_A.default.weight", "0.attentions.0.transformer_blocks.0.attn1.to_k.lora_B.default.weight", "0.attentions.0.transformer_blocks.0.attn1.to_v.lora_A.default.weight", "0.attentions.0.transformer_blocks.0.attn1.to_v.lora_B.default.weight", "0.attentions.0.transformer_blocks.0.attn1.to_out.0.lora_A.default.weight", "0.attentions.0.transformer_blocks.0.attn1.to_out.0.lora_B.default.weight", "0.attentions.0.transformer_blocks.0.attn2.to_q.lora_A.default.weight", "0.attentions.0.transformer_blocks.0.attn2.to_q.lora_B.default.weight", "0.attentions.0.transformer_blocks.0.attn2.to_k.lora_A.default.weight", "0.attentions.0.transformer_blocks.0.attn2.to_k.lora_B.default.weight", "0.attentions.0.transformer_blocks.0.attn2.to_v.lora_A.default.weight", "0.attentions.0.transformer_blocks.0.attn2.to_v.lora_B.default.weight", "0.attentions.0.transformer_blocks.0.attn2.to_out.0.lora_A.default.weight", "0.attentions.0.transformer_blocks.0.attn2.to_out.0.lora_B.default.weight", "0.attentions.0.transformer_blocks.0.ff.net.0.proj.lora_A.default.weight", "0.attentions.0.transformer_blocks.0.ff.net.0.proj.lora_B.default.weight", "0.attentions.0.transformer_blocks.0.ff.net.2.lora_A.default.weight", "0.attentions.0.transformer_blocks.0.ff.net.2.lora_B.default.weight", "0.attentions.0.proj_out.lora_A.default.weight", "0.attentions.0.proj_out.lora_B.default.weight", "0.attentions.1.proj_in.lora_A.default.weight", "0.attentions.1.proj_in.lora_B.default.weight", "0.attentions.1.transformer_blocks.0.attn1.to_q.lora_A.default.weight", "0.attentions.1.transformer_blocks.0.attn1.to_q.lora_B.default.weight", "0.attentions.1.transformer_blocks.0.attn1.to_k.lora_A.default.weight", "0.attentions.1.transformer_blocks.0.attn1.to_k.lora_B.default.weight", "0.attentions.1.transformer_blocks.0.attn1.to_v.lora_A.default.weight", "0.attentions.1.transformer_blocks.0.attn1.to_v.lora_B.default.weight", "0.attentions.1.transformer_blocks.0.attn1.to_out.0.lora_A.default.weight", "0.attentions.1.transformer_blocks.0.attn1.to_out.0.lora_B.default.weight", "0.attentions.1.transformer_blocks.0.attn2.to_q.lora_A.default.weight", "0.attentions.1.transformer_blocks.0.attn2.to_q.lora_B.default.weight", "0.attentions.1.transformer_blocks.0.attn2.to_k.lora_A.default.weight", "0.attentions.1.transformer_blocks.0.attn2.to_k.lora_B.default.weight", "0.attentions.1.transformer_blocks.0.attn2.to_v.lora_A.default.weight", "0.attentions.1.transformer_blocks.0.attn2.to_v.lora_B.default.weight", "0.attentions.1.transformer_blocks.0.attn2.to_out.0.lora_A.default.weight", "0.attentions.1.transformer_blocks.0.attn2.to_out.0.lora_B.default.weight", "0.attentions.1.transformer_blocks.0.ff.net.0.proj.lora_A.default.weight", "0.attentions.1.transformer_blocks.0.ff.net.0.proj.lora_B.default.weight", "0.attentions.1.transformer_blocks.0.ff.net.2.lora_A.default.weight", "0.attentions.1.transformer_blocks.0.ff.net.2.lora_B.default.weight", "0.attentions.1.proj_out.lora_A.default.weight", "0.attentions.1.proj_out.lora_B.default.weight", "0.resnets.0.conv1.lora_A.default.weight", "0.resnets.0.conv1.lora_B.default.weight", "0.resnets.0.time_emb_proj.lora_A.default.weight", "0.resnets.0.time_emb_proj.lora_B.default.weight", "0.resnets.0.conv2.lora_A.default.weight", "0.resnets.0.conv2.lora_B.default.weight", "0.resnets.1.conv1.lora_A.default.weight", "0.resnets.1.conv1.lora_B.default.weight", "0.resnets.1.time_emb_proj.lora_A.default.weight", "0.resnets.1.time_emb_proj.lora_B.default.weight", "0.resnets.1.conv2.lora_A.default.weight", "0.resnets.1.conv2.lora_B.default.weight", "1.attentions.0.proj_in.lora_A.default.weight", "1.attentions.0.proj_in.lora_B.default.weight", "1.attentions.0.transformer_blocks.0.attn1.to_q.lora_A.default.weight", "1.attentions.0.transformer_blocks.0.attn1.to_q.lora_B.default.weight", "1.attentions.0.transformer_blocks.0.attn1.to_k.lora_A.default.weight", "1.attentions.0.transformer_blocks.0.attn1.to_k.lora_B.default.weight", "1.attentions.0.transformer_blocks.0.attn1.to_v.lora_A.default.weight", "1.attentions.0.transformer_blocks.0.attn1.to_v.lora_B.default.weight", "1.attentions.0.transformer_blocks.0.attn1.to_out.0.lora_A.default.weight", "1.attentions.0.transformer_blocks.0.attn1.to_out.0.lora_B.default.weight", "1.attentions.0.transformer_blocks.0.attn2.to_q.lora_A.default.weight", "1.attentions.0.transformer_blocks.0.attn2.to_q.lora_B.default.weight", "1.attentions.0.transformer_blocks.0.attn2.to_k.lora_A.default.weight", "1.attentions.0.transformer_blocks.0.attn2.to_k.lora_B.default.weight", "1.attentions.0.transformer_blocks.0.attn2.to_v.lora_A.default.weight", "1.attentions.0.transformer_blocks.0.attn2.to_v.lora_B.default.weight", "1.attentions.0.transformer_blocks.0.attn2.to_out.0.lora_A.default.weight", "1.attentions.0.transformer_blocks.0.attn2.to_out.0.lora_B.default.weight", "1.attentions.0.transformer_blocks.0.ff.net.0.proj.lora_A.default.weight", "1.attentions.0.transformer_blocks.0.ff.net.0.proj.lora_B.default.weight", "1.attentions.0.transformer_blocks.0.ff.net.2.lora_A.default.weight", "1.attentions.0.transformer_blocks.0.ff.net.2.lora_B.default.weight", "1.attentions.0.proj_out.lora_A.default.weight", "1.attentions.0.proj_out.lora_B.default.weight", "1.attentions.1.proj_in.lora_A.default.weight", "1.attentions.1.proj_in.lora_B.default.weight", "1.attentions.1.transformer_blocks.0.attn1.to_q.lora_A.default.weight", "1.attentions.1.transformer_blocks.0.attn1.to_q.lora_B.default.weight", "1.attentions.1.transformer_blocks.0.attn1.to_k.lora_A.default.weight", "1.attentions.1.transformer_blocks.0.attn1.to_k.lora_B.default.weight", "1.attentions.1.transformer_blocks.0.attn1.to_v.lora_A.default.weight", "1.attentions.1.transformer_blocks.0.attn1.to_v.lora_B.default.weight", "1.attentions.1.transformer_blocks.0.attn1.to_out.0.lora_A.default.weight", "1.attentions.1.transformer_blocks.0.attn1.to_out.0.lora_B.default.weight", "1.attentions.1.transformer_blocks.0.attn2.to_q.lora_A.default.weight", "1.attentions.1.transformer_blocks.0.attn2.to_q.lora_B.default.weight", "1.attentions.1.transformer_blocks.0.attn2.to_k.lora_A.default.weight", "1.attentions.1.transformer_blocks.0.attn2.to_k.lora_B.default.weight", "1.attentions.1.transformer_blocks.0.attn2.to_v.lora_A.default.weight", "1.attentions.1.transformer_blocks.0.attn2.to_v.lora_B.default.weight", "1.attentions.1.transformer_blocks.0.attn2.to_out.0.lora_A.default.weight", "1.attentions.1.transformer_blocks.0.attn2.to_out.0.lora_B.default.weight", "1.attentions.1.transformer_blocks.0.ff.net.0.proj.lora_A.default.weight", "1.attentions.1.transformer_blocks.0.ff.net.0.proj.lora_B.default.weight", "1.attentions.1.transformer_blocks.0.ff.net.2.lora_A.default.weight", "1.attentions.1.transformer_blocks.0.ff.net.2.lora_B.default.weight", "1.attentions.1.proj_out.lora_A.default.weight", "1.attentions.1.proj_out.lora_B.default.weight", "1.resnets.0.conv1.lora_A.default.weight", "1.resnets.0.conv1.lora_B.default.weight", "1.resnets.0.time_emb_proj.lora_A.default.weight", "1.resnets.0.time_emb_proj.lora_B.default.weight", "1.resnets.0.conv2.lora_A.default.weight", "1.resnets.0.conv2.lora_B.default.weight", "1.resnets.1.conv1.lora_A.default.weight", "1.resnets.1.conv1.lora_B.default.weight", "1.resnets.1.time_emb_proj.lora_A.default.weight", "1.resnets.1.time_emb_proj.lora_B.default.weight", "1.resnets.1.conv2.lora_A.default.weight", "1.resnets.1.conv2.lora_B.default.weight", "2.attentions.0.proj_in.lora_A.default.weight", "2.attentions.0.proj_in.lora_B.default.weight", "2.attentions.0.transformer_blocks.0.attn1.to_q.lora_A.default.weight", "2.attentions.0.transformer_blocks.0.attn1.to_q.lora_B.default.weight", "2.attentions.0.transformer_blocks.0.attn1.to_k.lora_A.default.weight", "2.attentions.0.transformer_blocks.0.attn1.to_k.lora_B.default.weight", "2.attentions.0.transformer_blocks.0.attn1.to_v.lora_A.default.weight", "2.attentions.0.transformer_blocks.0.attn1.to_v.lora_B.default.weight", "2.attentions.0.transformer_blocks.0.attn1.to_out.0.lora_A.default.weight", "2.attentions.0.transformer_blocks.0.attn1.to_out.0.lora_B.default.weight", "2.attentions.0.transformer_blocks.0.attn2.to_q.lora_A.default.weight", "2.attentions.0.transformer_blocks.0.attn2.to_q.lora_B.default.weight", "2.attentions.0.transformer_blocks.0.attn2.to_k.lora_A.default.weight", "2.attentions.0.transformer_blocks.0.attn2.to_k.lora_B.default.weight", "2.attentions.0.transformer_blocks.0.attn2.to_v.lora_A.default.weight", "2.attentions.0.transformer_blocks.0.attn2.to_v.lora_B.default.weight", "2.attentions.0.transformer_blocks.0.attn2.to_out.0.lora_A.default.weight", "2.attentions.0.transformer_blocks.0.attn2.to_out.0.lora_B.default.weight", "2.attentions.0.transformer_blocks.0.ff.net.0.proj.lora_A.default.weight", "2.attentions.0.transformer_blocks.0.ff.net.0.proj.lora_B.default.weight", "2.attentions.0.transformer_blocks.0.ff.net.2.lora_A.default.weight", "2.attentions.0.transformer_blocks.0.ff.net.2.lora_B.default.weight", "2.attentions.0.proj_out.lora_A.default.weight", "2.attentions.0.proj_out.lora_B.default.weight", "2.attentions.1.proj_in.lora_A.default.weight", "2.attentions.1.proj_in.lora_B.default.weight", "2.attentions.1.transformer_blocks.0.attn1.to_q.lora_A.default.weight", "2.attentions.1.transformer_blocks.0.attn1.to_q.lora_B.default.weight", "2.attentions.1.transformer_blocks.0.attn1.to_k.lora_A.default.weight", "2.attentions.1.transformer_blocks.0.attn1.to_k.lora_B.default.weight", "2.attentions.1.transformer_blocks.0.attn1.to_v.lora_A.default.weight", "2.attentions.1.transformer_blocks.0.attn1.to_v.lora_B.default.weight", "2.attentions.1.transformer_blocks.0.attn1.to_out.0.lora_A.default.weight", "2.attentions.1.transformer_blocks.0.attn1.to_out.0.lora_B.default.weight", "2.attentions.1.transformer_blocks.0.attn2.to_q.lora_A.default.weight", "2.attentions.1.transformer_blocks.0.attn2.to_q.lora_B.default.weight", "2.attentions.1.transformer_blocks.0.attn2.to_k.lora_A.default.weight", "2.attentions.1.transformer_blocks.0.attn2.to_k.lora_B.default.weight", "2.attentions.1.transformer_blocks.0.attn2.to_v.lora_A.default.weight", "2.attentions.1.transformer_blocks.0.attn2.to_v.lora_B.default.weight", "2.attentions.1.transformer_blocks.0.attn2.to_out.0.lora_A.default.weight", "2.attentions.1.transformer_blocks.0.attn2.to_out.0.lora_B.default.weight", "2.attentions.1.transformer_blocks.0.ff.net.0.proj.lora_A.default.weight", "2.attentions.1.transformer_blocks.0.ff.net.0.proj.lora_B.default.weight", "2.attentions.1.transformer_blocks.0.ff.net.2.lora_A.default.weight", "2.attentions.1.transformer_blocks.0.ff.net.2.lora_B.default.weight", "2.attentions.1.proj_out.lora_A.default.weight", "2.attentions.1.proj_out.lora_B.default.weight", "2.resnets.0.conv1.lora_A.default.weight", "2.resnets.0.conv1.lora_B.default.weight", "2.resnets.0.time_emb_proj.lora_A.default.weight", "2.resnets.0.time_emb_proj.lora_B.default.weight", "2.resnets.0.conv2.lora_A.default.weight", "2.resnets.0.conv2.lora_B.default.weight", "2.resnets.1.conv1.lora_A.default.weight", "2.resnets.1.conv1.lora_B.default.weight", "2.resnets.1.time_emb_proj.lora_A.default.weight", "2.resnets.1.time_emb_proj.lora_B.default.weight", "2.resnets.1.conv2.lora_A.default.weight", "2.resnets.1.conv2.lora_B.default.weight", "3.resnets.0.conv1.lora_A.default.weight", "3.resnets.0.conv1.lora_B.default.weight", "3.resnets.0.time_emb_proj.lora_A.default.weight", "3.resnets.0.time_emb_proj.lora_B.default.weight", "3.resnets.0.conv2.lora_A.default.weight", "3.resnets.0.conv2.lora_B.default.weight", "3.resnets.1.conv1.lora_A.default.weight", "3.resnets.1.conv1.lora_B.default.weight", "3.resnets.1.time_emb_proj.lora_A.default.weight", "3.resnets.1.time_emb_proj.lora_B.default.weight", "3.resnets.1.conv2.lora_A.default.weight", "3.resnets.1.conv2.lora_B.default.weight". 

In [None]:
peft_controlnet.save_pretrained("save_controlnet")

In [None]:
pipe_peft = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
    model_id, unet=peft_unet, controlnet=peft_controlnet, torch_dtype=torch.float16
)



Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

In [None]:

im1 = read_image("reference/pandaset-01/renders/0m/01.jpg").to(dtype=torch.float32, device=torch.device("cuda"))
im2 = read_image("reference/pandaset-01/renders/2m/01.jpg").to(dtype=torch.float32, device=torch.device("cuda"))
pipe_peft(prompt="", image=im1, mask_image=im2)

AssertionError: 

## Inject interface

In [43]:
unet_injected = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=torch.float16)
controlnet_injected = PeftCompatibleControlNet.from_unet(unet_injected)

In [27]:
unet_injected = inject_adapter_in_model(peft_unet_conf, unet_injected)

In [44]:
controlnet_injected = inject_adapter_in_model(peft_controlnet_conf, controlnet_injected)

In [None]:
from diffusers import StableDiffusionImg2ImgPipeline

pipe_injected_base = StableDiffusionImg2ImgPipeline.from_pretrained(model_id, unet=unet_injected.to(torch.float32), torch_dtype=torch.float32)
pipe_injected_base.to("cuda")



Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

StableDiffusionImg2ImgPipeline {
  "_class_name": "StableDiffusionImg2ImgPipeline",
  "_diffusers_version": "0.27.2",
  "_name_or_path": "stabilityai/stable-diffusion-2-1",
  "feature_extractor": [
    "transformers",
    "CLIPImageProcessor"
  ],
  "image_encoder": [
    null,
    null
  ],
  "requires_safety_checker": false,
  "safety_checker": [
    null,
    null
  ],
  "scheduler": [
    "diffusers",
    "DDIMScheduler"
  ],
  "text_encoder": [
    "transformers",
    "CLIPTextModel"
  ],
  "tokenizer": [
    "transformers",
    "CLIPTokenizer"
  ],
  "unet": [
    "diffusers",
    "UNet2DConditionModel"
  ],
  "vae": [
    "diffusers",
    "AutoencoderKL"
  ]
}

In [39]:
import torchvision.transforms.v2 as tvtf
transforms = tvtf.Compose((tvtf.CenterCrop((1024, 1024)), tvtf.Resize((512, 512))))
im1 = read_image("reference/pandaset-01/renders/0m/01.jpg", tf_pipeline=transforms).to(dtype=torch.float32, device=torch.device("cuda"))[None, ...]
im2 = read_image("reference/pandaset-01/renders/2m/01.jpg", tf_pipeline=transforms).to(dtype=torch.float32, device=torch.device("cuda"))[None, ...]

pipe_injected_base(prompt="", image=im1)

NameError: name 'pipe_injected_base' is not defined

In [45]:
pipe_injected = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
    model_id, unet=unet_injected, controlnet=controlnet_injected, torch_dtype=torch.float32
)
pipe_injected.to("cuda")

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 194.00 MiB. GPU 0 has a total capacity of 23.67 GiB of which 100.06 MiB is free. Including non-PyTorch memory, this process has 22.51 GiB memory in use. Of the allocated memory 22.02 GiB is allocated by PyTorch, and 176.78 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:

help(pipe_injected.controlnet.controlnet_down_blocks)

Help on ModulesToSaveWrapper in module peft.utils.other object:

class ModulesToSaveWrapper(torch.nn.modules.module.Module)
 |  ModulesToSaveWrapper(module_to_save, adapter_name)
 |  
 |  Method resolution order:
 |      ModulesToSaveWrapper
 |      torch.nn.modules.module.Module
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __init__(self, module_to_save, adapter_name)
 |      Initialize internal Module state, shared by both nn.Module and ScriptModule.
 |  
 |  enable_adapters(self, enabled: bool)
 |      Toggle the enabling and disabling of adapters
 |      
 |      Takes care of setting the requires_grad flag for the adapter weights.
 |      
 |      Args:
 |          enabled (bool): True to enable adapters, False to disable adapters
 |  
 |  forward(self, *args, **kwargs)
 |      Define the computation performed at every call.
 |      
 |      Should be overridden by all subclasses.
 |      
 |      .. note::
 |          Although the recipe for forward pass needs t

In [None]:
pipe_injected.controlnet.controlnet_down_blocks

ModulesToSaveWrapper(
  (original_module): ModuleList(
    (0-3): 4 x Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
    (4-6): 3 x Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
    (7-11): 5 x Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
  )
  (modules_to_save): ModuleDict(
    (default): ModuleList(
      (0-3): 4 x Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
      (4-6): 3 x Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
      (7-11): 5 x Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
    )
  )
)

In [46]:
import torchvision.transforms.v2 as tvtf
transforms = tvtf.Compose((tvtf.CenterCrop((1024, 1024)), tvtf.Resize((512, 512))))
im1 = read_image("reference/pandaset-01/renders/0m/01.jpg", tf_pipeline=transforms).to(dtype=torch.float32, device=torch.device("cuda"))[None, ...]
im2 = read_image("reference/pandaset-01/renders/2m/01.jpg", tf_pipeline=transforms).to(dtype=torch.float32, device=torch.device("cuda"))[None, ...]

pipe_injected(prompt="", image=im1, control_image=im2)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

In [None]:
im1.shape, im2.shape

(torch.Size([1, 3, 1080, 1920]), torch.Size([1, 3, 1080, 1920]))