In [None]:
import os
import cv2
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import timm  # Recommended for Xception
from tqdm import tqdm

# ====================================================
# 1. CONFIGURATION
# ====================================================
CONFIG = {
    "gpu_id": 0,
    "num_workers": 4,
    "im_size": 299,        # Standard input size for Xception
    "batch_size": 16,      # Can often use larger batch size than Madry since no attack overhead
    "epochs": 20,
    "lr": 1e-4,
    "sequence_length": 5,  # Number of frames to extract per video
    "base_path": "/content/drive/MyDrive/csc490/code_and_datasets/video_splits_output", # Path to txt files
    "checkpoint_dir": "/content/drive/MyDrive/csc490/code_and_datasets/checkpoints"
}

device = torch.device(f"cuda:{CONFIG['gpu_id']}" if torch.cuda.is_available() else "cpu")
os.makedirs(CONFIG["checkpoint_dir"], exist_ok=True)

# ====================================================
# 2. DATASET & TRANSFORMS
# ====================================================

# We keep the same transform logic as Madry for consistency.
# Images are loaded in [0, 1] range. Normalization happens before the model.
train_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((CONFIG['im_size'], CONFIG['im_size'])),
    transforms.ToTensor(),
])

test_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((CONFIG['im_size'], CONFIG['im_size'])),
    transforms.ToTensor(),
])

class DeepfakeVideoDataset(Dataset):
    def __init__(self, file_list_path, sequence_length, transform=None):
        # Read video paths from the text file
        # ASSUMPTION: 'train.txt' contains Original, Deepfake, AND Adversarial video paths
        with open(file_list_path, 'r') as f:
            self.video_paths = [line.strip() for line in f.readlines() if line.strip()]

        self.sequence_length = sequence_length
        self.transform = transform

    def __len__(self):
        return len(self.video_paths)

    def get_label(self, path):
        # Logic:
        # - "original" folder -> Label 0 (Real)
        # - "deepfakes", "adversarial", "neuraltextures" etc. -> Label 1 (Fake)
        if "original" in path.lower():
            return 0
        else:
            return 1

    def __getitem__(self, idx):
        video_path = self.video_paths[idx]
        label = self.get_label(video_path)

        cap = cv2.VideoCapture(video_path)
        frames = []
        frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        # Random sampling
        if frame_count > self.sequence_length:
            indices = sorted(random.sample(range(frame_count), self.sequence_length))
        else:
            indices = list(range(frame_count))

        for i in range(frame_count):
            ret, frame = cap.read()
            if not ret: break
            if i in indices:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                if self.transform:
                    frame = self.transform(frame)
                frames.append(frame)
                if len(frames) >= self.sequence_length:
                    break
        cap.release()

        # Padding
        if len(frames) == 0:
            return torch.zeros((self.sequence_length, 3, CONFIG['im_size'], CONFIG['im_size'])), label

        while len(frames) < self.sequence_length:
            frames.append(frames[-1])

        return torch.stack(frames), label

# ====================================================
# 3. MODEL & NORMALIZATION HELPER
# ====================================================

# Manual Normalization (Same as Madry code for consistency)
def normalize_batch(imgs):
    mean = torch.tensor([0.485, 0.456, 0.406], device=imgs.device).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=imgs.device).view(1, 3, 1, 1)
    return (imgs - mean) / std

def get_model():
    # Load Pretrained Xception
    model = timm.create_model('xception', pretrained=True, num_classes=2)
    return model.to(device)

# ====================================================
# 4. MAIN TRAINING LOOP (Standard Fine-tuning)
# ====================================================

def main():
    # 1. Load Data
    # 'train.txt' MUST include the paths to the pre-generated adversarial videos
    train_dataset = DeepfakeVideoDataset(os.path.join(CONFIG['base_path'], "train.txt"), CONFIG['sequence_length'], train_transforms)
    val_dataset = DeepfakeVideoDataset(os.path.join(CONFIG['base_path'], "val.txt"), CONFIG['sequence_length'], test_transforms)

    train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=CONFIG['num_workers'])
    val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=CONFIG['num_workers'])

    print(f"Dataset Loaded: Train {len(train_dataset)} | Val {len(val_dataset)}")
    print("NOTE: Ensure 'train.txt' contains mixed data (Original + Deepfake + Adversarial)")

    # 2. Model & Optimizer
    model = get_model()
    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['lr'])
    criterion = nn.CrossEntropyLoss()

    best_acc = 0.0

    # 3. Training Loop
    for epoch in range(1, CONFIG['epochs'] + 1):
        model.train()
        train_loss = 0
        correct = 0
        total = 0

        print(f"\nEpoch {epoch}/{CONFIG['epochs']} - Fine-tuning Started...")

        for inputs, labels in tqdm(train_loader):
            # Flatten inputs: (B, Seq, C, H, W) -> (B*Seq, C, H, W)
            b, s, c, h, w = inputs.shape
            inputs = inputs.view(b * s, c, h, w).to(device)
            labels = labels.repeat_interleave(s).to(device)

            # --- DIFFERENCE IS HERE ---
            # No PGD Attack generation.
            # We assume 'inputs' already contains adversarial examples loaded from disk.

            # Normalize and Forward
            outputs = model(normalize_batch(inputs))
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        epoch_loss = train_loss / total
        epoch_acc = 100. * correct / total
        print(f"Train Loss: {epoch_loss:.4f} | Train Acc (Mixed): {epoch_acc:.2f}%")

        # 4. Validation
        model.eval()
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                b, s, c, h, w = inputs.shape
                inputs = inputs.view(b * s, c, h, w).to(device)
                labels = labels.repeat_interleave(s).to(device)

                outputs = model(normalize_batch(inputs))
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()

        val_acc = 100. * val_correct / val_total
        print(f"Val Acc: {val_acc:.2f}%")

        # 5. Save Checkpoint
        if val_acc > best_acc:
            best_acc = val_acc
            save_path = os.path.join(CONFIG['checkpoint_dir'], f"finetune_best_model.pth")
            torch.save(model.state_dict(), save_path)
            print(f"Saved Best Model to {save_path}")

if __name__ == "__main__":
    main()