<a href="https://colab.research.google.com/github/himanshu27tasveer/face_rec/blob/main/dataset_cleaning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%tensorflow_version 1.x

In [None]:
import tensorflow
print(tensorflow.__version__)

In [None]:
import os,cv2,time
import MTCNN
import matplotlib.pyplot as plt
import tensorflow as tf


def MTCNN_alignment(root_dir,output_dir,detect_multiple_faces=False,output_size=None,margin=44,dataset_range=None,
                    GPU_ratio=None,img_show=False):
    # ----record the start time
    d_t = time.time()

    #----var
    img_format = {'png', 'bmp', 'jpg'}
    quantity = 0

    #----collect all folders
    dirs = [obj.path for obj in os.scandir(root_dir) if obj.is_dir()]
    if len(dirs) == 0:
        print("No sub folders in ", root_dir)
    else:
        dirs.sort()
        print("Total class number: ", len(dirs))
        if dataset_range is not None:
            dirs = dirs[dataset_range[0]:dataset_range[1]]
            print("Working classes: {} to {}".format(dataset_range[0], dataset_range[1]))
        else:
            print("Working classes:All")

        #----initialization of MTCNN model
        minsize = 20  # minimum size of face
        threshold = [0.6, 0.7, 0.7]  # three steps's threshold
        factor = 0.709  # scale factor
        with tf.Graph().as_default():
            config = tf.ConfigProto(log_device_placement=True,
                                    allow_soft_placement=True,
                                    )
            if GPU_ratio is None:
                config.gpu_options.allow_growth = True
            else:
                config.gpu_options.per_process_gpu_memory_fraction = GPU_ratio
            sess = tf.Session(config=config)
            with sess.as_default():
                pnet, rnet, onet = MTCNN.create_mtcnn(sess, None)

        #----handle images of each dir
        for dir_path in dirs:
                paths = [file.path for file in os.scandir(dir_path) if file.name.split(".")[-1] in img_format]
                if len(paths) == 0:
                    print("No images in ", dir_path)
                else:
                    # ----create the save dir
                    save_dir = os.path.join(output_dir, dir_path.split("\\")[-1])
                    if not os.path.exists(save_dir):
                        os.makedirs(save_dir)

                    #----read images
                    quantity += len(paths)
                    for idx, path in enumerate(paths):
                        img = cv2.imread(path)
                        if img is None:
                            print("Read failed:", path)
                        else:
                            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)#img[:,:,::-1]
                            bounding_boxes, _ = MTCNN.detect_face(img_rgb, minsize, pnet, rnet, onet,
                                                                             threshold, factor)

                            # ----bounding boxes processing
                            nrof_faces = bounding_boxes.shape[0]
                            if nrof_faces > 0:
                                det = bounding_boxes[:, 0:4]
                                det_arr = []
                                img_size = np.asarray(img.shape)[0:2]
                                if nrof_faces > 1:
                                    if detect_multiple_faces:
                                        for i in range(nrof_faces):
                                            det_arr.append(np.squeeze(det[i]))
                                    else:
                                        bounding_box_size = (det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1])
                                        img_center = img_size / 2
                                        offsets = np.vstack(
                                            [(det[:, 0] + det[:, 2]) / 2 - img_center[1],
                                             (det[:, 1] + det[:, 3]) / 2 - img_center[0]])
                                        offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
                                        index = np.argmax(
                                            bounding_box_size - offset_dist_squared * 2.0)  # some extra weight on the centering
                                        det_arr.append(det[index, :])
                                else:
                                    det_arr.append(np.squeeze(det))

                                det_arr = np.array(det_arr)
                                det_arr = det_arr.astype(np.int16)

                                #----crop images
                                for i, det in enumerate(det_arr):
                                    det = np.squeeze(det)
                                    bb = np.zeros(4, dtype=np.int32)
                                    bb[0] = np.maximum(det[0] - margin / 2, 0)
                                    bb[1] = np.maximum(det[1] - margin / 2, 0)
                                    bb[2] = np.minimum(det[2] + margin / 2, img_size[1])
                                    bb[3] = np.minimum(det[3] + margin / 2, img_size[0])
                                    cropped = img[bb[1]:bb[3], bb[0]:bb[2], :]

                                    # ----resize images
                                    if output_size is not None:
                                        cropped = cv2.resize(cropped,output_size)

                                    #----save images
                                    filename = path.split("/")[-1]
                                    if i == 0:
                                        save_path = "{}.{}".format(filename.split(".")[0],'png')
                                    else:
                                        save_path = "{}_{}.{}".format(filename.split(".")[0], str(i),
                                                                  'png')
                                    save_path = os.path.join(save_dir, save_path)
                                    cv2.imwrite(save_path, cropped)

                                    #----display images
                                    if img_show is True:
                                        plt.subplot(1, 2, 1)
                                        plt.imshow(img[:, :, ::-1])
                                        plt.axis("off")

                                        plt.subplot(1, 2, 2)
                                        plt.imshow(cropped[:, :, ::-1])
                                        plt.axis("off")

                                        plt.show()





    #----statistics(to know the average process time of each image)
    if quantity != 0:
        d_t = time.time() - d_t
        print("ave process time of each image:", d_t / quantity)



if __name__ == "__main__":
    root_dir = "/content/CASIA-WebFace"
    output_dir = "/content/CASIA-WebFace_Aligned"
    detect_multiple_faces = False
    output_size = None
    margin = 44
    dataset_range = None
    img_show = False
    GPU_ratio = None

    MTCNN_alignment(root_dir, output_dir, detect_multiple_faces=detect_multiple_faces, output_size=output_size,
                    margin=margin, dataset_range=dataset_range, GPU_ratio=GPU_ratio, img_show=img_show)