In [1]:
import dlib
import cv2
import numpy as np
from PIL import Image
from diff_exp.data.attribute_celeba_dataset import default_args, Dataset
from omegaconf import OmegaConf
from diff_exp.transforms_utils import get_transform
import yaml
from tqdm import tqdm
from diff_exp.utils import TransformDataset, tensor2pil
from torchvision import transforms as tr

In [3]:
detector = dlib.get_frontal_face_detector()
predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")  # You need to download this file
    

In [None]:
def colorize_lips(img, predictor, lips_colors=(0, 255, 0), alpha=0.5):
    gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    img_copy = img.copy()
    
    faces = detector(gray)
    face = faces[0]
    shape = predictor(gray, face)
    
    lips_landmarks = shape.parts()[48:68]
    lips_points = [(x.x, x.y) for x in lips_landmarks]
    lips_points = np.array(lips_points, dtype=np.int32)

    cv2.fillPoly(img_copy, [lips_points], color=lips_colors)
    cv2.addWeighted(img_copy, alpha, img, 1 - alpha, 0, img_copy)
    return img_copy


In [4]:
args = default_args()
args = OmegaConf.create(args)

args.data_dir = "../data"

print(OmegaConf.to_yaml(args))
dataset = Dataset(**args)

target_attr: Smiling
data_dir: ../data
split: train
filter_path: null



In [5]:
transform_str = """
- - center_crop
  - size: 178
- - resize
  - size: 64
"""
transform_cfg = yaml.load(transform_str, yaml.Loader)
transform_cfg = OmegaConf.create(transform_cfg)
transform = get_transform(transform_cfg)

In [None]:
images_out = []
labels_out = []

pbar = tqdm(total=len(dataset))
n_failed = 0
tot = 0

for x, y in dataset:
    x_np = np.array(x)
    pbar.update()
    tot += 1
    try:
        out = colorize_lips(x_np, predictor, lips_colors=(0, 255, 0), alpha=0.5)

    except Exception as e:
        n_failed += 1
        continue

    pbar.set_description(f"Num failed: {n_failed}/{tot}")
    

    out = Image.fromarray(out)
    out = transform(out)
    x = transform(x)

    

    images_out.append(x)
    images_out.append(out)

    labels_out.append(0)
    labels_out.append(1)

        

Num failed: 1686/59621:  37%|████████████████████████████▏                                                | 59621/162770 [03:49<06:37, 259.81it/s]

In [16]:
len(dataset) - len(images_out) // 2

4492

In [12]:
np_images = [np.array(x) for x in images_out]
all_images = np.stack(np_images, axis=0)
all_labels = np.array(labels_out)

In [17]:
np.savez("green_lips_train.npz", all_images, all_labels)