In [1]:
import torch
from transformers import ViTForImageClassification, ViTFeatureExtractor, ViTImageProcessor
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from PIL import Image
import os
import shutil
from sklearn.metrics import accuracy_score
import random

In [2]:
# Custom Dataset Class
class DeepfakeDataset(Dataset):
    def __init__(self, image_dir, labels, transform):
        self.image_dir = image_dir
        self.labels = labels
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        image = Image.open(img_path).convert("RGB")
        label = self.labels[self.images[idx]]
        if self.transform:
            image = self.transform(image)
        return image, label

In [3]:
# Load the pre-trained Google Vision Transformer
PRETRAINED_MODEL = "google/vit-base-patch16-224"

model = ViTForImageClassification.from_pretrained(PRETRAINED_MODEL)
processor = ViTImageProcessor.from_pretrained(PRETRAINED_MODEL)

In [4]:
# Create the transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

In [5]:
# Prepare the Dataset
IMAGE_PATH = "/Users/Hung.Le/Downloads/df-training-images/"
AUTHENTIC = IMAGE_PATH + "authentic"
TAMPERED = IMAGE_PATH + "tampered"
TRAIN_IMAGE_FOLDER = "./train_images"
TEST_IMAGE_FOLDER = "./test_images"
TRAIN_PERCENTAGE = 0.9
TEST_PERCENTAGE = 1 - TRAIN_PERCENTAGE

def create_dataloader(target_image_dir: str, percent: float):
    labels = {}
    if os.path.isdir(target_image_dir):
        shutil.rmtree(target_image_dir) 
    os.mkdir(target_image_dir)
    
    authentic_list = os.listdir(AUTHENTIC)
    authentic_size = int(len(authentic_list) * percent)
    selected_authentic_list = random.choices(authentic_list, k=authentic_size)
    for f in selected_authentic_list:
        labels[f] = 0
        shutil.copy(AUTHENTIC + "/" + f, target_image_dir)
    
    tampered_list = os.listdir(TAMPERED)
    tampered_size = int(len(tampered_list) * percent)
    selected_tampered_list = random.choices(tampered_list, k=tampered_size)
    for f in selected_tampered_list:
        labels[f] = 1
        shutil.copy(TAMPERED + "/" + f, target_image_dir)

    print(len(os.listdir(target_image_dir)))
    
    dataset = DeepfakeDataset(image_dir=target_image_dir, labels=labels, transform=transform)
    return DataLoader(dataset, batch_size=32, shuffle=True)

data_loader = create_dataloader(TRAIN_IMAGE_FOLDER, TRAIN_PERCENTAGE)
test_data_loader = create_dataloader(TEST_IMAGE_FOLDER, TEST_PERCENTAGE)

5605
895


In [6]:
# Train the dragon
optimizer = AdamW(model.parameters(), lr=1e-4)
criterion = CrossEntropyLoss()

EPOCHS = 10

model.train()
for epoch in range(EPOCHS):
    for images, labels in data_loader:
        inputs = processor(images, return_tensors="pt", do_rescale=False, do_normalize=False)
        outputs = model(**inputs)
        loss = criterion(outputs.logits, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    print(f"Epoch {epoch + 1}: Loss = {loss.item()}")

Epoch 1: Loss = 0.5213854908943176
Epoch 2: Loss = 1.2809793949127197
Epoch 3: Loss = 0.2759154438972473
Epoch 4: Loss = 0.017742659896612167
Epoch 5: Loss = 0.29475897550582886
Epoch 6: Loss = 0.28638553619384766
Epoch 7: Loss = 0.07151441276073456
Epoch 8: Loss = 0.0024875965900719166
Epoch 9: Loss = 0.004380452446639538


KeyboardInterrupt: 

In [None]:
model.eval()
predictions, true_labels = [], []

with torch.no_grad():
    for images, labels in test_data_loader:
        inputs = processor(images, return_tensors="pt")
        outputs = model(**inputs)
        preds = torch.argmax(outputs.logits, dim=-1)
        predictions.extend(preds.cpu().numpy())
        true_labels.extend(labels.cpu().numpy())

accuracy = accuracy_score(true_labels, predictions)
print(f"Validation Accuracy: {accuracy * 100:.2f}%")

Validation Accuracy: 50.00%


In [None]:
torch.save(model, "vit_deepfake_detector_first_model.pth")