## Requirements & Setup

In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append('../src')

In [2]:
!pip install -q -r requirements.txt


[notice] A new release of pip is available: 23.2.1 -> 23.3.1
[notice] To update, run: python.exe -m pip install --upgrade pip


## Config & Training 

In [7]:
from huggingface_hub import hf_hub_download

# realistic_vision_path = hf_hub_download(repo_id="SG161222/Realistic_Vision_V6.0_B1_noVAE", filename="Realistic_Vision_V6.0_NV_B1_inpainting.safetensors")
realistic_vision_path = hf_hub_download(repo_id="SG161222/Realistic_Vision_V5.1_noVAE", filename="Realistic_Vision_V5.1-inpainting.safetensors")
vae_path = hf_hub_download(repo_id="stabilityai/sd-vae-ft-mse-original", filename="vae-ft-mse-840000-ema-pruned.safetensors")

In [10]:
from src.config import DatasetConfig, Config, ModelConfig, WandbConfig, EvaluationConfig, TrainConfig, LoraConfig, PromptConfig

dataset_config = DatasetConfig(
    roboflow_api_key='HNXIsW3WwnidNDQZHexX',
    roboflow_workspace='arked',
    project_name='kvist_windows',
    dataset_version=7,
    image_size=512,
    normalize_images=False,
    scaling_pixels=25,
)

model_config = ModelConfig(
    model_path=realistic_vision_path,
    vae_path=vae_path,
)

wandb_config = WandbConfig(
    project_name='kvist_windows',
    entity_name='maidacundo',
    run_name='test_hough_loss',
) 

eval_config=EvaluationConfig(
    prompts=['kvist windows'],
    eval_epochs=20,
)

train_config=TrainConfig(
    checkpoint_folder=wandb_config.project_name + "_checkpoints",
    train_batch_size=4,
    unet_lr=3e-4,
    text_encoder_lr=3e-4,
    learning_rate=1e-3,
    scheduler_num_cycles=2,
    lora_total_steps=2000,
    ti_total_steps=1000,
    scheduler_warmup_steps=100,
    criterion='mlsd',
)

lora_config=LoraConfig(
    rank=8,
    alpha=1,
)

config = Config(
    dataset=dataset_config,
    model=model_config,
    wandb=wandb_config,
    eval=eval_config,
    train=train_config,
)

In [None]:
from src.training import train

train(config)