In this tutorial we will learn to:
- Instantiate a DeepPrintExtractor
- Prepare a training dataset
- Train a DeepPrintExtractor

## Instantiate a DeepPrintExtractor

This package implements a number of variants of the DeepPrint architecture. The wrapper class for all these variants is called `DeepPrintExtractor`.
It has a `fit` method to train (and save) the model as well as an `extract` method to extract the DeepPrint features for fingerprint images. 

You can also try to implement your own models, but currently this is not directly supported by the package.

In [2]:
# GPU Detection and Setup
import torch

print("=" * 80)
print("DEVICE DETECTION")
print("=" * 80)
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    print(f"Current GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
else:
    print("\n⚠️  WARNING: CUDA not available! Training will run on CPU (VERY SLOW!)")
    print("\nPossible reasons:")
    print("1. PyTorch was installed without CUDA support")
    print("2. NVIDIA drivers are not installed or outdated")
    print("3. CUDA toolkit version mismatch")
    print("\nTo fix:")
    print("  pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121")
print("=" * 80)

DEVICE DETECTION
PyTorch version: 2.5.1+cu121
CUDA available: True
CUDA version: 12.1
Number of GPUs: 1
Current GPU: NVIDIA GeForce RTX 2080
GPU Memory: 8.00 GB


## Training the model

Instantiating the model was easy. To train it, first we will load the training data (see the [data tutorial](./dataset_tutorial.ipynb) for how to implement your own dataset).

Besides the fingerprint images, we also need a mapping from subjects to integer labels (for pytorch). For some variants we also need minutiae data. To see how a more complex dataset can be loaded, have a look at `flx/setup/datasets.py`.

Finally, we call the `fit` method, which trains the model and saves it to the specified path.

There is also the option to add a validation set, which will be used to evaluate the embeddings during training. This is useful to monitor the training progress and to avoid overfitting.
In this example we will not use a validation set for simplicity.

In [3]:
import os
import torch 

from flx.data.dataset import *
from flx.data.image_loader import SOCOFingLoader, CrossmatchLoader
from flx.data.transformed_image_loader import TransformedImageLoader
from flx.data.label_index import LabelIndex
from flx.image_processing.binarization import LazilyAllocatedBinarizer
from flx.image_processing.augmentation import RandomPoseTransform
from flx.data.image_helpers import pad_and_resize_to_deepprint_input_size
from flx.extractor.fixed_length_extractor import get_DeepPrint_Tex

# Define paths
SOCO_DIR = os.path.abspath("dataset/soco")
CROSSMATCH_DIR = os.path.abspath("dataset/crossmatch")
STAGE1_MODEL_DIR = os.path.abspath("trained_models/stage1_soco")
STAGE2_MODEL_DIR = os.path.abspath("trained_models/stage2_crossmatch")

# Create output directories
os.makedirs(STAGE1_MODEL_DIR, exist_ok=True)
os.makedirs(STAGE2_MODEL_DIR, exist_ok=True)

print("=" * 80)
print("TWO-STAGE TRAINING STRATEGY")
print("=" * 80)
print("\nStage 1: Feature Learning on SOCOFing (6000 subjects, 1 impression each)")
print("Stage 2: Fine-tuning on Crossmatch (77 subjects, 8 impressions each)")
print("\n" + "=" * 80)

TWO-STAGE TRAINING STRATEGY

Stage 1: Feature Learning on SOCOFing (6000 subjects, 1 impression each)
Stage 2: Fine-tuning on Crossmatch (77 subjects, 8 impressions each)



## Stage 1: Feature Learning on SOCOFing

In this stage, we train on the SOCOFing dataset (6000 subjects, 1 impression each).

**Goal**: Learn discriminative features that distinguish between different fingers.

**Why this works**: Even with only 1 impression per subject, the classification loss teaches the model to extract features that separate different fingerprints. The large number of subjects (6000) provides excellent diversity.

In [4]:
# STAGE 1: Load SOCOFing dataset
print("\n--- Loading SOCOFing Dataset ---")
soco_loader = SOCOFingLoader(SOCO_DIR)
print(f"Total samples: {len(soco_loader.ids)}")
print(f"Total subjects: {soco_loader.ids.num_subjects}")
print(f"Impressions per subject: 1")

# Apply preprocessing and augmentation for training
soco_image_loader = TransformedImageLoader(
    images=soco_loader,
    poses=RandomPoseTransform(),  # Random rotation and translation
    transforms=[
        LazilyAllocatedBinarizer(ridge_width=5.0),  # Gabor filtering + binarization
        pad_and_resize_to_deepprint_input_size
    ]
)

# Create datasets
soco_fingerprints = Dataset(soco_image_loader, soco_loader.ids)
soco_labels = Dataset(LabelIndex(soco_loader.ids), soco_loader.ids)

print(f"\nStage 1 Training Set: {len(soco_fingerprints.ids)} samples")
print(f"Number of classes (subjects): {soco_fingerprints.ids.num_subjects}")


--- Loading SOCOFing Dataset ---
Created IdentifierSet with 6000 subjects and a total of 6000 samples.
Total samples: 6000
Total subjects: 6000
Impressions per subject: 1

Stage 1 Training Set: 6000 samples
Number of classes (subjects): 6000


In [5]:
# STAGE 1: Create and train the model
print("\n--- Stage 1: Training on SOCOFing ---")

stage1_extractor = get_DeepPrint_Tex(
    num_training_subjects=soco_loader.ids.num_subjects,
    num_texture_dims=512
)
print("=" * 80)

stage1_extractor.fit(
    fingerprints=soco_fingerprints,
    minutia_maps=None,
    labels=soco_labels,
    validation_fingerprints=None,  # No validation in Stage 1
    validation_benchmark=None,
    num_epochs=40,
    out_dir=STAGE1_MODEL_DIR
)

print("\n" + "=" * 80)
print("Stage 1 Complete! Model saved to:", STAGE1_MODEL_DIR)
print("=" * 80)


--- Stage 1: Training on SOCOFing ---
Using device cuda:0
No model file found at c:\Users\koechian\Documents\Projects\fixed-length-fingerprint-extractors\notebooks\trained_models\stage1_soco\model.pyt


 --- Starting Epoch 1 of 40 ---

Training:
Using device cuda:0
No model file found at c:\Users\koechian\Documents\Projects\fixed-length-fingerprint-extractors\notebooks\trained_models\stage1_soco\model.pyt


 --- Starting Epoch 1 of 40 ---

Training:


100%|██████████| 750/750 [11:31<00:00,  1.08it/s]



Average Loss: 13.307163186391195
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=1,
    training_loss=13.307163186391195,
    loss_statistics={'crossent_loss_sum': 1.4132321206728617, 'center_loss_sum': 0.25016327770551045},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=1,
    training_loss=13.307163186391195,
    loss_statistics={'crossent_loss_sum': 1.4132321206728617, 'center_loss_sum': 0.25016327770551045},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 2 of 40 ---

Training:


 --- Starting Epoch 2 of 40 ---

Training:


100%|██████████| 750/750 [11:19<00:00,  1.10it/s]



Average Loss: 13.301856862386067
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=2,
    training_loss=13.301856862386067,
    loss_statistics={'crossent_loss_sum': 0.7076156949202219, 'center_loss_sum': 0.12375035820404688},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=2,
    training_loss=13.301856862386067,
    loss_statistics={'crossent_loss_sum': 0.7076156949202219, 'center_loss_sum': 0.12375035820404688},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 3 of 40 ---

Training:


 --- Starting Epoch 3 of 40 ---

Training:


100%|██████████| 750/750 [11:26<00:00,  1.09it/s]



Average Loss: 12.673852755228678
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=3,
    training_loss=12.673852755228678,
    loss_statistics={'crossent_loss_sum': 0.44704227670033775, 'center_loss_sum': 0.08103492149379518},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=3,
    training_loss=12.673852755228678,
    loss_statistics={'crossent_loss_sum': 0.44704227670033775, 'center_loss_sum': 0.08103492149379518},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 4 of 40 ---

Training:


 --- Starting Epoch 4 of 40 ---

Training:


100%|██████████| 750/750 [11:00<00:00,  1.14it/s]



Average Loss: 12.207549285888671
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=4,
    training_loss=12.207549285888671,
    loss_statistics={'crossent_loss_sum': 0.321889443953832, 'center_loss_sum': 0.059596470793088274},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=4,
    training_loss=12.207549285888671,
    loss_statistics={'crossent_loss_sum': 0.321889443953832, 'center_loss_sum': 0.059596470793088274},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 5 of 40 ---

Training:


 --- Starting Epoch 5 of 40 ---

Training:


100%|██████████| 750/750 [10:58<00:00,  1.14it/s]



Average Loss: 11.89570574315389
Multiclass accuracy: 0.00016666666488163173
TrainingLogEntry(
    epoch=5,
    training_loss=11.89570574315389,
    loss_statistics={'crossent_loss_sum': 0.25063581059773765, 'center_loss_sum': 0.046756832949320475},
    training_accuracy=0.00016666666488163173,
    validation_equal_error_rate=0.00016666666488163173,
}


 --- Starting Epoch 6 of 40 ---

Training:
TrainingLogEntry(
    epoch=5,
    training_loss=11.89570574315389,
    loss_statistics={'crossent_loss_sum': 0.25063581059773765, 'center_loss_sum': 0.046756832949320475},
    training_accuracy=0.00016666666488163173,
    validation_equal_error_rate=0.00016666666488163173,
}


 --- Starting Epoch 6 of 40 ---

Training:


100%|██████████| 750/750 [10:59<00:00,  1.14it/s]



Average Loss: 11.713527089436848
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=6,
    training_loss=11.713527089436848,
    loss_statistics={'crossent_loss_sum': 0.2058198075029585, 'center_loss_sum': 0.038212006664938396},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=6,
    training_loss=11.713527089436848,
    loss_statistics={'crossent_loss_sum': 0.2058198075029585, 'center_loss_sum': 0.038212006664938396},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 7 of 40 ---

Training:


 --- Starting Epoch 7 of 40 ---

Training:


100%|██████████| 750/750 [10:58<00:00,  1.14it/s]



Average Loss: 11.59484490331014
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=7,
    training_loss=11.59484490331014,
    loss_statistics={'crossent_loss_sum': 0.17493116042727516, 'center_loss_sum': 0.03211964107978912},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=7,
    training_loss=11.59484490331014,
    loss_statistics={'crossent_loss_sum': 0.17493116042727516, 'center_loss_sum': 0.03211964107978912},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 8 of 40 ---

Training:


 --- Starting Epoch 8 of 40 ---

Training:


100%|██████████| 750/750 [10:59<00:00,  1.14it/s]



Average Loss: 11.488104349772135
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=8,
    training_loss=11.488104349772135,
    loss_statistics={'crossent_loss_sum': 0.1519393995006879, 'center_loss_sum': 0.02756223095456759},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=8,
    training_loss=11.488104349772135,
    loss_statistics={'crossent_loss_sum': 0.1519393995006879, 'center_loss_sum': 0.02756223095456759},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 9 of 40 ---

Training:


 --- Starting Epoch 9 of 40 ---

Training:


100%|██████████| 750/750 [10:57<00:00,  1.14it/s]



Average Loss: 11.41110138575236
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=9,
    training_loss=11.41110138575236,
    loss_statistics={'crossent_loss_sum': 0.13446318251998335, 'center_loss_sum': 0.024024336795012157},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=9,
    training_loss=11.41110138575236,
    loss_statistics={'crossent_loss_sum': 0.13446318251998335, 'center_loss_sum': 0.024024336795012157},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 10 of 40 ---

Training:


 --- Starting Epoch 10 of 40 ---

Training:


100%|██████████| 750/750 [10:57<00:00,  1.14it/s]



Average Loss: 11.342518018086752
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=10,
    training_loss=11.342518018086752,
    loss_statistics={'crossent_loss_sum': 0.12057966655095419, 'center_loss_sum': 0.021201808456579842},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=10,
    training_loss=11.342518018086752,
    loss_statistics={'crossent_loss_sum': 0.12057966655095419, 'center_loss_sum': 0.021201808456579842},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 11 of 40 ---

Training:


 --- Starting Epoch 11 of 40 ---

Training:


100%|██████████| 750/750 [10:55<00:00,  1.14it/s]



Average Loss: 11.272408111572265
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=11,
    training_loss=11.272408111572265,
    loss_statistics={'crossent_loss_sum': 0.1091948129047047, 'center_loss_sum': 0.018900733991102737},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=11,
    training_loss=11.272408111572265,
    loss_statistics={'crossent_loss_sum': 0.1091948129047047, 'center_loss_sum': 0.018900733991102737},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 12 of 40 ---

Training:


 --- Starting Epoch 12 of 40 ---

Training:


100%|██████████| 750/750 [10:56<00:00,  1.14it/s]



Average Loss: 11.181955197652181
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=12,
    training_loss=11.181955197652181,
    loss_statistics={'crossent_loss_sum': 0.09949012287457784, 'center_loss_sum': 0.016988577236731847},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=12,
    training_loss=11.181955197652181,
    loss_statistics={'crossent_loss_sum': 0.09949012287457784, 'center_loss_sum': 0.016988577236731847},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 13 of 40 ---

Training:


 --- Starting Epoch 13 of 40 ---

Training:


100%|██████████| 750/750 [10:58<00:00,  1.14it/s]



Average Loss: 11.186284006754557
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=13,
    training_loss=11.186284006754557,
    loss_statistics={'crossent_loss_sum': 0.09218564520126735, 'center_loss_sum': 0.015374777911565243},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=13,
    training_loss=11.186284006754557,
    loss_statistics={'crossent_loss_sum': 0.09218564520126735, 'center_loss_sum': 0.015374777911565243},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 14 of 40 ---

Training:


 --- Starting Epoch 14 of 40 ---

Training:


100%|██████████| 750/750 [10:54<00:00,  1.15it/s]



Average Loss: 11.128980729420979
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=14,
    training_loss=11.128980729420979,
    loss_statistics={'crossent_loss_sum': 0.08536869637171428, 'center_loss_sum': 0.013997202934253784},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=14,
    training_loss=11.128980729420979,
    loss_statistics={'crossent_loss_sum': 0.08536869637171428, 'center_loss_sum': 0.013997202934253784},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 15 of 40 ---

Training:


 --- Starting Epoch 15 of 40 ---

Training:


100%|██████████| 750/750 [10:55<00:00,  1.14it/s]



Average Loss: 11.077694531758626
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=15,
    training_loss=11.077694531758626,
    loss_statistics={'crossent_loss_sum': 0.07950356800291274, 'center_loss_sum': 0.01281055298116472},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=15,
    training_loss=11.077694531758626,
    loss_statistics={'crossent_loss_sum': 0.07950356800291274, 'center_loss_sum': 0.01281055298116472},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 16 of 40 ---

Training:


 --- Starting Epoch 16 of 40 ---

Training:


100%|██████████| 750/750 [10:55<00:00,  1.14it/s]



Average Loss: 11.015957948048909
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=16,
    training_loss=11.015957948048909,
    loss_statistics={'crossent_loss_sum': 0.07428664127985636, 'center_loss_sum': 0.011775530190517506},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=16,
    training_loss=11.015957948048909,
    loss_statistics={'crossent_loss_sum': 0.07428664127985636, 'center_loss_sum': 0.011775530190517506},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 17 of 40 ---

Training:


 --- Starting Epoch 17 of 40 ---

Training:


100%|██████████| 750/750 [10:56<00:00,  1.14it/s]



Average Loss: 10.976946455637615
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=17,
    training_loss=10.976946455637615,
    loss_statistics={'crossent_loss_sum': 0.06984717875835943, 'center_loss_sum': 0.010865662891490787},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=17,
    training_loss=10.976946455637615,
    loss_statistics={'crossent_loss_sum': 0.06984717875835943, 'center_loss_sum': 0.010865662891490787},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 18 of 40 ---

Training:


 --- Starting Epoch 18 of 40 ---

Training:


100%|██████████| 750/750 [10:59<00:00,  1.14it/s]



Average Loss: 10.912006905873616
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=18,
    training_loss=10.912006905873616,
    loss_statistics={'crossent_loss_sum': 0.06571405023998685, 'center_loss_sum': 0.010063775462132912},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=18,
    training_loss=10.912006905873616,
    loss_statistics={'crossent_loss_sum': 0.06571405023998685, 'center_loss_sum': 0.010063775462132912},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 19 of 40 ---

Training:


 --- Starting Epoch 19 of 40 ---

Training:


100%|██████████| 750/750 [11:00<00:00,  1.14it/s]



Average Loss: 10.850717001597086
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=19,
    training_loss=10.850717001597086,
    loss_statistics={'crossent_loss_sum': 0.062037636589585686, 'center_loss_sum': 0.009348659381531833},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=19,
    training_loss=10.850717001597086,
    loss_statistics={'crossent_loss_sum': 0.062037636589585686, 'center_loss_sum': 0.009348659381531833},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 20 of 40 ---

Training:


 --- Starting Epoch 20 of 40 ---

Training:


100%|██████████| 750/750 [10:57<00:00,  1.14it/s]



Average Loss: 10.712154322306315
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=20,
    training_loss=10.712154322306315,
    loss_statistics={'crossent_loss_sum': 0.058238046725591026, 'center_loss_sum': 0.008712917818625768},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=20,
    training_loss=10.712154322306315,
    loss_statistics={'crossent_loss_sum': 0.058238046725591026, 'center_loss_sum': 0.008712917818625768},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 21 of 40 ---

Training:


 --- Starting Epoch 21 of 40 ---

Training:


100%|██████████| 750/750 [10:57<00:00,  1.14it/s]



Average Loss: 10.629973234812418
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=21,
    training_loss=10.629973234812418,
    loss_statistics={'crossent_loss_sum': 0.05513324264496092, 'center_loss_sum': 0.008140407454399836},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=21,
    training_loss=10.629973234812418,
    loss_statistics={'crossent_loss_sum': 0.05513324264496092, 'center_loss_sum': 0.008140407454399836},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 22 of 40 ---

Training:


 --- Starting Epoch 22 of 40 ---

Training:


100%|██████████| 750/750 [12:14<00:00,  1.02it/s]



Average Loss: 10.536871514638264
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=22,
    training_loss=10.536871514638264,
    loss_statistics={'crossent_loss_sum': 0.052246325088269784, 'center_loss_sum': 0.007622263092886318},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=22,
    training_loss=10.536871514638264,
    loss_statistics={'crossent_loss_sum': 0.052246325088269784, 'center_loss_sum': 0.007622263092886318},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 23 of 40 ---

Training:


 --- Starting Epoch 23 of 40 ---

Training:


100%|██████████| 750/750 [11:14<00:00,  1.11it/s]



Average Loss: 10.448148256937662
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=23,
    training_loss=10.448148256937662,
    loss_statistics={'crossent_loss_sum': 0.049630142768224084, 'center_loss_sum': 0.007153271545534548},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=23,
    training_loss=10.448148256937662,
    loss_statistics={'crossent_loss_sum': 0.049630142768224084, 'center_loss_sum': 0.007153271545534548},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 24 of 40 ---

Training:


 --- Starting Epoch 24 of 40 ---

Training:


100%|██████████| 750/750 [11:11<00:00,  1.12it/s]



Average Loss: 10.347637208302816
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=24,
    training_loss=10.347637208302816,
    loss_statistics={'crossent_loss_sum': 0.047167294555240206, 'center_loss_sum': 0.006726649261183209},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=24,
    training_loss=10.347637208302816,
    loss_statistics={'crossent_loss_sum': 0.047167294555240206, 'center_loss_sum': 0.006726649261183209},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 25 of 40 ---

Training:


 --- Starting Epoch 25 of 40 ---

Training:


100%|██████████| 750/750 [11:08<00:00,  1.12it/s]



Average Loss: 10.29520227432251
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=25,
    training_loss=10.29520227432251,
    loss_statistics={'crossent_loss_sum': 0.04513926426887512, 'center_loss_sum': 0.006336747081279755},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=25,
    training_loss=10.29520227432251,
    loss_statistics={'crossent_loss_sum': 0.04513926426887512, 'center_loss_sum': 0.006336747081279755},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 26 of 40 ---

Training:


 --- Starting Epoch 26 of 40 ---

Training:


100%|██████████| 750/750 [11:07<00:00,  1.12it/s]



Average Loss: 10.203918346405029
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=26,
    training_loss=10.203918346405029,
    loss_statistics={'crossent_loss_sum': 0.04307864792224688, 'center_loss_sum': 0.005978651852179796},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=26,
    training_loss=10.203918346405029,
    loss_statistics={'crossent_loss_sum': 0.04307864792224688, 'center_loss_sum': 0.005978651852179796},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 27 of 40 ---

Training:


 --- Starting Epoch 27 of 40 ---

Training:


100%|██████████| 750/750 [11:07<00:00,  1.12it/s]



Average Loss: 10.135078564961752
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=27,
    training_loss=10.135078564961752,
    loss_statistics={'crossent_loss_sum': 0.04127196077652919, 'center_loss_sum': 0.005649699181686213},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=27,
    training_loss=10.135078564961752,
    loss_statistics={'crossent_loss_sum': 0.04127196077652919, 'center_loss_sum': 0.005649699181686213},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 28 of 40 ---

Training:


 --- Starting Epoch 28 of 40 ---

Training:


100%|██████████| 750/750 [11:06<00:00,  1.13it/s]



Average Loss: 10.065642870585124
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=28,
    training_loss=10.065642870585124,
    loss_statistics={'crossent_loss_sum': 0.039588903236956824, 'center_loss_sum': 0.005347002428202402},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=28,
    training_loss=10.065642870585124,
    loss_statistics={'crossent_loss_sum': 0.039588903236956824, 'center_loss_sum': 0.005347002428202402},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 29 of 40 ---

Training:


 --- Starting Epoch 29 of 40 ---

Training:


100%|██████████| 750/750 [11:07<00:00,  1.12it/s]



Average Loss: 9.978020519256592
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=29,
    training_loss=9.978020519256592,
    loss_statistics={'crossent_loss_sum': 0.03794303579440062, 'center_loss_sum': 0.005065673285517199},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=29,
    training_loss=9.978020519256592,
    loss_statistics={'crossent_loss_sum': 0.03794303579440062, 'center_loss_sum': 0.005065673285517199},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 30 of 40 ---

Training:


 --- Starting Epoch 30 of 40 ---

Training:


100%|██████████| 750/750 [11:09<00:00,  1.12it/s]



Average Loss: 9.926965780893962
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=30,
    training_loss=9.926965780893962,
    loss_statistics={'crossent_loss_sum': 0.03655635964075724, 'center_loss_sum': 0.004805997778309716},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=30,
    training_loss=9.926965780893962,
    loss_statistics={'crossent_loss_sum': 0.03655635964075724, 'center_loss_sum': 0.004805997778309716},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 31 of 40 ---

Training:


 --- Starting Epoch 31 of 40 ---

Training:


100%|██████████| 750/750 [11:06<00:00,  1.13it/s]



Average Loss: 9.867165739695231
Multiclass accuracy: 0.00016666666488163173
TrainingLogEntry(
    epoch=31,
    training_loss=9.867165739695231,
    loss_statistics={'crossent_loss_sum': 0.03522280189811543, 'center_loss_sum': 0.004564156707256071},
    training_accuracy=0.00016666666488163173,
    validation_equal_error_rate=0.00016666666488163173,
}


 --- Starting Epoch 32 of 40 ---

Training:
TrainingLogEntry(
    epoch=31,
    training_loss=9.867165739695231,
    loss_statistics={'crossent_loss_sum': 0.03522280189811543, 'center_loss_sum': 0.004564156707256071},
    training_accuracy=0.00016666666488163173,
    validation_equal_error_rate=0.00016666666488163173,
}


 --- Starting Epoch 32 of 40 ---

Training:


100%|██████████| 750/750 [11:10<00:00,  1.12it/s]



Average Loss: 9.807834962209066
Multiclass accuracy: 0.00016666666488163173
TrainingLogEntry(
    epoch=32,
    training_loss=9.807834962209066,
    loss_statistics={'crossent_loss_sum': 0.033973124347627164, 'center_loss_sum': 0.004338730967914065},
    training_accuracy=0.00016666666488163173,
    validation_equal_error_rate=0.00016666666488163173,
}


 --- Starting Epoch 33 of 40 ---

Training:
TrainingLogEntry(
    epoch=32,
    training_loss=9.807834962209066,
    loss_statistics={'crossent_loss_sum': 0.033973124347627164, 'center_loss_sum': 0.004338730967914065},
    training_accuracy=0.00016666666488163173,
    validation_equal_error_rate=0.00016666666488163173,
}


 --- Starting Epoch 33 of 40 ---

Training:


100%|██████████| 750/750 [11:08<00:00,  1.12it/s]



Average Loss: 9.743643268585204
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=33,
    training_loss=9.743643268585204,
    loss_statistics={'crossent_loss_sum': 0.03277742728079208, 'center_loss_sum': 0.0041303123500612045},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=33,
    training_loss=9.743643268585204,
    loss_statistics={'crossent_loss_sum': 0.03277742728079208, 'center_loss_sum': 0.0041303123500612045},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 34 of 40 ---

Training:


 --- Starting Epoch 34 of 40 ---

Training:


100%|██████████| 750/750 [11:06<00:00,  1.12it/s]



Average Loss: 9.6978664894104
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=34,
    training_loss=9.6978664894104,
    loss_statistics={'crossent_loss_sum': 0.03172108352418039, 'center_loss_sum': 0.003932837371732675},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=34,
    training_loss=9.6978664894104,
    loss_statistics={'crossent_loss_sum': 0.03172108352418039, 'center_loss_sum': 0.003932837371732675},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 35 of 40 ---

Training:


 --- Starting Epoch 35 of 40 ---

Training:


100%|██████████| 750/750 [11:14<00:00,  1.11it/s]



Average Loss: 9.648100273132325
Multiclass accuracy: 0.00016666666488163173
TrainingLogEntry(
    epoch=35,
    training_loss=9.648100273132325,
    loss_statistics={'crossent_loss_sum': 0.030708119603565762, 'center_loss_sum': 0.0037493814028444743},
    training_accuracy=0.00016666666488163173,
    validation_equal_error_rate=0.00016666666488163173,
}


 --- Starting Epoch 36 of 40 ---

Training:
TrainingLogEntry(
    epoch=35,
    training_loss=9.648100273132325,
    loss_statistics={'crossent_loss_sum': 0.030708119603565762, 'center_loss_sum': 0.0037493814028444743},
    training_accuracy=0.00016666666488163173,
    validation_equal_error_rate=0.00016666666488163173,
}


 --- Starting Epoch 36 of 40 ---

Training:


100%|██████████| 750/750 [11:08<00:00,  1.12it/s]



Average Loss: 9.597571629842122
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=36,
    training_loss=9.597571629842122,
    loss_statistics={'crossent_loss_sum': 0.029747646636433072, 'center_loss_sum': 0.0035772548803576715},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=36,
    training_loss=9.597571629842122,
    loss_statistics={'crossent_loss_sum': 0.029747646636433072, 'center_loss_sum': 0.0035772548803576715},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 37 of 40 ---

Training:


 --- Starting Epoch 37 of 40 ---

Training:


100%|██████████| 750/750 [11:09<00:00,  1.12it/s]



Average Loss: 9.529968819936116
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=37,
    training_loss=9.529968819936116,
    loss_statistics={'crossent_loss_sum': 0.02877975936623307, 'center_loss_sum': 0.0034160812207170437},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=37,
    training_loss=9.529968819936116,
    loss_statistics={'crossent_loss_sum': 0.02877975936623307, 'center_loss_sum': 0.0034160812207170437},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 38 of 40 ---

Training:


 --- Starting Epoch 38 of 40 ---

Training:


100%|██████████| 750/750 [11:09<00:00,  1.12it/s]



Average Loss: 9.495313529968263
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=38,
    training_loss=9.495313529968263,
    loss_statistics={'crossent_loss_sum': 0.027970371911400244, 'center_loss_sum': 0.003264212076601229},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=38,
    training_loss=9.495313529968263,
    loss_statistics={'crossent_loss_sum': 0.027970371911400244, 'center_loss_sum': 0.003264212076601229},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 39 of 40 ---

Training:


 --- Starting Epoch 39 of 40 ---

Training:


100%|██████████| 750/750 [11:09<00:00,  1.12it/s]



Average Loss: 9.431253532409668
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=39,
    training_loss=9.431253532409668,
    loss_statistics={'crossent_loss_sum': 0.027107792183884188, 'center_loss_sum': 0.0031205845131323886},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=39,
    training_loss=9.431253532409668,
    loss_statistics={'crossent_loss_sum': 0.027107792183884188, 'center_loss_sum': 0.0031205845131323886},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}


 --- Starting Epoch 40 of 40 ---

Training:


 --- Starting Epoch 40 of 40 ---

Training:


100%|██████████| 750/750 [11:09<00:00,  1.12it/s]



Average Loss: 9.396391512552897
Multiclass accuracy: 0.0
TrainingLogEntry(
    epoch=40,
    training_loss=9.396391512552897,
    loss_statistics={'crossent_loss_sum': 0.02637725165486336, 'center_loss_sum': 0.002986471762508154},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}
TrainingLogEntry(
    epoch=40,
    training_loss=9.396391512552897,
    loss_statistics={'crossent_loss_sum': 0.02637725165486336, 'center_loss_sum': 0.002986471762508154},
    training_accuracy=0.0,
    validation_equal_error_rate=0.0,
}

Stage 1 Complete! Model saved to: c:\Users\koechian\Documents\Projects\fixed-length-fingerprint-extractors\notebooks\trained_models\stage1_soco

Stage 1 Complete! Model saved to: c:\Users\koechian\Documents\Projects\fixed-length-fingerprint-extractors\notebooks\trained_models\stage1_soco


## Stage 2: Fine-tuning on Crossmatch

In this stage, we load the trained model from Stage 1 and fine-tune it on the Crossmatch dataset (77 subjects, 8 impressions each).

**Goal**: Learn intra-class similarity - that multiple impressions of the same finger should have similar embeddings.

**Why this works**: The center loss now has multiple impressions per subject to work with, so it can pull embeddings of the same finger closer together. The model already knows how to extract discriminative features from Stage 1.

In [6]:
# STAGE 2: Load Crossmatch dataset
print("\n--- Loading Crossmatch Dataset ---")
crossmatch_loader = CrossmatchLoader(CROSSMATCH_DIR)
print(f"Total samples: {len(crossmatch_loader.ids)}")
print(f"Total subjects: {crossmatch_loader.ids.num_subjects}")

# Analyze impression distribution
from collections import defaultdict
impression_counts = defaultdict(int)
for identifier in crossmatch_loader.ids:
    impression_counts[identifier.subject] += 1

avg_impressions = sum(impression_counts.values()) / len(impression_counts)
print(f"Average impressions per subject: {avg_impressions:.1f}")

# Apply preprocessing and augmentation
crossmatch_image_loader = TransformedImageLoader(
    images=crossmatch_loader,
    poses=RandomPoseTransform(),  # Same augmentation as Stage 1
    transforms=[
        LazilyAllocatedBinarizer(ridge_width=5.0),
        pad_and_resize_to_deepprint_input_size
    ]
)

# Create datasets
crossmatch_fingerprints = Dataset(crossmatch_image_loader, crossmatch_loader.ids)
crossmatch_labels = Dataset(LabelIndex(crossmatch_loader.ids), crossmatch_loader.ids)

print(f"\nStage 2 Training Set: {len(crossmatch_fingerprints.ids)} samples")
print(f"Number of classes (subjects): {crossmatch_fingerprints.ids.num_subjects}")


--- Loading Crossmatch Dataset ---
Created IdentifierSet with 77 subjects and a total of 616 samples.
Total samples: 616
Total subjects: 77
Average impressions per subject: 8.0

Stage 2 Training Set: 616 samples
Number of classes (subjects): 77


In [8]:
# STAGE 2: Load Stage 1 model and fine-tune
print("\n--- Stage 2: Fine-tuning on Crossmatch ---")

# Create new extractor with Crossmatch dimensions
stage2_extractor = get_DeepPrint_Tex(
    num_training_subjects=crossmatch_loader.ids.num_subjects,
    num_texture_dims=512
)

# Load the best model from Stage 1 (only feature extractor, not classification head)
print(f"Loading best model from Stage 1: {STAGE1_MODEL_DIR}")
from flx.setup.paths import get_best_model_file
import torch
from flx.models.torch_helpers import get_device

model_path = get_best_model_file(STAGE1_MODEL_DIR)
checkpoint = torch.load(model_path, map_location=get_device())
stage1_state = checkpoint["model_state_dict"]

# Load only the feature extractor layers (exclude classification head with mismatched size)
stage2_state = stage2_extractor.model.state_dict()
pretrained_dict = {k: v for k, v in stage1_state.items() 
                   if k in stage2_state and v.shape == stage2_state[k].shape}

print(f"Loaded {len(pretrained_dict)}/{len(stage2_state)} layers from Stage 1")
print(f"Skipped classification head (different classes: 6000 -> 77)")
stage2_extractor.model.load_state_dict(pretrained_dict, strict=False)

print(f"\nModel: DeepPrint_Tex_512 (pre-trained from Stage 1)")
print(f"Learning rate: 0.025 (default - consider reducing if unstable)")
print(f"Number of epochs: 30")
print(f"Expected time: ~5-10 min/epoch on GPU (smaller dataset)")
print("\nStarting fine-tuning...")
print("Monitor: Center loss should decrease as model learns intra-class similarity")
print("=" * 80)

stage2_extractor.fit(
    fingerprints=crossmatch_fingerprints,
    minutia_maps=None,
    labels=crossmatch_labels,
    validation_fingerprints=None,
    validation_benchmark=None,
    num_epochs=30,
    out_dir=STAGE2_MODEL_DIR
)

print("\n" + "=" * 80)
print("Stage 2 Complete! Final model saved to:", STAGE2_MODEL_DIR)
print("=" * 80)
print("\nTraining Summary:")
print(f"  Stage 1: Trained on {len(soco_fingerprints.ids)} SOCOFing samples")
print(f"  Stage 2: Fine-tuned on {len(crossmatch_fingerprints.ids)} Crossmatch samples")
print(f"  Final model location: {STAGE2_MODEL_DIR}/best_model.pyt")
print("\nNext steps:")
print("  1. Evaluate the model on held-out test data")
print("  2. Extract embeddings for your fingerprint database")
print("  3. Use for fingerprint matching/verification")
print("=" * 80)


--- Stage 2: Fine-tuning on Crossmatch ---
Loading best model from Stage 1: c:\Users\koechian\Documents\Projects\fixed-length-fingerprint-extractors\notebooks\trained_models\stage1_soco
Loading best model from Stage 1: c:\Users\koechian\Documents\Projects\fixed-length-fingerprint-extractors\notebooks\trained_models\stage1_soco


  checkpoint = torch.load(model_path, map_location=get_device())


Loaded 896/898 layers from Stage 1
Skipped classification head (different classes: 6000 -> 77)

Model: DeepPrint_Tex_512 (pre-trained from Stage 1)
Learning rate: 0.025 (default - consider reducing if unstable)
Number of epochs: 30
Expected time: ~5-10 min/epoch on GPU (smaller dataset)

Starting fine-tuning...
Monitor: Center loss should decrease as model learns intra-class similarity
Using device cuda:0
No model file found at c:\Users\koechian\Documents\Projects\fixed-length-fingerprint-extractors\notebooks\trained_models\stage2_crossmatch\model.pyt


 --- Starting Epoch 1 of 30 ---

Training:


100%|██████████| 77/77 [01:17<00:00,  1.01s/it]



Average Loss: 6.1834256803834595
Multiclass accuracy: 0.027597401291131973
TrainingLogEntry(
    epoch=1,
    training_loss=6.1834256803834595,
    loss_statistics={'crossent_loss_sum': 0.5406462008302862, 'center_loss_sum': 0.23228201192694825},
    training_accuracy=0.027597401291131973,
    validation_equal_error_rate=0.027597401291131973,
}
TrainingLogEntry(
    epoch=1,
    training_loss=6.1834256803834595,
    loss_statistics={'crossent_loss_sum': 0.5406462008302862, 'center_loss_sum': 0.23228201192694825},
    training_accuracy=0.027597401291131973,
    validation_equal_error_rate=0.027597401291131973,
}


 --- Starting Epoch 2 of 30 ---

Training:


 --- Starting Epoch 2 of 30 ---

Training:


100%|██████████| 77/77 [01:15<00:00,  1.02it/s]



Average Loss: 5.627669972258729
Multiclass accuracy: 0.02110389620065689
TrainingLogEntry(
    epoch=2,
    training_loss=5.627669972258729,
    loss_statistics={'crossent_loss_sum': 0.2529046725143086, 'center_loss_sum': 0.09882470046157961},
    training_accuracy=0.02110389620065689,
    validation_equal_error_rate=0.02110389620065689,
}
TrainingLogEntry(
    epoch=2,
    training_loss=5.627669972258729,
    loss_statistics={'crossent_loss_sum': 0.2529046725143086, 'center_loss_sum': 0.09882470046157961},
    training_accuracy=0.02110389620065689,
    validation_equal_error_rate=0.02110389620065689,
}


 --- Starting Epoch 3 of 30 ---

Training:


 --- Starting Epoch 3 of 30 ---

Training:


100%|██████████| 77/77 [01:14<00:00,  1.04it/s]
100%|██████████| 77/77 [01:14<00:00,  1.04it/s]


Average Loss: 5.3322619401015245
Multiclass accuracy: 0.01461038924753666
TrainingLogEntry(
    epoch=3,
    training_loss=5.3322619401015245,
    loss_statistics={'crossent_loss_sum': 0.1660195840643598, 'center_loss_sum': 0.05615799528953833},
    training_accuracy=0.01461038924753666,
    validation_equal_error_rate=0.01461038924753666,
}
TrainingLogEntry(
    epoch=3,
    training_loss=5.3322619401015245,
    loss_statistics={'crossent_loss_sum': 0.1660195840643598, 'center_loss_sum': 0.05615799528953833},
    training_accuracy=0.01461038924753666,
    validation_equal_error_rate=0.01461038924753666,
}


 --- Starting Epoch 4 of 30 ---

Training:


 --- Starting Epoch 4 of 30 ---

Training:


100%|██████████| 77/77 [01:14<00:00,  1.04it/s]



Average Loss: 5.040047719881132
Multiclass accuracy: 0.0178571417927742
TrainingLogEntry(
    epoch=4,
    training_loss=5.040047719881132,
    loss_statistics={'crossent_loss_sum': 0.12159547703219699, 'center_loss_sum': 0.03590601600416295},
    training_accuracy=0.0178571417927742,
    validation_equal_error_rate=0.0178571417927742,
}


 --- Starting Epoch 5 of 30 ---

Training:
TrainingLogEntry(
    epoch=4,
    training_loss=5.040047719881132,
    loss_statistics={'crossent_loss_sum': 0.12159547703219699, 'center_loss_sum': 0.03590601600416295},
    training_accuracy=0.0178571417927742,
    validation_equal_error_rate=0.0178571417927742,
}


 --- Starting Epoch 5 of 30 ---

Training:


100%|██████████| 77/77 [01:12<00:00,  1.06it/s]



Average Loss: 4.8191471533341845
Multiclass accuracy: 0.027597403153777122
TrainingLogEntry(
    epoch=5,
    training_loss=4.8191471533341845,
    loss_statistics={'crossent_loss_sum': 0.09600780219226689, 'center_loss_sum': 0.024470877105539496},
    training_accuracy=0.027597403153777122,
    validation_equal_error_rate=0.027597403153777122,
}


 --- Starting Epoch 6 of 30 ---

Training:
TrainingLogEntry(
    epoch=5,
    training_loss=4.8191471533341845,
    loss_statistics={'crossent_loss_sum': 0.09600780219226689, 'center_loss_sum': 0.024470877105539496},
    training_accuracy=0.027597403153777122,
    validation_equal_error_rate=0.027597403153777122,
}


 --- Starting Epoch 6 of 30 ---

Training:


100%|██████████| 77/77 [01:13<00:00,  1.05it/s]



Average Loss: 4.70989719613806
Multiclass accuracy: 0.025974024087190628
TrainingLogEntry(
    epoch=6,
    training_loss=4.70989719613806,
    loss_statistics={'crossent_loss_sum': 0.08066625267396242, 'center_loss_sum': 0.01745660546602625},
    training_accuracy=0.025974024087190628,
    validation_equal_error_rate=0.025974024087190628,
}


 --- Starting Epoch 7 of 30 ---

Training:
TrainingLogEntry(
    epoch=6,
    training_loss=4.70989719613806,
    loss_statistics={'crossent_loss_sum': 0.08066625267396242, 'center_loss_sum': 0.01745660546602625},
    training_accuracy=0.025974024087190628,
    validation_equal_error_rate=0.025974024087190628,
}


 --- Starting Epoch 7 of 30 ---

Training:


100%|██████████| 77/77 [01:13<00:00,  1.04it/s]
100%|██████████| 77/77 [01:13<00:00,  1.04it/s]


Average Loss: 4.592771852171266
Multiclass accuracy: 0.037337660789489746
TrainingLogEntry(
    epoch=7,
    training_loss=4.592771852171266,
    loss_statistics={'crossent_loss_sum': 0.06924410349584025, 'center_loss_sum': 0.012769679495708841},
    training_accuracy=0.037337660789489746,
    validation_equal_error_rate=0.037337660789489746,
}


 --- Starting Epoch 8 of 30 ---

Training:
TrainingLogEntry(
    epoch=7,
    training_loss=4.592771852171266,
    loss_statistics={'crossent_loss_sum': 0.06924410349584025, 'center_loss_sum': 0.012769679495708841},
    training_accuracy=0.037337660789489746,
    validation_equal_error_rate=0.037337660789489746,
}


 --- Starting Epoch 8 of 30 ---

Training:


100%|██████████| 77/77 [01:15<00:00,  1.02it/s]



Average Loss: 4.40889319816193
Multiclass accuracy: 0.037337660789489746
TrainingLogEntry(
    epoch=8,
    training_loss=4.40889319816193,
    loss_statistics={'crossent_loss_sum': 0.05932705888113418, 'center_loss_sum': 0.00956189732805088},
    training_accuracy=0.037337660789489746,
    validation_equal_error_rate=0.037337660789489746,
}


 --- Starting Epoch 9 of 30 ---

Training:
TrainingLogEntry(
    epoch=8,
    training_loss=4.40889319816193,
    loss_statistics={'crossent_loss_sum': 0.05932705888113418, 'center_loss_sum': 0.00956189732805088},
    training_accuracy=0.037337660789489746,
    validation_equal_error_rate=0.037337660789489746,
}


 --- Starting Epoch 9 of 30 ---

Training:


100%|██████████| 77/77 [01:14<00:00,  1.04it/s]



Average Loss: 4.239085494697868
Multiclass accuracy: 0.034090910106897354
TrainingLogEntry(
    epoch=9,
    training_loss=4.239085494697868,
    loss_statistics={'crossent_loss_sum': 0.05160559337548535, 'center_loss_sum': 0.0072705942443954995},
    training_accuracy=0.034090910106897354,
    validation_equal_error_rate=0.034090910106897354,
}


 --- Starting Epoch 10 of 30 ---

Training:
TrainingLogEntry(
    epoch=9,
    training_loss=4.239085494697868,
    loss_statistics={'crossent_loss_sum': 0.05160559337548535, 'center_loss_sum': 0.0072705942443954995},
    training_accuracy=0.034090910106897354,
    validation_equal_error_rate=0.034090910106897354,
}


 --- Starting Epoch 10 of 30 ---

Training:


100%|██████████| 77/77 [01:12<00:00,  1.06it/s]



Average Loss: 4.24345748145859
Multiclass accuracy: 0.03896103799343109
TrainingLogEntry(
    epoch=10,
    training_loss=4.24345748145859,
    loss_statistics={'crossent_loss_sum': 0.04744439624346696, 'center_loss_sum': 0.00559882220703286},
    training_accuracy=0.03896103799343109,
    validation_equal_error_rate=0.03896103799343109,
}


 --- Starting Epoch 11 of 30 ---

Training:
TrainingLogEntry(
    epoch=10,
    training_loss=4.24345748145859,
    loss_statistics={'crossent_loss_sum': 0.04744439624346696, 'center_loss_sum': 0.00559882220703286},
    training_accuracy=0.03896103799343109,
    validation_equal_error_rate=0.03896103799343109,
}


 --- Starting Epoch 11 of 30 ---

Training:


100%|██████████| 77/77 [01:13<00:00,  1.05it/s]



Average Loss: 4.036047037545737
Multiclass accuracy: 0.04220779240131378
TrainingLogEntry(
    epoch=11,
    training_loss=4.036047037545737,
    loss_statistics={'crossent_loss_sum': 0.04152193095495456, 'center_loss_sum': 0.004342239891061254},
    training_accuracy=0.04220779240131378,
    validation_equal_error_rate=0.04220779240131378,
}


 --- Starting Epoch 12 of 30 ---

Training:
TrainingLogEntry(
    epoch=11,
    training_loss=4.036047037545737,
    loss_statistics={'crossent_loss_sum': 0.04152193095495456, 'center_loss_sum': 0.004342239891061254},
    training_accuracy=0.04220779240131378,
    validation_equal_error_rate=0.04220779240131378,
}


 --- Starting Epoch 12 of 30 ---

Training:


100%|██████████| 77/77 [01:14<00:00,  1.03it/s]



Average Loss: 4.089575352606835
Multiclass accuracy: 0.02922077849507332
TrainingLogEntry(
    epoch=12,
    training_loss=4.089575352606835,
    loss_statistics={'crossent_loss_sum': 0.03916508320606116, 'center_loss_sum': 0.0034346601954012208},
    training_accuracy=0.02922077849507332,
    validation_equal_error_rate=0.02922077849507332,
}


 --- Starting Epoch 13 of 30 ---

Training:
TrainingLogEntry(
    epoch=12,
    training_loss=4.089575352606835,
    loss_statistics={'crossent_loss_sum': 0.03916508320606116, 'center_loss_sum': 0.0034346601954012208},
    training_accuracy=0.02922077849507332,
    validation_equal_error_rate=0.02922077849507332,
}


 --- Starting Epoch 13 of 30 ---

Training:


100%|██████████| 77/77 [01:13<00:00,  1.05it/s]



Average Loss: 4.012444480673059
Multiclass accuracy: 0.0357142835855484
TrainingLogEntry(
    epoch=13,
    training_loss=4.012444480673059,
    loss_statistics={'crossent_loss_sum': 0.03587186205518115, 'center_loss_sum': 0.0027093349506596585},
    training_accuracy=0.0357142835855484,
    validation_equal_error_rate=0.0357142835855484,
}


 --- Starting Epoch 14 of 30 ---

Training:
TrainingLogEntry(
    epoch=13,
    training_loss=4.012444480673059,
    loss_statistics={'crossent_loss_sum': 0.03587186205518115, 'center_loss_sum': 0.0027093349506596585},
    training_accuracy=0.0357142835855484,
    validation_equal_error_rate=0.0357142835855484,
}


 --- Starting Epoch 14 of 30 ---

Training:


100%|██████████| 77/77 [01:13<00:00,  1.04it/s]



Average Loss: 3.8905168477590983
Multiclass accuracy: 0.03896103799343109
TrainingLogEntry(
    epoch=14,
    training_loss=3.8905168477590983,
    loss_statistics={'crossent_loss_sum': 0.03257729140747898, 'center_loss_sum': 0.0021594661617986787},
    training_accuracy=0.03896103799343109,
    validation_equal_error_rate=0.03896103799343109,
}


 --- Starting Epoch 15 of 30 ---

Training:
TrainingLogEntry(
    epoch=14,
    training_loss=3.8905168477590983,
    loss_statistics={'crossent_loss_sum': 0.03257729140747898, 'center_loss_sum': 0.0021594661617986787},
    training_accuracy=0.03896103799343109,
    validation_equal_error_rate=0.03896103799343109,
}


 --- Starting Epoch 15 of 30 ---

Training:


100%|██████████| 77/77 [01:14<00:00,  1.03it/s]



Average Loss: 3.89661405612896
Multiclass accuracy: 0.040584415197372437
TrainingLogEntry(
    epoch=15,
    training_loss=3.89661405612896,
    loss_statistics={'crossent_loss_sum': 0.0307388762374977, 'center_loss_sum': 0.0017329075942178825},
    training_accuracy=0.040584415197372437,
    validation_equal_error_rate=0.040584415197372437,
}


 --- Starting Epoch 16 of 30 ---

Training:
TrainingLogEntry(
    epoch=15,
    training_loss=3.89661405612896,
    loss_statistics={'crossent_loss_sum': 0.0307388762374977, 'center_loss_sum': 0.0017329075942178825},
    training_accuracy=0.040584415197372437,
    validation_equal_error_rate=0.040584415197372437,
}


 --- Starting Epoch 16 of 30 ---

Training:


100%|██████████| 77/77 [01:15<00:00,  1.02it/s]



Average Loss: 3.7970824148747826
Multiclass accuracy: 0.04545454680919647
TrainingLogEntry(
    epoch=16,
    training_loss=3.7970824148747826,
    loss_statistics={'crossent_loss_sum': 0.028265476783181167, 'center_loss_sum': 0.0013992297296811426},
    training_accuracy=0.04545454680919647,
    validation_equal_error_rate=0.04545454680919647,
}


 --- Starting Epoch 17 of 30 ---

Training:
TrainingLogEntry(
    epoch=16,
    training_loss=3.7970824148747826,
    loss_statistics={'crossent_loss_sum': 0.028265476783181167, 'center_loss_sum': 0.0013992297296811426},
    training_accuracy=0.04545454680919647,
    validation_equal_error_rate=0.04545454680919647,
}


 --- Starting Epoch 17 of 30 ---

Training:


100%|██████████| 77/77 [01:13<00:00,  1.05it/s]



Average Loss: 3.871733374409861
Multiclass accuracy: 0.04220779240131378
TrainingLogEntry(
    epoch=17,
    training_loss=3.871733374409861,
    loss_statistics={'crossent_loss_sum': 0.02732581551335439, 'center_loss_sum': 0.001142812211200242},
    training_accuracy=0.04220779240131378,
    validation_equal_error_rate=0.04220779240131378,
}


 --- Starting Epoch 18 of 30 ---

Training:
TrainingLogEntry(
    epoch=17,
    training_loss=3.871733374409861,
    loss_statistics={'crossent_loss_sum': 0.02732581551335439, 'center_loss_sum': 0.001142812211200242},
    training_accuracy=0.04220779240131378,
    validation_equal_error_rate=0.04220779240131378,
}


 --- Starting Epoch 18 of 30 ---

Training:


100%|██████████| 77/77 [01:14<00:00,  1.03it/s]



Average Loss: 3.7620215632698755
Multiclass accuracy: 0.045454543083906174
TrainingLogEntry(
    epoch=18,
    training_loss=3.7620215632698755,
    loss_statistics={'crossent_loss_sum': 0.02519952825137547, 'center_loss_sum': 0.000925621533871213},
    training_accuracy=0.045454543083906174,
    validation_equal_error_rate=0.045454543083906174,
}


 --- Starting Epoch 19 of 30 ---

Training:
TrainingLogEntry(
    epoch=18,
    training_loss=3.7620215632698755,
    loss_statistics={'crossent_loss_sum': 0.02519952825137547, 'center_loss_sum': 0.000925621533871213},
    training_accuracy=0.045454543083906174,
    validation_equal_error_rate=0.045454543083906174,
}


 --- Starting Epoch 19 of 30 ---

Training:


100%|██████████| 77/77 [01:13<00:00,  1.04it/s]



Average Loss: 3.7918136181769433
Multiclass accuracy: 0.037337664514780045
TrainingLogEntry(
    epoch=19,
    training_loss=3.7918136181769433,
    loss_statistics={'crossent_loss_sum': 0.02418102015392958, 'center_loss_sum': 0.0007651221428527379},
    training_accuracy=0.037337664514780045,
    validation_equal_error_rate=0.037337664514780045,
}


 --- Starting Epoch 20 of 30 ---

Training:
TrainingLogEntry(
    epoch=19,
    training_loss=3.7918136181769433,
    loss_statistics={'crossent_loss_sum': 0.02418102015392958, 'center_loss_sum': 0.0007651221428527379},
    training_accuracy=0.037337664514780045,
    validation_equal_error_rate=0.037337664514780045,
}


 --- Starting Epoch 20 of 30 ---

Training:


100%|██████████| 77/77 [01:20<00:00,  1.04s/it]



Average Loss: 3.6860181542185995
Multiclass accuracy: 0.04220779240131378
TrainingLogEntry(
    epoch=20,
    training_loss=3.6860181542185995,
    loss_statistics={'crossent_loss_sum': 0.022409342132605516, 'center_loss_sum': 0.000628271306465779},
    training_accuracy=0.04220779240131378,
    validation_equal_error_rate=0.04220779240131378,
}


 --- Starting Epoch 21 of 30 ---

Training:
TrainingLogEntry(
    epoch=20,
    training_loss=3.6860181542185995,
    loss_statistics={'crossent_loss_sum': 0.022409342132605516, 'center_loss_sum': 0.000628271306465779},
    training_accuracy=0.04220779240131378,
    validation_equal_error_rate=0.04220779240131378,
}


 --- Starting Epoch 21 of 30 ---

Training:


100%|██████████| 77/77 [01:13<00:00,  1.05it/s]



Average Loss: 3.694410534648152
Multiclass accuracy: 0.03246753290295601
TrainingLogEntry(
    epoch=21,
    training_loss=3.694410534648152,
    loss_statistics={'crossent_loss_sum': 0.021476932952782957, 'center_loss_sum': 0.0005136060309017523},
    training_accuracy=0.03246753290295601,
    validation_equal_error_rate=0.03246753290295601,
}


 --- Starting Epoch 22 of 30 ---

Training:
TrainingLogEntry(
    epoch=21,
    training_loss=3.694410534648152,
    loss_statistics={'crossent_loss_sum': 0.021476932952782957, 'center_loss_sum': 0.0005136060309017523},
    training_accuracy=0.03246753290295601,
    validation_equal_error_rate=0.03246753290295601,
}


 --- Starting Epoch 22 of 30 ---

Training:


100%|██████████| 77/77 [01:13<00:00,  1.05it/s]



Average Loss: 3.7278044007041236
Multiclass accuracy: 0.05032467469573021
TrainingLogEntry(
    epoch=22,
    training_loss=3.7278044007041236,
    loss_statistics={'crossent_loss_sum': 0.02074538306756453, 'center_loss_sum': 0.0004353238040979384},
    training_accuracy=0.05032467469573021,
    validation_equal_error_rate=0.05032467469573021,
}


 --- Starting Epoch 23 of 30 ---

Training:
TrainingLogEntry(
    epoch=22,
    training_loss=3.7278044007041236,
    loss_statistics={'crossent_loss_sum': 0.02074538306756453, 'center_loss_sum': 0.0004353238040979384},
    training_accuracy=0.05032467469573021,
    validation_equal_error_rate=0.05032467469573021,
}


 --- Starting Epoch 23 of 30 ---

Training:


100%|██████████| 77/77 [01:14<00:00,  1.03it/s]



Average Loss: 3.6671009063720703
Multiclass accuracy: 0.048701297491788864
TrainingLogEntry(
    epoch=23,
    training_loss=3.6671009063720703,
    loss_statistics={'crossent_loss_sum': 0.01957340273461323, 'center_loss_sum': 0.00035649344886698876},
    training_accuracy=0.048701297491788864,
    validation_equal_error_rate=0.048701297491788864,
}


 --- Starting Epoch 24 of 30 ---

Training:
TrainingLogEntry(
    epoch=23,
    training_loss=3.6671009063720703,
    loss_statistics={'crossent_loss_sum': 0.01957340273461323, 'center_loss_sum': 0.00035649344886698876},
    training_accuracy=0.048701297491788864,
    validation_equal_error_rate=0.048701297491788864,
}


 --- Starting Epoch 24 of 30 ---

Training:


100%|██████████| 77/77 [01:15<00:00,  1.02it/s]



Average Loss: 3.6331436045758134
Multiclass accuracy: 0.05032467469573021
TrainingLogEntry(
    epoch=24,
    training_loss=3.6331436045758134,
    loss_statistics={'crossent_loss_sum': 0.018619534757101174, 'center_loss_sum': 0.00030308818692559156},
    training_accuracy=0.05032467469573021,
    validation_equal_error_rate=0.05032467469573021,
}


 --- Starting Epoch 25 of 30 ---

Training:
TrainingLogEntry(
    epoch=24,
    training_loss=3.6331436045758134,
    loss_statistics={'crossent_loss_sum': 0.018619534757101174, 'center_loss_sum': 0.00030308818692559156},
    training_accuracy=0.05032467469573021,
    validation_equal_error_rate=0.05032467469573021,
}


 --- Starting Epoch 25 of 30 ---

Training:


100%|██████████| 77/77 [01:14<00:00,  1.04it/s]



Average Loss: 3.609228988746544
Multiclass accuracy: 0.0535714253783226
TrainingLogEntry(
    epoch=25,
    training_loss=3.609228988746544,
    loss_statistics={'crossent_loss_sum': 0.017789731273403414, 'center_loss_sum': 0.0002564136616208337},
    training_accuracy=0.0535714253783226,
    validation_equal_error_rate=0.0535714253783226,
}


 --- Starting Epoch 26 of 30 ---

Training:
TrainingLogEntry(
    epoch=25,
    training_loss=3.609228988746544,
    loss_statistics={'crossent_loss_sum': 0.017789731273403414, 'center_loss_sum': 0.0002564136616208337},
    training_accuracy=0.0535714253783226,
    validation_equal_error_rate=0.0535714253783226,
}


 --- Starting Epoch 26 of 30 ---

Training:


100%|██████████| 77/77 [01:13<00:00,  1.05it/s]



Average Loss: 3.652727845427278
Multiclass accuracy: 0.04220778867602348
TrainingLogEntry(
    epoch=26,
    training_loss=3.652727845427278,
    loss_statistics={'crossent_loss_sum': 0.01734536777545403, 'center_loss_sum': 0.00021582385259945612},
    training_accuracy=0.04220778867602348,
    validation_equal_error_rate=0.04220778867602348,
}


 --- Starting Epoch 27 of 30 ---

Training:
TrainingLogEntry(
    epoch=26,
    training_loss=3.652727845427278,
    loss_statistics={'crossent_loss_sum': 0.01734536777545403, 'center_loss_sum': 0.00021582385259945612},
    training_accuracy=0.04220778867602348,
    validation_equal_error_rate=0.04220778867602348,
}


 --- Starting Epoch 27 of 30 ---

Training:


100%|██████████| 77/77 [01:13<00:00,  1.05it/s]



Average Loss: 3.6056313576636376
Multiclass accuracy: 0.04870130121707916
TrainingLogEntry(
    epoch=27,
    training_loss=3.6056313576636376,
    loss_statistics={'crossent_loss_sum': 0.016510419171265882, 'center_loss_sum': 0.00018231865191521297},
    training_accuracy=0.04870130121707916,
    validation_equal_error_rate=0.04870130121707916,
}


 --- Starting Epoch 28 of 30 ---

Training:
TrainingLogEntry(
    epoch=27,
    training_loss=3.6056313576636376,
    loss_statistics={'crossent_loss_sum': 0.016510419171265882, 'center_loss_sum': 0.00018231865191521297},
    training_accuracy=0.04870130121707916,
    validation_equal_error_rate=0.04870130121707916,
}


 --- Starting Epoch 28 of 30 ---

Training:


100%|██████████| 77/77 [01:19<00:00,  1.03s/it]



Average Loss: 3.608847066953585
Multiclass accuracy: 0.034090906381607056
TrainingLogEntry(
    epoch=28,
    training_loss=3.608847066953585,
    loss_statistics={'crossent_loss_sum': 0.015952076368960025, 'center_loss_sum': 0.00015884808405926227},
    training_accuracy=0.034090906381607056,
    validation_equal_error_rate=0.034090906381607056,
}


 --- Starting Epoch 29 of 30 ---

Training:
TrainingLogEntry(
    epoch=28,
    training_loss=3.608847066953585,
    loss_statistics={'crossent_loss_sum': 0.015952076368960025, 'center_loss_sum': 0.00015884808405926227},
    training_accuracy=0.034090906381607056,
    validation_equal_error_rate=0.034090906381607056,
}


 --- Starting Epoch 29 of 30 ---

Training:


100%|██████████| 77/77 [01:13<00:00,  1.05it/s]



Average Loss: 3.5670240947178433
Multiclass accuracy: 0.04220779240131378
TrainingLogEntry(
    epoch=29,
    training_loss=3.5670240947178433,
    loss_statistics={'crossent_loss_sum': 0.015237351116235087, 'center_loss_sum': 0.00013775268900744113},
    training_accuracy=0.04220779240131378,
    validation_equal_error_rate=0.04220779240131378,
}


 --- Starting Epoch 30 of 30 ---

Training:
TrainingLogEntry(
    epoch=29,
    training_loss=3.5670240947178433,
    loss_statistics={'crossent_loss_sum': 0.015237351116235087, 'center_loss_sum': 0.00013775268900744113},
    training_accuracy=0.04220779240131378,
    validation_equal_error_rate=0.04220779240131378,
}


 --- Starting Epoch 30 of 30 ---

Training:


100%|██████████| 77/77 [01:14<00:00,  1.03it/s]



Average Loss: 3.558208886679117
Multiclass accuracy: 0.04707792028784752
TrainingLogEntry(
    epoch=30,
    training_loss=3.558208886679117,
    loss_statistics={'crossent_loss_sum': 0.014708289000895117, 'center_loss_sum': 0.00011758134323394014},
    training_accuracy=0.04707792028784752,
    validation_equal_error_rate=0.04707792028784752,
}

Stage 2 Complete! Final model saved to: c:\Users\koechian\Documents\Projects\fixed-length-fingerprint-extractors\notebooks\trained_models\stage2_crossmatch

Training Summary:
  Stage 1: Trained on 6000 SOCOFing samples
  Stage 2: Fine-tuned on 616 Crossmatch samples
  Final model location: c:\Users\koechian\Documents\Projects\fixed-length-fingerprint-extractors\notebooks\trained_models\stage2_crossmatch/best_model.pyt

Next steps:
  1. Evaluate the model on held-out test data
  2. Extract embeddings for your fingerprint database
  3. Use for fingerprint matching/verification
TrainingLogEntry(
    epoch=30,
    training_loss=3.558208886679117,


## Optional: Quick Test Before Full Training

Before running the full 70-epoch training (40 + 30), you can test with fewer epochs to verify everything works:
- Stage 1: 5 epochs on SOCOFing
- Stage 2: 5 epochs on Crossmatch

This will take ~1-2 hours and help you catch any issues early.

In [6]:
# QUICK TEST: Run with reduced epochs (comment out if doing full training)
# Uncomment the lines below to run a quick test:

TEST_STAGE1_DIR = os.path.abspath("trained_models/test_stage1")
TEST_STAGE2_DIR = os.path.abspath("trained_models/test_stage2")
os.makedirs(TEST_STAGE1_DIR, exist_ok=True)
os.makedirs(TEST_STAGE2_DIR, exist_ok=True)

# Quick test Stage 1
test_extractor1 = get_DeepPrint_Tex(
    num_training_subjects=soco_loader.ids.num_subjects,
    num_texture_dims=512
)
test_extractor1.fit(
    fingerprints=soco_fingerprints,
    minutia_maps=None,
    labels=soco_labels,
    validation_fingerprints=None,
    validation_benchmark=None,
    num_epochs=5,  # Just 5 epochs for testing
    out_dir=TEST_STAGE1_DIR
)

# Quick test Stage 2
test_extractor2 = get_DeepPrint_Tex(
    num_training_subjects=crossmatch_loader.ids.num_subjects,
    num_texture_dims=512
)
test_extractor2.load_best_model(TEST_STAGE1_DIR)
test_extractor2.fit(
    fingerprints=crossmatch_fingerprints,
    minutia_maps=None,
    labels=crossmatch_labels,
    validation_fingerprints=None,
    validation_benchmark=None,
    num_epochs=5,  # Just 5 epochs for testing
    out_dir=TEST_STAGE2_DIR
)

print("\nQuick test complete! If successful, run the full training above.")

print("Quick test code is commented out. Uncomment to run a 5+5 epoch test first.")

Using device cuda:0
Loaded existing model from c:\Users\koechian\Documents\Projects\fixed-length-fingerprint-extractors\notebooks\trained_models\test_stage1\model.pyt
Loaded existing model from c:\Users\koechian\Documents\Projects\fixed-length-fingerprint-extractors\notebooks\trained_models\test_stage1\model.pyt


  checkpoint = torch.load(full_param_path, map_location=get_device())




 --- Starting Epoch 4 of 5 ---

Training:


  0%|          | 0/750 [00:33<?, ?it/s]



KeyboardInterrupt: 

## Important Notes

### Training Time Expectations
- **Stage 1 (40 epochs)**: ~7-10 hours on GPU, ~40-60 hours on CPU
- **Stage 2 (30 epochs)**: ~2.5-5 hours on GPU, ~15-25 hours on CPU
- **Total**: ~10-15 hours on GPU

### What to Monitor During Training

**Stage 1 Metrics:**
- `training_loss`: Should steadily decrease
- `training_accuracy`: Should increase (aim for >90%)
- `crossent_loss`: Classification loss (should decrease)
- `center_loss`: May not decrease much (only 1 impression per subject)

**Stage 2 Metrics:**
- `training_loss`: Should continue to decrease
- `training_accuracy`: Should remain high or increase
- `center_loss`: Should decrease significantly (multiple impressions per subject)
- This is where the model learns intra-class similarity!

### Troubleshooting

**If training loss increases:**
- Learning rate might be too high
- Reduce from 0.025 to 0.01 in `flx/setup/config.py`

**If accuracy plateaus early:**
- Normal for Stage 1 around 85-95%
- Should improve in Stage 2 with fine-tuning

**If you run out of memory:**
- Reduce batch size in the dataloader settings
- Use a smaller embedding size (e.g., 256 instead of 512)

### After Training

The trained model will be saved in:
- Stage 1: `trained_models/stage1_soco/best_model.pyt`
- Stage 2: `trained_models/stage2_crossmatch/best_model.pyt`

Use the Stage 2 model for production - it has both discriminative features AND intra-class similarity.