In [1]:
import time
import argparse
import os
import matplotlib
from tqdm import tqdm

matplotlib.use('Agg')
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler

from dataloader.dataloader_RGB import load_cisia_surf
# from data.dataloader_RGB_Depth import load_cisia_surf
# from dataloader.dataloader_RGB_Depth_IR import load_cisia_surf

from models.model_RGB import Model
# from models.model_RGB_Depth import Model
# from models.model_RGB_Depth_IR import Model

from loger import Logger
from evalution import eval_model
from centerloss import CenterLoss
from utils import plot_roc_curve, plot_eval_metric

In [2]:
time_object = time.localtime(time.time())
time_string = time.strftime('%Y-%m-%d_%I:%M_%p', time_object)
use_cuda = True if torch.cuda.is_available() else False

parser = argparse.ArgumentParser(description='face anti-spoofing test')
parser.add_argument('--batch-size', default='64', type=int, help='train batch size')
parser.add_argument('--test-size', default='64', type=int, help='test batch size')
parser.add_argument('--save-path', default='./logs/RGB/Test/', type=str, help='logs save path')
parser.add_argument('--message', default='', type=str, help='pretrained model checkpoint')
parser.add_argument('--mode', default=1, type=int, help='dataset protocol_mode')
args = parser.parse_known_args()[0]
print(args)

save_path = args.save_path + f'{time_string}' + '_' + f'{args.message}'

Namespace(batch_size=64, message='', mode=1, save_path='./logs/RGB/Test/', test_size=64)


In [3]:
if not os.path.exists(save_path):
    os.makedirs(save_path)

# logger = Logger(f'{save_path}/logs.logs')
# logger.Print(args.message)

_, test_data = load_cisia_surf(train_size=args.batch_size,test_size=args.test_size, mode=args.mode)

eval_history = []
eval_loss = []
eval_score = []
test_score = []


train dataset count : 5580
test dataset count : 2520
train loader count : 88
test loader count : 40


In [4]:
def val(epoch=0, data_set=test_data,flag=1, weight_dir=''):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = Model(pretrained=False, num_classes=2)

    if use_cuda:
        model = torch.nn.DataParallel(model,device_ids=list(range(torch.cuda.device_count()))) 
        model = model.cuda()
        model.load_state_dict(
            torch.load(weight_dir))

    y_true = []
    y_pred = []
    y_prob = []
    
    model.eval()

    total_batch = 0
    
    with torch.no_grad():
        pbar = tqdm(enumerate(data_set, 1))
        for batch, data in pbar :

            rgb_img = data[0]
            labels = data[1]

            if use_cuda:
                rgb_img = rgb_img.cuda()
                labels = labels.cuda()

            # 예측 오류 계산
            outputs, features = model(rgb_img)
            _, pred_outputs = torch.max(outputs, 1)
            prob_outputs = F.softmax(outputs,1)[:,1]

            y_true.extend(labels.data.cpu().numpy())
            y_pred.extend(pred_outputs.data.cpu().numpy())
            y_prob.extend(prob_outputs.data.cpu().numpy())


            total_batch = batch


    eval_result, score, acer = eval_model(y_true, y_pred, y_prob)
    eval_history.append(eval_result)

    plot_roc_curve(save_path, epoch, y_true, y_prob)
    plot_eval_metric(save_path, epoch, y_true, y_pred)

    if flag == 0 :
        eval_score.append(score)
        # avg_loss = total_loss/total_batch
        # eval_loss.append(avg_loss)
        # message = f'|eval|loss:{avg_loss:.6f}|'
        # logger.Print(message)
    else:
        test_score.append(score)


In [5]:
global_dir = '/mnt/nas3/yrkim/liveness_lidar_project/GC_project/code/models/output/RGB/checkpoint_v'+ str(args.mode) +'_0/global_min_acer_model.pth'
print("--global_dir start--")
val(epoch="_global_min", weight_dir=global_dir)

--global_dir start--
auc_value: 0.6674052028218693
