# keypoint based experiments

In [5]:
import os
import sys
from os import path
sys.path.append(path.join(path.dirname("eval_keypoint.py"), '..'))
sys.path.append(path.join(path.dirname("eval_keypoint.py"), '../..'))


import cv2
import h5py
import glob
import time
import math
import torch
import random
import pickle
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
from torchvision import transforms
from torch.nn import functional as F
from skimage.util.shape import view_as_windows
from facerec.retinaface_detection import RetinaDetector


['value_checker_ica.py', 'deeppix_customized.py', 'visualize_five.py', 'models', 'filter-analysis-ica.ipynb', 'main-ica.py', 'fine_tune.py', 'patch_based_heatmap_maker.py', 'keypoint_experiment.ipynb', 'analytics', 'infer.ipynb', 'config.py', 'patch_model.py', 'all_patch_train.py', 'config.json', 'patch_train.py', 'remove_damaged_images.py', '__pycache__', 'spoof-heatmap.png', 'eval', 'all_keypoin.py', 'facerec', 'deeppix_train.py', 'Untitled.ipynb', 'combined-heatmap.png', 'ica_filters.py', 'infer_deeppix.py', '__init__.py', 'live-heatmap.png', 'metrics', '.ipynb_checkpoints', 'middle_patch', 'patch_dicts', 'testpoint.py', 'explore_the_implementation.ipynb', 'ica_train.py', 'spoofv4.pt', 'eval_keypoint.py', 'gaze_data_maker.py', 'focal_loss.py']


In [6]:
device_ids= [i for i in range(0, 5)]

In [7]:
class New_PatchDataset(torch.utils.data.Dataset):
    def __init__(self, path, transform=None, color_mode=['rgb'], im_size=96, patch_size=None, phase='train', training_type='None'):
        self.fnames = path
        self.color_mode = color_mode
        self.im_size = im_size
        self.transform = transform
        self.patch_size = patch_size
        self.phase = phase
        self.training_type = training_type
        self.detector = RetinaDetector()
        
        self.count = 0
    

    def __getitem__(self, idx):
        
        im, label = self.create_single_sample(idx)
    
        return im, label

    def __len__(self):
        return len(self.fnames)
    
        
    def create_single_sample(self, idx):
        
        im_path = self.fnames[idx]

        label = self.fnames[idx].split("/")#[-2]

        img = cv2.imread(im_path)
        img = cv2.resize(img, (224, 224))

        imgs = {}
        for mode in self.color_mode:
            imgs[mode] = cv2.cvtColor(img, eval(f'cv2.COLOR_BGR2{mode.upper()}'))
        
        imgs = self.getting_patches(imgs, im_path)
        if "spoof" in label or "device" in label or "print" in label or "print2" in label or "video-replay2" in label:
            ground = torch.ones(1)
        elif "print1" in label or "video-replay1" in label:
            ground = torch.ones(1)
        elif "live" in label or "real" in label:
            ground = torch.zeros(1)
#         else:
#             print(label)
        return imgs, ground
    
    def getting_patches(self, img_patch, im_path_name):

        
        new_im = img_patch['rgb']
        
        y1,y2,x1,x2 = self.getting_middle_path(im_path_name)
        new_im = new_im[y1:y2,x1:x2,:]
        new_im_m_tr = self.transform(new_im)
        
        return new_im_m_tr
    

    def divide_single_img_into_patches(self, img, size=(224, 224), patch_size=(48, 48, 3), step=1):

        img = cv2.resize(img, size)
        #             print(img.shape, len(patch_size))
        patch_grid = view_as_windows(img, patch_size, step)
        #         print(patch_grid.shape)
        return patch_grid
    
    def get_another_landmark(self):
        img = np.array([[[ 73.25691 , 106.048355],
                        [149.7507  , 105.70673 ],
                        [111.53898 , 141.779   ],
                        [ 79.02094 , 186.59473 ],
                        [141.23201 , 186.73901 ]]], dtype=np.float32)
        
        return img
    
    def getting_middle_path(self,im_path):
        """
            landmark 0: eye-1     (x,y)
            landmamrk 1: eye-2    (x,y)
            landmark 2: nose      (x,y)
            landmark 3: lip-left  (x,y)
            landmar 4: lip-right  (x,y)
        """
        
        # extracting landmarks
        im, faces, landmarks = self.detector.infer(im_path, resize=[224,224])
        
        if np.array(landmarks).any() == False:
            landmarks = self.get_another_landmark()
            with open("not_found.txt","a") as file:
                file.write(f"{im_path}\n")

        patches = []
        
        half = self.im_size//2 # for 96, half = 48
        
        middle_point_x = (landmarks[0][0][0]+landmarks[0][1][0]) //2
        middle_point_x = (middle_point_x+landmarks[0][2][0]) // 2
        middle_point_y = (landmarks[0][0][1]+landmarks[0][1][1])//2 
        middle_point_y = (middle_point_y+ landmarks[0][2][1]) //2
        
        #middle_point_x -=10
        middle_point_y -=15
        

        if middle_point_y - half < 0:
            y1 = 0
            y2 = self.im_size
        elif middle_point_y + half > im.shape[0]:
            y2 = im.shape[0]
            y1 = im.shape[0] - self.im_size
        else:
            y1 = int(middle_point_y) - half
            y2 = int(middle_point_y) + half
        if middle_point_x - half < 0:
            x1 = 0
            x2 = self.im_size
        elif middle_point_x + half > im.shape[1]:
            x2 = im.shape[0]
            x1 = im.shape[0] - self.im_size
        else:
            x1 = int(middle_point_x) - half
            x2 = int(middle_point_x) + half
        
        return y1,y2, x1,x2

In [9]:
def eval_data(model, dataloader, batch_size, is_keras_model=False, keras_model_name=None):
    
    if is_keras_model:
        if keras_model_name == None:
            keras_model = laod_keras_model()
        else:
            keras_model = laod_keras_model(keras_model_name)
    
    middle_accuracy = 0
    
    #model = model.double()
    model = model.eval()
    
    total_steps = 0
    
    tq = tqdm(dataloader)

    for idx, (imgs, label) in enumerate(tq):

        if is_keras_model:
            preds, correct    = infer_keras_model(keras_model, imgs, label)
            running_corrects = correct 
        else:
            img = imgs.cuda(0)
            outputs = model(img)
            probs, preds = torch.max(outputs, 1)
            label = label.reshape(-1)
            
            running_corrects = torch.sum(probs.cpu() == label.data)
            running_corrects = np.sum(preds.cpu().detach().numpy() == label.data.detach().numpy())
        
        middle_accuracy += running_corrects #.item()
        
        total_steps += img.size(0)
        
        temp_acc = middle_accuracy/total_steps
        
        tq.set_postfix(iter=idx, acc=temp_acc)
    
    total_accuracy = middle_accuracy/total_steps#(batch_size*total_steps)
    print(f"Total num of acc: {middle_accuracy}| Total iter {total_steps}")
    print(f"Total eval acc: {total_accuracy}")
    
    return total_accuracy

In [12]:
def oulu(hard_protocol='Protocol_4'):
    
    oulu_train_path = "/home/ec2-user/SageMaker/dataset/spoof-data/oulu"
    
    test_imgs = glob.glob(f"{oulu_train_path}/{hard_protocol}/Test/**/*") 

    print("OULU data test vs val: ",len(test_imgs))
    return test_imgs

In [10]:
def load_model(im_size, model_name="oulu", modeltype="patch"):
    if modeltype == "patch":
        from models.patch_based_cnn.model import PatchModel
        model = PatchModel(im_size)
        model = torch.nn.DataParallel(model, device_ids)
        model_path = f"../ckpts/patch_based_cnn/{im_size}/{model_name}_model.pth"
        model.load_state_dict(torch.load(model_path))
        
    elif modeltype == "deeppix":
        model_path = f'../ckpts/deeppix/six_channel/proto4/proto4-best.pt'
        model = torch.load(model_path)
        model = torch.nn.DataParallel(model.module, device_ids)
        
    model = model.cuda(device_ids[0])
    model.eval()
    
    return model

In [13]:
oulu_model = load_model(48, model_name="oulu_Protocol_4")
images = oulu("Protocol_4")
dataset = New_PatchDataset(images, transform=transform, im_size=image_size, color_mode=['rgb'])
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

OULU data test vs val:  47839
