In [50]:

import torch
import os

ORIGINAL_MODEL_PATH = os.path.abspath("trained_models/best_model.pyt")

# Load checkpoints
checkpoint = torch.load(ORIGINAL_MODEL_PATH, map_location='cpu')

print(f"Checkpoint keys: {list(checkpoint.keys())}")
print(f"Total model parameters: {len(checkpoint['model_state_dict'].keys())}")

tex_weight = checkpoint['model_state_dict']['texture_branch._6_linear.weight']
print("Texture Embedding Layer:")
print(f"   Shape: {tex_weight.shape}")
print(f"   Embedding dimensions: {tex_weight.shape[0]}")
print(f"   Input features: {tex_weight.shape[1]}")

# Check if there's a minutia branch
has_minutia = 'minutia_embedding._4_linear.weight' in checkpoint['model_state_dict']
print(f"   Has minutia branch: {'True' if has_minutia else 'False'}")

if has_minutia:
    min_weight = checkpoint['model_state_dict']['minutia_embedding._4_linear.weight']
    print(f"   Minutia embedding dimensions: {min_weight.shape[0]}")

# Try to determine number of training subjects from loss layer
loss_keys = [k for k in checkpoint['loss_state_dict'].keys() if 'classification' in k or 'centers' in k]
print("Loss function info:")
for key in loss_keys[:5]:  # Show first 5
    print(f"   {key}: {checkpoint['loss_state_dict'][key].shape}")

# Check for center loss centers which would tell us num_subjects
if 'texture_center_loss.centers' in checkpoint['loss_state_dict']:
    centers = checkpoint['loss_state_dict']['texture_center_loss.centers']
    num_training_subjects = centers.shape[0]
    print(f"Detected training subjects: {num_training_subjects}")
    print(f"   (from center loss centers shape: {centers.shape})")
    
print("\n" + "="*60)

Checkpoint keys: ['model_state_dict', 'loss_state_dict', 'optimizer_state_dict']
Total model parameters: 1170
Texture Embedding Layer:
   Shape: torch.Size([256, 1536])
   Embedding dimensions: 256
   Input features: 1536
   Has minutia branch: True
   Minutia embedding dimensions: 256
Loss function info:
   minu_loss_fun.center_loss_fun.centers: torch.Size([8000, 256])
   texture_loss_fun.center_loss_fun.centers: torch.Size([8000, 256])



In [52]:
from flx.extractor.fixed_length_extractor import (
    get_DeepPrint_TexMinu, 
    DeepPrintExtractor
)

extractor: DeepPrintExtractor = get_DeepPrint_TexMinu(num_training_subjects=8000, num_dims=256)

MODEL_DIR: str = os.path.abspath("trained_models/")
extractor.load_best_model(MODEL_DIR)

Loaded best model from /Users/koechian/Documents/Projects/fixed-length-fingerprint-extractors/notebooks/trained_models/best_model.pyt


In [66]:
if 'extractor' in globals() and extractor is not None:

    from flx.data.dataset import *
    from flx.data.image_loader import FVC2004Loader
    from flx.data.transformed_image_loader import TransformedImageLoader
    from flx.image_processing.binarization import LazilyAllocatedBinarizer
    from flx.data.image_helpers import pad_and_resize_to_deepprint_input_size
    
    DATASET_PATH = os.path.abspath("dataset/split")
    
    image_loader = TransformedImageLoader(
        images=FVC2004Loader(DATASET_PATH),
        poses=None, 
        transforms=[
            LazilyAllocatedBinarizer(5.0),
            pad_and_resize_to_deepprint_input_size,
        ],
    )

    dataset = Dataset(image_loader, image_loader.ids)

    texture_embeddings, minutae_embeddings = extractor.extract(dataset)

Created IdentifierSet with 8 subjects and a total of 16 samples.


100%|██████████| 1/1 [00:05<00:00,  5.24s/it]
100%|██████████| 1/1 [00:05<00:00,  5.24s/it]


101 - My Index
102 - My Thumb
103_1 - Nelson's Thumb

In [None]:
import numpy as np
from flx.data.embedding_loader import EmbeddingLoader

# Show sample embedding shape    
sample_id = list(dataset.ids)[0]
sample_embeddings = texture_embeddings.get(sample_id)
print(f"Texture embedding shape: {sample_embeddings.shape}")

# Combine embeddings for full DeepPrint representation
combined_embeddings = EmbeddingLoader.combine(texture_embeddings, minutae_embeddings)
print(f"Combined embedding shape: {combined_embeddings.get(sample_id).shape}")

# Show all available identifiers with their image files
ids_list = list(dataset.ids)
print(f"\n{'='*80}")
print(f"Available Identifiers and Images ({len(ids_list)} total):")
print(f"{'='*80}")
for idx, id in enumerate(ids_list[:20]):
    filepath = image_loader._images._files.get(id)
    filename = filepath.split('/')[-1]
    print(f"[{idx:2d}] ID(subject={id.subject}, impression={id.impression}) -> {filename}")

# Sanity check: Test on a pair using FULL DeepPrint representation (texture + minutiae)
if len(ids_list) >= 10:
    print(f"\n{'='*80}")
    print("Similarity Test (using combined texture + minutiae embeddings):")
    print(f"{'='*80}")
    
    # Choose which identifiers to compare (modify these indices)
    idx1, idx2, idx3 = 0, 1, 4
    id1, id2, id3 = ids_list[idx1], ids_list[idx2], ids_list[idx3]
    
    # Get combined embeddings
    emb1 = combined_embeddings.get(id1)
    emb2 = combined_embeddings.get(id2)
    emb3 = combined_embeddings.get(id3)
    
    # Compute similarities (dot product of concatenated embeddings)
    same_score = np.dot(emb1, emb2)
    diff_score = np.dot(emb1, emb3)
    
    print(f"\nComparing:")
    print(f"  [{idx1}] {image_loader._images._files.get(id1).split('/')[-1]}")
    print(f"  [{idx2}] {image_loader._images._files.get(id2).split('/')[-1]}")
    print(f"  Similarity: {same_score:.6f}")
    
    print(f"\nComparing:")
    print(f"  [{idx1}] {image_loader._images._files.get(id1).split('/')[-1]}")
    print(f"  [{idx3}] {image_loader._images._files.get(id3).split('/')[-1]}")
    print(f"  Similarity: {diff_score:.6f}")


Texture embedding shape: (256,)
Created IdentifierSet with 8 subjects and a total of 16 samples.
Created IdentifierSet with 8 subjects and a total of 16 samples.
Created IdentifierSet with 8 subjects and a total of 16 samples.
Combined embedding shape: (512,)

Available Identifiers and Images (16 total):
[ 0] ID(subject=100, impression=0) -> 101_1.tif
[ 1] ID(subject=100, impression=1) -> 101_2.tif
[ 2] ID(subject=101, impression=0) -> 102_1.tif
[ 3] ID(subject=101, impression=1) -> 102_2.tif
[ 4] ID(subject=102, impression=0) -> 103_1.tif
[ 5] ID(subject=102, impression=1) -> 103_2.tif
[ 6] ID(subject=103, impression=0) -> 104_1.tif
[ 7] ID(subject=103, impression=1) -> 104_2.tif
[ 8] ID(subject=104, impression=0) -> 105_1.tif
[ 9] ID(subject=104, impression=1) -> 105_2.tif
[10] ID(subject=105, impression=0) -> 106_1.tif
[11] ID(subject=105, impression=1) -> 106_2.tif
[12] ID(subject=106, impression=0) -> 107_1.tif
[13] ID(subject=106, impression=1) -> 107_2.tif
[14] ID(subject=107, i

In [None]:
from flx.scripts.generate_benchmarks import create_verification_benchmark
from flx.benchmarks.matchers import CosineSimilarityMatcher
from flx.data.embedding_loader import EmbeddingLoader

ids_list = list(dataset.ids)
unique_subjects = sorted(set(id.subject for id in ids_list))
unique_impressions = sorted(set(id.impression for id in ids_list))

print(f"Dataset has {dataset.num_subjects} subjects")
print(f"Unique subjects: {unique_subjects}")
print(f"Unique impressions: {unique_impressions}")

NUM_IMPRESSIONS_PER_SUBJECT = min(8, len(unique_impressions))
impressions_to_use = unique_impressions[:NUM_IMPRESSIONS_PER_SUBJECT]

print(f"  Subjects: {unique_subjects}")
print(f"  Impressions per subject: {impressions_to_use}")

benchmark = create_verification_benchmark(
    subjects=unique_subjects,
    impressions_per_subject=impressions_to_use)

embeddings = EmbeddingLoader.combine(texture_embeddings, minutae_embeddings)
matcher = CosineSimilarityMatcher(embeddings)

results = benchmark.run(matcher)

print(f"\nEqual-Error-Rate: {results.get_equal_error_rate()}")


Dataset has 10 subjects
Unique subjects: [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]
Unique impressions: [0, 1, 2, 3, 4, 5, 6, 7]

Creating benchmark with:
  Subjects: [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]
  Impressions per subject: [0, 1, 2, 3, 4, 5, 6, 7]


100%|██████████| 10/10 [00:00<00:00, 20919.22it/s]
100%|██████████| 10/10 [00:00<00:00, 54971.22it/s]
100%|██████████| 10/10 [00:00<00:00, 48998.88it/s]
100%|██████████| 10/10 [00:00<00:00, 20919.22it/s]
100%|██████████| 10/10 [00:00<00:00, 54971.22it/s]
100%|██████████| 10/10 [00:00<00:00, 48998.88it/s]
100%|██████████| 10/10 [00:00<00:00, 56833.39it/s]
100%|██████████| 10/10 [00:00<00:00, 22028.91it/s]
100%|██████████| 10/10 [00:00<00:00, 23096.39it/s]
100%|██████████| 10/10 [00:00<00:00, 62508.26it/s]
100%|██████████| 10/10 [00:00<00:00, 29937.93it/s]
100%|██████████| 10/10 [00:00<00:00, 22028.91it/s]
100%|██████████| 10/10 [00:00<00:00, 23096.39it/s]
100%|██████████| 10/10 [00:00<00:00, 62508.26it/s]
100%|██████████| 10/10 [00:00<00:00, 29937.93it/s]



Created IdentifierSet with 10 subjects and a total of 80 samples.
Created IdentifierSet with 10 subjects and a total of 80 samples.
Created IdentifierSet with 10 subjects and a total of 80 samples.


100%|██████████| 920/920 [00:00<00:00, 113369.56it/s]


Equal-Error-Rate: 0.125





In [60]:
# Diagnostic: Check similarity score distributions
import matplotlib.pyplot as plt

# Get all genuine (same subject) and impostor (different subject) scores
genuine_scores = []
impostor_scores = []

for comparison in results._results:
    if comparison.comparison.sample1.subject == comparison.comparison.sample2.subject:
        genuine_scores.append(comparison.similarity)
    else:
        impostor_scores.append(comparison.similarity)

genuine_scores = np.array(genuine_scores)
impostor_scores = np.array(impostor_scores)

print(f"\n{'='*80}")
print("Score Distribution Analysis:")
print(f"{'='*80}")
print(f"Genuine comparisons: {len(genuine_scores)}")
print(f"  Mean: {genuine_scores.mean():.4f}")
print(f"  Std:  {genuine_scores.std():.4f}")
print(f"  Min:  {genuine_scores.min():.4f}")
print(f"  Max:  {genuine_scores.max():.4f}")

print(f"\nImpostor comparisons: {len(impostor_scores)}")
print(f"  Mean: {impostor_scores.mean():.4f}")
print(f"  Std:  {impostor_scores.std():.4f}")
print(f"  Min:  {impostor_scores.min():.4f}")
print(f"  Max:  {impostor_scores.max():.4f}")

print(f"\nSeparation (genuine mean - impostor mean): {genuine_scores.mean() - impostor_scores.mean():.4f}")
print(f"Equal-Error-Rate: {results.get_equal_error_rate():.4f}")

# Plot distributions
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.hist(genuine_scores, bins=30, alpha=0.7, label='Genuine', color='green', edgecolor='black')
plt.hist(impostor_scores, bins=30, alpha=0.7, label='Impostor', color='red', edgecolor='black')
plt.xlabel('Similarity Score')
plt.ylabel('Count')
plt.title('Score Distributions')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.hist(genuine_scores, bins=30, alpha=0.7, label='Genuine', color='green', density=True, edgecolor='black')
plt.hist(impostor_scores, bins=30, alpha=0.7, label='Impostor', color='red', density=True, edgecolor='black')
plt.xlabel('Similarity Score')
plt.ylabel('Density')
plt.title('Normalized Distributions')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nInterpretation:")
print("  - Good separation: genuine scores should be much higher than impostor scores")
print("  - Overlap indicates discrimination difficulty")
print("  - For good performance: genuine mean should be > 0.8, impostor mean < 0.4")


AttributeError: 'VerificationResult' object has no attribute '_results'