In [1]:
import os
import time
import argparse
import numpy as np
from tqdm import tqdm
import seaborn as sns

import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from sklearn.metrics import roc_auc_score, average_precision_score

from model import dataloader
from model import create_dpnet_one

from utils import *

from sklearn.metrics import make_scorer, roc_curve
from scipy.optimize import brentq
from scipy.interpolate import interp1d

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(0)
np.random.seed(0)

def calculate_eer(y_true, y_score):
    fpr, tpr, thresholds = roc_curve(y_true, y_score, pos_label=1)
    eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    return eer

def calculate_tpr_10(y_true, y_score):
    fpr, tpr, thresholds = roc_curve(y_true, y_score, pos_label=1, drop_intermediate = False)
    idx = 0
    for _, i in enumerate(fpr):
        if i <= .1: idx = _
        else: break
    return tpr[idx]
    
class DPNet():
    def __init__(self, device, log_dir, args, train_loader, val_loader, test_loader):
        self.device = device
        self.log_dir = log_dir
        self.args = args
        
        self.train_loader = train_loader
        self.val_loader   = val_loader
        self.test_loader  = test_loader
        
        self.best_val_auc = 0
        self.counter = 0
        self.patience = 5
        
        self.build_model()
        
    def build_model(self):        
        self.model = create_dpnet_one()
        self.model = nn.DataParallel(self.model)
        self.criterion = nn.CrossEntropyLoss()

        if self.args.checkpoint:
            cp = torch.load(self.args.checkpoint)
            self.epoch = cp['epoch']
            self.model.load_state_dict(cp['state_dict'])

        self.model = self.model.to(self.device)  
        
    def test(self):
        self.validate_1epoch(test_mode=True)

    def validate_1epoch(self, test_mode = False):
        if test_mode:
            print('|--> [testing stage]')
        else:
            print('|--> Epoch:[{0}/{1}][validation stage]'.format(self.epoch+1, self.args.num_epochs))

        losses, top1 = AverageMeter(), AverageMeter()

        # Evaluate mode
        self.model.eval()
        self.dic_video_level_preds = {}
        
        start = time.time()
        with torch.no_grad():
            progress = tqdm(self.test_loader) if test_mode else tqdm(self.val_loader)
            for _, (video_names, inputs, labels) in enumerate(progress):
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)

                # Compute output
                batch_size = inputs.shape[0]             
                outputs, min_distances = self.model(inputs)

                # Accumulate video level prediction
                preds = outputs.data.cpu().numpy()
                for i in range(batch_size):
                    video_name = video_names[i]
                    if video_name not in self.dic_video_level_preds.keys():
                        self.dic_video_level_preds[video_name] = preds[i,:]
                    else:
                        self.dic_video_level_preds[video_name] += preds[i,:]
        
        # Calculate video level statistics
        video_top1, video_auc, video_loss, video_pauc_10, video_tar_10, video_eer = self.frame_2_video_level_accuracy()

        info = {'Epoch': [self.epoch],
                'Time':  [round(time.time()-start,3)],
                'Loss':  [round(video_loss,5)],
                'Acc':   [round(video_top1,4)],
                'AUC':   [round(video_auc,4)],
                'pAUC_10':    [round(video_pauc_10,4)],
                'TAR_10':    [round(video_tar_10,4)],
                'EER':   [round(video_eer,4)]}
        if test_mode:
            print(info)
        else:
            record_info(info, os.path.join(self.log_dir, 'test.csv'))
        return video_top1, video_auc, video_loss
                             
    def frame_2_video_level_accuracy(self):
        correct = 0
        video_level_preds = np.zeros((len(self.dic_video_level_preds),2))
        video_level_labels = np.zeros(len(self.dic_video_level_preds))
        
        for i, name in enumerate(sorted(self.dic_video_level_preds.keys())):
            preds = self.dic_video_level_preds[name]
            label = 1.0 if 'FAKE' in name else 0.0
                
            video_level_preds[i,:] = preds / 100
            video_level_labels[i] = label
            if np.argmax(preds) == (label):
                correct += 1
        if self.args.save_predictions:        
            np.save(open(f'predictions/{self.args.start_task}_{self.args.task}_labels_{self.args.stream}.npy','wb'), video_level_labels)
            np.save(open(f'predictions/{self.args.start_task}_{self.args.task}_preds_{self.args.stream}.npy','wb'), video_level_preds)

        video_level_labels = torch.from_numpy(video_level_labels).long()
        video_level_preds = torch.from_numpy(video_level_preds).float()
            
        top1 = accuracy(video_level_preds, video_level_labels, topk=(1,))
        loss = self.criterion(video_level_preds, video_level_labels)
                                 
        logits = nn.functional.softmax(video_level_preds, dim=1)[:, 1].numpy()
        auc = roc_auc_score(video_level_labels, logits)
        pauc_10 = roc_auc_score(video_level_labels, logits, max_fpr=0.1)
        tar_10 = calculate_tpr_10(video_level_labels, logits)
        eer = calculate_eer(video_level_labels, logits)
        
        return top1.item(), auc, loss.item(), pauc_10, tar_10, eer                         

In [2]:
for i in ['FF++']:
    for j in ['FF++', 'DFD', 'Celeb-DF', 'DeeperForensics']:
        print(f'{i} to {j}')

        class Args:
            gpu = '0,1'
            start_task = i
            task = j
            num_workers = 8
            num_epochs = 20
            batch_size = 32
            learning_rate = 2e-4
            stream = 'rgb'
            checkpoint = ''
            save_predictions = False
        args = Args()
        
        if args.stream =='rgb':
            args.checkpoint = f'record/c40/{i}/seed_2/best_val_checkpoint.pth'      
        elif args.stream == 'luminance':
            args.checkpoint = f'record/luminance/{i}/seed_1/best_val_checkpoint.pth'
        elif args.stream == 'sharpened':
            args.checkpoint = f'record/sharpened/{i}/seed_1/best_val_checkpoint.pth'

        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        log_dir = ''

        train_transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
        test_transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

        data_dir = '/meladyfs/newyork/loctrinh/DATASETS/'
        frame_count = {f'{j}': pd.read_csv(os.path.join(data_dir, f'{j}', 'video_stat.csv'), index_col=0)}
        train_df = pd.read_csv(os.path.join(data_dir, '{}/splits/{}_trainlist_01.csv'.format(j, args.task)))
        val_df = pd.read_csv(os.path.join(data_dir, '{}/splits/{}_vallist_01.csv'.format(j, args.task)))
        test_df = pd.read_csv(os.path.join(data_dir, '{}/splits/{}_testlist_01.csv'.format(j, args.task)))
        
        if args.stream == 'rgb':
            data_loader = dataloader.StackC40ImageLoader(args.batch_size, args.num_workers, data_dir, frame_count,
                                                       train_df, val_df, test_df, train_transform, test_transform)
        elif args.stream == 'luminance':
            data_loader = dataloader.LuminanceGradientImageLoader(args.batch_size, args.num_workers, data_dir, frame_count,
                                                                  train_df, val_df, test_df, train_transform, test_transform)
        elif args.stream == 'sharpened':
            data_loader = dataloader.SharpenedImageLoader(args.batch_size, args.num_workers, data_dir, frame_count,
                                                                  train_df, val_df, test_df, train_transform, test_transform)
        
        train_loader, val_loader, test_loader, push_loader = data_loader.run()

        # =================== Training =================== 
        detector = DPNet(device=device,
                            log_dir=log_dir,
                            args=args,
                            train_loader=train_loader,
                            val_loader=val_loader,
                            test_loader=test_loader)
        detector.test()

FF++ to FF++
==> Training data: 972000 frames
==> Validation data: 70000 frames
==> Testing data: 70000 frames
==> Pushing data: 360000 frames
==> Loading pretrained model model/pretrained_models/hrnetv2_w48_imagenet_pretrained.pth


  0%|          | 0/2188 [00:00<?, ?it/s]

|--> [testing stage]


100%|██████████| 2188/2188 [12:53<00:00,  3.23it/s]


{'Epoch': [2], 'Time': [773.789], 'Loss': [0.37331], 'Acc': [83.0], 'AUC': [0.9091], 'pAUC_10': [0.8146], 'TAR_10': [0.7946], 'EER': [0.1821]}
FF++ to DFD
==> Training data: 749969 frames
==> Validation data: 24766 frames
==> Testing data: 33019 frames
==> Pushing data: 281940 frames
==> Loading pretrained model model/pretrained_models/hrnetv2_w48_imagenet_pretrained.pth


  0%|          | 0/1032 [00:00<?, ?it/s]

|--> [testing stage]


100%|██████████| 1032/1032 [06:14<00:00,  2.95it/s]


{'Epoch': [2], 'Time': [374.934], 'Loss': [0.75972], 'Acc': [64.759], 'AUC': [0.7061], 'pAUC_10': [0.5765], 'TAR_10': [0.2602], 'EER': [0.3606]}
FF++ to Celeb-DF
==> Training data: 1600307 frames
==> Validation data: 51740 frames
==> Testing data: 51740 frames
==> Pushing data: 601000 frames
==> Loading pretrained model model/pretrained_models/hrnetv2_w48_imagenet_pretrained.pth


  0%|          | 0/1617 [00:00<?, ?it/s]

|--> [testing stage]


100%|██████████| 1617/1617 [14:27<00:00,  2.60it/s]


{'Epoch': [2], 'Time': [867.727], 'Loss': [0.70246], 'Acc': [70.0772], 'AUC': [0.7176], 'pAUC_10': [0.5602], 'TAR_10': [0.2794], 'EER': [0.3265]}
FF++ to DeeperForensics
==> Training data: 2277684 frames
==> Validation data: 115200 frames
==> Testing data: 241200 frames
==> Pushing data: 843600 frames
==> Loading pretrained model model/pretrained_models/hrnetv2_w48_imagenet_pretrained.pth


  0%|          | 0/7538 [00:00<?, ?it/s]

|--> [testing stage]


100%|██████████| 7538/7538 [1:15:25<00:00,  2.41it/s]  


{'Epoch': [2], 'Time': [4525.527], 'Loss': [0.99517], 'Acc': [58.9967], 'AUC': [0.8452], 'pAUC_10': [0.7096], 'TAR_10': [0.6092], 'EER': [0.2402]}
