In [48]:
from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score
import tensorflow.keras as keras
import tensorflow as tf
from Data.DataPipe import TestDataGenerator
from Network.Det_RN50 import Det_RN50
import pandas as pd
import numpy as np
import argparse
import cv2
from tqdm import tqdm

In [None]:
# Put the path to models you want to test in this list
models = [ "./train_model/model5/model5-cp-8.ckpt" ] 

# Put the names of test sets you want to test your model on in this list
test_set = ['biggan','crn', "cyclegan", "deepfake", "gaugan", "imle", "progan", "san", "seeingdark", 
            "stargan", "stylegan", "stylegan2", "whichfaceisreal"] 

for ckpt in models[:]:
    model = Det_RN50()
    model.load_weights(ckpt)
    for test_name in test_set[:]:
        print('\n\nModel Loaded:{}'.format(ckpt))
        print('\nTesting on:{}\n'.format(test_name))
        img_idx = "Img_index/test/" + test_name + "_test.csv"
        root = "../CNN_synth_testset/" + test_name + "/"
        img_idx = pd.read_csv(img_idx)
        ANS = pd.DataFrame(columns=['file','label'])
        len_size = len(img_idx)
        with tqdm(total=len_size) as pbar:
            for i,row in img_idx.iterrows():
                img = cv2.imread(root+row['file'])
                img = img/255.0
                h,w,c = img.shape
                img = img.reshape(1,h,w,3)
                pred = model.predict(img)[0][0]
                pred = tf.sigmoid(pred).numpy()
                ANS = ANS.append({'file':row['file'],'label':pred},ignore_index=True)
                pbar.update(1)
        ANS.to_csv('test_results/'+ckpt.split('/')[2]+'/'+test_name+'.csv',columns=['file','label'],index=False)
            

### Evaluation Metric Calculation

In [50]:
test_set = ["progan","stylegan",'biggan', "cyclegan","stargan", "gaugan",'crn', "imle", "seeingdark","san",  "deepfake", 
             "stylegan2",'whichfaceisreal']
models = ['model-2c','model-8c','model1','model2','model3','model4','model5']

result_table = pd.DataFrame(columns=["progan","stylegan",'biggan', "cyclegan","stargan", "gaugan",'crn', "imle", "seeingdark","san",  "deepfake", 
             "stylegan2",'whichfaceisreal'])

for model in models[:]:
    result = {"progan":0,"stylegan":0,'biggan':0, "cyclegan":0,"stargan":0, "gaugan":0,'crn':0, "imle":0, "seeingdark":0,"san":0,"deepfake":0, 
             "stylegan2":0,'whichfaceisreal':0}
    for dataset in test_set[:]:      
        model_name = model
        dataset = dataset
        gt_path = 'Img_index/test/' + dataset + '_test.csv'
        pred_path = 'test_results/'+model_name+'/' + dataset + '.csv'

        y_true = np.array(pd.read_csv(gt_path)['label'])
        y_pred = np.array(pd.read_csv(pred_path)['label'])

        r_acc = accuracy_score(y_true[y_true==0], y_pred[y_true==0] > 0.5)
        f_acc = accuracy_score(y_true[y_true==1], y_pred[y_true==1] > 0.5)
        acc = accuracy_score(y_true, y_pred > 0.5)
        ap = average_precision_score(y_true, y_pred)
#         print('\n\nmodel: {}'.format(model_name))
#         print('Dataset: {}'.format(dataset))
#         print('\nAccuracy: {}'.format(acc))
#         print('AP: {}'.format(np.round(ap*100,2)))
#         print('fake_recall: {}'.format(f_acc))
#         print('real_recall: {}'.format(r_acc))
        result[dataset] = np.round(100*ap,1)
    
    series = pd.Series(result,name=model)
    result_table = result_table.append(series)


In [51]:
result_table

Unnamed: 0,progan,stylegan,biggan,cyclegan,stargan,gaugan,crn,imle,seeingdark,san,deepfake,stylegan2,whichfaceisreal
model-2c,97.6,82.3,66.4,82.8,85.5,88.1,96.2,97.9,79.2,58.1,61.7,81.1,99.1
model-8c,100.0,95.2,68.3,87.1,100.0,63.4,98.4,95.3,97.5,87.1,94.8,95.4,94.0
model1,100.0,96.6,72.1,76.0,100.0,59.3,97.8,94.1,98.9,91.1,96.4,99.8,97.5
model2,100.0,96.5,76.7,84.4,100.0,65.1,92.6,93.8,99.6,49.1,67.3,99.4,98.9
model3,100.0,96.2,74.7,86.1,95.9,90.3,99.4,99.6,85.0,69.3,88.5,94.6,95.3
model4,100.0,97.7,78.1,91.9,98.2,92.4,96.7,98.1,99.2,56.2,65.9,98.5,99.7
model5,99.6,93.6,72.1,81.7,99.2,73.1,87.7,81.5,98.8,54.5,68.0,97.4,97.5
