In [1]:
%matplotlib inline

import torch
import random
import torchvision
from tqdm import tqdm
from PIL import Image
from matplotlib import pyplot as plt

from ml_glasses.commons import FaceAlignTransform


class ConvertPilJpegFileToImage:
    def __call__(self, image):
        image.load()
        return image._new(image.im)

Let's make a CelebA-based eyeglasses dataset by taking all images with eyeglasses and the same amount of random images without them. I'll also align the faces in them.

In [2]:
import dlib
dlib.DLIB_USE_CUDA

True

In [None]:
torch.manual_seed(5)
random.seed(5)

for split in ('train', 'valid'):
    celeba = torchvision.datasets.CelebA('.', split=split, transform=ConvertPilJpegFileToImage())
    aligner = FaceAlignTransform()

    total_eyeglasses = 10000
    for _, attrs in tqdm(celeba):
        if attrs[15] == 1:
            total_eyeglasses += 1

    no_eyeglass_left = total_eyeglasses # mathes the no. of images with eyeglasses

    idxs = list(range(len(celeba)))
    random.shuffle(idxs)  # to pick random no-eyeglasses samples

    meta_file = open(f'meta_celeba_{split}.txt', 'w')

    errors = 0

    for idx in tqdm(idxs):
        image, attrs = celeba[idx]

        has_eyeglasses = attrs[15] == 1

        try:
            if has_eyeglasses or no_eyeglass_left > 0:
                label = 1 if has_eyeglasses else 0
                image = aligner(image)
                
                filename = f'celeba_eyeglasses_{split}/{idx}_{label}.png'
                Image.fromarray(image).save(filename)
                meta_file.write(f'{idx}_{label}.png {label}\n')
                
                if label == 0 and no_eyeglass_left > 0:
                    no_eyeglass_left -= 1

        except IndexError:
            errors += 1
            
    meta_file.close()

 52%|█████▏    | 84470/162770 [01:10<01:04, 1222.28it/s]

In [65]:
img = Image.fromarray(images[0])

In [66]:
img.save('kek.png')