### Load packages

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from evaluate_utils import get_val_data, get_val_pair, evaluate
from get_model import get_model

### Define necassary variables and types

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
data_path = "data/faces_emore"

In [3]:
class FRDataset(Dataset):
    def __init__(self, inputs, targets):
        self.inputs = inputs
        self.targets = targets.repeat(2)
        
    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        input_data = self.inputs[idx]
        target_data = self.targets[idx]

        return input_data, target_data

def crop_memmap(img_memmap, img_shape):
    target_height = img_shape[0]
    start_height = (112 - target_height) // 2
    end_height = start_height + target_height

    target_width = img_shape[1]
    start_width = (112 - target_width) // 2
    end_width = start_width + target_width
    cropped_img_memmap = img_memmap[:, :, start_height:end_height, start_width:end_width]
    print(cropped_img_memmap.shape)
    return cropped_img_memmap

### Load data

**Option #1:**

In [4]:
val_data = get_val_data("data/faces_emore")
agedb_30, cfp_fp, lfw, agedb_30_issame, cfp_fp_issame, lfw_issame, cplfw, cplfw_issame, calfw, calfw_issame = val_data

loading validation data memfile
loading validation data memfile
loading validation data memfile
loading validation data memfile
loading validation data memfile


**Option #2:**

In [5]:
# inputs, targets
# dataname_list = ['agedb_30', 'cfp_fp', 'lfw', 'cplfw', 'calfw']
agedb_30, agedb_30_issame = get_val_pair(data_path, 'agedb_30')
cfp_fp, cfp_fp_issame = get_val_pair(data_path, 'cfp_fp')
lfw, lfw_issame = get_val_pair(data_path, 'lfw')
cplfw, cplfw_issame = get_val_pair(data_path, 'cplfw')
calfw, calfw_issame = get_val_pair(data_path, 'calfw')

loading validation data memfile
loading validation data memfile
loading validation data memfile
loading validation data memfile
loading validation data memfile


**Build dataset**

backbone architecture 찾을수 있으면 계속 추가 가능함.

In [6]:
model, img_shape = get_model("AdaFace")
model, img_shape = get_model("ArcFace")
model, img_shape = get_model("CosFace")

Load existing checkpoint: ckpts/model_ir_se50.pth
Load existing checkpoint: ckpts/cosface.pth


In [7]:
if img_shape != (112, 112):
    agedb_30 = crop_memmap(agedb_30, img_shape)
    cfp_fp = crop_memmap(cfp_fp, img_shape)
    lfw = crop_memmap(lfw, img_shape)
    cplfw = crop_memmap(cplfw, img_shape)
    calfw = crop_memmap(calfw, img_shape)
    
dataset_agedb_30 = FRDataset(agedb_30, agedb_30_issame)
dataset_cfp_fp = FRDataset(cfp_fp, cfp_fp_issame)
dataset_lfw = FRDataset(lfw, lfw_issame)
dataset_cplfw = FRDataset(cplfw, cplfw_issame)
dataset_calfw = FRDataset(calfw, calfw_issame)

(12000, 3, 112, 96)
(14000, 3, 112, 96)
(12000, 3, 112, 96)
(12000, 3, 112, 96)
(12000, 3, 112, 96)


In [8]:
batch_size = 128
test_dataloader = DataLoader(dataset_agedb_30,
                             batch_size=128,
                             num_workers=8,
                             shuffle=False,)

In [9]:
acc = 0
for idx, (inputs, targets) in enumerate(test_dataloader):
    features = model.forward(inputs.to(device), False)
    tpr, fpr, accuracy, best_thresholds = evaluate(features.detach().cpu().numpy(), targets[0::2])
    if idx > 0 and (idx) % 10 == 0:
        print(f"Iter [{idx}/{len(test_dataloader)}] - TAR@FAR=0.01%: {acc/(idx+1):.4f}")
    # acc, best_threshold = accuracy.mean(), best_thresholds.mean()
    acc += accuracy.mean()
print(f"Result TAR@FAR=0.01% - {acc/(idx+1)}")

Iter [10/94] - TAR@FAR=0.01%: 0.8162
Iter [20/94] - TAR@FAR=0.01%: 0.8840
Iter [30/94] - TAR@FAR=0.01%: 0.8935
Iter [40/94] - TAR@FAR=0.01%: 0.8952
Iter [50/94] - TAR@FAR=0.01%: 0.9082
Iter [60/94] - TAR@FAR=0.01%: 0.9101
Iter [70/94] - TAR@FAR=0.01%: 0.9169
Iter [80/94] - TAR@FAR=0.01%: 0.9195
Iter [90/94] - TAR@FAR=0.01%: 0.9222
Result TAR@FAR=0.01% - 0.9353343465045592
