In [2]:
import torch
from transformers import ViTForImageClassification, ViTFeatureExtractor, ViTImageProcessor
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from PIL import Image
import os
import shutil
from sklearn.metrics import accuracy_score
import random

In [4]:
# Custom Dataset Class
class DeepfakeDataset(Dataset):
    def __init__(self, image_dir, labels, transform):
        self.image_dir = image_dir
        self.labels = labels
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        image = Image.open(img_path).convert("RGB")
        label = self.labels[self.images[idx]]
        if self.transform:
            image = self.transform(image)
        return image, label

In [5]:
# Load the pre-trained Google Vision Transformer
PRETRAINED_MODEL = "google/vit-base-patch16-224"

model = ViTForImageClassification.from_pretrained(PRETRAINED_MODEL)
processor = ViTImageProcessor.from_pretrained(PRETRAINED_MODEL)

In [6]:
# Create the transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

In [9]:
# Prepare the Dataset
IMAGE_PATH = "/Users/Hung.Le/Downloads/df-training-images/"
AUTHENTIC = IMAGE_PATH + "source/real"
TAMPERED = IMAGE_PATH + "source/fake"
TRAIN_IMAGE_FOLDER = "./train_images"
TEST_IMAGE_FOLDER = "./test_images"
TRAIN_PERCENTAGE = 0.9
TEST_PERCENTAGE = 1 - TRAIN_PERCENTAGE

def create_dataloader(target_image_dir: str, percent: float):
    labels = {}
    if os.path.isdir(target_image_dir):
        shutil.rmtree(target_image_dir) 
    os.mkdir(target_image_dir)
    
    authentic_list = os.listdir(AUTHENTIC)
    authentic_size = int(len(authentic_list) * percent)
    selected_authentic_list = random.sample(authentic_list, k=authentic_size)
    for f in selected_authentic_list:
        labels[f] = 0
        shutil.copy(AUTHENTIC + "/" + f, target_image_dir)
    
    tampered_list = os.listdir(TAMPERED)
    tampered_size = int(len(tampered_list) * percent)
    selected_tampered_list = random.sample(tampered_list, k=tampered_size)
    for f in selected_tampered_list:
        labels[f] = 1
        shutil.copy(TAMPERED + "/" + f, target_image_dir)

    print("Number of images = " + str(len(os.listdir(target_image_dir))))
    
    dataset = DeepfakeDataset(image_dir=target_image_dir, labels=labels, transform=transform)
    return DataLoader(dataset, batch_size=32, shuffle=True)

In [10]:
# Train the dragon
data_loader = create_dataloader(TRAIN_IMAGE_FOLDER, TRAIN_PERCENTAGE)
optimizer = AdamW(model.parameters(), lr=1e-4)
criterion = CrossEntropyLoss()

EPOCHS = 20

model.train()
for epoch in range(EPOCHS):
    for images, labels in data_loader:
        inputs = processor(images, return_tensors="pt", do_rescale=False, do_normalize=False)
        outputs = model(**inputs)
        loss = criterion(outputs.logits, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    print(f"Epoch {epoch + 1}: Loss = {loss.item()}")

Number of images = 8550
Epoch 1: Loss = 0.36950746178627014
Epoch 2: Loss = 0.7061669230461121
Epoch 3: Loss = 0.150094673037529
Epoch 4: Loss = 0.5439589023590088
Epoch 5: Loss = 0.2041434794664383
Epoch 6: Loss = 0.0003474567783996463
Epoch 7: Loss = 0.5543035864830017
Epoch 8: Loss = 0.0010337498970329762
Epoch 9: Loss = 0.163868710398674
Epoch 10: Loss = 0.0037351239006966352
Epoch 11: Loss = 0.0014524428406730294
Epoch 12: Loss = 0.001491726259700954
Epoch 13: Loss = 0.1237998977303505
Epoch 14: Loss = 0.23559431731700897
Epoch 15: Loss = 0.09915665537118912
Epoch 16: Loss = 0.2214786559343338
Epoch 17: Loss = 0.015697797760367393
Epoch 18: Loss = 0.002015285659581423
Epoch 19: Loss = 0.008106004446744919
Epoch 20: Loss = 0.16620177030563354


In [11]:
# Test the dragon
data_loader = create_dataloader(TEST_IMAGE_FOLDER, TEST_PERCENTAGE)
model.eval()
predictions, true_labels = [], []

with torch.no_grad():
    for images, labels in data_loader:
        inputs = processor(images, return_tensors="pt")
        outputs = model(**inputs)
        preds = torch.argmax(outputs.logits, dim=-1)
        predictions.extend(preds.cpu().numpy())
        true_labels.extend(labels.cpu().numpy())

accuracy = accuracy_score(true_labels, predictions)
print(f"Validation Accuracy: {accuracy * 100:.2f}%")

It looks like you are trying to rescale already rescaled images. If the input images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again.


Number of images = 949
Validation Accuracy: 78.29%


In [None]:
torch.save(model, "vit_deepfake_detector_first_model.pth")