In [None]:
import argparse
import os
from pathlib import Path

import albumentations as A
import pandas as pd
import pytorch_lightning as pl
import torch
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

from classes import CustomDataset
from utils import *
import ast

In [None]:
DATASET_PATHS = Path('test')
NUM_WORKERS = int(os.cpu_count() / 2)
SEED = 13


In [None]:
test_df = pd.read_csv(DATASET_PATHS / 'test.csv')
test_data = []
test_df

In [None]:
for x in test_df.itertuples():
    test_data.append(x)
print(f'Test total len: {len(test_data)}', sep='\n')
pl.seed_everything(SEED)
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False
device = torch.device(
    "cuda") if torch.cuda.is_available() else torch.device("cpu")


In [None]:
def load_eval_module(checkpoint_path: str, device: torch.device) -> FaceNet:
    module = FaceNet.load_from_checkpoint(checkpoint_path)
    module.to(device)
    module.eval()

    return module

path = '/4tb/nikonov/face_spoofing/checkpoints/lightning_logs/version_13/checkpoints/epoch=5-step=396.ckpt'
test_model = load_eval_module(path, device = 'cuda:0')

In [None]:
test_transform = A.Compose([
    A.Resize(512, 512),
    ToTensorV2(),
])

test_dataset = CustomDataset(test_data, test_transform)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(20, 10))
plt.rcParams['axes.facecolor'] = 'white'
k = 0
for x in test_data:
    img = cv2.cvtColor(cv2.imread(x.image), cv2.COLOR_BGR2RGB)
    box = ast.literal_eval(x.box)
    img_for_model = test_transform(image=img[box[1]:box[3], box[0]:box[2]])[
        'image'].float()
    img_for_model = (img_for_model[np.newaxis, ...]/255).to(device)
    predict = test_model(img_for_model).argmax(dim=-1)
    predict = predict.detach().cpu().numpy()
    if predict == 0:
        img = cv2.rectangle(img, (box[0], box[1]),
                            (box[2], box[3]), (0, 255, 0), 5)
    else:
        img = cv2.rectangle(img, (box[0], box[1]), (box[2], box[3]), 255, 5)
    # axes[k].axis('off')
    axes[k].imshow(img)
    x = img.shape[1]
    y = img.shape[0]
    axes[k].set_xlabel('no_spoof' if predict == 0 else 'spoof')
    k += 1
