# SCEdit-pytorch
This is an implementation of [SCEdit: Efficient and Controllable Image Diffusion Generation via Skip Connection Editing](https://scedit.github.io/) by [mkshing](https://twitter.com/mk1stats).

- Code: https://github.com/mkshing/scedit-pytorch


## **Setup**

In [None]:
!nvidia-smi
!git clone https://mkshing@github.com/mkshing/scedit-pytorch.git
!pip install -r scedit-pytorch/requirements.txt
!pip install bitsandbytes
!pip install -U xformers torchvision --index-url https://download.pytorch.org/whl/cu121

In [None]:
#@markdown **(Optional) Login HuggingFace for `push_to_hub`**
from huggingface_hub import login
login()

In [None]:

# @markdown **(Optional) Login wandb**<br> If you don't use wandb for logging, make sure to remove `--report_to="wandb"`
!pip install wandb
!wandb login


## **Run SC-Tuner**

In this example, use 5 dog images as usual by downloading from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ).

In [None]:
#@title **Dataset**
from huggingface_hub import snapshot_download

local_dir = "./dog"
snapshot_download(
    "diffusers/dog-example",
    local_dir=local_dir, repo_type="dataset",
    ignore_patterns=".gitattributes",
)


In [None]:
# @title **Train:**
MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
INSTANCE_DIR="dog"
OUTPUT_DIR="scedit-trained-xl"

! accelerate launch scedit-pytorch/train_dreambooth_scedit_sdxl.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --mixed_precision="fp16" \
  --instance_prompt="a photo of sbu dog" \
  --resolution=1024 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=8 \
  --learning_rate=5e-5 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=1000 \
  --checkpointing_steps=100 \
  --validation_prompt="A photo of sbu dog in a bucket" \
  --validation_epochs=200 \
  --use_8bit_adam \
  --enable_xformers_memory_efficient_attention \
  --report_to="wandb" \
  --seed="0" \
  --push_to_hub
  # --gradient_checkpointing \


### **Inference**

In [None]:
# @markdown **load pipeline**
import sys
sys.path.append("/content/scedit-pytorch")
from huggingface_hub.repocard import RepoCard
from diffusers import DiffusionPipeline
import torch
from scedit_pytorch import UNet2DConditionModel, load_scedit_into_unet


base_model_id = "stabilityai/stable-diffusion-xl-base-1.0" # @param {type: "string"}
scedit_model_id = "/content/scedit-trained-xl" # @param {type: "string"}
scale = 1.0 # @param {type: "number"}
# card = RepoCard.load(scedit_model_id)
# base_model_id = card.data.to_dict()["base_model"]

# load unet with sctuner
unet = UNet2DConditionModel.from_pretrained(base_model_id, subfolder="unet")
unet.set_sctuner(scale=scale)
unet = load_scedit_into_unet(scedit_model_id, unet)
# load pipeline
pipe = DiffusionPipeline.from_pretrained(base_model_id, unet=unet)
pipe = pipe.to(device="cuda", dtype=torch.float16)

In [None]:
image = pipe("A picture of a sbu dog in a bucket", num_inference_steps=25).images[0]
image.save("sks_dog.png")
display(image)