<a href="https://colab.research.google.com/github/nitaymayo/Donut_Blink/blob/master/Donut_Blink_Quick_Version.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Donut Blink - Short Version

this notebook will demonstrate just the Donut_Blink model preformance as a pretraind model<br>
*for the full code of the model please visit the full notebook - <a href="https://github.com/nitaymayo/Donut_Blink/blob/master/Donut_Blink.ipynb">here</a>

## Imports

In [10]:
import tensorflow as tf
import os
import matplotlib.patches as patches
import dlib
import cv2
import numpy as np
from matplotlib.patches import Circle
from IPython.display import display, Javascript, Image
from google.colab.output import eval_js
from base64 import b64decode, b64encode
import PIL
import io
import html

## Eyes clipping algorithms

In [11]:
! wget https://github.com/nitaymayo/Donut_Blink/raw/master/shape_predictor_68_face_landmarks.dat

# Face detector
hog_face_detector = dlib.get_frontal_face_detector()

# Eye detector
dlib_facelandmark = dlib.shape_predictor("/content/shape_predictor_68_face_landmarks.dat")


--2023-04-03 07:49:39--  https://github.com/nitaymayo/Donut_Blink/raw/master/shape_predictor_68_face_landmarks.dat
Resolving github.com (github.com)... 20.27.177.113
Connecting to github.com (github.com)|20.27.177.113|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/nitaymayo/Donut_Blink/master/shape_predictor_68_face_landmarks.dat [following]
--2023-04-03 07:49:40--  https://raw.githubusercontent.com/nitaymayo/Donut_Blink/master/shape_predictor_68_face_landmarks.dat
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 99693937 (95M) [application/octet-stream]
Saving to: ‘shape_predictor_68_face_landmarks.dat.1’


2023-04-03 07:49:41 (175 MB/s) - ‘shape_predictor_68_face_landmarks.dat.1’ saved [99693

In [12]:
#Clipping the eyes from the pictures
y_factor = 18
x_factor = 15

def eye_box(img, left_point, right_point):
  """
    Function to get the coordinates for the box around the eye
    Args:
    img: img of the full face
    left_point: the left corner coordinates of the eye in the img
    right point: the right corner coordinates of the eye in the img

    Returns:
    tuple as so:
    in index 0: tuple -> (the lower left corner coordinates of th box, the upper right corner coordinates of the box)
    in index 1: img of the eye cropped as the box dims
  """
  corner_l_l = (left_point.x-x_factor, left_point.y+y_factor)

  corner_u_r = (right_point.x+x_factor, right_point.y-y_factor)

  X_axis_coordinates = (corner_l_l[0], corner_u_r[0])
  Y_axis_coordinates = (corner_l_l[1], corner_u_r[1])

  return ((corner_l_l, corner_u_r), img[tf.math.reduce_min(Y_axis_coordinates):tf.math.reduce_max(Y_axis_coordinates), tf.math.reduce_min(X_axis_coordinates):tf.math.reduce_max(X_axis_coordinates)])

In [13]:
IMG_SIZE = (224, 224)
def get_eyes_imgs(img, img_size=IMG_SIZE):
  """
    Function to get all the eyes in a given image
    The function gets the eye coordinate using the 
    dlib facelandmark detector and crops the eye img 
    with the above function 'eye_box()'

    Args:
    img: image in array form
    img_size: eye img size to be returned 

    Returns:
    Array with dictioneries as so: {'img':contains cropped eye img, 'coordinates': tuple with the (lower left, upper right) points of the box}
  """
  img = np.asarray(img)
  eyes = []
  #Getting eyes position

  faces = hog_face_detector(img)

  if not faces:
    return False

  for face in faces:
    face_landmarks = dlib_facelandmark(img, face)

    left_coordinate, left_eye = eye_box(img, face_landmarks.part(36), face_landmarks.part(39))

    right_coordinate, right_eye = eye_box(img, face_landmarks.part(42), face_landmarks.part(45))

    left_eye = cv2.resize(left_eye, img_size)
    right_eye = cv2.resize(right_eye, img_size)

    eyes.append({"img": left_eye,
                 "coordinates": left_coordinate})
    eyes.append({"img": right_eye,
                "coordinates": right_coordinate})

  return eyes

## Load the pretrained model

In [14]:
! wget https://github.com/nitaymayo/Donut_Blink/raw/master/eye_state_detector_fine_tuned.h5

final_model = tf.keras.models.load_model("eye_state_detector_fine_tuned.h5")

--2023-04-03 07:49:44--  https://github.com/nitaymayo/Donut_Blink/raw/master/eye_state_detector_fine_tuned.h5
Resolving github.com (github.com)... 20.27.177.113
Connecting to github.com (github.com)|20.27.177.113|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/nitaymayo/Donut_Blink/master/eye_state_detector_fine_tuned.h5 [following]
--2023-04-03 07:49:45--  https://raw.githubusercontent.com/nitaymayo/Donut_Blink/master/eye_state_detector_fine_tuned.h5
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 24390504 (23M) [application/octet-stream]
Saving to: ‘eye_state_detector_fine_tuned.h5.1’


2023-04-03 07:49:45 (99.5 MB/s) - ‘eye_state_detector_fine_tuned.h5.1’ saved [24390504/24390504]



## The final function
this function gathers the cropping function and the model prediction to give the final result: 
1. eye box coordinates
2. eye state

In [15]:
def get_eyes_bbox_state(img, model):
  """
    Args:
    img: to predict on
    model: to make the predictions

    Returns:
    array with dictioneries {'coordinates': eye box coordinates, 
                             'state': 0 for closed eye and 1 for open one}
  """
  # Seperate the img to seperate eye imgs
  eyes = get_eyes_imgs(img)
  if not eyes:
      return False
  res = []
  eyes_imgs = np.asarray([eye["img"] for eye in eyes])
  preds = tf.squeeze(np.round(model.predict(eyes_imgs, verbose=0)))
  for i, eye in enumerate(eyes):
    # Getting eye state
    eye_state = {"coordinates": eye["coordinates"]}
    state = np.round(model.predict(tf.expand_dims(eye["img"], axis=0), verbose=0)[0])

    eye_state["state"] = preds[i]

    res.append(eye_state)    

  return res

## Java-script configuration for using the web cam to show the models predictions

In [16]:
# function to convert the JavaScript object into an OpenCV image
def js_to_image(js_reply):
  """
  Params:
          js_reply: JavaScript object containing image from webcam
  Returns:
          img: OpenCV BGR image
  """
  # decode base64 image
  image_bytes = b64decode(js_reply.split(',')[1])
  # convert bytes to numpy array
  jpg_as_np = np.frombuffer(image_bytes, dtype=np.uint8)
  # decode numpy array into OpenCV BGR image
  img = cv2.imdecode(jpg_as_np, flags=1)

  return img

# function to convert OpenCV Rectangle bounding box image into base64 byte string to be overlayed on video stream
def bbox_to_bytes(bbox_array):
  """
  Params:
          bbox_array: Numpy array (pixels) containing rectangle to overlay on video stream.
  Returns:
        bytes: Base64 image byte string
  """
  # convert array into PIL image
  bbox_PIL = PIL.Image.fromarray(bbox_array, 'RGBA')
  iobuf = io.BytesIO()
  # format bbox into png for return
  bbox_PIL.save(iobuf, format='png')
  # format return string
  bbox_bytes = 'data:image/png;base64,{}'.format((str(b64encode(iobuf.getvalue()), 'utf-8')))

  return bbox_bytes

In [17]:
# JavaScript to properly create our live video stream using our webcam as input
def video_stream():
  js = Javascript('''
    var video;
    var div = null;
    var stream;
    var captureCanvas;
    var imgElement;
    var labelElement;
    
    var pendingResolve = null;
    var shutdown = false;
    
    function removeDom() {
       stream.getVideoTracks()[0].stop();
       video.remove();
       div.remove();
       video = null;
       div = null;
       stream = null;
       imgElement = null;
       captureCanvas = null;
       labelElement = null;
    }
    
    function onAnimationFrame() {
      if (!shutdown) {
        window.requestAnimationFrame(onAnimationFrame);
      }
      if (pendingResolve) {
        var result = "";
        if (!shutdown) {
          captureCanvas.getContext('2d').drawImage(video, 0, 0, 640, 480);
          result = captureCanvas.toDataURL('image/jpeg', 0.8)
        }
        var lp = pendingResolve;
        pendingResolve = null;
        lp(result);
      }
    }
    
    async function createDom() {
      if (div !== null) {
        return stream;
      }

      div = document.createElement('div');
      div.style.border = '2px solid black';
      div.style.padding = '3px';
      div.style.width = '100%';
      div.style.maxWidth = '600px';
      document.body.appendChild(div);
      
      const modelOut = document.createElement('div');
      modelOut.innerHTML = "<span>Status:</span>";
      labelElement = document.createElement('span');
      labelElement.innerText = 'No data';
      labelElement.style.fontWeight = 'bold';
      modelOut.appendChild(labelElement);
      div.appendChild(modelOut);
           
      video = document.createElement('video');
      video.style.display = 'block';
      video.width = div.clientWidth - 6;
      video.setAttribute('playsinline', '');
      video.onclick = () => { shutdown = true; };
      stream = await navigator.mediaDevices.getUserMedia(
          {video: { facingMode: "environment"}});
      div.appendChild(video);

      imgElement = document.createElement('img');
      imgElement.style.position = 'absolute';
      imgElement.style.zIndex = 1;
      imgElement.onclick = () => { shutdown = true; };
      div.appendChild(imgElement);
      
      const instruction = document.createElement('div');
      instruction.innerHTML = 
          '<span style="color: red; font-weight: bold;">' +
          'When finished, click here or on the video to stop this demo</span>';
      div.appendChild(instruction);
      instruction.onclick = () => { shutdown = true; };
      
      video.srcObject = stream;
      await video.play();

      captureCanvas = document.createElement('canvas');
      captureCanvas.width = 640; //video.videoWidth;
      captureCanvas.height = 480; //video.videoHeight;
      window.requestAnimationFrame(onAnimationFrame);
      
      return stream;
    }
    async function stream_frame(label, imgData) {
      if (shutdown) {
        removeDom();
        shutdown = false;
        return '';
      }

      var preCreate = Date.now();
      stream = await createDom();
      
      var preShow = Date.now();
      if (label != "") {
        labelElement.innerHTML = label;
      }
            
      if (imgData != "") {
        var videoRect = video.getClientRects()[0];
        imgElement.style.top = videoRect.top + "px";
        imgElement.style.left = videoRect.left + "px";
        imgElement.style.width = videoRect.width + "px";
        imgElement.style.height = videoRect.height + "px";
        imgElement.src = imgData;
      }
      
      var preCapture = Date.now();
      var result = await new Promise(function(resolve, reject) {
        pendingResolve = resolve;
      });
      shutdown = false;
      
      return {'create': preShow - preCreate, 
              'show': preCapture - preShow, 
              'capture': Date.now() - preCapture,
              'img': result};
    }
    ''')

  display(js)
  
def video_frame(label, bbox):
  data = eval_js('stream_frame("{}", "{}")'.format(label, bbox))
  return data

## Turn on the webcam and see the model predictions

In [None]:
# start streaming video from webcam
video_stream()
# label for video
label_html = 'Capturing...'
# initialze bounding box to empty
bbox = ''
count = 0 
while True:
    js_reply = video_frame(label_html, bbox)
    if not js_reply:
        break

    # convert JS response to OpenCV Image
    img = js_to_image(js_reply["img"])

    # create transparent overlay for bounding box
    bbox_array = np.zeros([480,640,4], dtype=np.uint8)

    # get face region coordinates
    eyes = get_eyes_bbox_state(img, final_model)
    if eyes: 
      # get face bounding box for overlay
      for eye in eyes:
        color = (255,0,0) if eye["state"] == 0 else (0,255,0)

        bbox_array = cv2.rectangle(bbox_array,eye["coordinates"][0],eye["coordinates"][1],color,2)

      bbox_array[:,:,3] = (bbox_array.max(axis = 2) > 0 ).astype(int) * 255
      # convert overlay of bbox into bytes
      bbox_bytes = bbox_to_bytes(bbox_array)
      # update bbox so next frame gets new overlay
      bbox = bbox_bytes