Within this notebook we evaluate the performance of GUM-Net on the FVC2004 dataset and on our custom Kaggle dataset.
As alignment measure we use the Pearson Correlation Coefficient (PCC) $$
\rho_{X,Y} = \frac{\sum_{i=1}^{n}(x_i-\bar{x})(y_i-\bar{y})}{\sqrt{\sum_{i=1}^{n}(x_i-\bar{x})^2}\,\sqrt{\sum_{i=1}^{n}(y_i-\bar{y})^2}}$$ between a template X (Sa) and its (warped) impression Y (Sb).

## Dataset paths and Dataloaders
We need a single path to a database for the FVC2004 dataset and both a path to a master template and a path to its impressions for our custom dataset.

In [12]:
fvc_path = "data/FVC/FVC2004/Dbs/DB1_B"
custom_impressions_path = "data/Kaggle/data/5x5000/Finger_1"
custom_template_path = "data/Kaggle/data/5x5000/Master_Templates/1.png"

In [13]:
from datasets.DB1_data import get_set_eval_dataloader # Switch to DB2_data/DB3_data is possible
from datasets.kaggle_data import get_eval_dataloader

In [14]:
import torch
from PIL import Image
import torchvision.transforms as transforms

FVC_loader = get_set_eval_dataloader(
    data_root=fvc_path,
    batch_size=64,
    num_workers=0,
    num_images=10*8, # Be aware: Set A databases contain 100*8 images, set B databases only 10*8
)
Kaggle_loader = get_eval_dataloader(
    data_root=custom_impressions_path,
    batch_size=100,
    num_workers=0,
)
template_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Pad(padding=(-7, -70, -8, -70), fill=255),
    transforms.Resize((192, 192)),
    transforms.RandomInvert(p=1.0),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])
template_image = Image.open(custom_template_path).convert("L")
template_tensor: torch.Tensor = template_transform(template_image) # type: ignore
template = template_tensor.unsqueeze(0).repeat(100, 1, 1, 1)[:100] # Set this to the same batch size as the Kaggle dataloader

Loaded 80 samples (no split).
Loaded 5000 samples (no split).


## Variables and helper functions

In [15]:
from model.losses.pearson_correlation_loss import PearsonCorrelationLoss

In [16]:
cross_correlation = []
cross_correlation_warped = []
loss = PearsonCorrelationLoss()

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import cv2


def plot_side_by_side(image_left, image_middle, image_right, title_left="Template", title_middle="Impression", title_right="Warped"):
    fig, axes = plt.subplots(1, 3, figsize=(8, 4))
    axes[0].imshow(image_left, cmap="gray")
    axes[0].set_title(title_left)
    axes[0].axis("off")

    axes[1].imshow(image_middle, cmap="gray")
    axes[1].set_title(title_middle)
    axes[1].axis("off")

    axes[2].imshow(image_right, cmap="gray")
    axes[2].set_title(title_right)
    axes[2].axis("off")

    plt.tight_layout()
    plt.show()

def per_pair_cross_correlation(Sa, Sb_CC, warped_Sb):
    before = loss(Sa, Sb_CC)
    after = loss(Sa, warped_Sb)
    cross_correlation.append(before)
    cross_correlation_warped.append(after)
    print(f"Batch {idx}: Before CC={before:.4f}, After CC={after:.4f}")


def tensor_to_uint8(t):
    if isinstance(t, torch.Tensor):
        t = t.detach()

    if t.dim() == 4:
        batch = t
    elif t.dim() == 3:
        batch = t.unsqueeze(0)
    elif t.dim() == 2:
        batch = t.unsqueeze(0).unsqueeze(0)
    else:
        raise ValueError(f"Unsupported tensor shape: {tuple(t.shape)}")

    # Undo normalization (mean=0.5, std=0.5)
    batch = batch * 0.5 + 0.5
    batch = torch.clamp(batch, 0, 1)
    imgs = (batch.squeeze(1).cpu().numpy() * 255).astype(np.uint8)
    return imgs

def align_with_orb(template_tensor, impression_tensor):
    templates = tensor_to_uint8(template_tensor)
    impressions = tensor_to_uint8(impression_tensor)

    if templates.ndim == 2:
        templates = templates[None, ...]
    if impressions.ndim == 2:
        impressions = impressions[None, ...]

    if templates.shape[0] != impressions.shape[0]:
        raise ValueError(
            f"Batch size mismatch: templates={templates.shape[0]} impressions={impressions.shape[0]}"
        )

    aligned_list = []
    for i in range(templates.shape[0]):
        template = templates[i]
        impression = impressions[i]

        orb = cv2.ORB_create(2000)

        kp1, des1 = orb.detectAndCompute(template, None)
        kp2, des2 = orb.detectAndCompute(impression, None)

        aligned = impression
        if des1 is not None and des2 is not None:
            bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
            matches = bf.match(des1, des2)

            if len(matches) >= 10:
                matches = sorted(matches, key=lambda x: x.distance)

                src_pts = np.float32([kp2[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2) # type: ignore
                dst_pts = np.float32([kp1[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2) # type: ignore

                # Affine transform (better for fingerprints than homography)
                M, mask = cv2.estimateAffinePartial2D(src_pts, dst_pts, method=cv2.RANSAC)

                if M is not None:
                    aligned = cv2.warpAffine(impression, M, (template.shape[1], template.shape[0]))

        aligned_list.append(aligned)

    aligned_np = np.stack(aligned_list, axis=0)
    aligned_tensor = torch.from_numpy(aligned_np).float() / 255.0
    aligned_tensor = (aligned_tensor - 0.5) / 0.5
    aligned_tensor = aligned_tensor.unsqueeze(1).to(template_tensor.device)
    return aligned_tensor

## Model initialization

In [18]:
from model.gumnet import GumNet
from typing import Union

In [19]:
def init_gumnet(device: Union[str, torch.device] = "cpu"):
    ckpt_path = "model/gumnet_2d_best_noise_level_0_8x8_200.pth.zip" # Use the checkpoints provided in the repository
    model = GumNet(grid_size=8)
    state = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(state)
    model.to(device)

    return model

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = init_gumnet(device)
model.eval()

Using device: cpu


GumNet(
  (feature_extractor): GumNetFeatureExtraction(
    (shared_conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
    (bn1_sa): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn1_sb): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pool1): DCTSpectralPooling(
      (dct_h): LinearDCT(in_features=190, out_features=190, bias=False)
      (dct_w): LinearDCT(in_features=190, out_features=190, bias=False)
      (idct_h): LinearDCT(in_features=100, out_features=100, bias=False)
      (idct_w): LinearDCT(in_features=100, out_features=100, bias=False)
    )
    (shared_conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (bn2_sa): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2_sb): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pool2): DCTSpectralPooling(
      (dct_h): LinearDCT(in_features=98, out_features=98, bias=Fals

## Running inference

In [None]:
for idx, batch in enumerate(FVC_loader):
    #Sa = template.to(device) # Uncomment and switch the dataloder to use the Kaggle template
    Sa = batch["Sa"].to(device)
    Sb = batch["Sb"].to(device)


    with torch.no_grad():
        warped_Sb, control_points = model(Sa, Sb)
    
    #warped_Sb = align_with_orb(Sa, warped_Sb) # Uncomment to apply ORB-based alignment
    
    per_pair_cross_correlation(Sa, Sb, warped_Sb)

Batch 0: Before CC=0.6553, After CC=0.7212
Batch 1: Before CC=0.7477, After CC=0.8469
Batch 2: Before CC=0.6341, After CC=0.7251
Batch 3: Before CC=0.6399, After CC=0.7346
Batch 4: Before CC=0.7359, After CC=0.8202
Batch 5: Before CC=0.7208, After CC=0.7861
Batch 6: Before CC=0.7931, After CC=0.8620
Batch 7: Before CC=0.6491, After CC=0.7154
Batch 8: Before CC=0.6977, After CC=0.7619
Batch 9: Before CC=0.6850, After CC=0.7827


## Print results

In [22]:
def print_CC():
    total_before = sum(cross_correlation) / len(cross_correlation)
    total_after = sum(cross_correlation_warped) / len(cross_correlation_warped)
    percentage_improvement = ((total_after - total_before) / total_before) * 100
    print(f"Average Cross-Correlation Before Warping: {total_before:.4f}")
    print(f"Average Cross-Correlation After Warping: {total_after:.4f}")
    print(f"Percentage Improvement: {percentage_improvement:.2f}%")

print_CC()

Average Cross-Correlation Before Warping: 0.6959
Average Cross-Correlation After Warping: 0.7756
Percentage Improvement: 11.46%
