# Install and Import Libraries

In [None]:

!pip install -q mediapipe torch_geometric
print("Libraries installed.")

In [None]:
import math
import random
import numpy as np
from tqdm.notebook import tqdm
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GATv2Conv, global_mean_pool
from sklearn.datasets import fetch_lfw_people
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, auc
import matplotlib.pyplot as plt

# Mediapipe Setup

In [None]:
!wget -q -O face_landmarker.task https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task
base_options = python.BaseOptions(model_asset_path='face_landmarker.task')
options = vision.FaceLandmarkerOptions(
    base_options=base_options,
    output_face_blendshapes=False,
    output_facial_transformation_matrixes=False,
    num_faces=1)
detector = vision.FaceLandmarker.create_from_options(options)

In [None]:
# Pre-calculate edges of the graphs
connections = list(vision.FaceLandmarksConnections.FACE_LANDMARKS_TESSELATION)
src = [c[0] if isinstance(c, tuple) else c.start for c in connections]
dst = [c[1] if isinstance(c, tuple) else c.end for c in connections]
EDGE_INDEX = torch.tensor([src + dst, dst + src], dtype=torch.long)

# Data Preparation for Closed Set task

In [None]:
# Transformations
global_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def process_data_to_dual_stream(images, labels):
    processed_data = []

    for img_arr, label in tqdm(zip(images, labels), total=len(images)):

        img_u8 = (img_arr * 255).astype(np.uint8)

        # MediaPipe Detection
        mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=img_u8)
        detection_result = detector.detect(mp_image)

        if detection_result.face_landmarks:
            landmarks = detection_result.face_landmarks[0]

            # Create Graph Data
            # Node Features are JUST coordinates (x, y, z)
            pos = torch.tensor([[lm.x, lm.y, lm.z] for lm in landmarks], dtype=torch.float32)
            graph = Data(x=pos, edge_index=EDGE_INDEX, pos=pos)

            # Image Data
            processed_data.append({
                'image': img_u8,
                'graph': graph,
                'label': label
            })

    return processed_data

In [None]:
# Load LFW dataset with people that have at least 3 pictures
lfw_people = fetch_lfw_people(min_faces_per_person=3, resize=1.0, color=True)
dataset_list = process_data_to_dual_stream(lfw_people.images, lfw_people.target)
print(f"{len(dataset_list)} valid face samples")

In [None]:
# Dataset classes for test and evaluation for the open and closed set cases. Output: pairs of images and graphs
class DualStreamPairDatasetOpen(Dataset):
    def __init__(self, data_list, transform=None):
        self.data_list = data_list
        self.transform = transform
        self.label_to_indices = {}

        # Group indices by label
        for idx, item in enumerate(data_list):
            lbl = item['label']
            if lbl not in self.label_to_indices:
                self.label_to_indices[lbl] = []
            self.label_to_indices[lbl].append(idx)

    def __len__(self):
        return len(self.data_list)


    def __getitem__(self, index):
        # Fetch the first sample
        item1 = self.data_list[index]
        label1 = item1['label']

        img1 = item1['image']
        if self.transform: img1 = self.transform(img1)
        graph1 = item1['graph'].clone()

        # 50% chance of Same Person (0), 50% Different (1)
        should_get_same_class = random.randint(0, 1) == 0

        if should_get_same_class:
            possible_indices = self.label_to_indices[label1]
            # If only one image exists for this person, we must pick it
            if len(possible_indices) == 1:
                idx2 = index
            else:
                idx2 = index
                while idx2 == index: # Find a different image of same person
                    idx2 = random.choice(possible_indices)
            target = 0.0
        else:
            # Pick a random different label
            all_labels = list(self.label_to_indices.keys())
            target_label = random.choice(all_labels)
            while target_label == label1:
                target_label = random.choice(all_labels)

            idx2 = random.choice(self.label_to_indices[target_label])
            target = 1.0

        # Fetch the second sample
        item2 = self.data_list[idx2]
        img2 = item2['image']
        if self.transform: img2 = self.transform(img2)
        graph2 = item2['graph'].clone()

        return img1, graph1, img2, graph2, torch.tensor(target, dtype=torch.float32)

class DualStreamPairDatasetClosed(Dataset):
    def __init__(self, data_list, transform=None):
        self.data_list = data_list
        self.transform = transform
        self.label_to_indices = {}

        for idx, item in enumerate(data_list):
            lbl = item['label']
            if lbl not in self.label_to_indices:
                self.label_to_indices[lbl] = []
            self.label_to_indices[lbl].append(idx)

        self.valid_indices = [
            idx for idx, item in enumerate(data_list)
            if len(self.label_to_indices[item['label']]) > 1
        ]

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, i):
        index = self.valid_indices[i]

        item1 = self.data_list[index]
        label1 = item1['label']
        img1 = self.transform(item1['image']) if self.transform else item1['image']
        graph1 = item1['graph'].clone()

        should_get_same_class = random.randint(0, 1) == 0

        if should_get_same_class:
            possible_indices = self.label_to_indices[label1]
            idx2 = index
            while idx2 == index:
                idx2 = random.choice(possible_indices)
            target = 0.0
        else:
            all_labels = list(self.label_to_indices.keys())
            target_label = random.choice(all_labels)
            while target_label == label1:
                target_label = random.choice(all_labels)
            idx2 = random.choice(self.label_to_indices[target_label])
            target = 1.0

        item2 = self.data_list[idx2]
        img2 = self.transform(item2['image']) if self.transform else item2['image']
        graph2 = item2['graph'].clone()

        return img1, graph1, img2, graph2, torch.tensor(target, dtype=torch.float32)

# Collate function to manage images and graphs
def dual_pair_collate_fn(batch):
    imgs1 = torch.stack([item[0] for item in batch])
    graphs1 = Batch.from_data_list([item[1] for item in batch])

    imgs2 = torch.stack([item[2] for item in batch])
    graphs2 = Batch.from_data_list([item[3] for item in batch])

    targets = torch.stack([item[4] for item in batch])

    return imgs1, graphs1, imgs2, graphs2, targets

In [None]:
# Dataset class to handle graphs and images
class DualStreamDataset(Dataset):
    def __init__(self, data_list, transform=None):
        self.data_list = data_list
        self.transform = transform

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        item = self.data_list[idx]

        img = item['image']
        if self.transform:
            img = self.transform(img)

        graph = item['graph'].clone()

        label = torch.tensor(item['label'], dtype=torch.long)
        return img, graph, label

# Collator to handle Graph Batches
def dual_collate_fn(batch):
    images = torch.stack([item[0] for item in batch])
    graphs = Batch.from_data_list([item[1] for item in batch])
    labels = torch.stack([item[2] for item in batch])
    return images, graphs, labels

In [None]:
# Split data for train/val/test
train_val_data, test_data = train_test_split(
    dataset_list, test_size=0.15, stratify=[d['label'] for d in dataset_list], random_state=42
)
train_data, val_data = train_test_split(
    train_val_data, test_size=(0.15/0.85), stratify=[d['label'] for d in train_val_data], random_state=42
)

# Training loader (Batches of single images for ArcFace Loss)
train_loader = DataLoader(
    DualStreamDataset(train_data, transform=global_transform),
    batch_size=32, shuffle=True, collate_fn=dual_collate_fn
)

# Validation and Test loader (Pairs for Verification Metrics)
val_loader = DataLoader(
    DualStreamPairDatasetClosed(val_data, transform=global_transform),
    batch_size=32, shuffle=False, collate_fn=dual_pair_collate_fn
)

test_loader = DataLoader(
    DualStreamPairDatasetClosed(test_data, transform=global_transform),
    batch_size=32, shuffle=False, collate_fn=dual_pair_collate_fn
)

print(f"Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}")

all_labels = [d['label'] for d in dataset_list]
num_classes = len(set(all_labels))
print(f"Training with {num_classes} unique identities.")

# Loss and Model definition

In [None]:
# Define ArcFace Loss
class ArcFaceLoss(nn.Module):
    def __init__(self, in_features, num_classes, s=30.0, m=0.40):
        super(ArcFaceLoss, self).__init__()
        self.in_features = in_features
        self.num_classes = num_classes
        self.s = s
        self.m = m

        # Classifier weight matrix
        self.weight = nn.Parameter(torch.FloatTensor(num_classes, in_features))
        nn.init.xavier_uniform_(self.weight)

    def forward(self, features, targets):
        # Normalize features and weights
        features = F.normalize(features)
        W = F.normalize(self.weight)

        # Dot product, i.e. Cosine similarity
        cosine = F.linear(features, W)

        # Angular Margin
        theta = torch.acos(torch.clamp(cosine, -1.0 + 1e-7, 1.0 - 1e-7))
        target_logits = torch.cos(theta + self.m)

        # One-Hot encoding for targets
        one_hot = torch.zeros(cosine.size(), device=features.device)
        one_hot.scatter_(1, targets.view(-1, 1).long(), 1.0)

        # Apply margin only to the correct class and scale
        output = one_hot * target_logits + (1.0 - one_hot) * cosine
        output *= self.s

        return F.cross_entropy(output, targets)

In [None]:
class HybridFaceNetworkV2(nn.Module):
    def __init__(self, embedding_dim=128):
        super().__init__()

        # ResNet18
        self.cnn = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        self.cnn.fc = nn.Linear(512, embedding_dim)
        self.cnn_bn = nn.BatchNorm1d(embedding_dim)

        # GNN with GAT
        self.gnn_conv1 = GATv2Conv(3, 32, heads=4, concat=True)
        self.gnn_bn1 = nn.BatchNorm1d(128)

        self.gnn_conv2 = GATv2Conv(128, 32, heads=4, concat=True)
        self.gnn_bn2 = nn.BatchNorm1d(128)

        self.gnn_conv3 = GATv2Conv(128, embedding_dim, heads=1, concat=False)
        self.gnn_bn3 = nn.BatchNorm1d(embedding_dim)

        # Final Layer
        self.fusion_fc = nn.Linear(256, 128)

    def forward_cnn(self, images):
        x = self.cnn(images)
        x = self.cnn_bn(x)
        return F.relu(x)

    def forward_gnn(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = self.gnn_conv1(x, edge_index)
        x = self.gnn_bn1(x)
        x = F.elu(x)

        x = self.gnn_conv2(x, edge_index)
        x = self.gnn_bn2(x)
        x = F.elu(x)

        x = self.gnn_conv3(x, edge_index)
        x = self.gnn_bn3(x)

        x = global_mean_pool(x, batch)
        return x

    def forward(self, images, graph_data):
        # Extract features
        emb_cnn = self.forward_cnn(images)
        emb_gnn = self.forward_gnn(graph_data)

        # Concatenate
        combined = torch.cat([emb_cnn, emb_gnn], dim=1)

        # Final Layer and Normalization
        final_emb = self.fusion_fc(combined)
        return F.normalize(final_emb, p=2, dim=1)

In [None]:
# Validation function
def validate(model, loader, device):
    model.eval()
    distances = []
    labels = []

    with torch.no_grad():
        for img1, g1, img2, g2, targets in loader:
            img1, g1 = img1.to(device), g1.to(device)
            img2, g2 = img2.to(device), g2.to(device)

            emb1 = model(img1, g1)
            emb2 = model(img2, g2)

            dists = 1 - F.cosine_similarity(emb1, emb2)

            distances.extend(dists.cpu().numpy())
            labels.extend(targets.numpy())

    distances = np.array(distances)
    labels = np.array(labels)

    thresholds = np.arange(0, 2.0, 0.05)
    best_acc = 0

    for thresh in thresholds:
        preds = (distances > thresh).astype(float)

        acc = accuracy_score(labels, preds)
        if acc > best_acc:
            best_acc = acc

    return best_acc

In [None]:
# Initialize the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = HybridFaceNetworkV2(embedding_dim=128).to(device)
criterion = ArcFaceLoss(in_features=128, num_classes=num_classes).to(device)
optimizer = torch.optim.Adam([
    {'params': model.parameters()},
    {'params': criterion.parameters()}
], lr=0.001, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

# Training the Model for the Closed Set task

In [None]:
best_val_acc = 0
epochs = 25

for epoch in range(epochs):
    model.train()
    total_loss = 0
    for imgs, graphs, labels in train_loader:
        imgs, graphs, labels = imgs.to(device), graphs.to(device), labels.to(device)

        optimizer.zero_grad()
        embeddings = model(imgs, graphs)
        loss = criterion(embeddings, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)

    # Check accuracy on validation pairs
    val_acc = validate(model, val_loader, device)

    # Checkpointing
    save_msg = ""
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "closed_hybrid_model.pth")
        save_msg = "--> Best Model Saved!"

    scheduler.step()

    print(f"Epoch {epoch+1}/{epochs} | Train Loss: {avg_train_loss:.4f} | Val Acc: {val_acc*100:.2f}% {save_msg}")

print(f"Final Best Validation Accuracy: {best_val_acc*100:.2f}%")

# Test the Model and Compute Performance Metrics for the Closed Set task

In [None]:
model.eval()
distances = []
labels_gt = []

# Evaluation loop
with torch.no_grad():
    for img1, g1, img2, g2, targets in test_loader:

        img1 = img1.to(device)
        g1 = g1.to(device)
        img2 = img2.to(device)
        g2 = g2.to(device)
        targets = targets.to(device)

        emb1 = model(img1, g1)
        emb2 = model(img2, g2)

        dists = 1 - F.cosine_similarity(emb1, emb2)

        distances.extend(dists.cpu().numpy())
        labels_gt.extend(targets.cpu().numpy())

distances = np.array(distances)
labels_gt = np.array(labels_gt)

In [None]:
# Metric calculation

# Sort distances to create thresholds
thresholds = np.sort(distances)

fars = []
frrs = []

for thresh in thresholds:
    pred_same = distances < thresh

    is_same_person = (labels_gt == 0)
    is_diff_person = (labels_gt == 1)

    # False Accept Rate (FAR)
    num_fa = np.sum(pred_same & is_diff_person)
    num_neg = np.sum(is_diff_person)
    far = num_fa / num_neg if num_neg > 0 else 0

    # False Reject Rate (FRR)
    num_fr = np.sum((~pred_same) & is_same_person)
    num_pos = np.sum(is_same_person)
    frr = num_fr / num_pos if num_pos > 0 else 0

    fars.append(far)
    frrs.append(frr)

fars = np.array(fars)
frrs = np.array(frrs)
gars = 1 - frrs  # Genuine Accept Rate (GAR)

# Equal Error Rate
diffs = np.abs(fars - frrs)
min_diff_idx = np.argmin(diffs)
eer_threshold = thresholds[min_diff_idx]
eer_val = (fars[min_diff_idx] + frrs[min_diff_idx]) / 2

# AUC
sorted_indices = np.argsort(fars)
roc_auc = auc(fars[sorted_indices], gars[sorted_indices])

print(f"\n--- Evaluation Results ---")
print(f"Best Threshold (EER): {eer_threshold:.4f}")
print(f"EER (Equal Error Rate): {eer_val:.2f}")
print(f"AUC (Area Under Curve): {roc_auc:.4f}")

# ML Mentrics at ERR threshold
final_preds = (distances > eer_threshold).astype(float)

acc = accuracy_score(labels_gt, final_preds)
prec = precision_score(labels_gt, final_preds, pos_label=0)
rec = recall_score(labels_gt, final_preds, pos_label=0)
f1 = f1_score(labels_gt, final_preds, pos_label=0)

print(f"Accuracy:  {acc:.4f}")
print(f"Precision: {prec:.4f}")
print(f"Recall:    {rec:.4f}")
print(f"F1 Score:  {f1:.4f}")

In [None]:
plt.figure(figsize=(14, 6))

# Distance Histograms
plt.subplot(1, 2, 1)
plt.hist(distances[labels_gt==0], bins=30, alpha=0.6, color='green', label='Same Person')
plt.hist(distances[labels_gt==1], bins=30, alpha=0.6, color='red', label='Diff Person')
plt.axvline(eer_threshold, color='black', linestyle='--', label='EER Threshold')
plt.title("Cosine Distance Distributions (Hybrid Model)")
plt.xlabel("Distance (Lower is more similar)")
plt.legend()

# FAR vs FRR
plt.subplot(1, 2, 2)
plt.plot(thresholds, fars, label='FAR (False Accept)', color='red')
plt.plot(thresholds, frrs, label='FRR (False Reject)', color='blue')
plt.scatter(eer_threshold, eer_val, color='black', zorder=5)
plt.text(eer_threshold, eer_val + 0.05, f" EER: {eer_val:.2f}", fontsize=10)
plt.title("FAR vs FRR Curves")
plt.xlabel("Threshold")
plt.ylabel("Error Rate")
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(14, 10))

# ROC Curve (GAR vs FAR)
plt.subplot(2, 2, 3)
plt.plot(fars, gars, color='darkorange', lw=2, label=f'ROC curve (EER = {eer_val:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=1, linestyle='--')
plt.xlim([-0.01, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Accept Rate (FAR)')
plt.ylabel('Genuine Accept Rate (GAR)')
plt.title('Receiver Operating Characteristic (ROC)')
plt.legend(loc="lower right")
plt.grid(True, alpha=0.3)

# DET Curve (FRR vs FAR in Log scale)
plt.subplot(2, 2, 4)
plt.loglog(fars, frrs, color='blue', lw=2)
plt.scatter(eer_val, eer_val, color='black', zorder=5, label=f'EER {eer_val:.2f}')
plt.xlabel('False Accept Rate (FAR) - Log Scale')
plt.ylabel('False Reject Rate (FRR) - Log Scale')
plt.title('Detection Error Tradeoff (DET)')
plt.grid(True, which="both", ls="-", alpha=0.3)
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
# CMC / CMS CALCULATION
def calculate_cmc(model, data_list, device, transform, max_rank=50):
    model.eval()
    all_embeddings = []
    all_labels = []

    loader = DataLoader(
        DualStreamDataset(data_list, transform=transform),
        batch_size=32, shuffle=False, collate_fn=dual_collate_fn
    )

    with torch.no_grad():
        for imgs, graphs, labels in loader:
            imgs, graphs = imgs.to(device), graphs.to(device)
            embeddings = model(imgs, graphs)
            all_embeddings.append(embeddings.cpu())
            all_labels.append(labels.cpu())

    all_embeddings = torch.cat(all_embeddings, dim=0)
    all_labels = torch.cat(all_labels, dim=0)

    # Calculate Similarity Matrix, matrix shape: [N_probes, N_gallery]
    sim_matrix = torch.mm(all_embeddings, all_embeddings.t())

    # Mask self-similarity (diagonal) so we don't match an image with itself
    sim_matrix.fill_diagonal_(-1)

    num_samples = all_labels.size(0)
    ranks = []

    # For each probe, find the rank of the first correct match
    for i in range(num_samples):
        query_label = all_labels[i]
        _, sorted_indices = torch.sort(sim_matrix[i], descending=True)

        matching_labels = (all_labels[sorted_indices] == query_label).nonzero(as_tuple=True)[0]

        if len(matching_labels) > 0:
            first_match_rank = matching_labels[0].item() + 1
            ranks.append(first_match_rank)

    # Calculate CMS at each rank
    cms_scores = []
    for k in range(1, max_rank + 1):
        count = sum(1 for r in ranks if r <= k)
        cms_scores.append(count / num_samples)

    return cms_scores

# Run calculation
max_k = 100
cmc_curve = calculate_cmc(model, test_data, device, global_transform, max_rank=max_k)

print(f"CMS at Rank 1 (Top-1 Accuracy): {cmc_curve[0]*100:.2f}%")
print(f"CMS at Rank 5: {cmc_curve[4]*100:.2f}%")
print(f"CMS at Rank 10: {cmc_curve[9]*100:.2f}%")

# Plot CMC

plt.figure(figsize=(8, 6))
plt.plot(range(1, max_k + 1), cmc_curve, marker='o', linestyle='-', color='blue')
plt.xlabel('Rank (k)')
plt.ylabel('Probability of Identification (CMS)')
plt.title('Cumulative Match Characteristic (CMC) Curve')
plt.grid(True, alpha=0.3)
plt.ylim([0, 1.05])
plt.show()

# Data Preparation for Open Set task

In [None]:
all_labels = list(set([d['label'] for d in dataset_list]))

# Split the identities, not the images
train_labels, val_test_labels = train_test_split(
    all_labels, train_size=0.70, random_state=42
)

val_labels, test_labels = train_test_split(
    val_test_labels, test_size=0.50, random_state=42
)

train_data = [d for d in dataset_list if d['label'] in train_labels]
val_data   = [d for d in dataset_list if d['label'] in val_labels]
test_data  = [d for d in dataset_list if d['label'] in test_labels]

print(f"Open-Set Split -> Train People: {len(train_labels)}, Val People: {len(val_labels)}, Test People: {len(test_labels)}")
print(f"Total Samples -> Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}")

num_classes = len(train_labels)

label_map = {old_label: i for i, old_label in enumerate(train_labels)}
for d in train_data:
    d['label'] = label_map[d['label']]

# Create Loaders
train_loader = DataLoader(
    DualStreamDataset(train_data, transform=global_transform),
    batch_size=32, shuffle=True, collate_fn=dual_collate_fn
)

val_loader = DataLoader(
    DualStreamPairDatasetOpen(val_data, transform=global_transform),
    batch_size=32, shuffle=False, collate_fn=dual_pair_collate_fn
)

test_loader = DataLoader(
    DualStreamPairDatasetOpen(test_data, transform=global_transform),
    batch_size=32, shuffle=False, collate_fn=dual_pair_collate_fn
)

all_labels = [d['label'] for d in dataset_list]
num_classes = len(set(all_labels))
print(f"Training with {num_classes} unique identities.")

# Training the Model for the Open Set task

In [None]:
best_val_acc = 0
epochs = 25

for epoch in range(epochs):
    model.train()
    total_loss = 0
    for imgs, graphs, labels in train_loader:
        imgs, graphs, labels = imgs.to(device), graphs.to(device), labels.to(device)

        optimizer.zero_grad()
        embeddings = model(imgs, graphs)
        loss = criterion(embeddings, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)

    # Check accuracy on validation pairs
    val_acc = validate(model, val_loader, device)

    # Checkpointing
    save_msg = ""
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "open_hybrid_model.pth")
        save_msg = "--> Best Model Saved!"

    scheduler.step()

    print(f"Epoch {epoch+1}/{epochs} | Train Loss: {avg_train_loss:.4f} | Val Acc: {val_acc*100:.2f}% {save_msg}")

print(f"Final Best Validation Accuracy: {best_val_acc*100:.2f}%")

# Test the Model and Compute Performance Metrics for the Open Set task

In [None]:
model.eval()
distances = []
labels_gt = []

# Evaluation loop
with torch.no_grad():
    for img1, g1, img2, g2, targets in test_loader:

        img1 = img1.to(device)
        g1 = g1.to(device)
        img2 = img2.to(device)
        g2 = g2.to(device)
        targets = targets.to(device)

        emb1 = model(img1, g1)
        emb2 = model(img2, g2)

        dists = 1 - F.cosine_similarity(emb1, emb2)

        distances.extend(dists.cpu().numpy())
        labels_gt.extend(targets.cpu().numpy())

distances = np.array(distances)
labels_gt = np.array(labels_gt)

In [None]:
# Metric calculation

# Sort distances to create thresholds
thresholds = np.sort(distances)

fars = []
frrs = []

for thresh in thresholds:
    pred_same = distances < thresh

    is_same_person = (labels_gt == 0)
    is_diff_person = (labels_gt == 1)

    # False Accept Rate (FAR)
    num_fa = np.sum(pred_same & is_diff_person)
    num_neg = np.sum(is_diff_person)
    far = num_fa / num_neg if num_neg > 0 else 0

    # False Reject Rate (FRR)
    num_fr = np.sum((~pred_same) & is_same_person)
    num_pos = np.sum(is_same_person)
    frr = num_fr / num_pos if num_pos > 0 else 0

    fars.append(far)
    frrs.append(frr)

fars = np.array(fars)
frrs = np.array(frrs)
gars = 1 - frrs  # Genuine Accept Rate (GAR)

# Equal Error Rate
diffs = np.abs(fars - frrs)
min_diff_idx = np.argmin(diffs)
eer_threshold = thresholds[min_diff_idx]
eer_val = (fars[min_diff_idx] + frrs[min_diff_idx]) / 2

# AUC
sorted_indices = np.argsort(fars)
roc_auc = auc(fars[sorted_indices], gars[sorted_indices])

print(f"\n--- Evaluation Results ---")
print(f"Best Threshold (EER): {eer_threshold:.4f}")
print(f"EER (Equal Error Rate): {eer_val:.4f}%")
print(f"AUC (Area Under Curve): {roc_auc:.4f}")

# ML Mentrics at ERR threshold
final_preds = (distances > eer_threshold).astype(float)

acc = accuracy_score(labels_gt, final_preds)
prec = precision_score(labels_gt, final_preds, pos_label=0)
rec = recall_score(labels_gt, final_preds, pos_label=0)
f1 = f1_score(labels_gt, final_preds, pos_label=0)

print(f"Accuracy:  {acc:.4f}%")
print(f"Precision: {prec:.4f}")
print(f"Recall:    {rec:.4f}")
print(f"F1 Score:  {f1:.4f}")

In [None]:
plt.figure(figsize=(14, 6))

# Distance Histograms
plt.subplot(1, 2, 1)
plt.hist(distances[labels_gt==0], bins=30, alpha=0.6, color='green', label='Same Person')
plt.hist(distances[labels_gt==1], bins=30, alpha=0.6, color='red', label='Diff Person')
plt.axvline(eer_threshold, color='black', linestyle='--', label='EER Threshold')
plt.title("Cosine Distance Distributions (Hybrid Model)")
plt.xlabel("Distance (Lower is more similar)")
plt.legend()

# FAR vs FRR
plt.subplot(1, 2, 2)
plt.plot(thresholds, fars, label='FAR (False Accept)', color='red')
plt.plot(thresholds, frrs, label='FRR (False Reject)', color='blue')
plt.scatter(eer_threshold, eer_val, color='black', zorder=5)
plt.text(eer_threshold, eer_val + 0.05, f" EER: {eer_val:.2f}", fontsize=10)
plt.title("FAR vs FRR Curves")
plt.xlabel("Threshold")
plt.ylabel("Error Rate")
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(14, 10))

# ROC Curve (GAR vs FAR)
plt.subplot(2, 2, 3)
plt.plot(fars, gars, color='darkorange', lw=2, label=f'ROC curve (EER = {eer_val:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=1, linestyle='--')
plt.xlim([-0.01, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Accept Rate (FAR)')
plt.ylabel('Genuine Accept Rate (GAR)')
plt.title('Receiver Operating Characteristic (ROC)')
plt.legend(loc="lower right")
plt.grid(True, alpha=0.3)

# DET Curve (FRR vs FAR in Log scale)
plt.subplot(2, 2, 4)
plt.loglog(fars, frrs, color='blue', lw=2)
plt.scatter(eer_val, eer_val, color='black', zorder=5, label=f'EER {eer_val:.2f}')
plt.xlabel('False Accept Rate (FAR) - Log Scale')
plt.ylabel('False Reject Rate (FRR) - Log Scale')
plt.title('Detection Error Tradeoff (DET)')
plt.grid(True, which="both", ls="-", alpha=0.3)
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
# CMC / CMS CALCULATION
def calculate_cmc(model, data_list, device, transform, max_rank=50):
    model.eval()
    all_embeddings = []
    all_labels = []

    loader = DataLoader(
        DualStreamDataset(data_list, transform=transform),
        batch_size=32, shuffle=False, collate_fn=dual_collate_fn
    )

    with torch.no_grad():
        for imgs, graphs, labels in loader:
            imgs, graphs = imgs.to(device), graphs.to(device)
            embeddings = model(imgs, graphs)
            all_embeddings.append(embeddings.cpu())
            all_labels.append(labels.cpu())

    all_embeddings = torch.cat(all_embeddings, dim=0)
    all_labels = torch.cat(all_labels, dim=0)

    # Calculate Similarity Matrix, matrix shape: [N_probes, N_gallery]
    sim_matrix = torch.mm(all_embeddings, all_embeddings.t())

    # Mask self-similarity (diagonal) so we don't match an image with itself
    sim_matrix.fill_diagonal_(-1)

    num_samples = all_labels.size(0)
    ranks = []

    # For each probe, find the rank of the first correct match
    for i in range(num_samples):
        query_label = all_labels[i]
        _, sorted_indices = torch.sort(sim_matrix[i], descending=True)

        matching_labels = (all_labels[sorted_indices] == query_label).nonzero(as_tuple=True)[0]

        if len(matching_labels) > 0:
            first_match_rank = matching_labels[0].item() + 1
            ranks.append(first_match_rank)

    # Calculate CMS at each rank
    cms_scores = []
    for k in range(1, max_rank + 1):
        count = sum(1 for r in ranks if r <= k)
        cms_scores.append(count / num_samples)

    return cms_scores

# Run calculation
max_k = 100
cmc_curve = calculate_cmc(model, test_data, device, global_transform, max_rank=max_k)

print(f"CMS at Rank 1 (Top-1 Accuracy): {cmc_curve[0]*100:.2f}%")
print(f"CMS at Rank 5: {cmc_curve[4]*100:.2f}%")
print(f"CMS at Rank 10: {cmc_curve[9]*100:.2f}%")

# Plot CMC

plt.figure(figsize=(8, 6))
plt.plot(range(1, max_k + 1), cmc_curve, marker='o', linestyle='-', color='blue')
plt.xlabel('Rank (k)')
plt.ylabel('Probability of Identification (CMS)')
plt.title('Cumulative Match Characteristic (CMC) Curve')
plt.grid(True, alpha=0.3)
plt.ylim([0, 1.05])
plt.show()