In [1]:
from batch_face import (
    RetinaFace,
    SixDRep
)
from sixdrepnet.model import SixDRepNet
import os
import numpy as np
import cv2
from math import cos, sin

import torch
from torchvision import transforms
from PIL import Image
from sixdrepnet import utils


In [2]:
# image transformations
transformations = transforms.Compose([transforms.Resize(224),
                                      transforms.CenterCrop(224),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

In [3]:
detector = RetinaFace(gpu_id=-1) # MacOS no cuda
cam = 1
device = torch.device('cpu')
model = SixDRepNet(backbone_name='RepVGG-B1g2',
                   backbone_file='',
                   deploy=True,
                   pretrained=False)
bw = False

In [4]:
def get_input_data(image, offset_coeff=1) -> dict:
    try:
        coeff = 1280 / image.shape[1]
        resized_image = cv2.resize(image, (1280, int(image.shape[0]*coeff)))
        with torch.no_grad():
            faces = detector(resized_image)
            result = []
            for box, landmarks, score in faces:

                # Print the location of each face in this image
                if score < .95:
                    continue
                x_min = int(box[0])
                y_min = int(box[1])
                x_max = int(box[2])
                y_max = int(box[3])

                x_min2 = int(box[0])
                y_min2 = int(box[1])
                x_max2 = int(box[2])
                y_max2 = int(box[3])

                x_3 = int(landmarks[0][0])
                y_3 = int(landmarks[0][1])
                x_4 = int(landmarks[1][0])
                y_4 = int(landmarks[1][1])

                bbox_width = abs(x_max - x_min)
                bbox_height = abs(y_max - y_min)

                x_min = max(0, x_min-int(0.2*bbox_height))
                y_min = max(0, y_min-int(0.2*bbox_width))
                x_max += int(0.2*bbox_height)
                y_max += int(0.2*bbox_width)

                img = resized_image[y_min:y_max, x_min:x_max]
                img = Image.fromarray(img)
                img = img.convert('RGB')
                img = transformations(img)

                img = torch.Tensor(img[None, :]).to(device)

                c = cv2.waitKey(1)
                if c == 27:
                    break

                R_pred = model(img)

                euler = utils.compute_euler_angles_from_rotation_matrices(
                    R_pred)*180/np.pi

                curr = {'p_pred_deg': euler[:, 0].cpu(),
                        'y_pred_deg': euler[:, 1].cpu(),
                        'r_pred_deg': euler[:, 2].cpu()
                        }

                offset = abs(((x_3 - x_min2)/2 + (x_max2-x_4)/2)/2)
                x_offset = int(offset*1.2*offset_coeff)
                y_offset = int(offset*0.8*offset_coeff)

                y_3_min = int((y_3 - y_offset) / coeff)
                y_3_max = int((y_3 + y_offset) / coeff)
                x_3_min = int((x_3 - x_offset) / coeff)
                x_3_max = int((x_3 + x_offset) / coeff)

                y_4_min = int((y_4 - y_offset) / coeff)
                y_4_max = int((y_4 + y_offset) / coeff)
                x_4_min = int((x_4 - x_offset) / coeff)
                x_4_max = int((x_4 + x_offset) / coeff)

                right_eye = image[y_3_min:y_3_max, x_3_min: x_3_max]
                left_eye = image[y_4_min:y_4_max, x_4_min: x_4_max]
                left_eye = cv2.resize(
                    left_eye, (right_eye.shape[1], right_eye.shape[0]))
                curr['image'] = cv2.hconcat([right_eye, left_eye])
                curr['box'] = list(map(lambda x: x/coeff, box))
                curr['landmarks'] = list(
                    map(lambda y: list(map(lambda x: x/coeff, y)), landmarks))
                result.append(curr)
    except Exception as e:
        print(e.args)
        return None
    return result

In [5]:
def draw_eye_axis(img, yaw, pitch, roll, tdx, tdy, size=100):

    pitch = pitch * np.pi / 180
    yaw = -(yaw * np.pi / 180)
    roll = roll * np.pi / 180

    x = size * (sin(yaw)) + tdx
    y = size * (-cos(yaw) * sin(pitch)) + tdy

    cv2.line(img, (int(tdx), int(tdy)), (int(x), int(y)), (255, 255, 0), 3)

    return img


In [6]:
# Draw a bounding box around the face
def draw_face_box(image, box):
    x_min, y_min, x_max, y_max = map(int, box)
    cv2.rectangle(image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)

# Draw landmarks (eyes) on the face
def draw_landmarks(image, landmarks):
    for (x, y) in landmarks:
        cv2.circle(image, (int(x), int(y)), 5, (0, 0, 255), -1)

In [7]:
import torch
from torch import nn
import torch.nn.functional as F
class SixthEyeNet(nn.ModuleList):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 9, 3)
        self.pool = nn.MaxPool2d(3, 3)
        self.conv2 = nn.Conv2d(9, 26, 3)
        self.fc1 = nn.Linear(3432, 600)
        self.fc2 = nn.Linear(600, 50)
        self.fc3 = nn.Linear(53, 2)

    def forward(self, x):
        x, head_pos = x
        head_pos = head_pos
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.cat((x, head_pos), 1)
        x = self.fc3(x)
        return x

In [13]:
def draw_axis(img, yaw, pitch, roll, tdx=None, tdy=None, size = 100):

    pitch = pitch * np.pi / 180
    yaw = -(yaw * np.pi / 180)
    roll = roll * np.pi / 180

    if tdx != None and tdy != None:
        tdx = tdx
        tdy = tdy
    else:
        height, width = img.shape[:2]
        tdx = width / 2
        tdy = height / 2

    # X-Axis pointing to right. drawn in red
    x1 = size * (cos(yaw) * cos(roll)) + tdx
    y1 = size * (cos(pitch) * sin(roll) + cos(roll) * sin(pitch) * sin(yaw)) + tdy

    # Y-Axis | drawn in green
    #        v
    x2 = size * (-cos(yaw) * sin(roll)) + tdx
    y2 = size * (cos(pitch) * cos(roll) - sin(pitch) * sin(yaw) * sin(roll)) + tdy

    # Z-Axis (out of the screen) drawn in blue
    x3 = size * (sin(yaw)) + tdx
    y3 = size * (-cos(yaw) * sin(pitch)) + tdy

    cv2.line(img, (int(tdx), int(tdy)), (int(x1),int(y1)),(0,0,255),4)
    cv2.line(img, (int(tdx), int(tdy)), (int(x2),int(y2)),(0,255,0),4)
    cv2.line(img, (int(tdx), int(tdy)), (int(x3),int(y3)),(255,0,0),4)

    return img


In [15]:
from torchvision import transforms
if __name__ == '__main__':

    transforms = transforms.Compose([transforms.ToPILImage(),
                                     transforms.Resize((70, 210)),
                                     transforms.ToTensor()])
    cap = cv2.VideoCapture(1)

    cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
    cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)

    # Check if the webcam is opened correctly
    if not cap.isOpened():
        raise IOError("Cannot open webcam")

    net = SixthEyeNet()
    EYE_MODEL_PATH = './models/sixth_eye_net_combined.pth'
    bw = False
    net.load_state_dict(torch.load(EYE_MODEL_PATH))
    net.to(device)
    
    with torch.no_grad():
        n = 0
        while True:
            coeff = 1
            _, frame = cap.read()
            # images = os.listdir('./datasets/me_test/')
            # coeff = 1
            # frame = cv2.imread(
            #     f'./datasets/me_test/{images[n]}')
            input_data = get_input_data(frame)
            if input_data is None:
                continue
            if len(input_data) == 0:
                continue

            for face in input_data:
                box = face['box']

                # Print the location of each face in this image
                x_min = int(box[0])
                y_min = int(box[1])
                x_max = int(box[2])
                y_max = int(box[3])

                bbox_width = abs(x_max - x_min)
                bbox_height = abs(y_max - y_min)

                x_min = max(0, x_min-int(0.2*bbox_height))
                y_min = max(0, y_min-int(0.2*bbox_width))
                x_max += int(0.2*bbox_height)
                y_max += int(0.2*bbox_width)

                hp = face['p_pred_deg']
                hy = face['y_pred_deg']
                hr = face['r_pred_deg']

                image = face['image']
                image = cv2.resize(image, (210, 70),
                                   interpolation=cv2.INTER_CUBIC)
                if bw:
                    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
                image = transforms(image)
                head_pos = torch.unsqueeze(torch.tensor(
                    [float(hp), float(hr), float(hy)], dtype=torch.float32), dim=0).to(device)
                image = torch.unsqueeze(image, dim=0).to(device)
                res = net((image, head_pos))
                res = res.tolist()[0]
                pitch = res[0]
                yaw = -res[1]
                
                print(pitch, yaw)

                draw_axis(frame, yaw, pitch, hr,
                                x_min+int(.5*(x_max-x_min)), y_min+int(.5*(y_max-y_min)), size=130*coeff)

            cv2.imshow("Demo", frame)

            # Check if 'q' is pressed to exit the loop
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

            n += 1

    # Release the camera and close the window
    cap.release()
    cv2.destroyAllWindows()

  net.load_state_dict(torch.load(EYE_MODEL_PATH))


10.495601654052734 -2.5265488624572754
7.048986911773682 20.043874740600586
6.682283878326416 20.086645126342773
3.417095422744751 17.6622371673584
16.1337947845459 19.640209197998047
0.268480509519577 14.339947700500488
15.396345138549805 -6.4582953453063965
2.030202865600586 -21.04322624206543
11.470514297485352 -0.4900476038455963
9.315225601196289 -3.183924436569214
-2.678722858428955 11.458205223083496
-9.969350814819336 3.3005893230438232
-7.4863600730896 0.5643004775047302
-0.6463432908058167 3.51548171043396
-9.347675323486328 7.448936462402344
4.326562404632568 11.313180923461914
18.901512145996094 8.2750825881958
13.136411666870117 22.577693939208984
2.2814748287200928 20.149808883666992
-3.361856698989868 -8.726157188415527
-5.089542388916016 1.581687092781067
1.0036885738372803 19.32237434387207
-1.6245946884155273 9.524994850158691
-0.5818844437599182 11.042604446411133
-3.3284189701080322 11.471497535705566
-2.0994760990142822 8.605146408081055
-0.3149457275867462 9.19127

KeyboardInterrupt: 

In [16]:
cv2.destroyAllWindows()
cap.release()