In [1]:
3==3

True

In [2]:
import time
import argparse
import os
from tqdm.notebook import tqdm as tqdm_notebook
# import matplotlib

# matplotlib.use('Agg')
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler

# from data.dataloader_RGB import load_cisia_surf
from dataloader.dataloader_RGB_Depth import load_cisia_surf
# from dataloader.dataloader_RGB_Depth_IR import load_cisia_surf

# from models.model_RGB import Model
from models.model_RGB_Depth import Model
# from models.model_RGB_Depth_IR import Model

# from loger import Logger
from evalution import eval_model
from centerloss import CenterLoss
from utils import plot_roc_curve, plot_eval_metric


In [3]:
time_object = time.localtime(time.time())
time_string = time.strftime('%Y-%m-%d_%I:%M_%p', time_object)
use_cuda = True if torch.cuda.is_available() else False

parser = argparse.ArgumentParser(description='face anti-spoofing test')
parser.add_argument('--batch-size', default='64', type=int, help='train batch size')
parser.add_argument('--test-size', default='64', type=int, help='test batch size')
parser.add_argument('--save-path', default='./logs/RGB_Depth/Test/', type=str, help='logs save path')
parser.add_argument('--message', default='test', type=str, help='pretrained model checkpoint')
parser.add_argument('--mode', default=1, type=int, help='dataset protocol_mode')
args = parser.parse_known_args()[0]

save_path = args.save_path + f'{time_string}' + '_' + f'{args.message}'

if not os.path.exists(save_path):
    os.makedirs(save_path)
# logger = Logger(f'{save_path}/logs.logs')
# logger.Print(args.message)

_, test_data = load_cisia_surf(train_size=args.batch_size,test_size=args.test_size, mode=args.mode)

eval_history = []
eval_loss = []
eval_score = []
test_score = []

train dataset count : 5580
test dataset count : 2520
train loader count : 88
test loader count : 40


In [4]:
def val(epoch=0, data_set=test_data,flag=1, weight_dir=''):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = Model(pretrained=False, num_classes=2)

    if use_cuda:
        model = torch.nn.DataParallel(model,device_ids=list(range(torch.cuda.device_count()))) 
        model = model.to(device)
        model.load_state_dict(
            torch.load(weight_dir))

    y_true = []
    y_pred = []
    y_prob = []
    
    model.eval()

    total_batch = 0

#     with open(save_path+fprob_{epoch}.txt', 'w') as fb:
    with torch.no_grad():
        pbar = tqdm_notebook(enumerate(data_set, 1),total=40)
        for batch, data in pbar :

            rgb_img = data[0]
            depth_img = data[1]
            labels = data[2]

            if use_cuda:
                rgb_img = rgb_img.to(device)
                depth_img = depth_img.to(device)
                labels = labels.to(device)

            # 예측 오류 계산
            outputs, features = model(rgb_img, depth_img)
            _, pred_outputs = torch.max(outputs, 1)
            prob_outputs = F.softmax(outputs,1)[:,1]

            y_true.extend(labels.data.cpu().numpy())
            y_pred.extend(pred_outputs.data.cpu().numpy())
            y_prob.extend(prob_outputs.data.cpu().numpy())

            total_batch = batch

#         fb.close()

    eval_result, score, acer = eval_model(y_true, y_pred, y_prob)
    eval_history.append(eval_result)

    plot_roc_curve(save_path, epoch, y_true, y_prob)
    plot_eval_metric(save_path, epoch, y_true, y_pred)

    if flag == 0 :
        eval_score.append(score)
        # avg_loss = total_loss/total_batch
        # eval_loss.append(avg_loss)
        # message = f'|eval|loss:{avg_loss:.6f}|'
        # logger.Print(message)
    else:
        test_score.append(score)

    with open(save_path+f'/val_{epoch}.txt', 'w') as f:
        for i in range(len(y_true)):
            message = f'{y_prob[i]:.6f} {y_pred[i]} {y_true[i]}'
            f.write(message)
            f.write('\n')
        f.close()
        
    return y_true,y_pred

In [5]:
global_dir = '/mnt/nas3/yrkim/liveness_lidar_project/GC_project/code/models/output/RGB_Depth_IR/checkpoint_v' + str(args.mode) + '_0/' +'Cycle_1_min_acer_model.pth'
global_dir

'/mnt/nas3/yrkim/liveness_lidar_project/GC_project/code/models/output/RGB_Depth_IR/checkpoint_v1_0/Cycle_1_min_acer_model.pth'

In [None]:
global_dir = '/mnt/nas3/yrkim/liveness_lidar_project/GC_project/code/models/output/RGB_Depth/checkpoint_v1_0/Cycle_2_min_acer_model.pth'
# global_dir = '/mnt/nas3/yrkim/liveness_lidar_project/GC_project/code/models/output/RGB_Depth/checkpoint_v1_0/std/Cycle_2_min_acer_model.pth'


print("--global_dir start--")
y_true,y_pred = val(epoch=1, weight_dir=global_dir)
print("--global_dir end--")


--global_dir start--


  0%|          | 0/40 [00:00<?, ?it/s]

In [None]:
3/0

In [None]:
from collections import Counter
print(Counter(y_true))
print(Counter(y_pred))

In [None]:
from sklearn.metrics import confusion_matrix
confusion_matrix(y_true, y_pred)

In [None]:
from sklearn.metrics import f1_score
f1_score(y_true, y_pred)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Model(pretrained=False, num_classes=2)

if use_cuda:
    model = torch.nn.DataParallel(model,device_ids=list(range(torch.cuda.device_count()))) 
    model = model.to(device)
    model.load_state_dict(
        torch.load(global_dir))
model.eval()
print()

In [None]:
pbar = tqdm_notebook(enumerate(test_data, 1),total=40)
for batch, data in pbar :

    rgb_img = data[0]
    depth_img = data[1]
    labels = data[2]
    
    break

In [None]:
metadata_root = '/mnt/nas3/yrkim/liveness_lidar_project/GC_project/code/metadata/'
data_root = '/mnt/nas3/yrkim/liveness_lidar_project/GC_project/data/'
datatxt='MakeTextFileCode_RGB_Depth/test_data_list.txt'
img_paths = []
rgb_paths = []
depth_paths = []
labels = []

lines_in_txt = open(os.path.join(metadata_root,datatxt),'r')

for line in lines_in_txt:
    line = line.rstrip() 
    split_str = line.split()
    rgb_path = os.path.join(data_root,split_str[0])
    depth_path = os.path.join(data_root,split_str[1])
    label = split_str[2] 
    rgb_paths.append(rgb_path)
    depth_paths.append(depth_path)

    labels.append(label)

In [None]:
rgb_paths_mask = [rgb_paths[int(i)] for i in range(len(labels)) if int(labels[i])==0]
rgb_paths_real = [rgb_paths[int(i)] for i in range(len(labels)) if int(labels[i])==1]

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import cv2
for k in range(30) :
    img = cv2.imread(rgb_paths_mask[k])
    print('Mask')
    display(Image.fromarray(img))
    img = cv2.imread(rgb_paths_real[k])
    print('real')
    display(Image.fromarray(img))


In [None]:
img = cv2.imread(rgb_paths_real[0])
Image.fromarray(img)