In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import time
import fiftyone as fo
import fiftyone.zoo as foz
import fiftyone.brain as fob
from tqdm.notebook import tqdm, trange
from PIL import Image, ImageOps
import multiprocessing
import matplotlib.pyplot as plt
import torchsummary
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import os
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
# import resnet18 model from pytorch
from torchvision.models import resnet18
from torch.utils.tensorboard import SummaryWriter
import mxnet as mx
from mxnet import recordio
import torch.multiprocessing as mp
from sklearn.model_selection import train_test_split
from collections import defaultdict
import logging
import cv2
from facenet_pytorch import MTCNN

In [None]:
mtcnn = MTCNN(keep_all=False, device='cuda:1', image_size=112, margin=0)

def detect_and_crop_face(image, mtcnn, target_size=(224, 224), padding_color=(0, 0, 0)):
    """Detects face, crops using bounding box, makes it square, and resizes to target_size."""
    
    # Convert to PIL
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    elif isinstance(image, torch.Tensor):
        image = transforms.ToPILImage()(image)
    elif not isinstance(image, Image.Image):
        raise ValueError("Input image must be a numpy array, torch tensor, or PIL Image.")

    if image.mode != 'RGB':
        image = image.convert('RGB')

    # Detect face
    boxes, _ = mtcnn.detect(image)

    if boxes is not None:
        face = image.crop((boxes[0][0], boxes[0][1], boxes[0][2], boxes[0][3]))
    else:
        print("No face detected. Using original image.")
        face = image
    
    face = face.resize(target_size, Image.LANCZOS)
    return face
    

In [None]:
class RecCASIAWebFaceDataset(Dataset):
    def __init__(self, path_imgrec, transform=None):
        self.transform = transform
        assert path_imgrec
        if path_imgrec:
            logging.info('loading recordio %s...',
                         path_imgrec)
            path_imgidx = path_imgrec[0:-4] + ".idx"
            print(path_imgrec, path_imgidx)
            self.imgrec = recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
            s = self.imgrec.read_idx(0)
            header, _ = recordio.unpack(s)
            if header.flag > 0:
                print('header0 label', header.label)
                self.header0 = (int(header.label[0]), int(header.label[1]))
                # assert(header.flag==1)
                # self.imgidx = range(1, int(header.label[0]))
                self.imgidx = []
                self.id2range = {}
                self.seq_identity = range(int(header.label[0]), int(header.label[1]))
                for identity in self.seq_identity:
                    s = self.imgrec.read_idx(identity)
                    header, _ = recordio.unpack(s)
                    a, b = int(header.label[0]), int(header.label[1])
                    count = b - a
                    self.id2range[identity] = (a, b)
                    self.imgidx += range(a, b)
                print('id2range', len(self.id2range))
            else:
                self.imgidx = list(self.imgrec.keys)
            self.seq = self.imgidx

    def __getitem__(self, idx):
        # Map global index to class ID and local index
        actual_idx = idx + 1  # MXNet indices start from 1
        
        # Read record
        header, s = recordio.unpack(self.imgrec.read_idx(actual_idx))
        img = mx.image.imdecode(s).asnumpy()
        label = int(header.label)
        
        # # Convert to PIL and apply transforms
        img = Image.fromarray(img)
        if self.transform:
            img = self.transform(img)
        
        return img, label

    def __len__(self):
        return len(self.seq)

In [None]:
casia_dataset = RecCASIAWebFaceDataset(path_imgrec='./faces_webface_112x112/train.rec', transform=transforms.Compose([
    transforms.ToTensor(),
]))

In [None]:
def make_lfw_dataset(path_save):
    if not os.path.exists(path_save):
        os.makedirs(path_save)
    # load lfw dataset
    lfw_dataset = foz.load_zoo_dataset("lfw")
    # get the images and labels
    images = []
    labels = []
    for sample in lfw_dataset:
        img = sample.filepath
        label = sample['ground_truth']['label']
        # remove the label from the image name
        images.append(img)
        labels.append(label)
    print("images", len(images))
    print("labels", len(labels))
    # save the images after mtcnn, with same name as the original image
    print("save_dir", path_save)
    print(images[0])
    print(labels[0])
    save_paths = []
    for i, (img, label) in enumerate(zip(images, labels)):
        label_dir = os.path.join(path_save, str(label))
        if not os.path.exists(label_dir):
            os.makedirs(label_dir)
        img = img.split('/')[-1]
        save_paths.append(os.path.join(label_dir, img))
    print("save_paths", len(save_paths))
    print("save_paths", save_paths[0])
    # convert to list of pil images
    images = [Image.open(img) for img in images]
    # save the images
    for i, (img, save_path) in tqdm(enumerate(zip(images, save_paths))):
        # Detect and crop face
        cropped_face = detect_and_crop_face(img, mtcnn)
        # Save the cropped face
        if isinstance(cropped_face, torch.Tensor):
            cropped_face = transforms.ToPILImage()(cropped_face)
        cropped_face.save(save_path)

    # delete images of form number_nr.jpg
    for label in labels:
        label_dir = os.path.join(path_save, str(label))
        if os.path.exists(label_dir): #Check to see if labeldir exists
            for img_file in os.listdir(label_dir):
                if img_file.endswith('_1.jpg') or img_file.endswith('_2.jpg') or img_file.endswith('_3.jpg') or img_file.endswith('_4.jpg'):
                    os.remove(os.path.join(label_dir, img_file))

In [None]:
# go thorugh /home/ichitu/py-files/lfw_funneled_cropped and resize every image to 224x224 in a new directory
def resize_images_in_directory(input_dir, output_dir, target_size=(224, 224)):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for root, _, files in os.walk(input_dir):
        for file in files:
            if file.endswith('.jpg') or file.endswith('.png'):
                img_path = os.path.join(root, file)
                img = Image.open(img_path)
                img = img.resize(target_size, Image.Resampling.LANCZOS)
                label_dir = os.path.basename(root)
                # Create label directory in output path
                label_dir_path = os.path.join(output_dir, label_dir)
                if not os.path.exists(label_dir_path):
                    os.makedirs(label_dir_path)
                # Save the resized image
                # Use the same filename
                # but in the new directory
                img.save(os.path.join(label_dir_path, file))
# resize_images_in_directory('./lfw_funneled_cropped', './lfw_funneled_cropped_224x224')

In [None]:
def make_casia_dataset(casia_dataset, path_save):
    if not os.path.exists(path_save):
        os.makedirs(path_save)
    i = 0
    for sample in tqdm(casia_dataset):
        img = sample[0]
        label = sample[1]
        
        face = detect_and_crop_face(img, mtcnn)
        save_path = os.path.join(path_save, str(label))
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        img_name = os.path.join(save_path, str(i) + '.jpg')
        face.save(img_name)
        i += 1
        
        

In [None]:
make_casia_dataset(casia_dataset, './casia_webface_224x224_cropped_mtcnn')