# Check GPU version.

In [None]:
!nvidia-smi

# Mount google drive.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Install TensorFlow-1.14 GPU.

In [None]:
# Select TensorFlow-1.x version.
%tensorflow_version 1.x

# Uninstall previous TensorFlow version.
!pip uninstall tensorflow -y 1>/dev/null 2>/dev/null 
!pip uninstall tensorflow-gpu -y 1>/dev/null 2>/dev/null 

# Install TensorFlow-1.14 and Keras-2.2.4.
!pip install --upgrade tensorflow-gpu==1.14.0 1>/dev/null 2>/dev/null 
!pip install --upgrade tensorflow==1.14.0 1>/dev/null 2>/dev/null 
!pip install --upgrade keras==2.2.4 1>/dev/null 2>/dev/null 

# Restart the runtime.

# Set the root directory.

In [None]:
import os

root_dir = '/content/'
os.chdir(root_dir)

!ls -al

# Import TensorFlow-1.14.

In [None]:
try:
  %tensorflow_version 1.x
except Exception:
  pass

import numpy as np
np.random.seed(7)

import tensorflow as tf
print(tf.__version__)

### Install keras_vggface module.

In [None]:
!pip install keras_vggface 1>/dev/null 2>/dev/null 

### Download ResNet-50 model trained using VGG Face-2 dataset.

In [None]:
image_shape = (224, 224, 3)
image_load_shape = (256, 256, 3)

In [None]:
aligned_image_dir = '/content/drive/My Drive/aligned_images/'
test_image_dir = '/content/drive/My Drive/test_images/'

# Create ResNet-50 (trained on VGG Face-2 dataset) based feature extractor.

In [None]:
from keras_vggface.vggface import VGGFace

model = VGGFace(model='resnet50', include_top=False, input_shape=image_shape, pooling='avg')

### Normalize input image.

In [None]:
from keras_vggface import utils

def normalize_image(image_filename): 
  input_image = image.load_img(image_filename, target_size=(image_shape[0], image_shape[1])) 
  output_image = image.img_to_array(input_image)
  output_image = np.expand_dims(output_image, axis=0) 
  output_image = utils.preprocess_input(output_image, version=1)
  return( output_image )

### Compute image features.

In [None]:
def compute_image_features(model, image_filename): 
  current_image = normalize_image(image_filename)
  current_features = model.predict(current_image)
  current_features = current_features[0]  
  current_features = current_features / np.linalg.norm(current_features)
  return( current_features )

# Register persons using single image.

In [None]:
from keras.preprocessing import image

def compute_features(model, aligned_image_dir):
  image_features = {}
  image_filenames = os.listdir(aligned_image_dir)
  for image_filename in image_filenames:
    identifier = image_filename.split('.jpg')
    identifier = identifier[0]

    image_path = os.path.join(aligned_image_dir, image_filename)
    current_features = compute_image_features(model, image_path)

    image_features[identifier] = current_features
  return(image_features)

In [None]:
image_features = compute_features(model, aligned_image_dir)

# Identify person using pre-computed image features.

In [None]:
def identify_person(image_features, current_features, threshold=100):
  person_name = 'unknown'
  minimum_distance = float('inf')

  for person in image_features:
    person_features = image_features[person]
    current_distance = np.linalg.norm(person_features - current_features)

    if(current_distance < minimum_distance):
      minimum_distance = current_distance
      person_name = person

  if(minimum_distance > threshold):
    person_name = 'unknown'

  return(person_name, minimum_distance)

# Test one-shot recognition.

In [None]:
def identify_persons(image_features, test_image_dir):
  image_filenames = os.listdir(test_image_dir)
  for image_filename in image_filenames:

    identifier = image_filename.split('.jpg')
    identifier = identifier[0]

    image_path = os.path.join(test_image_dir, image_filename)
    current_features = compute_image_features(model, image_path)
    person_name, minimum_distance = identify_person(image_features, current_features)
    print('**************************************************')
    print('ground truth -', identifier)
    print('predicted -',person_name)
    print('distance -',minimum_distance)
    print('**************************************************')

In [None]:
identify_persons(model, image_features, test_image_dir)

# Evaluate the model.

### Download and extract aligned test dataset from goolge drive.

In [None]:
!gdown --id 1WEftISRMb-8v9iIFomzCzSAzbEvxOKFu # aligned_vggface2_test.tar.gz
!ls -al

In [None]:
!tar -xzf aligned_vggface2_test.tar.gz
!ls -al

In [None]:
!rm -rf aligned_vggface2_test.tar.gz
!ls -al

### Check downloaded test dataset.

In [None]:
!ls -al
!ls -l test/ | grep ^d | wc -l

# Preprocess input image.

In [None]:
def center_crop(input_image , target_size):
  target_height, target_width = target_size
  
  height_offset = (input_image.shape[0] - target_height) // 2 
  width_offset = (input_image.shape[1] - target_width) // 2 

  targate_image = input_image[height_offset:(height_offset+target_height), width_offset:(width_offset+target_width) ]
  return(targate_image)

In [None]:
import cv2
from keras_vggface import utils

def preprocess_test_image(image_filename): 
  input_image = cv2.imread(image_filename) 
  input_image = cv2.resize(input_image, (image_load_shape[0], image_load_shape[1])) 
  input_image = center_crop(input_image, (image_shape[0], image_shape[1]))  
  input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
  input_image = input_image.astype(np.float)
  input_image = np.expand_dims(input_image, axis=0) 
  input_image = utils.preprocess_input(input_image, version=1)
  return( input_image )

# Compute test image features.

In [None]:
def compute_test_image_features(model, image_filename): 
  current_image = preprocess_test_image(image_filename)
  current_features = model.predict(current_image)
  current_features = current_features[0]  
  current_features = current_features / np.linalg.norm(current_features)
  return( current_features )

### Evaluate the model on test dataset.

In [None]:
test_image_dir = '/content/test/'
show_accuracy = True

In [None]:
minimum_similarity = 0.5
maximum_distance = 0.5

In [None]:
separator = ' ,'

verification_filename = 'vggface2_verification'
verification_file = open(verification_filename, 'w')

In [None]:
class_names = os.listdir(test_image_dir)
number_of_images = 0
positive_distance_images = 0
positive_similarity_images = 0
for class_name in class_names:
  class_root_dir = os.path.join(test_image_dir, class_name)
  if(not os.path.isdir(class_root_dir)):
    continue

  #print(class_name)
  image_filenames = os.listdir(class_root_dir)
  current_number_of_images = len(image_filenames)  
  image_index = np.random.randint(0, current_number_of_images)

  base_image_filename = os.path.join(class_root_dir, image_filenames[image_index])  
  base_features = compute_test_image_features(model, base_image_filename)

  for image_filename in image_filenames:
    current_image_filename = os.path.join(class_root_dir, image_filename)
    if(not os.path.isfile(current_image_filename)):
      continue

    number_of_images = number_of_images + 1
    
    current_features = compute_test_image_features(model, current_image_filename)

    current_distance = np.linalg.norm(base_features - current_features)    
    distance_status = current_distance < maximum_distance
    positive_distance_images = positive_distance_images + distance_status

    current_similarity = np.dot(base_features, np.transpose(current_features))
    similarity_status = current_similarity > minimum_similarity
    positive_similarity_images = positive_similarity_images + similarity_status

    verification_file.write(base_image_filename 
                            + separator + current_image_filename 
                            + separator + str(current_distance)
                            + separator + str(current_similarity)
                            + os.linesep)
    
    #print(current_similarity)
    
    if show_accuracy and (number_of_images % 1000 == 0):
      print('accuracy (distance) - ', positive_distance_images/number_of_images)
      print('accuracy (similarity) - ', positive_similarity_images/number_of_images)
  
verification_file.close()
print('accuracy (distance) - ', positive_distance_images/number_of_images)
print('accuracy (similarity) - ', positive_similarity_images/number_of_images)