In [1]:
import torch
import random
import os
import cv2
import numpy as np

import timm

from albumentations import Compose, RandomBrightnessContrast, \
    HorizontalFlip, FancyPCA, HueSaturationValue, OneOf, ToGray, \
    ShiftScaleRotate, ImageCompression, PadIfNeeded, GaussNoise, DualTransform

from tqdm import tqdm

IMAGE_SIZE = 224
random.seed(42)
BATCH_SIZE = 64

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class IsotropicResize(DualTransform):
    def __init__(self, max_side, interpolation_down, interpolation_up):
        super(IsotropicResize, self).__init__(False, 1)
        self.max_side = max_side
        self.interpolation_down = interpolation_down
        self.interpolation_up = interpolation_up

    def apply(self, img, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC, **params):

        h, w = img.shape[:2]

        if max(w, h) == self.max_side:
            return img
        if w > h:
            scale = self.max_side / w
            h = h * scale
            w = self.max_side
        else:
            scale = self.max_side / h
            w = w * scale
            h = self.max_side
        interpolation = interpolation_up if scale > 1 else interpolation_down

        img = img.astype('uint8')
        resized = cv2.resize(img, (int(w), int(h)), interpolation=interpolation)
        return resized

    def apply_to_mask(self, img, **params):
        return self.apply(img, interpolation_down=cv2.INTER_NEAREST, interpolation_up=cv2.INTER_NEAREST, **params)

    def get_transform_init_args_names(self):
        return ("max_side", "interpolation_down", "interpolation_up")

In [3]:
# Define dataloader
class DeepFakesDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, labels, image_size, mode='train'):
        self.image_paths = image_paths
        self.labels = torch.from_numpy(labels)
        self.image_size = image_size
        self.mode = mode
        self.n_samples = len(image_paths)
    
    def create_train_transforms(self, size):
        return Compose([
            ImageCompression(quality_lower=60, quality_upper=100, p=0.2),
            GaussNoise(p=0.3),
            HorizontalFlip(),
            OneOf([
                IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC),
                IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_LINEAR),
                IsotropicResize(max_side=size, interpolation_down=cv2.INTER_LINEAR, interpolation_up=cv2.INTER_LINEAR),
            ], p=1),
            PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT),
            OneOf([RandomBrightnessContrast(), FancyPCA(), HueSaturationValue()], p=0.4),
            ToGray(p=0.2),
            ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=5, border_mode=cv2.BORDER_CONSTANT, p=0.5),
        ]
        )
        
    def create_val_transform(self, size):
        return Compose([
            IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC),
            PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT),
        ])

    def __getitem__(self, index):
        image = np.asarray(cv2.imread(self.image_paths[index]))
        
        if self.mode == 'train':
            transform = self.create_train_transforms(self.image_size)
        else:
            transform = self.create_val_transform(self.image_size)
        
        image = transform(image=image)['image']
        
        return torch.tensor(image).float(), self.labels[index]

    def __len__(self):
        return self.n_samples

In [4]:
def load_data(txt_path):
    dataset = []
    labels = []

    with open(txt_path, 'r') as f:
        for line in f:
            line = line.rstrip()
            words = line.split()
            dataset.append(words[0])
            labels.append(int(words[1]))

    return dataset, labels

In [6]:
# The evaluate function to calculate the correct predictions, positive class and negative class
def evaluate(preds, labels):
    rounded_preds = preds.round()

    correct = sum(pred == label for pred, label in zip(rounded_preds, labels)).item()
    positive_class = int(sum(rounded_preds).item())
    negative_class = (len(rounded_preds) - positive_class)
    
    return correct, positive_class, negative_class

In [7]:
# input the path of the model and the path of the txt file
def test(model_path, data_path):
    dataset, labels = load_data(data_path)

    # set the device to cuda if available to speed up training
    if torch.cuda.is_available():
        device = torch.device('cuda')
        print(torch.cuda.get_device_name())
    else:
        device = torch.device('cpu')

    # load the data loader
    dataset = DeepFakesDataset(dataset, np.asarray(labels), IMAGE_SIZE)
    dl = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    del dataset

    model = timm.create_model('tf_efficientnetv2_m', pretrained=True, num_classes=1)

    if torch.cuda.is_available():
        model.load_state_dict(torch.load(model_path))
    else:
        model.load_state_dict(torch.load(model_path), map_location=torch.device('cpu'))

    model = model.to(device)
    all_preds = torch.Tensor()
    all_labels = torch.Tensor()

    model.eval()
    for images, labels in tqdm(dl, desc='Testing'):
        images = np.transpose(images, (0, 3, 1, 2))
        images = images.to(device)

        with torch.no_grad():
            y_pred = model(images)
        
        y_pred = y_pred.cpu()

        all_preds = torch.cat((all_preds, torch.sigmoid(y_pred)))
        all_labels = torch.cat((all_labels, labels))

    corrects, positive, negative = evaluate(all_preds, all_labels)
    accuracy = corrects / len(all_labels)
    return accuracy, positive / len(all_labels), negative / len(all_labels)

In [8]:
# Please replace 'project_data.txt' with the path to the txt file
accuracy, positivity, negativity = test('model_33_10', 'project_data.txt')
accuracy, positivity, negativity

NVIDIA GeForce RTX 2080 Ti


Testing: 100%|██████████| 862/862 [09:21<00:00,  1.53it/s]


(0.9690039175856066, 0.5128228380731282, 0.4871771619268717)