In [None]:

!pip install torch torchvision torchaudio tqdm face_recognition pandas timm opencv-python --quiet

import os
import cv2
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import timm
from tqdm import tqdm
from google.colab import drive


if not os.path.exists('/content/drive'):
    drive.mount('/content/drive', force_remount=True)

TXT_FOLDER_PATH = "/content"

print(f"checking text file: {TXT_FOLDER_PATH}")

if not os.path.exists(os.path.join(TXT_FOLDER_PATH, "train.txt")):
    raise FileNotFoundError("Files not uploaded")
else:
    print("found file")

CONFIG = {
    "gpu_id": 0, "batch_size": 8,
    "epochs": 1,
    "lr": 1e-4,
    "epsilon": 0.05, "alpha": 0.01, "pgd_steps": 1,
    "adv_ratio": 0.5,
    "sequence_length": 5, "im_size": 299, "w_clean": 0.6, "w_adv": 0.4,
    "base_path": TXT_FOLDER_PATH,
    "checkpoint_dir": "/content/drive/MyDrive/csc490/code_and_datasets/checkpoints"
}
os.makedirs(CONFIG["checkpoint_dir"], exist_ok=True)
device = torch.device(f"cuda:{CONFIG['gpu_id']}" if torch.cuda.is_available() else "cpu")


train_transforms = transforms.Compose([
    transforms.ToPILImage(), transforms.Resize((299, 299)), transforms.ToTensor(),
])

class DeepfakeVideoDataset(Dataset):
    def __init__(self, txt_filename, sequence_length, transform=None, limit=None):
        file_path = os.path.join(CONFIG['base_path'], txt_filename)
        if not os.path.exists(file_path):
            self.video_paths = []
        else:
            with open(file_path, 'r') as f:
                lines = [line.strip() for line in f.readlines() if line.strip()]

            if limit and len(lines) > limit: self.video_paths = lines[:limit]
            else: self.video_paths = lines
        self.sequence_length = sequence_length
        self.transform = transform

    def __len__(self): return len(self.video_paths)
    def get_label(self, path): return 0 if "original" in path.lower() else 1
    def __getitem__(self, idx):
        video_path = self.video_paths[idx]
        label = self.get_label(video_path)
        try:
            cap = cv2.VideoCapture(video_path)
            cnt = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            indices = sorted(random.sample(range(cnt), self.sequence_length)) if cnt > self.sequence_length else list(range(cnt))
            frames = []
            for i in range(cnt):
                ret, frame = cap.read()
                if not ret: break
                if i in indices:
                    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    if self.transform: frame = self.transform(frame)
                    frames.append(frame)
            cap.release()
        except: frames = []
        if len(frames) == 0: return torch.zeros((self.sequence_length, 3, 299, 299)), label
        while len(frames) < self.sequence_length: frames.append(frames[-1])
        return torch.stack(frames), label

class VideoXception(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = timm.create_model('xception', pretrained=True, num_classes=2)
    def forward(self, x):
        b, s, c, h, w = x.shape
        x = x.view(b*s, c, h, w)
        x = (x - torch.tensor([0.485, 0.456, 0.406], device=x.device).view(1,3,1,1)) / \
            torch.tensor([0.229, 0.224, 0.225], device=x.device).view(1,3,1,1)
        logits = self.backbone(x)
        return torch.mean(logits.view(b, s, -1), dim=1)

def pgd_attack(model, x, y, eps, alpha, steps):
    model.eval()
    x_adv = x.clone().detach().requires_grad_(True)
    for _ in range(steps):
        out = model(x_adv)
        loss = F.cross_entropy(out, y)
        model.zero_grad()
        loss.backward()
        with torch.no_grad():
            x_adv += alpha * x_adv.grad.sign()
            delta = torch.clamp(x_adv - x, -eps, eps)
            x_adv = torch.clamp(x + delta, 0, 1)
            x_adv.requires_grad_(True)
    model.train()
    return x_adv.detach()

def main():

    train_ds = DeepfakeVideoDataset("train.txt", CONFIG['sequence_length'], train_transforms, limit=32)
    val_clean_ds = DeepfakeVideoDataset("val_clean.txt", CONFIG['sequence_length'], train_transforms, limit=16)
    val_adv_ds = DeepfakeVideoDataset("val_adv.txt", CONFIG['sequence_length'], train_transforms, limit=16)

    train_loader = DataLoader(train_ds, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=2)
    val_clean_loader = DataLoader(val_clean_ds, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=2)
    val_adv_loader = DataLoader(val_adv_ds, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=2)

    model = VideoXception().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['lr'])
    criterion = nn.CrossEntropyLoss()

    print(f"\nSTARTING EPOCH 1/1 (Fast Run)...")
    model.train()
    t_corr, t_total = 0, 0

    for x, y in tqdm(train_loader):
        x, y = x.to(device), y.to(device)
        if random.random() < CONFIG['adv_ratio']:
            x = pgd_attack(model, x, y, CONFIG['epsilon'], CONFIG['alpha'], CONFIG['pgd_steps'])
        out = model(x)
        loss = criterion(out, y)
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        t_corr += out.argmax(1).eq(y).sum().item(); t_total += x.size(0)

    print(f"\n Train Done Acc: {100.*t_corr/t_total:.2f}%")

    # Validation
    model.eval()
    def validate(loader):
        c, t = 0, 0
        if len(loader)==0: return 0.0
        with torch.no_grad():
            for x, y in loader:
                x, y = x.to(device), y.to(device)
                c += model(x).argmax(1).eq(y).sum().item(); t += x.size(0)
        return 100.*c/t

    acc_clean = validate(val_clean_loader)
    acc_adv = validate(val_adv_loader)

    print(f"------------- FINAL RESULT -------------")
    print(f"Val Clean Acc : {acc_clean:.2f}%")
    print(f"Val Robust Acc: {acc_adv:.2f}%")
    print(f"----------------------------------------")

    torch.save(model.state_dict(), "final_result_model.pth")
    print("saved")

if __name__ == "__main__":
    main()

[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m98.9/100.1 MB[0m [31m41.5 MB/s[0m eta [36m0:00:01[0m
[?25h[31mERROR: Operation cancelled by user[0m[31m
[0m

ValueError: mount failed