Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic dataset pipelines #139

Merged
merged 10 commits into from
Jan 16, 2023
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ __test*
merged_lora*
wandb
exps*
.vscode
.vscode
build
lora_diffusion.egg-info
1 change: 1 addition & 0 deletions lora_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .lora import *
from .dataset import *
from .utils import *
from .preprocess_files import *
4 changes: 3 additions & 1 deletion lora_diffusion/cli_lora_pti.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def loss_step(
+ 0.05
)

mask = mask / mask.mean()
mask = mask / mask.max()

model_pred = model_pred * mask

Expand Down Expand Up @@ -535,6 +535,7 @@ def train(
continue_inversion: bool = False,
continue_inversion_lr: Optional[float] = None,
use_face_segmentation_condition: bool = False,
use_mask_captioned_data: bool = False,
scale_lr: bool = False,
lr_scheduler: str = "linear",
lr_warmup_steps: int = 0,
Expand Down Expand Up @@ -629,6 +630,7 @@ def train(
size=resolution,
color_jitter=color_jitter,
use_face_segmentation_condition=use_face_segmentation_condition,
use_mask_captioned_data=use_mask_captioned_data,
)

train_dataset.blur_amount = 200
Expand Down
144 changes: 87 additions & 57 deletions lora_diffusion/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from PIL import Image, ImageFilter
from torch.utils.data import Dataset
from torchvision import transforms
import glob

from .preprocess_files import face_mask_google_mediapipe

OBJECT_TEMPLATE = [
"a photo of a {}",
Expand Down Expand Up @@ -93,18 +96,88 @@ def __init__(
h_flip=True,
color_jitter=False,
resize=True,
use_mask_captioned_data=False,
use_face_segmentation_condition=False,
blur_amount: int = 70,
):
self.size = size
self.tokenizer = tokenizer
self.resize = resize

self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
instance_data_root = Path(instance_data_root)
if not instance_data_root.exists():
raise ValueError("Instance images root doesn't exists.")

self.instance_images_path = list(Path(instance_data_root).iterdir())
self.instance_images_path = []
self.mask_path = []

assert not (
use_mask_captioned_data and use_template
), "Can't use both mask caption data and template."

# Prepare the instance images
if use_mask_captioned_data:
src_imgs = glob.glob(str(instance_data_root) + "/*src.jpg")
for f in src_imgs:
idx = int(str(Path(f).stem).split(".")[0])
mask_path = f"{instance_data_root}/{idx}.mask.png"

if Path(mask_path).exists():
self.instance_images_path.append(f)
self.mask_path.append(mask_path)
else:
print(f"Mask not found for {f}")

self.captions = open(f"{instance_data_root}/caption.txt").readlines()

else:
possibily_src_images = glob.glob(
str(instance_data_root) + "/*.jpg"
) + glob.glob(str(instance_data_root) + "/*.png")
possibily_src_images = (
set(possibily_src_images)
- set(glob.glob(str(instance_data_root) + "/*mask.png"))
- set([str(instance_data_root) + "/caption.txt"])
)

self.instance_images_path = list(set(possibily_src_images))

assert (
len(self.instance_images_path) > 0
), "No images found in the instance data root."

self.instance_images_path = sorted(self.instance_images_path)

self.use_mask = use_face_segmentation_condition or use_mask_captioned_data

if use_face_segmentation_condition:

for idx in range(len(self.instance_images_path)):
targ = f"{instance_data_root}/{idx}.mask.png"
# see if the mask exists
if not Path(targ).exists():
print(f"Mask not found for {targ}")

print(
"Warning : this will pre-process all the images in the instance data root."
)

if len(self.mask_path) > 0:
print(
"Warning : masks already exists, but will be overwritten."
)

masks = face_mask_google_mediapipe(
[Image.open(f) for f in self.instance_images_path]
)
for idx, mask in enumerate(masks):
mask.save(f"{instance_data_root}/{idx}.mask.png")

break

for idx in range(len(self.instance_images_path)):
self.mask_path.append(f"{instance_data_root}/{idx}.mask.png")

self.num_instance_images = len(self.instance_images_path)
self.token_map = token_map

Expand All @@ -122,6 +195,7 @@ def __init__(
self.class_prompt = class_prompt
else:
self.class_data_root = None

self.h_flip = h_flip
self.image_transforms = transforms.Compose(
[
Expand All @@ -138,14 +212,6 @@ def __init__(
]
)

self.use_face_segmentation_condition = use_face_segmentation_condition
if self.use_face_segmentation_condition:
import mediapipe as mp

mp_face_detection = mp.solutions.face_detection
self.face_detection = mp_face_detection.FaceDetection(
model_selection=1, min_detection_confidence=0.5
)
self.blur_amount = blur_amount

def __len__(self):
Expand All @@ -166,64 +232,28 @@ def __getitem__(self, index):

text = random.choice(self.templates).format(input_tok)
else:
text = self.instance_images_path[index % self.num_instance_images].stem
text = self.captions[index % self.num_instance_images].strip()

if self.token_map is not None:
for token, value in self.token_map.items():
text = text.replace(token, value)

print(text)

if self.use_face_segmentation_condition:
image = cv2.imread(
str(self.instance_images_path[index % self.num_instance_images])
)
results = self.face_detection.process(
cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
)
black_image = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)

if results.detections:

for detection in results.detections:

x_min = int(
detection.location_data.relative_bounding_box.xmin
* image.shape[1]
)
y_min = int(
detection.location_data.relative_bounding_box.ymin
* image.shape[0]
)
width = int(
detection.location_data.relative_bounding_box.width
* image.shape[1]
)
height = int(
detection.location_data.relative_bounding_box.height
* image.shape[0]
)

# draw the colored rectangle
black_image[y_min : y_min + height, x_min : x_min + width] = 255

# blur the image
black_image = Image.fromarray(black_image, mode="L").filter(
ImageFilter.GaussianBlur(radius=self.blur_amount)
if self.use_mask:
example["mask"] = (
self.image_transforms(
Image.open(self.mask_path[index % self.num_instance_images])
)
* 0.5
+ 1
)
# to tensor
black_image = transforms.ToTensor()(black_image)
# resize as the instance image
black_image = transforms.Resize(
self.size, interpolation=transforms.InterpolationMode.BILINEAR
)(black_image)

example["mask"] = black_image

if self.h_flip and random.random() > 0.5:
hflip = transforms.RandomHorizontalFlip(p=1)

example["instance_images"] = hflip(example["instance_images"])
if self.use_face_segmentation_condition:
if self.use_mask:
example["mask"] = hflip(example["mask"])

example["instance_prompt_ids"] = self.tokenizer(
Expand Down