# Download Dependencies

In [None]:
!git clone https://github.com/timesler/facenet-pytorch.git facenet_pytorch
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install mxnet
!pip install torchmetrics

Cloning into 'facenet_pytorch'...
remote: Enumerating objects: 1338, done.[K
remote: Counting objects: 100% (293/293), done.[K
remote: Compressing objects: 100% (75/75), done.[K
remote: Total 1338 (delta 233), reused 225 (delta 217), pack-reused 1045 (from 1)[K
Receiving objects: 100% (1338/1338), 23.19 MiB | 42.32 MiB/s, done.
Resolving deltas: 100% (662/662), done.
Looking in indexes: https://download.pytorch.org/whl/cu118
Collecting mxnet
  Downloading mxnet-1.9.1-py3-none-manylinux2014_x86_64.whl.metadata (3.4 kB)
Collecting graphviz<0.9.0,>=0.8.1 (from mxnet)
  Downloading graphviz-0.8.4-py2.py3-none-any.whl.metadata (6.4 kB)
Downloading mxnet-1.9.1-py3-none-manylinux2014_x86_64.whl (49.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.1/49.1 MB[0m [31m23.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading graphviz-0.8.4-py2.py3-none-any.whl (16 kB)
Installing collected packages: graphviz, mxnet
Successfully installed graphviz-0.8.4 mxnet-1.9.1
Collecting 

#Import Packages

In [None]:
import numpy as np
np.bool = bool
import mxnet as mx
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, Subset, random_split
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import math
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
from facenet_pytorch import MTCNN
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
from torchmetrics import Accuracy
import torch_xla.utils.serialization as xser
from Utils import preprocess_image, CANONICAL_LANDMARKS
from CasiaWebFace import CASIAWebFaceDataset
from Intermediate_Strategy import MobileFaceNetIntermediate
from MobileFaceNet import MobileFaceNet
from Later_Strategy import MobileFaceNetLater
from LFW import LFWPairsDataset
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.datasets import fetch_lfw_pairs
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances

  state_dict = torch.load(state_dict_path)
  state_dict = torch.load(state_dict_path)
  state_dict = torch.load(state_dict_path)


# Cosine Similarity Function

In [None]:
def eval_angles(model, threshold, device, dataloader):
    model.eval()  # Set the model to evaluation mode
    angles = []  # This will store the cosine similarities
    labels = []  # This will store the ground truth labels

    for img1, img2, label in dataloader:
        img1 = img1.to(device)  # Move first image to the appropriate device
        img2 = img2.to(device)  # Move second image to the appropriate device

        with torch.no_grad():  # Disable gradient calculation
            # Get embeddings for both images
            emb1 = model(img1)
            emb2 = model(img2)

            # Move tensors to CPU and convert to numpy arrays
            emb1 = emb1.cpu().numpy().squeeze()  # Remove the extra dimension
            emb2 = emb2.cpu().numpy().squeeze()  # Remove the extra dimension

            # Calculate the cosine similarity between the embeddings
            cosine_similarity_value = cosine_similarity([emb1], [emb2])[0][0]
            cosine = np.clip(cosine_similarity_value, -1.0, 1.0)

            # Store the cosine similarity and the label
            angles.append(cosine)
            labels.extend(label.cpu().numpy())  # Move labels to CPU and store

    # Convert cosine similarities to binary predictions based on the threshold
    predictions = [1 if cos_sim > threshold else 0 for cos_sim in angles]

    # Calculate evaluation metrics
    accuracy = accuracy_score(labels, predictions) * 100
    precision = precision_score(labels, predictions, zero_division=1) * 100
    recall = recall_score(labels, predictions, zero_division=1) * 100
    f1 = f1_score(labels, predictions, zero_division=1) * 100

    # Return accuracy, precision, recall, and F1 score as a dictionary
    metrics = {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1
    }

    return metrics


# Initialize and Spawn Packages

In [None]:
def _mp_fn(rank):
    # Define the device as TPU
    device = xm.xla_device()

    # Define the mean and std for normalization
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    def preprocess_and_transform(img):
        img = transforms.ToTensor()(img)
        img = transforms.Resize(112)(img)
        img = transforms.Normalize(mean=mean, std=std)(img)
        return img
    # Define the transformation pipeline for validation and testing without augmentation
    test_val_transform = transforms.Compose([
        transforms.Lambda(preprocess_and_transform),
    ])
    lfw_pairs = fetch_lfw_pairs(color=True)
    dataset = LFWPairsDataset(lfw_pairs, transform=test_val_transform)
    dataset = [dataset[i] for i in range(1300)]
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=8)

    model = MobileFaceNetIntermediate(embedding_size=128).to(device)
    checkpoint = torch.load("/content/xxx.pth")
    model.load_state_dict(checkpoint['model_state_dict'], strict=False)

    # Wrap DataLoader with MpDeviceLoader for TPU distributed loading
    dataloader = pl.MpDeviceLoader(dataloader, device)

    metrics = eval_angles(model, 0.5, device, dataloader)
    xm.master_print(f"Accuracy: {metrics['accuracy']:.2f}%")
    xm.master_print(f"Precision: {metrics['precision']:.2f}%")
    xm.master_print(f"Recall: {metrics['recall']:.2f}%")
    xm.master_print(f"F1-Score: {metrics['f1_score']:.2f}%")

# Spawn the training across 8 TPU cores
xmp.spawn(_mp_fn, args=(), nprocs=1, start_method='fork')


  checkpoint = torch.load("/content/xxx.pth")


Accuracy: 84.62%
Precision: 84.62%
Recall: 100.00%
F1-Score: 91.67%
