<a href="https://colab.research.google.com/github/justadudewhohacks/ipynbs/blob/master/face_detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Dependencies

In [0]:
!pip install -U -q PyDrive
!pip install git+https://github.com/justadudewhohacks/image_augment.py
!pip install git+https://github.com/justadudewhohacks/colabsnippets

# Download Data

In [0]:
from colabsnippets.DataDownloader import DataDownloader

data_downloader = DataDownloader(data_dir = './data')

data_downloader.download_data({
	"WIDER" : [
    { "images": "1JHmXqGPngDCbM56eYPeqsaCgJC4vgL4m", "boxes": "1aeAGd5LmL8EBB1yaZxKOp1NbZ1CBJBmm" }
	]
}, ['boxes'])

print('done!')

# Common

In [0]:
import cv2
import math
import json
import random
import time
import types
import os
import numpy as np
import tensorflow as tf
from augment import ImageAugmentor, augment
from augment.augment import abs_coords
from colabsnippets.utils import load_json
from colabsnippets import BatchLoader

'''
--------------------------------------------------------------------------------

Data Loader

--------------------------------------------------------------------------------
'''
  
def transform_boxes(boxes):
  out_boxes = []
  for box in boxes:
    out_box = (box['x'], box['y'], box['width'], box['height'])
    for val in out_box:
      if abs(val) > 1.0:
        raise Exception("box is probably not a valid relative box: {}".format(out_box))
    out_boxes.append(out_box)
  return out_boxes
  
def extract_data_labels(data):
  db = data['db']
  img_file = data['file']
  boxes_file = img_file.replace('.jpg', '.json')
  boxes_dir = "boxes-shard{}".format(data['shard']) if 'shard' in data else 'boxes'
  boxes_path = "./data/{}/{}/{}".format(db, boxes_dir, boxes_file)
  boxes = load_json(boxes_path)
  return transform_boxes(boxes)
    
def resolve_image_path(data):
  db = data['db']
  img_file = data['file']
  img_dir = "images-shard{}".format(data['shard']) if 'shard' in data else 'images'
  img_path = "./data/{}/{}/{}".format(db, img_dir, img_file)
  return img_path

def min_bbox(boxes):
  min_x, min_y, max_x, max_y = 1.0, 1.0, 0, 0
  for box in boxes:
    x, y, w, h = box
    pts = [(x, y), (x + w, y + h)]
    for x, y in pts:
      min_x = x if x < min_x else min_x
      min_y = y if y < min_y else min_y
      max_x = max_x if x < max_x else x
      max_y = max_y if y < max_y else y

  return [min_x, min_y, max_x, max_y]

class DataLoader(BatchLoader):
  def __init__(self, data, image_augmentor = None, start_epoch = None, is_test = False):  
    self.image_augmentor = image_augmentor
    BatchLoader.__init__(
      self, 
      data if type(data) is types.FunctionType else lambda: data, 
      resolve_image_path, 
      extract_data_labels,
      start_epoch = start_epoch, 
      is_test = is_test
    )
      
  def load_image_and_labels_batch(self, datas, image_size):
    batch_x, batch_y = [], []
    for data in datas:
      boxes = self.extract_data_labels(data)
      image = self.load_image(data)
      roi = min_bbox(boxes)
      if self.image_augmentor is not None:
        image, boxes = self.image_augmentor.augment(image, boxes = boxes, random_crop = roi, pad_to_square = True, resize = image_size)
      else:
        image, boxes = augment(image, boxes = boxes, random_crop = roi, pad_to_square = True, resize = image_size)
      batch_x.append(image)
      batch_y.append(boxes)
        
    return batch_x, batch_y


'''
--------------------------------------------------------------------------------

utility

--------------------------------------------------------------------------------
'''

def gpu_session(callback):
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  config.allow_soft_placement = True
  config.log_device_placement = True
  with tf.Session(config = config) as session:
    with tf.device('/gpu:0'):
      return callback(session)

def get_checkpoint(model_name, epoch):
  return model_name + '.ckpt-' + str(epoch)

def draw_box(img, box):
  x, y, w, h = abs_coords(box, img)

  cv2.rectangle(img, (x, y), (x + w, y + h), (255, 0, 0), 1)
  cv2.circle(img, (x, y), 2, (0, 0, 255), -1)
  cv2.circle(img, (x, y + h), 2, (0, 0, 255), -1)
  cv2.circle(img, (x + w, y), 2, (0, 0, 255), -1)
  cv2.circle(img, (x + w, y + h), 2, (0, 0, 255), -1)

# Debug

## Check Inputs

In [0]:
!rm -rf ./check_inputs && mkdir ./check_inputs

from IPython.display import Image, display

num_inputs = 10
image_size = 400
num_images_per_row = 2
db = 'WIDER'

image_augmentor = ImageAugmentor.load('./augmentor_4.json')
train_data = load_json('./data/trainData.json')

db_data = []
for data in train_data:
  if db is None or data['db'] == db:
    db_data.append(data)
    
data_loader = DataLoader(db_data, start_epoch = 0, image_augmentor = image_augmentor)
batch_x, batch_y = data_loader.next_batch(num_inputs, image_size)

file_idx = 0
idx = 0
while idx < num_inputs:
  imgs = np.stack(batch_x[idx : idx + num_images_per_row], axis = 0)
  all_boxes = batch_y[idx : idx + num_images_per_row]
  for i, boxes in enumerate(all_boxes):
    for box in boxes:
      draw_box(imgs[i], box)
  
  merged_img = np.concatenate(imgs, axis = 1)
  
  file = './check_inputs/' + str(file_idx) + '.jpg'
  cv2.imwrite(file, merged_img)
  display(Image(file))
  
  file_idx += 1
  idx += num_images_per_row

!rm -rf ./check_inputs