In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms

device = "cuda" if torch.cuda.is_available() else "cpu"
device

# Random
torch.manual_seed(2022)

<torch._C.Generator at 0x7f6fb8075330>

In [2]:
from functools import partial
import numpy as np
import torch
from timm.models.efficientnet import tf_efficientnet_b4_ns, tf_efficientnet_b3_ns, \
    tf_efficientnet_b5_ns, tf_efficientnet_b2_ns, tf_efficientnet_b6_ns, tf_efficientnet_b7_ns
from torch import nn
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.linear import Linear
from torch.nn.modules.pooling import AdaptiveAvgPool2d
import torch.nn.functional as F 

encoder_params = {
    "tf_efficientnet_b3_ns": {
        "features": 1536,
        "init_op": partial(tf_efficientnet_b3_ns, pretrained=True, drop_path_rate=0.2)
    },
    "tf_efficientnet_b2_ns": {
        "features": 1408,
        "init_op": partial(tf_efficientnet_b2_ns, pretrained=False, drop_path_rate=0.2)
    },
    "tf_efficientnet_b4_ns": {
        "features": 1792,
        "init_op": partial(tf_efficientnet_b4_ns, pretrained=True, drop_path_rate=0.5)
    },
    "tf_efficientnet_b5_ns": {
        "features": 2048,
        "init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.2)
    },
    "tf_efficientnet_b4_ns_03d": {
        "features": 1792,
        "init_op": partial(tf_efficientnet_b4_ns, pretrained=True, drop_path_rate=0.3)
    },
    "tf_efficientnet_b5_ns_03d": {
        "features": 2048,
        "init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.3)
    },
    "tf_efficientnet_b5_ns_04d": {
        "features": 2048,
        "init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.4)
    },
    "tf_efficientnet_b6_ns": {
        "features": 2304,
        "init_op": partial(tf_efficientnet_b6_ns, pretrained=True, drop_path_rate=0.2)
    },
    "tf_efficientnet_b7_ns": {
        "features": 2560,
        "init_op": partial(tf_efficientnet_b7_ns, pretrained=True, drop_path_rate=0.2)
    },
    "tf_efficientnet_b6_ns_04d": {
        "features": 2304,
        "init_op": partial(tf_efficientnet_b6_ns, pretrained=True, drop_path_rate=0.4)
    },
}

class DeepFakeClassifier(nn.Module):
    def _init_(self, encoder, dropout_rate=0.0) -> None:
        super()._init_()
        self.encoder = encoder_params[encoder]["init_op"]()
        self.avg_pool = AdaptiveAvgPool2d((1, 1))
        self.dropout = Dropout(dropout_rate)
        self.fc1 = Linear(encoder_params[encoder]["features"], 1)
        self.fc2 = Linear(120, 1)

    def forward(self, x):
        x = self.encoder.forward_features(x)
        x = self.avg_pool(x).flatten(1)
        x = self.dropout(x)
        x = self.fc1(x)
#         x = self.fc2(x)
        return x

In [3]:
def test(model_ft, data_path):
    import torch
    import torchvision
    from torch.utils.data import DataLoader
    from tqdm import tqdm
    import torch.nn as nn
    import torch.nn.functional as F 
    from torchvision import datasets, models, transforms
    import matplotlib.pyplot as plt
    from torchmetrics.functional import precision_recall, auc, roc
    
    all_data = datasets.ImageFolder(root=data_path,
                                    transform=transforms.Compose([
                                        transforms.Resize((299,299)),
                                        transforms.ToTensor(),
                                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
                                    target_transform=None)
        
    test_dataloader = DataLoader(dataset=all_data, 
                            batch_size=10,
                            num_workers=1, 
                            shuffle=False)
    
    model_ft.to(device)
    model_ft.eval()
    list_preds = None
    list_labels = None
    
    with torch.no_grad():
        for inputs, labels in tqdm(test_dataloader):
            inputs = inputs.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.float32)

            outputs = model_ft(inputs)
            preds = nn.Sigmoid()(outputs)
            preds = preds.squeeze(1)

            if list_preds is not None:
                list_preds = torch.hstack((list_preds, preds.cpu()))
                list_labels = torch.hstack((list_labels, labels.cpu()))
            else:
                list_preds = preds.cpu()
                list_labels = labels.cpu()
        
    precision, recall = precision_recall(list_preds, list_labels.long(), average='macro', num_classes=1)
    fpr, tpr, thresholds = roc(list_preds, list_labels)
    
    # after you calculate AUC
    list_preds = list_preds > 0.5
    list_preds = list_preds.long()
    list_labels = list_labels.long()
    
    acc = torch.sum(list_preds == list_labels) / len(all_data)
    AUC = auc(tpr, fpr, reorder=True)
    
    return acc.item(), recall.item(), precision.item(), max(AUC.item(), 1-AUC.item())

Modify the following paths to data and model

In [4]:
PATH_TO_MODEL = "./dfdc_t6.pt"
PATH_TO_DATA = "./data"

In [5]:
model = torch.load(PATH_TO_MODEL)
acc, recall, precision, auc = test(model, PATH_TO_DATA) 
print(acc, recall, precision, auc)

100%|███████████████████████████████████████| 1200/1200 [02:33<00:00,  7.81it/s]


0.996666669845581 0.9975000023841858 0.9925373196601868 0.9999079062181409
