# I. ROI Extraction for Palmprint Images
Tested on Tongji dataset - Full palmprint

In [None]:
# Environment
conda env update --file environment.yml --prune

In [None]:
pip install torch

In [None]:
pip install torchvision

In [1]:
import numpy as np
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function, Variable
from torchvision import transforms, models
from torch.utils.data import Dataset
import torchvision
from torchvision import datasets
import itertools
import matplotlib.pyplot as plt
import time
from PIL import Image, ImageDraw, ImageFont
import copy
torch.set_default_dtype(torch.float64)
import numpy as np
import cv2

from networks.ROILAnet import ROILAnet
from networks.TPSGridGen import TPSGridGen

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
ROIModelPath = 'weights/ROI_extractor_augmented_TJ-NTU.pt' # path to the trained ROI model
#CNNModelPath = 'weights/resnet18_tongji_unfreezed_extractor.pt'

In [4]:
# For the tongji data set we have 600 palm samples of 300 different people.
class_names = []
for i in range(600):
    class_names.append(f'user_{i}')

## Defnition of Loaders

In [5]:
def loadROIModel(weightPath: str = None):
    """
    @weightPath: path to file ROILAnet() weights
    load localization network with pretrain weights
    """
    model = ROILAnet()
    model.load_state_dict(torch.load(weightPath, map_location=torch.device(device)))
    model = model.to(device)
    model.eval()
    model.requires_grads=False
    return model

In [6]:
def getThinPlateSpline(target_width: int = 112, target_height: int = 112) -> torch.Tensor:
    """
    @target_width: desired I_ROI output width
    @target_height: desired I_ROI output height
    make instance of TPS grid generator
    """
    # Create points on palm with TPS
    target_control_points = torch.Tensor(list(itertools.product(
        torch.arange(-1.0, 1.00001, 1.0),
        torch.arange(-1.0, 1.00001, 1.0),
    )))
    gridgen = TPSGridGen(target_height=target_height, target_width=target_width, target_control_points=target_control_points)
    gridgen = gridgen.to(device)
    return gridgen

In [7]:
def getOriginalAndResizedInput(path: str = None) -> (np.ndarray, torch.Tensor, torch.Tensor):
    """
    @path: Image needs to be loaded from database
     This function loads a resized image from a directory given in the path.
     After resizing to 56x56 pixels, the original and resized images will be returned as triples (PILMain, source_image, resized)
    """
    if path is None:
        return (None, None)
    
    # Define transformer for resized input of feature extraction CNN
    resizeTranformer = transforms.Compose([
            transforms.Resize((56,56)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    PILMain = Image.open(path).convert(mode = 'RGB') # load image in PIL format
    sourceImage = np.array(PILMain).astype('float64') # convert from PIL to float64
    sourceImage = transforms.ToTensor()(sourceImage).unsqueeze_(0) # add first dimension, which is batch dim
    sourceImage = sourceImage.to(device) # load to available device

    resizedImage = resizeTranformer(PILMain)
    resizedImage = resizedImage.view(-1,resizedImage.size(0),resizedImage.size(1),resizedImage.size(2))
    resizedImage = resizedImage.to(device) # load to available device
    return (PILMain, sourceImage,resizedImage)

In [8]:
def getThetaHat(resizedImage: torch.Tensor = None, model = None) -> torch.Tensor: 
    """
    @resizedImage: Cut image
    @model: ROI Localisation network, generate a resized theta vector Image: the image needs to be loaded from the database via the getOriginalAndResizedInput function. 
     Here the theta vector is computed using the pre-trained network. The vector has size [9, 2] -> represents 9 pairs of x and y values
    """
    if resizedImage is None:
        return None
    
    with torch.no_grad(): # cancel gradients try to predict by ROI
        theta_hat = model.forward(resizedImage)
    theta_hat = theta_hat.view(-1, 2, 9) # split into vectors x and y -> theta_hat is initially a vector like [xxxxxxxxyyyyyyyyyy]
    theta_hat = torch.stack((theta_hat[:,0], theta_hat[:,1]),-1)
    return theta_hat

In [9]:
def sampleGrid(theta_hat: torch.Tensor = None, sourceImage: torch.Tensor = None, target_width: int = 112, target_height: int = 112 ) -> torch.Tensor:
    """
    @theta_hat: theta vector of the normalized xy coordinate pair
    @sourceImage: the original image without cropping or resizing
    @target_width: wide target IROI output
    @target_height: tall target IROI output
     Mesh samples from a given theta vector, source image, and mesh generator
    """
    gridgen = getThinPlateSpline(target_width, target_height)
    # Create grid points from the calculation of theta_hat vector
    source_coordinate = gridgen(theta_hat)
    # Create target grid - with target height and target width
    grid = source_coordinate.view(-1, target_height, target_width, 2).to(device)
    # Sample ROI from input image and generate T(theta_hat)
    target_image = F.grid_sample(sourceImage, grid,align_corners=False)
    return target_image

In [10]:
def printExtraction(target_image: torch.Tensor = None, source_image = None):
    """
    @source_image: print source_image in PIL format
    @target_image: print target_image as tensor (ROI)
    """
    # Prepare to display -> return from gpu if needed
    target_image = target_image.cpu().data.numpy().squeeze().swapaxes(0, 1).swapaxes(1, 2)
    target_image = Image.fromarray(target_image.astype('uint8'))
    plt.imshow(source_image)
    plt.show() # show original image
    plt.imshow(target_image)
    plt.show() # show ROI image

In [11]:
def loadCNNModel(weightPath: str = None):
    """
    @weightPath: path to ROILAnet() weights loadCNNModel file with pretrained weights
    """
    model = models.resnet18(pretrained=False)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, len(class_names))
    model.load_state_dict(torch.load(weightPath, map_location=torch.device(device)))
    model.to(device)
    return model

In [12]:
def getIROI(model, input):
    resizedImage = F.interpolate(input, (56, 56))
    theta_hat = getThetaHat(resizedImage=resizedImage, model=model) # generate theta with normalized ROI
    IROI = sampleGrid(theta_hat=theta_hat, sourceImage=input, target_width=224, target_height=224) # get ROI from original image
    IROI.to(device)
    return IROI

In [13]:
def markImage(image, theta):
    nimg = np.array(image)
    ocvim = cv2.cvtColor(nimg, cv2.COLOR_RGB2BGR)
    for idx, coord in enumerate(theta):
        currX = coord[0]
        currY = coord[1]
        x = int((ocvim.shape[1] - 1) / (1 + 1) * (currX - 1) + ocvim.shape[1])
        y = int((ocvim.shape[0] - 1) / (1 + 1) * (currY - 1) + ocvim.shape[0])
        ocvim = cv2.circle(ocvim,(x,y),6,(200,0,0),2)
        ocvim = cv2.putText(
            ocvim, # numpy array on which text is written
            str(idx), # text
            (x, y), # position at which writing has to start
            cv2.FONT_HERSHEY_SIMPLEX, # font family
            1, # font size
            (209, 80, 0, 255), # font color
            3)
    ocvim = ocvim[...,::-1]
    return ocvim

In [14]:
def getOriginalAndResizedInput(PILMain) -> (np.ndarray, torch.Tensor, torch.Tensor):
    """
    @path: image path retrieved from database
     This function loads a resized image from a directory provided in the path.
     After resizing to 56x56 pixels, the original and resized images will be returned.
     as triples (PILMain, source_image, resizedImage)
    """
    if PILMain is None:
        return (None, None)
    
    # Define transformer for resized input of CNN feature extraction
    resizeTranformer = transforms.Compose([
            transforms.Resize((56,56)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    # PILMain = PILMain.convert(mode = 'RGB') # load image in PIL format
    sourceImage = np.array(PILMain).astype('float64') # convert from PIL to float64
    sourceImage = transforms.ToTensor()(sourceImage).unsqueeze_(0) # add first dimension, which is batch dim
    sourceImage = sourceImage.to(device) # load to available device

    resizedImage = resizeTranformer(PILMain)
    resizedImage = resizedImage.view(-1,resizedImage.size(0),resizedImage.size(1),resizedImage.size(2))
    resizedImage = resizedImage.to(device) # load to available device
    return (PILMain, sourceImage,resizedImage)

In [None]:
# Load localisation network for ROI extraction
localisationNetwork = loadROIModel(ROIModelPath) # load localisation network

In [16]:
from PIL import Image  
import PIL
from torchvision.utils import save_image
import torch
import torchvision
import matplotlib.pyplot as plt
from scipy.misc import face

# CROP ROI

In [17]:
def predictAndShow(path: str, pname, p_classname):
    """
    Image prediction
    @path: path to palm image in (Tongji) dataset
    """
    grayTransformer = transforms.Compose([
                    transforms.CenterCrop((224,224)),
                    transforms.Grayscale(num_output_channels=3),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    CNNtransformer = transforms.Compose([
        transforms.Grayscale(),
        transforms.CenterCrop((224,224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3,1,1)),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    plt.figure(figsize=(20,20)) # specifying the overall grid size
 
    # Check user 0
    inputPIL = Image.open(path).convert('RGB')
    # plt.subplot(4,1,1)
    # plt.imshow(inputPIL)
    # plt.title('Input Image')
    classes=['user_0']
    (PILMain, sourceImage,resizedImage) = getOriginalAndResizedInput(inputPIL)
    sourceImage = torch.stack([sourceImage.squeeze()])
    resizedImage = torch.stack([resizedImage.squeeze()])
    # Get normalized coordinates
    theta_hat = getThetaHat(resizedImage, localisationNetwork)
    # plt.subplot(4,1,2)
    # plt.imshow(markImage(inputPIL, theta_hat[0]))
    # plt.title('Estimated points')
    # Get all ROIs
    IROI = sampleGrid(theta_hat=theta_hat, sourceImage=sourceImage, target_width=300, target_height=300)
    IROI = IROI[0]
    plt.subplot(4,1,3)

    plt.imshow((IROI.cpu()[0]),cmap='gray')
    # plt.title('ROI Extraction')
    # save_image(IROI.cpu()[0], 'img1.tiff')
    plt.axis('off')
    # plt.tight_layout(pad=0)
     
    plt.savefig('ROI_session1/'+p_classname+'/'+ pname ,bbox_inches='tight',pad_inches=0)

In [18]:
import os
path1 = "session1/"
path2 = "session2/"

In [19]:
files1= os.listdir(os.path.expanduser(path1))

# files1 = files1.sort()
files1 = sorted(files1)

print(files1[0])
files2= os.listdir(os.path.expanduser(path2))
# Print(len(files2))
files2 = sorted(files2)
print(files2[0])

00001.tiff
00001.tiff


In [20]:
import shutil

In [None]:
# Split data for hand 1 of session 1
user = 1
hand1 = 0
hand2 = 0
count = 0

for img in files1:
    if (count == 0 & hand1 == 0):
        os.mkdir('hand1_session1/' + str(user))
        shutil.copy2('session1/' + img, 'hand1_session1/' + str(user) + '/')
        count = count + 1
        hand1 = hand1 + 1
    else:
        if (count <= 9):
            shutil.copy2('session1/' + img, 'hand1_session1/' + str(user) + '/')
            hand1 = hand1 + 1
            count = count + 1
        else:
            if(count == 11):
                count = count + 1
                hand2 = hand2 + 1
            else:
                if(count <= 20):
                    hand2 = hand2 + 1
                    count = count + 1
                    if(count == 20):
                        count = 0
                        hand1 = 0
                        hand2 = 0
                        user = user + 1

In [None]:
# Rename the image to add suffix (1) which is the image of session1
from glob import glob
path3 = "session12/"

for class_ in range(1,301):
    full_path_directory = 'session12/' + str(class_) + '/'
    # print(full_path_directory)
    # print(class_)
    path3 = full_path_directory
    files3 = os.listdir(os.path.expanduser(path3))
    for img in files3:
        # print(img)
        x = img.split(".")
        img_new = x[0] + " (1)."+x[1]
        # print(x[0])
        old_file = os.path.join(full_path_directory, img)
        new_file = os.path.join(full_path_directory, img_new)
        # print(old_file)
        # print(new_file)
        os.rename(old_file, new_file)
print('Finish!')

In [None]:
# Split data for hand 1 of session 2
user = 1
hand1 = 0
hand2 = 0
count = 0

for img in files2:
    if (count == 0 & hand1 == 0):
        # Moving
        shutil.copy2('session2/' + img, 'session12/' + str(user) + '/')
        count = count + 1
        hand1 = hand1 + 1
    else:
        if (count <= 9):
            shutil.copy2('session2/' + img, 'session12/' + str(user) + '/')
            hand1 = hand1 + 1
            count = count + 1
        else:
            if(count == 11):
                count = count + 1
                hand2 = hand2 + 1
            else:
                if(count <= 20):
                    hand2 = hand2 + 1
                    count = count + 1
                    if(count == 20):
                        count = 0
                        hand1 = 0
                        hand2 = 0
                        user = user + 1

# Perform ROI extraction for hand 1 images, split to avoid RAM overflow

In [None]:
from glob import glob
for class_ in range(1,28):
    class_imgs_path = glob(f'hand1_session1/{str(class_)}/*.tiff')
    
    os.mkdir('ROI_session1/' + str(class_))
    for img in class_imgs_path:
        result = os.path.splitext(str(os.path.basename(img)))[0]
        # print(result + ".JPEG")
        predictAndShow(img, result + ".JPEG", str(class_))

In [None]:
from glob import glob
for class_ in range(28,55):
    class_imgs_path = glob(f'hand1_session1/{str(class_)}/*.tiff')
    
    os.mkdir('ROI_session1/' + str(class_))
    for img in class_imgs_path:
        result = os.path.splitext(str(os.path.basename(img)))[0]
        # print(result + ".JPEG")
        predictAndShow(img, result + ".JPEG", str(class_))

In [None]:
from glob import glob
for class_ in range(55,82):
    class_imgs_path = glob(f'hand1_session1/{str(class_)}/*.tiff')
    
    os.mkdir('ROI_session1/' + str(class_))
    for img in class_imgs_path:
        result = os.path.splitext(str(os.path.basename(img)))[0]
        # print(result + ".JPEG")
        predictAndShow(img, result + ".JPEG", str(class_))

In [None]:
from glob import glob
for class_ in range(82,109):
    class_imgs_path = glob(f'hand1_session1/{str(class_)}/*.tiff')
    
    os.mkdir('ROI_session1/' + str(class_))
    for img in class_imgs_path:
        result = os.path.splitext(str(os.path.basename(img)))[0]
        # print(result + ".JPEG")
        predictAndShow(img, result + ".JPEG", str(class_))

In [None]:
from glob import glob
for class_ in range(109,135):
    class_imgs_path = glob(f'hand1_session1/{str(class_)}/*.tiff')
    
    os.mkdir('ROI_session1/' + str(class_))
    for img in class_imgs_path:
        result = os.path.splitext(str(os.path.basename(img)))[0]
        # print(result + ".JPEG")
        predictAndShow(img, result + ".JPEG", str(class_))

In [None]:
from glob import glob
for class_ in range(135,163):
    class_imgs_path = glob(f'hand1_session1/{str(class_)}/*.tiff')
    
    os.mkdir('ROI_session1/' + str(class_))
    for img in class_imgs_path:
        result = os.path.splitext(str(os.path.basename(img)))[0]
        # print(result + ".JPEG")
        predictAndShow(img, result + ".JPEG", str(class_))

In [None]:
from glob import glob
for class_ in range(163,191):
    class_imgs_path = glob(f'hand1_session1/{str(class_)}/*.tiff')
    
    os.mkdir('ROI_session1/' + str(class_))
    for img in class_imgs_path:
        result = os.path.splitext(str(os.path.basename(img)))[0]
        # print(result + ".JPEG")
        predictAndShow(img, result + ".JPEG", str(class_))

In [None]:
from glob import glob
for class_ in range(191,219):
    class_imgs_path = glob(f'hand1_session1/{str(class_)}/*.tiff')
    
    os.mkdir('ROI_session1/' + str(class_))
    for img in class_imgs_path:
        result = os.path.splitext(str(os.path.basename(img)))[0]
        # print(result + ".JPEG")
        predictAndShow(img, result + ".JPEG", str(class_))

In [None]:
from glob import glob
for class_ in range(219,241):
    class_imgs_path = glob(f'hand1_session1/{str(class_)}/*.tiff')
    
    os.mkdir('ROI_session1/' + str(class_))
    for img in class_imgs_path:
        result = os.path.splitext(str(os.path.basename(img)))[0]
        # print(result + ".JPEG")
        predictAndShow(img, result + ".JPEG", str(class_))

In [None]:
from glob import glob
for class_ in range(241,269):
    class_imgs_path = glob(f'hand1_session1/{str(class_)}/*.tiff')
    
    os.mkdir('ROI_session1/' + str(class_))
    for img in class_imgs_path:
        result = os.path.splitext(str(os.path.basename(img)))[0]
        # print(result + ".JPEG")
        predictAndShow(img, result + ".JPEG", str(class_))

In [None]:
from glob import glob
for class_ in range(269,301):
    class_imgs_path = glob(f'hand1_session1/{str(class_)}/*.tiff')
    
    os.mkdir('ROI_session1/' + str(class_))
    for img in class_imgs_path:
        result = os.path.splitext(str(os.path.basename(img)))[0]
        # print(result + ".JPEG")
        predictAndShow(img, result + ".JPEG", str(class_))