In [1]:
from torchvision import models
import torch
# torch.random.manual_seed(10)
# torch.random.get_rng_state()

In [2]:
'''
from:
https://github.com/kazuto1011/cifar10-pytorch
with an added BatchNorm layer at the start
'''

import math
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    ''' Basicblock: conv-batchnorm-relu with a residual connection. '''

    def __init__(self, n_in, n_out, stride=1):
        super(BasicBlock, self).__init__()
        self.connection = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(n_in, n_out, 3, stride, 1, bias=False)),
            ('norm1', nn.BatchNorm2d(n_out)),
            ('relu1', nn.ReLU(inplace=True)),
            ('conv2', nn.Conv2d(n_out, n_out, 3, 1, 1, bias=False)),
            ('norm2', nn.BatchNorm2d(n_out)),
        ]))
        self.relu = nn.ReLU(inplace=True)
        self.downsample = nn.Sequential(
            nn.Conv2d(n_in, n_out, 1, stride, bias=False),
            nn.BatchNorm2d(n_out),
        )
        self.stride = stride

    def forward(self, x):
        mapping = self.connection(x)
        if self.stride != 1:
            x = self.downsample(x)
        return self.relu(mapping + x)


class ResidualBlock(nn.Module):
    ''' n_block times the basicblock. '''

    def __init__(self, n_in, n_out, n_block, stride=1):
        super(ResidualBlock, self).__init__()
        self.blocks = nn.Sequential()
        self.blocks.add_module('block0', BasicBlock(n_in, n_out, stride))
        for i in range(n_block - 1):
            block = BasicBlock(n_out, n_out)
            self.blocks.add_module('block{}'.format(i + 1), block)

    def forward(self, x):
        return self.blocks(x)


class ResNetCifar10(nn.Module):
    ''' Residual network built for the CIFAR10 database. '''

    def __init__(self, n_classes=2, n_block=3):
        super(ResNetCifar10, self).__init__()
        ch = [16, 32, 64]
        self.bn0 = nn.BatchNorm2d(3, affine=True)
        self.features = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(3, ch[0], 3, 1, 1, bias=False)),
            ('norm1', nn.BatchNorm2d(ch[0])),
            ('relu1', nn.ReLU(inplace=True)),
            ('resb1', ResidualBlock(ch[0], ch[0], n_block)),
            ('resb2', ResidualBlock(ch[0], ch[1], n_block, 2)),
            ('resb3', ResidualBlock(ch[1], ch[2], n_block, 2)),
            ('avgpl', nn.AvgPool2d(8)),
        ]))
        self.fc = nn.Linear(ch[2], n_classes)
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        x = self.bn0(x)
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [3]:
''' Patch-based deep learning spoof detection. '''

import os
import dlib
import torch
import numpy as np
from scipy.misc import imresize
import torch.nn.functional as F
from torch.autograd import Variable
# from utils import networks

# from document_check_framework.monitoring import logger
# log = logger.get_logger(__name__)

class EmptyVideoException(Exception):
    ''' Exception to throw when the video is empty. '''
    pass

class FaceNotFoundException(Exception):
    ''' Exception to throw when dlib face detection fails. '''
    pass

class NonRGBImageException(Exception):
    ''' Exception to throw when the image is not an RGB image. '''
    pass


class spoof_detector(object):
    ''' The object used to detect spoofs. '''

    def __init__(self, network_location):
        # size and patch_size are required to be 128 and 32 respectively
        # size is the edge length to resize the image to before taking patches
        # the n_blocks however can be changed freely - lower is faster
        # it indicates the amount of patches to take horizontal and vertical
        # recommended range for resolution is somewhere between 8 and 32
        self.size = 128
        self.patch_size = 32
        self.n_blocks = 32
        self.threshold = 0.90
        self.threshold = float(self.threshold)
        self.detector = dlib.get_frontal_face_detector()
        self.network = ResNetCifar10(4)
#         filename = os.path.join(network_location, 'network_dict')
        filename = network_location
        state_dict = torch.load(filename, map_location=lambda s, loc: s)
        self.network.load_state_dict(state_dict)
        self.network.eval()
        self.network.train(False)

    def detect_face(self, img):
        ''' Detect faces in the image and return crop of the largest face. '''
        # detect face in all four rotations
        # img is an RGB image with shape: (h, w, 3)
        for r in range(4):
            # rotate 90 * r degrees
            # if we rotate we also need to copy otherwise
            # dlib and pytorch receive a non-continuous ndarray
            rotated_img = np.rot90(img, r).copy()
            # detect faces
            dets, conf, _ = self.detector.run(rotated_img)
            # filter detections with conf <= 0
            dets = [d for d, c in zip(dets, conf) if c > 0]
            print(dets, conf)
            # if at least one face was detected
            if len(dets) > 0:
                # calculate face sizes per detection
                size = [d.bottom() - d.top() + d.right() - d.left() for d in dets]
                # select detection corresponding with largest face
                det = dets[np.argmax(size)]
                # prevent top and left from becoming negative values
                top, left = max(det.top(), 0), max(det.left(), 0)
                bottom, right = det.bottom(), det.right()
                # return cropped image
                return rotated_img[top:bottom, left:right]
        return None

    def create_heatmap(self, face_crop):
        '''
        Predict uniformly sampled patches from the face_crop.
        First create a batch containing all patches.
        Then run the batch with patches through the network.
        Return the output predictions reshaped as a heatmap.
        '''
        # face_crop is an RGB image with shape: (self.size, self.size, 3)
        face_crop = face_crop.transpose(2, 0, 1)
        _, h, w = face_crop.shape
        # create a batch with uniformly sampled patches
        hr = float(h - self.patch_size) / (self.n_blocks - 1)
        wr = float(w - self.patch_size) / (self.n_blocks - 1)
        # create empty batch
        n_patches = self.n_blocks**2
        batch = np.zeros((n_patches, 3, self.patch_size, self.patch_size))
        for y in range(self.n_blocks):
            for x in range(self.n_blocks):
                y1 = int(y * hr)
                y2 = int(y * hr + self.patch_size)
                x1 = int(x * wr)
                x2 = int(x * wr + self.patch_size)
                patch_float = face_crop[:, y1:y2, x1:x2] / 255.0
                batch[y * self.n_blocks + x, :, :, :] = patch_float
        # network makes a prediction per patch
        inputs = Variable(torch.from_numpy(batch).float(), volatile=True)
        outputs = F.softmax(self.network(inputs)).data.cpu().numpy()
        # reshape outputs into a heatmap where channels are class predictions:
        # (genuine_sample, picture_attack, screen_attack and document_attack)
        return outputs.reshape((self.n_blocks, self.n_blocks, -1))

    def spoof(self, image):
        '''
        Turns input image into a spoof attempt prediction.
        Note: a score of 1.0 indicates a genuine samples,
        while a score of 0.0 indicates a spoofing attempt!
        '''
        # filter non RGB images
        if len(image.shape) != 3 or image.shape[2] != 3:
            raise NonRGBImageException
        # detect face in image
        face_crop = self.detect_face(image)
        if face_crop is None:
            raise FaceNotFoundException
        # resize image
        face_crop = imresize(face_crop, (self.size, self.size))
        # spoof prediction
        heatmap = self.create_heatmap(face_crop)
        # transform heatmap into a genuine sample prediction
        spoof_score = np.mean(heatmap[:, :, 0])
        return spoof_score

In [4]:
''' API functions to call from the main service. '''

import io
import os
# import av
import PIL
import numpy as np
# from document_check_framework.imago import Imago
# from document_check_framework.monitoring import logger
# from ._methods.spoof_detection import spoof_detector, \
#     FaceNotFoundException, NonRGBImageException, EmptyVideoException
# log = logger.get_logger(__name__)

# EXIF_ORIENTATION_TAG = 274

def score_photo(uuid, detector=None):
    ''' Download image based on uuid and returns predicted spoof score. '''
    if detector is None:
        detector = spoof_detector(model_dict)
    # download image using Imago and transform to numpy with PIL
    img = Imago.Instance().get_object(uuid, 'live_photos')
    img = PIL.Image.open(io.BytesIO(img.content))
    # rotate image based on exif data
    image = np.array(img, dtype=np.uint8)
    if hasattr(img, '_getexif') and img._getexif() is not None:
        if EXIF_ORIENTATION_TAG in img._getexif():
            orientation = img._getexif()[EXIF_ORIENTATION_TAG]
            if orientation in [3, 4]:
                image = np.rot90(image, 2)
            if orientation in [5, 6]:
                image = np.rot90(image, 3)
            if orientation in [7, 8]:
                image = np.rot90(image, 1)
    # detect spoof and return score
    return detector.spoof(image)

In [5]:
import os, cv2
os.getcwd()
image = '0a9a1f3b-e79e-4043-81bd-0518b683084e.jpg'
image = '08a7ac11-a455-465a-9a6e-f2fded3a339b.jpg'
im = cv2.imread(image, cv2.IMREAD_UNCHANGED)
print(im.shape)

(1280, 768, 3)


In [6]:
import torch
filename = '/media/dataserver/workspace/blanca/MSDNet-GCN/test/utils/network_dict'
# state_dict = torch.load(filename, map_location=lambda s, loc: s)

# network = ResNetCifar10(4)
# network.load_state_dict(state_dict)
# for i, j in state_dict.items(): print(i)
detector = spoof_detector(filename)
detector.spoof(im)

# [rectangle(280,321,651,692)] 0.567812
# [rectangle(238,404,610,775)] 1.20046

[rectangle(238,404,610,775)] 1.20046


`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.


0.94968033

In [7]:
def read_image_OpenCV(path):
        im = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        im = cv2.resize(im, (256, 256))
        return im

input_path = '/workspace/blanca/training/train/good_fit/300w01_indoor_203.png'
im = read_image_OpenCV(input_path)