# Imports

In [2]:
import numpy as np
import itertools
from glob import glob
import torch.nn as nn
import torch
import matplotlib.pyplot as plt
from scipy.special import expit
import sys
from blazeface import FaceExtractor, BlazeFace, VideoReader
from architectures import fornet
from architectures.fornet import FeatureExtractor
from utils import utils
from utils.utils import get_transformer
from utils.utils import plot_confusion_matrix
sys.path.append('..')


# Select architecture, device, face policy, face size, frames per video, dataset and provide model path

In [3]:
net_choices = ['TimmV2', 'TimmV2ST', 'ViT', 'ViTST']
choices = {'v2': 'TimmV2', 'v2st': 'TimmV2ST', 'vit': 'ViT', 'vitst': 'ViTST'}
device = torch.device(
    'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
face_policy = 'scale'
face_size = 224
frames_per_video = 32

dataset = "ffpp"
net_name = net_choices[0]
net_class = getattr(fornet, net_name)
model_path = "../models/" + dataset + "_" + "v2.pth"


# Provide path to video files

In [4]:
video_paths = glob('../sample_videos/ffpp/real/**/*.mp4', recursive=True)
file_names = []
for i in video_paths:
    file_names.append(i.split("/")[4])
file_names.sort()
len(file_names)
file_names

video_idxs = [1, 3]

input_dir = '../sample_videos/ffpp/real/'


# Load weights

In [5]:
net: FeatureExtractor = net_class().eval().to(device)
net.load_state_dict(torch.load(model_path, map_location='cpu')['net'])


<All keys matched successfully>

# Load face extractor

In [6]:
transf = utils.get_transformer(
    face_policy, face_size, net.get_normalizer(), train=False)
facedet = BlazeFace().to(device)
facedet.load_weights("blazeface/blazeface.pth")
facedet.load_anchors("blazeface/anchors.npy")
videoreader = VideoReader(verbose=False)


def video_read_fn(x): return videoreader.read_frames(
    x, num_frames=frames_per_video)


face_extractor = FaceExtractor(video_read_fn=video_read_fn, facedet=facedet)


# Extract faces

In [7]:
faces = face_extractor.process_videos(
    input_dir=input_dir, filenames=file_names, video_idxs=video_idxs)
total_videos = len(video_idxs)


faces_frames = [frames_per_video *
                x for x in range(0, total_videos+1)]   # [0,32,64,96]

faces_hc = torch.stack([transf(image=frame['faces'][0])['image']
                       for frame in faces if len(frame['faces'])])


# Make predictions

In [8]:
predictions = {}
with torch.no_grad():
    for i in range(0, total_videos):  # (0,3) i.e 0,1,2
        pred = net(faces_hc[faces_frames[i]:faces_frames[i+1]
                            ].to(device)).cpu().numpy().flatten()
        score = expit(pred.mean())
        predictions[input_dir+file_names[video_idxs[i]]
                    ] = [round(score, 3), 'real' if score < 0.1 else 'fake']
        predictions[input_dir+file_names[video_idxs[i]]] = [round(score, 3), {
            'predicted_class': 'real' if score < 0.1 else 'fake', 'true_class': input_dir.split("/")[3]}]


In [9]:
predictions


{'../sample_videos/ffpp/real/091.mp4': [0.006,
  {'predicted_class': 'real', 'true_class': 'real'}],
 '../sample_videos/ffpp/real/250.mp4': [0.002,
  {'predicted_class': 'real', 'true_class': 'real'}]}

# Analysis

In [10]:
pclass = []
tclass = []
res = []  # [   [predicted_class,true_class],    [predicted_class,true_class]     ....  ]
for preds in predictions:
    predicted_class = predictions[preds][1]['predicted_class']
    true_class = predictions[preds][1]['true_class']
    res.append([predicted_class, true_class])
    pclass.append(predicted_class)
    tclass.append(true_class)


In [11]:
for i in range(0, len(pclass)):
    if(pclass[i] == 'real'):
        pclass[i] = 0
    elif(pclass[i] == 'fake'):
        pclass[i] = 1

pclass = torch.Tensor(pclass)

for i in range(0, len(tclass)):
    if(tclass[i] == 'real'):
        tclass[i] = 0
    elif(tclass[i] == 'fake'):
        tclass[i] = 1
tclass = torch.Tensor(tclass)


In [12]:
stacked = torch.stack((tclass, pclass), dim=1)


In [None]:
cmt = torch.zeros(2, 2, dtype=torch.int64)
for p in stacked:
    tl, pl = p.tolist()
    cmt[int(tl), int(pl)] = cmt[int(tl), int(pl)] + 1
cmt = cmt.detach().cpu().numpy()
cmt


In [None]:
names = ('real', 'fake')
plt.figure(figsize=(4, 4))
plot_confusion_matrix(cmt, names)
