In [None]:
!pip install --upgrade --quiet transformers

In [1]:
import transformers
print(transformers.__version__)

4.52.4


In [2]:
import os
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from tqdm import tqdm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import accuracy_score, f1_score
import matplotlib.pyplot as plt
import kagglehub

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
import kagglehub
path = kagglehub.dataset_download("hungle3401/faceforensics")
print("Dataset path:", path)

Dataset path: /kaggle/input/faceforensics


In [5]:
#extracting images from the video

import os
import cv2
from tqdm import tqdm

real_video_dir = os.path.join(path, "FF++", "real")
fake_video_dir = os.path.join(path, "FF++", "fake")
output_real = '/kaggle/working/frames/real'
output_fake = '/kaggle/working/frames/fake'

def extract_frames_from_videos(video_dir, output_dir, label, frames_per_video=5):
    os.makedirs(output_dir, exist_ok=True)
    video_files = [f for f in os.listdir(video_dir) if f.endswith('.mp4')]
    for video_file in tqdm(video_files, desc=f"Extracting {label}"):
        video_path = os.path.join(video_dir, video_file)
        cap = cv2.VideoCapture(video_path)
        frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        step = max(1, frame_count // frames_per_video)
        for i in range(frames_per_video):
            cap.set(cv2.CAP_PROP_POS_FRAMES, i * step)
            ret, frame = cap.read()
            if ret:
                img_name = f"{label}_{os.path.splitext(video_file)[0]}_frame{i}.jpg"
                img_path = os.path.join(output_dir, img_name)
                cv2.imwrite(img_path, frame)
        cap.release()

extract_frames_from_videos(real_video_dir, output_real, 'real')
extract_frames_from_videos(fake_video_dir, output_fake, 'fake')

Extracting real: 100%|██████████| 200/200 [06:04<00:00,  1.82s/it]
Extracting fake: 100%|██████████| 200/200 [06:06<00:00,  1.83s/it]


In [6]:
from transformers import ViTForImageClassification, ViTImageProcessor, TrainingArguments, Trainer
from torch.utils.data import Dataset, random_split
from PIL import Image
import torch
import numpy as np

class FaceForensicsDataset(Dataset):
    def __init__(self, root_dir, processor):
        self.samples = []
        self.processor = processor
        for label, subfolder in enumerate(['real', 'fake']):
            folder = os.path.join(root_dir, subfolder)
            for fname in os.listdir(folder):
                if fname.endswith('.jpg'):
                    self.samples.append((os.path.join(folder, fname), label))
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        processed = self.processor(images=image, return_tensors="pt")
        item = {k: v.squeeze(0) for k, v in processed.items()}
        item['labels'] = torch.tensor(label)
        return item

processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
dataset = FaceForensicsDataset('/kaggle/working/frames', processor)

# Split into train/val
val_pct = 0.2
val_size = int(len(dataset) * val_pct)
train_size = len(dataset) - val_size
train_ds, val_ds = random_split(dataset, [train_size, val_size])

In [10]:
model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224-in21k',
    num_labels=2,
    id2label={0: 'real', 1: 'fake'},
    label2id={'real': 0, 'fake': 1}
)

from sklearn.metrics import accuracy_score, f1_score

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([item['pixel_values'] for item in batch]),
        'labels': torch.tensor([item['labels'] for item in batch])
    }

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=1)
    return {
        "accuracy": accuracy_score(labels, preds),
        "f1": f1_score(labels, preds)
    }

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir='./vit-ff',
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=10,
    learning_rate=2e-5,
    load_best_model_at_end=True,
    metric_for_best_model="eval_accuracy",
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
)

trainer.train()

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,F1
1,0.6179,0.57429,0.735,0.765487
2,0.4758,0.522705,0.7575,0.789588
3,0.479,0.514581,0.7425,0.770601
4,0.3746,0.519964,0.725,0.746544
5,0.4176,0.526852,0.725,0.74537




TrainOutput(global_step=500, training_loss=0.4831079092025757, metrics={'train_runtime': 480.7606, 'train_samples_per_second': 16.64, 'train_steps_per_second': 1.04, 'total_flos': 6.19935916916736e+17, 'train_loss': 0.4831079092025757, 'epoch': 5.0})

In [11]:
eval_results = trainer.evaluate()
print("Validation results:", eval_results)

def predict_vit(image_path, model, processor, device='cuda'):
    image = Image.open(image_path).convert('RGB')
    inputs = processor(images=image, return_tensors="pt")
    model = model.to(device)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        pred = logits.argmax(dim=1).item()
    return 'real' if pred == 0 else 'fake'



Validation results: {'eval_loss': 0.5227053761482239, 'eval_accuracy': 0.7575, 'eval_f1': 0.7895878524945771, 'eval_runtime': 14.169, 'eval_samples_per_second': 28.231, 'eval_steps_per_second': 1.764, 'epoch': 5.0}
