# Setup

In [4]:

import torch
import numpy as np
from imgaug import augmenters as iaa
import torchvision
import PIL.Image as Image
import cv2
import math

from blindMatch import *

Congratulations on [BlindAuth] file import!


# Data Generator


In [5]:
class FingerprintDataset(torch.utils.data.Dataset):
    def __init__(self, meta, degree, size=224):
        with open(meta, 'r') as fin:
            self.x = [x for x in fin]

        self.transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((size, size)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.RandomRotation(degree),
            torchvision.transforms.Normalize((0.5, 0.5, 0.5), (1.0, 1.0, 1.0))
        ])

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        entry = self.x[idx].split('|')
        # Write down your data path
        root = '../data/PolyU_Dataset/'

        file = Image.fromarray(cv2.imread(root + entry[0]))
        file = self.transform(file)

        n_id = entry[1]

        return file, np.array(n_id, dtype=np.int64)


In [6]:
batch = 512
train_data = 'center_p'

trainset = FingerprintDataset('../dataset/Poly/polyu_meta.txt', 0)
testset = FingerprintDataset('../dataset/Poly/polyu_meta_eval.txt', 0)


dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch, shuffle=True, drop_last=True, num_workers=5)
dataloader_val = torch.utils.data.DataLoader(testset, batch_size=batch, shuffle=False, drop_last=False, num_workers=5)

# Trainining the model

In [None]:
c = 8
num_ids = 400
device = 'cuda'

# network selecting - ST/Resnet18
net = FingerSTNNet(16*c) # net = FingerNet(16*c)
net = net.to(device)
print("Feature size", 16*c, "STNet")

fc = FingerCentroids(num_ids, 16*c)
fc = fc.to(device)


loss_arcface = ArcFace(m=0.2)

optim = torch.optim.Adam(net.parameters(), lr=0.001)
degree = 20
best_acc = 0
best_loss = 100
    


for epoch in range(100):
    net = net.train()
    loss_accum = []

    for idx, (img, lbl) in enumerate(dataloader):
        img = img.to(device)
        lbl = lbl.to(device)

        feat = net(img)
        logit = fc(feat)
        logit = loss_arcface(logit, lbl)
        
        loss = torch.nn.functional.cross_entropy(logit, lbl)
        loss_accum.append(loss.item())

        loss.backward()
        optim.step()
        optim.zero_grad()
        
        if idx % 2 == 0:
            print('.', end='')

    loss_accum = torch.tensor(loss_accum)
    print(f'epoch: {epoch} | loss: {loss_accum.mean().item():.04f}')

    net = net.eval()
    
    eval_results = []
    lbls = []

    with torch.no_grad():
        for idx, (img, lbl) in enumerate(dataloader_val):
            img = img.to(device)
            feat = net(img)
         
            eval_results.append(feat.to('cpu'))
            lbls.append(lbl)
           

        eval_results = torch.cat(eval_results)
        mat_similarity = eval_results.matmul(eval_results.T)

        lbls = torch.cat(lbls)
        lbls = lbls.view(-1, lbls.size(0)) == lbls.view(lbls.size(0), -1)

        accuracy = []

        total_comp = torch.ones_like(mat_similarity).triu(1)
        total_comp = total_comp.sum().item()

        thresh_best_acc = 0
        for threshold in [0.0, 0.2, 0.4, 0.5, 0.6, 0.7, 0,8, 0.85, 0.9, 0.95]:
            threshed = mat_similarity > threshold
            
            #remove diagonal
            correct = (threshed == lbls).triu(1).sum()

            accuracy.append(correct / total_comp)
            if accuracy[-1] > thresh_best_acc:
                thresh_best_acc = accuracy[-1]
        
        print(f'Accuracy: {" | ".join(f"{acc:.03f}" for acc in accuracy)}')
        
        if best_acc < thresh_best_acc:
            best_acc = thresh_best_acc
            best_loss = loss_accum.mean()
            best_epoch = epoch
            best_sim = mat_similarity
            

    print('=' * 20)

Feature size 128 STNet


: 

In [None]:
print("MOST", best_loss, best_acc, best_epoch)

accuracy = []
for threshold in [0.0, 0.2, 0.4, 0.5, 0.6, 0.7, 0,8, 0.85, 0.9, 0.95]:
    # threshed = best_sim > threshold8
    threshed = best_sim > threshold
    
    #remove diagonal
    correct = (threshed == lbls).triu(1).sum()

    accuracy.append(correct / total_comp)
    if accuracy[-1] > thresh_best_acc:
        thresh_best_acc = accuracy[-1]
        
print("MOST", best_loss, best_acc, thresh_best_acc)
print(f'Accuracy: {" | ".join(f"{acc:.05f}" for acc in accuracy) } \n Loss: {"".join(f"{loss_accum.mean():.05f}")}')
print("".join(f"{accuracy[-1]:.05f}"), "".join(f"{loss_accum[-1]:.05f}"))
print( 16*c, net)