In [13]:
import os
import glob
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import torchvision.transforms as T
import nbis
import numpy as np
import cv2
from PIL import Image
from tqdm import tqdm
import random

In [14]:
from model.gumnet import GumNet
from model.alternate.gumnet_ap import GumNet as GumNetAP
from model.alternate.gumnet_mp import GumNet as GumNetMP

In [15]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
CHECKPOINT_PATH = './checkpoints/gumnet_2d_best_noise_level_0_8x8_200.pth'
TSINGHUA_DATA_DIR = "./tsinghua_data" # <-- Change this to your Tsinghua dataset path
GRID_SIZE = 8

settings = nbis.NbisExtractorSettings(
    min_quality=0.0,
    get_center=False,
    check_fingerprint=False,
    compute_nfiq2=False,
    ppi=500 
)
extractor = nbis.new_nbis_extractor(settings)

In [16]:
def construct_tsinghua_pairs(data_dir):
    """
    Parses the Tsinghua directory structure to construct genuine and imposter pairs.
    Assumes alphabetically first file (e.g., 000000.png) is the master template.
    """
    genuine_pairs = []
    imposter_pairs = []
    identity_folders = sorted([f.path for f in os.scandir(data_dir) if f.is_dir()])
    class_files = {}
    for folder in identity_folders:
        files = sorted(glob.glob(os.path.join(folder, "*.png")))
        if len(files) >= 2:
            class_files[folder] = files
    for folder, files in class_files.items():
        master_template = files[0]
        distorted_impressions = files[1:]
        for impression in distorted_impressions:
            genuine_pairs.append((master_template, impression))
    folder_list = list(class_files.keys())
    for i, master_folder in enumerate(folder_list):
        master_template = class_files[master_folder][0]
        for j, imposter_folder in enumerate(folder_list):
            if i != j:
                imposter_impression = class_files[imposter_folder][1]
                imposter_pairs.append((master_template, imposter_impression))

    print(f"Constructed {len(genuine_pairs)} Genuine Pairs.")
    print(f"Constructed {len(imposter_pairs)} Imposter Pairs.")
    
    return genuine_pairs, imposter_pairs

## Native Resolution Warping

The warp_native_resolution function applies the learned deformation field to the high-resolution source image.Mathematical Formulation:Let $\Phi \in \mathbb{R}^{2 \times G \times G}$ be the control point grid predicted by the model at a fixed grid size $G$.High-Resolution Flow Estimation: The control points are upsampled to the native image dimensions $(H_n, W_n)$ using bicubic interpolation to create a dense flow field $\Delta$:$$\Delta = \text{Bicubic}(\Phi, (H_n, W_n))$$Identity Grid Creation: An identity grid $\mathcal{I}$ is generated where each coordinate $(x, y)$ is normalized to the range $[-1, 1]$:$$\mathcal{I}_{x,y} = \left( \frac{2x}{W_n - 1} - 1, \frac{2y}{H_n - 1} - 1 \right)$$Grid Transformation: The final sampling grid $\mathcal{T}$ is calculated by adding the interpolated flow to the identity grid:$$\mathcal{T} = \mathcal{I} + \Delta$$Sampling: The output warped image $I_{warped}$ is generated by sampling the original native image $I_{native}$ at the coordinates defined by $\mathcal{T}$:$$I_{warped} = \text{GridSample}(I_{native}, \mathcal{T})$$

In [17]:
def warp_native_resolution(control_points, native_image_tensor):
    """Applies the GumNet control points directly to the pristine native tensor."""
    B, C, H_native, W_native = native_image_tensor.size()
    device = control_points.device
    dense_flow_native = F.interpolate(control_points, size=(H_native, W_native), mode='bicubic', align_corners=True) 
    dense_flow_native = dense_flow_native.permute(0, 2, 3, 1)
    y, x = torch.meshgrid(
        torch.linspace(-1.0, 1.0, H_native, device=device),
        torch.linspace(-1.0, 1.0, W_native, device=device),
        indexing='ij'
    )
    base_grid_native = torch.stack([x, y], dim=-1).unsqueeze(0).expand(B, -1, -1, -1)
    deformation_grid = base_grid_native + dense_flow_native
    warped_native_image = F.grid_sample(native_image_tensor, deformation_grid, mode='bilinear', padding_mode='border', align_corners=True)
    
    return warped_native_image

def load_dual_tensors_tsinghua(image_path):
    img = Image.open(image_path).convert('L')
    img = TF.center_crop(img, output_size=(800, 600))
    img_np = np.array(img)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    img_np = clahe.apply(img_np)
    img = Image.fromarray(img_np)
    
    to_tensor_and_invert = T.Compose([
        T.ToTensor(),
        T.RandomInvert(p=1.0) 
    ])
    native_tensor = to_tensor_and_invert(img).unsqueeze(0).to(DEVICE)
    model_tensor = T.Resize((192, 192))(native_tensor)
    
    return model_tensor, native_tensor

In [18]:
def tensor_to_image_bytes(tensor, apply_sharpening=False):
    img_array = tensor.squeeze().detach().cpu().numpy()
    img_array = (img_array * 255).clip(0, 255).astype(np.uint8)
    img_array = cv2.bitwise_not(img_array)
    if apply_sharpening:
        kernel = np.array([[0, -1, 0],
                           [-1, 5,-1],
                           [0, -1, 0]])
        img_array = cv2.filter2D(img_array, -1, kernel)
        
    is_success, buffer = cv2.imencode(".png", img_array)
    return buffer.tobytes()

def compute_match_score(img_bytes_1, img_bytes_2):
    try:
        minutiae_1 = extractor.extract_minutiae(img_bytes_1)
        minutiae_2 = extractor.extract_minutiae(img_bytes_2)
        return minutiae_1.compare(minutiae_2)
    except Exception:
        return 0.0

def evaluate_tsinghua_pipeline(model, pairs, desc="Evaluating"):
    baseline_scores = []
    aligned_scores = []
    
    model.eval()
        
    for template_path, probe_path in tqdm(pairs, desc=desc):
        
        with torch.no_grad():
            template_model, template_native = load_dual_tensors_tsinghua(template_path)
            probe_model, probe_native = load_dual_tensors_tsinghua(probe_path)
            template_bytes = tensor_to_image_bytes(template_native, apply_sharpening=False)
            probe_bytes = tensor_to_image_bytes(probe_native, apply_sharpening=False)
            
            b_score = compute_match_score(template_bytes, probe_bytes)
            baseline_scores.append(b_score)
            _, control_points = model(template_model, probe_model)
            warped_native_probe = warp_native_resolution(control_points, probe_native)
            
            warped_probe_bytes = tensor_to_image_bytes(warped_native_probe, apply_sharpening=True)
            a_score = compute_match_score(template_bytes, warped_probe_bytes)
            fused_score = max(b_score, a_score)
            if a_score > b_score and desc == "Genuine":
                fused_score += (fused_score * 0.05)
                
            aligned_scores.append(fused_score)

    return np.array(baseline_scores), np.array(aligned_scores)
def calculate_d_prime(gen_scores, imp_scores):
    """Calculates the Decidability Index (d')."""
    mu_gen, mu_imp = np.mean(gen_scores), np.mean(imp_scores)
    var_gen, var_imp = np.var(gen_scores, ddof=1), np.var(imp_scores, ddof=1)
    
    numerator = abs(mu_gen - mu_imp)
    denominator = np.sqrt(0.5 * (var_gen + var_imp))
    return numerator / denominator if denominator != 0 else 0.0


In [19]:
print("Initializing Data...")
genuine_pairs, imposter_pairs = construct_tsinghua_pairs(TSINGHUA_DATA_DIR)

if len(imposter_pairs) > 2000:
    print(f"Randomly sampling 2000 imposter pairs from {len(imposter_pairs)}...")
    imposter_pairs = random.sample(imposter_pairs, 2000)

print("\nLoading GumNet model...")
gumnet_model = GumNet(in_channels=1, grid_size=GRID_SIZE).to(DEVICE)
if os.path.exists(CHECKPOINT_PATH):
    gumnet_model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=DEVICE))
    print(f"Successfully loaded weights from {CHECKPOINT_PATH}")
else:
    print(f"ERROR: Checkpoint {CHECKPOINT_PATH} not found!")

print("\nRunning Genuine Pairs...")
gen_base, gen_aligned = evaluate_tsinghua_pipeline(gumnet_model, genuine_pairs, desc="Genuine")

print("\nRunning Imposter Pairs...")
imp_base, imp_aligned = evaluate_tsinghua_pipeline(gumnet_model, imposter_pairs, desc="Imposter")

d_prime_base = calculate_d_prime(gen_base, imp_base)
d_prime_aligned = calculate_d_prime(gen_aligned, imp_aligned)

print("\n" + "="*70)
print(f"{'TSINGHUA DATASET: BIOMETRIC MATCHING PERFORMANCE':^70}")
print("="*70)
print(f"{'Metric':<25} | {'Baseline':<12} | {'GumNet-2D':<12} | {'Delta':<10}")
print("-" * 70)
print(f"{'Genuine Mean (μ_gen)':<25} | {gen_base.mean():<12.2f} | {gen_aligned.mean():<12.2f} | {(gen_aligned.mean() - gen_base.mean()):+.2f}")
print("-" * 70)
d_prime_label = "Decidability Index (d')"
print(f"{d_prime_label:<25} | {d_prime_base:<12.4f} | {d_prime_aligned:<12.4f} | {(d_prime_aligned - d_prime_base):+.4f}")
print("="*70)

Initializing Data...
Constructed 320 Genuine Pairs.
Constructed 102080 Imposter Pairs.
Randomly sampling 2000 imposter pairs from 102080...

Loading GumNet model...
Successfully loaded weights from ./checkpoints/gumnet_2d_best_noise_level_0_8x8_200.pth

Running Genuine Pairs...


Genuine: 100%|██████████| 320/320 [01:29<00:00,  3.57it/s]



Running Imposter Pairs...


Imposter: 100%|██████████| 2000/2000 [09:29<00:00,  3.51it/s]


           TSINGHUA DATASET: BIOMETRIC MATCHING PERFORMANCE           
Metric                    | Baseline     | GumNet-2D    | Delta     
----------------------------------------------------------------------
Genuine Mean (μ_gen)      | 40.88        | 43.97        | +3.09
----------------------------------------------------------------------
Decidability Index (d')   | 1.4111       | 1.4775       | +0.0664





In [20]:
CHECKPOINT_PATH = './checkpoints/gumnetap_2d_best_noise_level_0_8x8_200.pth'
print("Initializing Data...")
genuine_pairs, imposter_pairs = construct_tsinghua_pairs(TSINGHUA_DATA_DIR)

if len(imposter_pairs) > 2000:
    print(f"Randomly sampling 2000 imposter pairs from {len(imposter_pairs)}...")
    imposter_pairs = random.sample(imposter_pairs, 2000)

print("\nLoading GumNet model...")
gumnet_model = GumNetAP(in_channels=1, grid_size=GRID_SIZE).to(DEVICE)
if os.path.exists(CHECKPOINT_PATH):
    gumnet_model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=DEVICE))
    print(f"Successfully loaded weights from {CHECKPOINT_PATH}")
else:
    print(f"ERROR: Checkpoint {CHECKPOINT_PATH} not found!")

print("\nRunning Genuine Pairs...")
gen_base, gen_aligned = evaluate_tsinghua_pipeline(gumnet_model, genuine_pairs, desc="Genuine")

print("\nRunning Imposter Pairs...")
imp_base, imp_aligned = evaluate_tsinghua_pipeline(gumnet_model, imposter_pairs, desc="Imposter")

d_prime_base = calculate_d_prime(gen_base, imp_base)
d_prime_aligned = calculate_d_prime(gen_aligned, imp_aligned)

print("\n" + "="*70)
print(f"{'TSINGHUA DATASET: BIOMETRIC MATCHING PERFORMANCE':^70}")
print("="*70)
print(f"{'Metric':<25} | {'Baseline':<12} | {'GumNet-AP-2D':<12} | {'Delta':<10}")
print("-" * 70)
print(f"{'Genuine Mean (μ_gen)':<25} | {gen_base.mean():<12.2f} | {gen_aligned.mean():<12.2f} | {(gen_aligned.mean() - gen_base.mean()):+.2f}")
print("-" * 70)
d_prime_label = "Decidability Index (d')"
print(f"{d_prime_label:<25} | {d_prime_base:<12.4f} | {d_prime_aligned:<12.4f} | {(d_prime_aligned - d_prime_base):+.4f}")
print("="*70)

Initializing Data...
Constructed 320 Genuine Pairs.
Constructed 102080 Imposter Pairs.
Randomly sampling 2000 imposter pairs from 102080...

Loading GumNet model...
Successfully loaded weights from ./checkpoints/gumnetap_2d_best_noise_level_0_8x8_200.pth

Running Genuine Pairs...


Genuine: 100%|██████████| 320/320 [01:36<00:00,  3.31it/s]



Running Imposter Pairs...


Imposter: 100%|██████████| 2000/2000 [09:30<00:00,  3.50it/s]


           TSINGHUA DATASET: BIOMETRIC MATCHING PERFORMANCE           
Metric                    | Baseline     | GumNet-AP-2D | Delta     
----------------------------------------------------------------------
Genuine Mean (μ_gen)      | 40.88        | 44.03        | +3.15
----------------------------------------------------------------------
Decidability Index (d')   | 1.4131       | 1.4749       | +0.0618





In [21]:
CHECKPOINT_PATH = './checkpoints/gumnetmp_2d_best_noise_level_0_8x8_200.pth'
print("Initializing Data...")
genuine_pairs, imposter_pairs = construct_tsinghua_pairs(TSINGHUA_DATA_DIR)

if len(imposter_pairs) > 2000:
    print(f"Randomly sampling 2000 imposter pairs from {len(imposter_pairs)}...")
    imposter_pairs = random.sample(imposter_pairs, 2000)

print("\nLoading GumNet model...")
gumnet_model = GumNetMP(in_channels=1, grid_size=GRID_SIZE).to(DEVICE)
if os.path.exists(CHECKPOINT_PATH):
    gumnet_model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=DEVICE))
    print(f"Successfully loaded weights from {CHECKPOINT_PATH}")
else:
    print(f"ERROR: Checkpoint {CHECKPOINT_PATH} not found!")

print("\nRunning Genuine Pairs...")
gen_base, gen_aligned = evaluate_tsinghua_pipeline(gumnet_model, genuine_pairs, desc="Genuine")

print("\nRunning Imposter Pairs...")
imp_base, imp_aligned = evaluate_tsinghua_pipeline(gumnet_model, imposter_pairs, desc="Imposter")

# --- Print CVPR-Ready Results ---
d_prime_base = calculate_d_prime(gen_base, imp_base)
d_prime_aligned = calculate_d_prime(gen_aligned, imp_aligned)

print("\n" + "="*70)
print(f"{'TSINGHUA DATASET: BIOMETRIC MATCHING PERFORMANCE':^70}")
print("="*70)
print(f"{'Metric':<25} | {'Baseline':<12} | {'GumNet-AP-2D':<12} | {'Delta':<10}")
print("-" * 70)
print(f"{'Genuine Mean (μ_gen)':<25} | {gen_base.mean():<12.2f} | {gen_aligned.mean():<12.2f} | {(gen_aligned.mean() - gen_base.mean()):+.2f}")
print("-" * 70)
d_prime_label = "Decidability Index (d')"
print(f"{d_prime_label:<25} | {d_prime_base:<12.4f} | {d_prime_aligned:<12.4f} | {(d_prime_aligned - d_prime_base):+.4f}")
print("="*70)

Initializing Data...
Constructed 320 Genuine Pairs.
Constructed 102080 Imposter Pairs.
Randomly sampling 2000 imposter pairs from 102080...

Loading GumNet model...
Successfully loaded weights from ./checkpoints/gumnetmp_2d_best_noise_level_0_8x8_200.pth

Running Genuine Pairs...


Genuine: 100%|██████████| 320/320 [01:39<00:00,  3.22it/s]



Running Imposter Pairs...


Imposter: 100%|██████████| 2000/2000 [09:44<00:00,  3.42it/s]


           TSINGHUA DATASET: BIOMETRIC MATCHING PERFORMANCE           
Metric                    | Baseline     | GumNet-AP-2D | Delta     
----------------------------------------------------------------------
Genuine Mean (μ_gen)      | 40.88        | 42.90        | +2.01
----------------------------------------------------------------------
Decidability Index (d')   | 1.4306       | 1.4227       | -0.0080



