Skip to content

Commit

Permalink
Basic inpainting training to LoRa PTI
Browse files Browse the repository at this point in the history
  • Loading branch information
levi committed Jan 29, 2023
1 parent 1707928 commit 5eb9880
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 5 deletions.
85 changes: 80 additions & 5 deletions lora_diffusion/cli_lora_pti.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,60 @@ def collate_fn(examples):

return train_dataloader

def inpainting_dataloader(train_dataset, train_batch_size, tokenizer, vae, text_encoder):
def collate_fn(examples):
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]
mask_values = [example["instance_masks"] for example in examples]
masked_image_values = [example["instance_masked_images"] for example in examples]

# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if examples[0].get("class_prompt_ids", None) is not None:
input_ids += [example["class_prompt_ids"] for example in examples]
pixel_values += [example["class_images"] for example in examples]
mask_values += [example["class_masks"] for example in examples]
masked_image_values += [example["class_masked_images"] for example in examples]

pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
mask_values = torch.stack(mask_values).to(memory_format=torch.contiguous_format).float()
masked_image_values = torch.stack(masked_image_values).to(memory_format=torch.contiguous_format).float()

input_ids = tokenizer.pad(
{"input_ids": input_ids},
padding="max_length",
max_length=tokenizer.model_max_length,
return_tensors="pt",
).input_ids

batch = {
"input_ids": input_ids,
"pixel_values": pixel_values,
"mask_values": mask_values,
"masked_image_values": masked_image_values
}

if examples[0].get("mask", None) is not None:
batch["mask"] = torch.stack([example["mask"] for example in examples])

return batch

train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=train_batch_size,
shuffle=True,
collate_fn=collate_fn,
)

return train_dataloader

def loss_step(
batch,
unet,
vae,
text_encoder,
scheduler,
train_inpainting=False,
t_mutliplier=1.0,
mixed_precision=False,
mask_temperature=1.0,
Expand All @@ -186,6 +233,16 @@ def loss_step(
).latent_dist.sample()
latents = latents * 0.18215

if train_inpainting:
masked_image_latents = vae.encode(
batch["masked_image_values"].to(dtype=weight_dtype).to(unet.device)
).latent_dist.sample()
masked_image_latents = masked_image_latents * 0.18215
mask = F.interpolate(
batch["mask_values"].to(dtype=weight_dtype).to(unet.device),
scale_factor=1/8
)

noise = torch.randn_like(latents)
bsz = latents.shape[0]

Expand All @@ -199,21 +256,26 @@ def loss_step(

noisy_latents = scheduler.add_noise(latents, noise, timesteps)

if train_inpainting:
latent_model_input = torch.cat([noisy_latents, mask, masked_image_latents], dim=1)
else:
latent_model_input = noisy_latents

if mixed_precision:
with torch.cuda.amp.autocast():

encoder_hidden_states = text_encoder(
batch["input_ids"].to(text_encoder.device)
)[0]

model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
model_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
else:

encoder_hidden_states = text_encoder(
batch["input_ids"].to(text_encoder.device)
)[0]

model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
model_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample

if scheduler.config.prediction_type == "epsilon":
target = noise
Expand Down Expand Up @@ -270,6 +332,7 @@ def train_inversion(
log_wandb: bool = False,
wandb_log_prompt_cnt: int = 10,
class_token: str = "person",
train_inpainting: bool = False,
mixed_precision: bool = False,
clip_ti_decay: bool = True,
):
Expand Down Expand Up @@ -302,6 +365,7 @@ def train_inversion(
vae,
text_encoder,
scheduler,
train_inpainting=train_inpainting,
mixed_precision=mixed_precision,
)
/ accum_iter
Expand Down Expand Up @@ -423,6 +487,7 @@ def perform_tuning(
lora_unet_target_modules,
lora_clip_target_modules,
mask_temperature,
train_inpainting,
):

progress_bar = tqdm(range(num_steps))
Expand All @@ -446,6 +511,7 @@ def perform_tuning(
vae,
text_encoder,
scheduler,
train_inpainting=train_inpainting,
t_mutliplier=0.8,
mixed_precision=True,
mask_temperature=mask_temperature,
Expand Down Expand Up @@ -508,6 +574,7 @@ def train(
stochastic_attribute: Optional[str] = None,
perform_inversion: bool = True,
use_template: Literal[None, "object", "style"] = None,
train_inpainting: bool = False,
placeholder_tokens: str = "",
placeholder_token_at_data: Optional[str] = None,
initializer_tokens: Optional[str] = None,
Expand Down Expand Up @@ -650,13 +717,19 @@ def train(
color_jitter=color_jitter,
use_face_segmentation_condition=use_face_segmentation_condition,
use_mask_captioned_data=use_mask_captioned_data,
train_inpainting=train_inpainting,
)

train_dataset.blur_amount = 200

train_dataloader = text2img_dataloader(
train_dataset, train_batch_size, tokenizer, vae, text_encoder
)
if train_inpainting:
train_dataloader = inpainting_dataloader(
train_dataset, train_batch_size, tokenizer, vae, text_encoder
)
else:
train_dataloader = text2img_dataloader(
train_dataset, train_batch_size, tokenizer, vae, text_encoder
)

index_no_updates = torch.arange(len(tokenizer)) != -1

Expand Down Expand Up @@ -710,6 +783,7 @@ def train(
log_wandb=log_wandb,
wandb_log_prompt_cnt=wandb_log_prompt_cnt,
class_token=class_token,
train_inpainting=train_inpainting,
mixed_precision=False,
tokenizer=tokenizer,
clip_ti_decay=clip_ti_decay,
Expand Down Expand Up @@ -807,6 +881,7 @@ def train(
lora_unet_target_modules=lora_unet_target_modules,
lora_clip_target_modules=lora_clip_target_modules,
mask_temperature=mask_temperature,
train_inpainting=train_inpainting,
)


Expand Down
31 changes: 31 additions & 0 deletions lora_diffusion/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import cv2
import numpy as np
from PIL import Image, ImageFilter
from torch import zeros_like
from torch.utils.data import Dataset
from torchvision import transforms
import glob
Expand Down Expand Up @@ -85,6 +86,29 @@ def _shuffle(lis):
return random.sample(lis, len(lis))


def _get_cutout_holes(height, width, min_holes=8, max_holes=32, min_height=16, max_height=128, min_width=16, max_width=128):
holes = []
for _n in range(random.randint(min_holes, max_holes)):
hole_height = random.randint(min_height, max_height)
hole_width = random.randint(min_width, max_width)
y1 = random.randint(0, height - hole_height)
x1 = random.randint(0, width - hole_width)
y2 = y1 + hole_height
x2 = x1 + hole_width
holes.append((x1, y1, x2, y2))
return holes


def _generate_random_mask(image):
mask = zeros_like(image[:1])
holes = _get_cutout_holes(mask.shape[1], mask.shape[2])
for (x1, y1, x2, y2) in holes:
mask[:, y1:y2, x1:x2] = 1.
if random.uniform(0, 1) < 0.25:
mask.fill_(1.)
masked_image = image * (mask < 0.5)
return mask, masked_image

class PivotalTuningDatasetCapation(Dataset):
"""
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
Expand All @@ -106,11 +130,13 @@ def __init__(
resize=True,
use_mask_captioned_data=False,
use_face_segmentation_condition=False,
train_inpainting=False,
blur_amount: int = 70,
):
self.size = size
self.tokenizer = tokenizer
self.resize = resize
self.train_inpainting = train_inpainting

instance_data_root = Path(instance_data_root)
if not instance_data_root.exists():
Expand Down Expand Up @@ -239,6 +265,9 @@ def __getitem__(self, index):
instance_image = instance_image.convert("RGB")
example["instance_images"] = self.image_transforms(instance_image)

if self.train_inpainting:
example["instance_masks"], example["instance_masked_images"] = _generate_random_mask(example["instance_images"])

if self.use_template:
assert self.token_map is not None
input_tok = list(self.token_map.values())[0]
Expand Down Expand Up @@ -283,6 +312,8 @@ def __getitem__(self, index):
if not class_image.mode == "RGB":
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
if self.train_inpainting:
example["class_masks"], example["class_masked_images"] = _generate_random_mask(example["class_images"])
example["class_prompt_ids"] = self.tokenizer(
self.class_prompt,
padding="do_not_pad",
Expand Down
34 changes: 34 additions & 0 deletions training_scripts/inpainting_example.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
export MODEL_NAME="runwayml/stable-diffusion-inpainting"
export INSTANCE_DIR="./data/data_disney"
export OUTPUT_DIR="./exps/output_dsn"

lora_pti \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--train_text_encoder \
--train_inpainting \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--scale_lr \
--learning_rate_unet=1e-4 \
--learning_rate_text=1e-5 \
--learning_rate_ti=5e-4 \
--color_jitter \
--lr_scheduler_lora="linear" \
--lr_warmup_steps_lora=100 \
--placeholder_tokens="<s1>|<s2>" \
--use_template="style"\
--save_steps=100 \
--max_train_steps_ti=1000 \
--max_train_steps_tuning=1000 \
--perform_inversion=True \
--clip_ti_decay \
--weight_decay_ti=0.000 \
--weight_decay_lora=0.001\
--continue_inversion \
--continue_inversion_lr=1e-4 \
--device="cuda:0" \
--lora_rank=1 \
# --use_face_segmentation_condition\

0 comments on commit 5eb9880

Please sign in to comment.