In [None]:
!pip install transformers diffusers accelerate peft

In [1]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )


In [2]:
from huggingface_hub import hf_hub_download

realistic_vision_path = hf_hub_download(repo_id="SG161222/Realistic_Vision_V5.1_noVAE", filename="Realistic_Vision_V5.1-inpainting.safetensors")
vae_path = hf_hub_download(repo_id="stabilityai/sd-vae-ft-mse-original", filename="vae-ft-mse-840000-ema-pruned.safetensors")

In [3]:
import sys
sys.path.append('../src')

In [4]:
from config import DatasetConfig, Config, ModelConfig, WandbConfig, EvaluationConfig, TrainConfig

dataset_config = DatasetConfig(
    roboflow_api_key='HNXIsW3WwnidNDQZHexX',
    roboflow_workspace='arked',
    project_name='facades-flzke',
    dataset_version=11,
    data_root='facades_data',
    image_size=512,
)

model_config = ModelConfig(
    model_path=realistic_vision_path,
    vae_path=vae_path,
)

wandb_config = WandbConfig(
    project_name='facades',
)

eval_config=EvaluationConfig(
    prompts=['white facade', 'brick facade'],
)

train_config=TrainConfig(
    checkpoint_folder = wandb_config.project_name + "_checkpoints",
    train_batch_size = 4,
    unet_lr=1e-4,
    text_encoder_lr=1e-4,
    scheduler_num_cycles=4,
    total_steps=1000,
)

config = Config(
    dataset=dataset_config,
    model=model_config,
    wandb=wandb_config,
    eval=eval_config,
    train=train_config,
)

In [5]:
from model import get_models

text_encoder, vae, unet, tokenizer, noise_scheduler, placeholder_token_ids = get_models(
    config.model.model_path,
    config.model.vae_path,
    device=config.device,
    load_from_safetensor=True,
)

loading VAE...
loading model...


`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.


In [9]:
from peft import LoraConfig, LoraModel

UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"}
UNET_TARGET_MODULES = ["to_q", "to_v", "to_k", "to_out.0", "ff.net.0.proj"] #, "proj_in", "conv1", "conv2"]

config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=UNET_TARGET_MODULES,
    lora_dropout=0.1,
    bias='none',
)

In [10]:
unet_lora = LoraModel(unet, config, "default")

In [11]:
print_trainable_parameters(unet_lora)

trainable params: 2492928 || all params: 862028292 || trainable%: 0.28919329250970804
