In [1]:
import sys, os, threading, time, torchvision, json

import numpy as np
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from torch.autograd import Variable
from scipy.spatial import distance
from pandas import DataFrame
from app_utils.accessory_lib import pytorch_system_info
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.io import read_image
from tqdm import tqdm

facebox_path = os.getcwd() +'/' + 'PIPNet/FaceBoxes_PyTorch/'
sys.path.append(facebox_path)

from models.faceboxes import FaceBoxes
from data import cfg
from utils.box_utils import decode
from utils.nms_wrapper import nms
from layers.functions.prior_box import PriorBox

from PIPNet.lib.networks import *
from PIPNet.lib.functions import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
initial_time = time.time()
elapsed_time = 0
ref_frame = 0
det_frame = 0
fps = 0
input_size = 256
net_stride = 32
num_nb = 10
data_name = 'data_300W'
experiment_name = 'pip_32_16_60_r101_l2_l1_10_1_nb10'
num_lms = 68
enable_gaze = True
enable_log = False
image_scale = 0.0
offset_height = 0
offset_width = 0
det_box_scale = 1.2
eye_det = 0.15
cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
device = torch.device("cuda")

In [3]:
img_path  = 'dataset/aihub_drowsy_dataset/Validation/image_semi_restrictions'
label_path = 'dataset/aihub_drowsy_dataset/Validation/label_semi_restrictions'

img_dir_name_list = os.listdir(img_path)
img_dir_name_list.sort()

label_dir_name_list = os.listdir(label_path)
label_dir_name_list.sort()

normal_state_img_dir_list = []
normal_state_label_dir_list = []
yawn_state_img_dir_list =[]
yawn_state_label_dir_list = []
drowsy_state_img_dir_list = []
drowsy_state_label_dir_list = []


for img_dir_name in img_dir_name_list:
    if img_dir_name[11:13] == '01':
        normal_state_img_dir_list.append(img_path + os.sep + img_dir_name)

    elif img_dir_name[11:13] == '03':
        yawn_state_img_dir_list.append(img_path + os.sep + img_dir_name)

    elif img_dir_name[11:13] == '02':
        drowsy_state_img_dir_list.append(img_path + os.sep + img_dir_name)

for label_dir_name in label_dir_name_list:
    if label_dir_name[11:13] == '01':
        normal_state_label_dir_list.append(label_path + os.sep + label_dir_name)

    elif label_dir_name[11:13] == '03':
        yawn_state_label_dir_list.append(label_path + os.sep + label_dir_name)

    elif label_dir_name[11:13] == '02':
        drowsy_state_label_dir_list.append(label_path + os.sep + label_dir_name)


normal_state_img_path_list = []
normal_state_label_path_list = []
yawn_state_img_path_list =[]
yawn_state_label_path_list = []
drowsy_state_img_path_list = []
drowsy_state_label_path_list = []


for normal_state_img_dir in normal_state_img_dir_list:
    normal_state_img_file_name_list = os.listdir(normal_state_img_dir)
    normal_state_img_file_name_list.sort()

    for normal_state_img_file_name in normal_state_img_file_name_list:
        normal_state_img_path_list.append(normal_state_img_dir + os.sep + normal_state_img_file_name)

for yawn_state_img_dir in yawn_state_img_dir_list:
    yawn_state_img_file_name_list = os.listdir(yawn_state_img_dir)
    yawn_state_img_file_name_list.sort()

    for yawn_state_img_file_name in yawn_state_img_file_name_list:
        yawn_state_img_path_list.append(yawn_state_img_dir + os.sep + yawn_state_img_file_name)

for drowsy_state_img_dir in drowsy_state_img_dir_list:
    drowsy_state_img_file_name_list = os.listdir(drowsy_state_img_dir)
    drowsy_state_img_file_name_list.sort()

    for drowsy_state_img_file_name in drowsy_state_img_file_name_list:
        drowsy_state_img_path_list.append(drowsy_state_img_dir + os.sep + drowsy_state_img_file_name)


for normal_state_label_dir in normal_state_label_dir_list:
    normal_state_label_file_name_list = os.listdir(normal_state_label_dir)
    normal_state_label_file_name_list.sort()

    for normal_state_label_file_name in normal_state_label_file_name_list:
        normal_state_label_path_list.append(normal_state_label_dir + os.sep + normal_state_label_file_name)

for yawn_state_label_dir in yawn_state_label_dir_list:
    yawn_state_label_file_name_list = os.listdir(yawn_state_label_dir)
    yawn_state_label_file_name_list.sort()

    for yawn_state_label_file_name in yawn_state_label_file_name_list:
        yawn_state_label_path_list.append(yawn_state_label_dir + os.sep + yawn_state_label_file_name)

for drowsy_state_label_dir in drowsy_state_label_dir_list:
    drowsy_state_label_file_name_list = os.listdir(drowsy_state_label_dir)
    drowsy_state_label_file_name_list.sort()

    for drowsy_state_label_file_name in drowsy_state_label_file_name_list:
        drowsy_state_label_path_list.append(drowsy_state_label_dir + os.sep + drowsy_state_label_file_name)

pass

In [4]:
class Driver_State_Dataset(Dataset):
    def __init__(self, label_path_list, img_path_list):
        self.label_path = label_path_list
        self.img_path = img_path_list

    def __len__(self):
        return len(self.label_path)

    def __getitem__(self, idx):
        with open(self.label_path[idx]) as f:
            label_data = json.load(f)
            key_points_raw = (label_data['ObjectInfo']['KeyPoints']['Points'])
            key_points = [float(i) for i in key_points_raw]

        img = read_image(self.img_path[idx])

        return key_points, img.to(device)

In [5]:
normal_state_dataset = Driver_State_Dataset(normal_state_label_path_list, normal_state_img_path_list)
normal_state_dataloader = DataLoader(normal_state_dataset, batch_size=1)

In [6]:
transformations = transforms.Compose([transforms.Resize(448), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
meanface_indices, reverse_index1, reverse_index2, max_len = get_meanface(os.path.join('PIPNet/data', data_name, 'meanface.txt'), num_nb)
resnet101 = models.resnet101(weights='ResNet101_Weights.DEFAULT')
landmark_net = Pip_resnet101(resnet101, num_nb=num_nb, num_lms=num_lms, input_size=input_size, net_stride=net_stride)

landmark_net = landmark_net.to(device)
save_dir = os.path.join('PIPNet/snapshots', data_name, experiment_name)
weight_file = os.path.join(save_dir, 'epoch%d.pth' % (60 - 1))
state_dict = torch.load(weight_file, map_location=device)
landmark_net.load_state_dict(state_dict)
landmark_net.eval()
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
preprocess = transforms.Compose([transforms.Resize((input_size, input_size)), transforms.ToTensor(), normalize])


def check_keys(model, pretrained_state_dict):
    ckpt_keys = set(pretrained_state_dict.keys())
    model_keys = set(model.state_dict().keys())
    used_pretrained_keys = model_keys & ckpt_keys
    unused_pretrained_keys = ckpt_keys - model_keys
    missing_keys = model_keys - ckpt_keys
    print('Missing keys:{}'.format(len(missing_keys)))
    print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
    print('Used keys:{}'.format(len(used_pretrained_keys)))
    assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
    return True


def remove_prefix(state_dict, prefix):
    print('remove prefix \'{}\''.format(prefix))
    f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
    return {f(key): value for key, value in state_dict.items()}


def load_model(model, pretrained_path, load_to_cpu):
    print('Loading pretrained model from {}'.format(pretrained_path))
    if load_to_cpu:
        pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
    else:
        device = torch.cuda.current_device()
        pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
    if "state_dict" in pretrained_dict.keys():
        pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
    else:
        pretrained_dict = remove_prefix(pretrained_dict, 'module.')

    check_keys(model, pretrained_dict)
    model.load_state_dict(pretrained_dict, strict=False)
    return model


net = FaceBoxes(phase='test', size=None, num_classes=2)    # initialize detector
net = load_model(net, 'PIPNet/FaceBoxes_PyTorch/weights/Final_FaceBoxes.pth', False)
net.eval()
net = net.to(device)

img = cv2.imread(normal_state_img_path_list[0])

frame_width = img.shape[1]
frame_height = img.shape[0]

frame_width_resize = int((1-image_scale)*frame_width)
frame_height_resize = int((1-image_scale)*frame_height)

scale = torch.Tensor([frame_width_resize, frame_height_resize, frame_width_resize, frame_height_resize])
scale = scale.to(device)

det_frame = 0
image_border = image_scale / 2

Loading pretrained model from PIPNet/FaceBoxes_PyTorch/weights/Final_FaceBoxes.pth
remove prefix 'module.'
Missing keys:0
Unused checkpoint keys:0
Used keys:174


In [6]:
pred_eye_pos = []
actual_eye_pos = []

for label_path, img_path in tqdm(zip(drowsy_state_label_path_list, drowsy_state_img_path_list)):
    img = cv2.imread(img_path)
    img_tensor = np.float32(img)
    img_tensor -= (104, 117, 123)
    img_tensor = img_tensor.transpose(2, 0, 1)
    img_tensor = torch.from_numpy(img_tensor).unsqueeze(0)
    img_tensor = img_tensor.to(device)
    loc, conf = net(img_tensor)

    priorbox = PriorBox(cfg, image_size=(frame_height_resize, frame_width_resize))
    priors = priorbox.forward()
    priors = priors.to(device)
    prior_data = priors.data
    boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance'])
    boxes = boxes * scale
    boxes = boxes.cpu().numpy()
    scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
    inds = np.where(scores > 0.05)[0]
    boxes = boxes[inds]
    scores = scores[inds]
    order = scores.argsort()[::-1][:5000]

    boxes = boxes[order]
    scores = scores[order]

    dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
    keep = nms(dets, 0.3, force_cpu=False)
    dets = dets[keep, :]
    dets = dets[:750, :]

    for i, b in enumerate(dets):
        if b[4] > 0.5:
            b = list(map(int, b))

            det_xmin = b[0]
            det_ymin = b[1]
            det_xmax = b[2]
            det_ymax = b[3]
            det_width = det_xmax - det_xmin
            det_height = det_ymax - det_ymin

            det_xmin -= int(det_width * (det_box_scale - 1) / 2)
            det_ymin -= int(det_height * (det_box_scale - 1) / 2)
            det_xmax += int(det_width * (det_box_scale - 1) / 2)
            det_ymax += int(det_height * (det_box_scale - 1) / 2)

            det_xmin = max(det_xmin, 0)
            det_ymin = max(det_ymin, 0)
            det_xmax = min(det_xmax, frame_width - 1)
            det_ymax = min(det_ymax, frame_height - 1)

            det_width = det_xmax - det_xmin + 1
            det_height = det_ymax - det_ymin + 1
            det_crop = img[det_ymin:det_ymax, det_xmin:det_xmax, :]
            det_crop = cv2.resize(det_crop, (input_size, input_size))
            inputs = Image.fromarray(det_crop[:, :, ::-1].astype('uint8'), 'RGB')
            inputs = preprocess(inputs).unsqueeze(0)
            inputs = inputs.to(device)
            lms_pred_x, lms_pred_y, lms_pred_nb_x, lms_pred_nb_y, outputs_cls, max_cls = forward_pip(landmark_net, inputs, preprocess, input_size, net_stride, num_nb)
            lms_pred = torch.cat((lms_pred_x, lms_pred_y), dim=1).flatten()
            tmp_nb_x = lms_pred_nb_x[reverse_index1, reverse_index2].view(num_lms, max_len)
            tmp_nb_y = lms_pred_nb_y[reverse_index1, reverse_index2].view(num_lms, max_len)
            tmp_x = torch.mean(torch.cat((lms_pred_x, tmp_nb_x), dim=1), dim=1).view(-1, 1)
            tmp_y = torch.mean(torch.cat((lms_pred_y, tmp_nb_y), dim=1), dim=1).view(-1, 1)

            lms_pred_merge = torch.cat((tmp_x, tmp_y), dim=1).flatten()
            lms_pred = lms_pred.cpu().numpy()
            lms_pred_merge = lms_pred_merge.cpu().numpy()

            eye_x = (lms_pred_merge[36 * 2:48 * 2:2] * det_width).astype(np.int32) + det_xmin
            eye_y = (lms_pred_merge[(36 * 2) + 1:(48 * 2) + 1:2] * det_height).astype(np.int32) + det_ymin

            pred_eye_pos.append(np.concatenate((eye_x, eye_y)))

            with open(label_path) as f:
                label_data = json.load(f)
                key_points_raw = (label_data['ObjectInfo']['KeyPoints']['Points'])
                key_points = [float(i) for i in key_points_raw]

                actual_eye_x = (key_points[36 * 2:48 * 2:2])
                actual_eye_y = (key_points[(36 * 2) + 1:(48 * 2) + 1:2])

                actual_eye_pos.append(np.concatenate((actual_eye_x, actual_eye_y)))

pred_eye_pos = np.array(pred_eye_pos)
actual_eye_pos = np.array(actual_eye_pos)


pass

256it [00:25, 10.13it/s]
