In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
from torchvision.transforms.functional import to_pil_image
import cv2
from einops import repeat
from einops.layers.torch import Rearrange
import numpy as np

# =============================
# Dataset
# =============================
class VideoFaceDataset(Dataset):
    def __init__(self, root_dir, split="train", category="real", transform=None, num_frames=32):
        self.video_root = os.path.join(root_dir, "image", split, category)
        self.transform = transform
        self.num_frames = num_frames
        self.category = category

        self.video_dirs = sorted([
            d for d in os.listdir(self.video_root)
            if os.path.isdir(os.path.join(self.video_root, d))
        ])

        print(f"[{split}][{category}] Found {len(self.video_dirs)} videos")
        self.label_value = 0 if category == "real" else 1

    def __len__(self):
        return len(self.video_dirs)

    def __getitem__(self, idx):
        video_dir = os.path.join(self.video_root, self.video_dirs[idx])
        frame_files = sorted([
            f for f in os.listdir(video_dir)
            if f.lower().endswith((".jpg", ".png"))
        ])

        total_frames = len(frame_files)
        indices = torch.linspace(0, total_frames - 1, steps=self.num_frames).long().tolist()

        frames = []
        for frame_idx in indices:
            frame_path = os.path.join(video_dir, frame_files[frame_idx])
            frame = cv2.imread(frame_path)
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = to_pil_image(frame)
            if self.transform:
                frame = self.transform(frame)
            frames.append(frame)

        if len(frames) < self.num_frames:
            pad_tensor = torch.zeros_like(frames[0])
            frames.extend([pad_tensor] * (self.num_frames - len(frames)))
        frames = torch.stack(frames)  # (T, C, H, W)

        label_tensor = torch.tensor(self.label_value, dtype=torch.long)
        return frames, label_tensor

# =============================
# Model
# =============================
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=8, emb_size=128):
        super().__init__()
        self.projection = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
            nn.Linear(patch_size * patch_size * in_channels, emb_size)
        )

    def forward(self, x):
        return self.projection(x)

class Attention(nn.Module):
    def __init__(self, dim, n_heads, dropout=0.):
        super().__init__()
        self.att = nn.MultiheadAttention(embed_dim=dim, num_heads=n_heads, dropout=dropout)
        self.q = nn.Linear(dim, dim)
        self.k = nn.Linear(dim, dim)
        self.v = nn.Linear(dim, dim)

    def forward(self, x):
        q = self.q(x).permute(1, 0, 2)
        k = self.k(x).permute(1, 0, 2)
        v = self.v(x).permute(1, 0, 2)
        attn_output, _ = self.att(q, k, v)
        return attn_output.permute(1, 0, 2)

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Sequential):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        return x + res

class ViT(nn.Module):
    def __init__(self, ch=3, img_size=144, patch_size=16, emb_dim=64,
                n_layers=4, dropout=0.1, heads=2, out_dim=2):
        super().__init__()
        self.patch_embedding = PatchEmbedding(ch, patch_size, emb_dim)
        num_patches = (img_size // patch_size) ** 2

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, emb_dim))
        self.cls_token = nn.Parameter(torch.rand(1, 1, emb_dim))

        self.layers = nn.ModuleList([
            nn.Sequential(
                ResidualAdd(PreNorm(emb_dim, Attention(emb_dim, heads, dropout))),
                ResidualAdd(PreNorm(emb_dim, FeedForward(emb_dim, emb_dim * 2, dropout)))
            ) for _ in range(n_layers)
        ])

        self.head = nn.Sequential(
            nn.LayerNorm(emb_dim),
            nn.Linear(emb_dim, out_dim),
        )

    def forward(self, imgs):
        # imgs: (B, T, C, H, W)
        B, T, C, H, W = imgs.shape
        imgs = imgs.view(B * T, C, H, W)  # (B*T, C, H, W)

        x = self.patch_embedding(imgs)
        b, n, _ = x.shape
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)
        x = torch.cat([cls_tokens, x], dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        for layer in self.layers:
            x = layer(x)
        out = self.head(x[:, 0, :])  # (B*T, out_dim)

        return out.view(B, T, -1).mean(dim=1)

# =============================
# Collate
# =============================
def custom_collate(batch):
    frames = torch.stack([item[0] for item in batch], dim=0)  # (B, T, C, H, W)
    labels = torch.tensor([item[1] for item in batch], dtype=torch.long)
    return frames, labels

# =============================
# Training
# =============================
transform = T.Compose([
    T.Resize((144, 144)),
    T.ToTensor(),
])

root_dir = "D://projectdeepfake//thesis work"

train_real_dataset = VideoFaceDataset(root_dir, split="train", category="real", transform=transform)
train_attack_dataset = VideoFaceDataset(root_dir, split="train", category="attack", transform=transform)
devel_real_dataset = VideoFaceDataset(root_dir, split="devel", category="real", transform=transform)
devel_attack_dataset = VideoFaceDataset(root_dir, split="devel", category="attack", transform=transform)
test_real_dataset = VideoFaceDataset(root_dir, split="test", category="real", transform=transform)
test_attack_dataset = VideoFaceDataset(root_dir, split="test", category="attack", transform=transform)

train_dataset = ConcatDataset([train_real_dataset, train_attack_dataset])
devel_dataset = ConcatDataset([devel_real_dataset, devel_attack_dataset])
test_dataset = ConcatDataset([test_real_dataset, test_attack_dataset])

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=custom_collate)
devel_loader = DataLoader(devel_dataset, batch_size=2, shuffle=False, collate_fn=custom_collate)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False, collate_fn=custom_collate)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ViT(out_dim=2).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

scaler = torch.cuda.amp.GradScaler()  # correct for CUDA/GPU

num_epochs = 3
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(inputs)
            loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        running_loss += loss.item()
        torch.cuda.empty_cache()
    avg_train_loss = running_loss / len(train_loader)

    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, labels in devel_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            with torch.cuda.amp.autocast():
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
        avg_val_loss = val_loss / len(devel_loader)

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

# =============================
# Testing
# =============================
model.eval()
test_loss = 0.0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        with torch.cuda.amp.autocast():
            outputs = model(inputs)
            loss = criterion(outputs, labels)
        test_loss += loss.item()
avg_test_loss = test_loss / len(test_loader)
print(f"Test Loss: {avg_test_loss:.4f}")

[train][real] Found 30 videos
[train][attack] Found 300 videos
[devel][real] Found 60 videos
[devel][attack] Found 300 videos
[test][real] Found 80 videos
[test][attack] Found 400 videos
Epoch 1/3, Train Loss: 0.3145, Val Loss: 0.6822
Epoch 2/3, Train Loss: 0.3194, Val Loss: 0.4713
Epoch 3/3, Train Loss: 0.2998, Val Loss: 0.4352
Test Loss: 0.4349
