In [None]:
import os
import numpy as np
import torch
import torchvision.transforms as transforms

from easydict import EasyDict as edict

from recognition.arcface_torch.backbones import get_model
from recognition.arcface_torch.configs.aihub_r50_onegpu import config as cfg
from utils.utils_config import return_aihub_dataloader

from datasets.AIHubDataset import AIHubDataset
from validate_aihub import validate_aihub

In [None]:
import wandb

run = wandb.init(
    entity="jongphago",
    project="arcface-evaluation-with-aihub",
)

In [None]:
aihub_transforms = transforms.Compose(
    [
        transforms.Resize(size=(cfg.data.image_size, cfg.data.image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=cfg.data.aihub_mean, std=cfg.data.aihub_std),
    ]
)

In [None]:
test = cfg.data.test
test.task = "family"
aihub_dataloader = return_aihub_dataloader(cfg.data.test, aihub_transforms)

In [None]:
# Pre-trained
target_weight_path = f"/home/jupyter/face/utils/model/arcface/{cfg.network}/backbone.pth"

# AIHub fine-tune
target_weight_path = "/home/jongphago/insightface/work_dirs/aihub_r50_onegpu/checkpoint_gpu_0.pt"

In [None]:
backbone = get_model(
    cfg.network, dropout=cfg.dropout, fp16=cfg.fp16, num_features=cfg.embedding_size
)

In [None]:
state_dict = torch.load(target_weight_path)
if "state_dict_backbone" in state_dict:
    model_weights = state_dict['state_dict_backbone']
else:
    model_weights = state_dict

In [None]:
backbone.load_state_dict(model_weights)
backbone.cuda().eval()

In [None]:
out = validate_aihub(backbone, aihub_dataloader, cfg.network, 0)

In [None]:
best_distances, (accuracy, precision, recall, roc_auc, tar, far) = out

In [None]:
wandb.log(
    {
        "accuracy": np.mean(accuracy),
        "precision": np.mean(precision),
        "recall": np.mean(recall),
        "best_distances": np.mean(best_distances),
    },
)

In [None]:
ckp_name = os.path.basename(target_weight_path)
dir_name = os.path.dirname(target_weight_path)
config_name = os.path.basename(dir_name) 

In [None]:
api = wandb.Api()

In [None]:
config = {
    'task': test.task,
    'split': test.split,
    'checkpoint': ckp_name,
    'config': config_name,
}

In [None]:
# Access attributes directly from the run object or from the W&B App 
username = wandb.run.entity
project = wandb.run.project
run_id = wandb.run.id
print(f"{username}/{project}/{run_id}")

run = api.run(f"{username}/{project}/{run_id}")
run.config = config
run.update()

## Validate by task

In [None]:
for task in ["family", "age", "individuals"]:
    test = cfg.data.test
    test.task = task
    aihub_dataloader = return_aihub_dataloader(cfg.data.test, aihub_transforms)
    out = validate_aihub(backbone, aihub_dataloader, cfg.network, 0, task=f"{task}_")
    best_distances, (accuracy, precision, recall, roc_auc, tar, far) = out