In [None]:
!pip install torch torchvision torchaudio --quiet
!pip install tqdm --quiet

In [None]:
!pip3 install face_recognition

from google.colab import drive
drive.mount('/content/drive')

import os
import glob
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pandas as pd
from torchvision import transforms
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, roc_curve, auc

# --- Set device ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Change directory and load splits ---
os.chdir("/content/drive/MyDrive/csc490/code_and_datasets/video_splits_output")
train_files = read_list("train.txt")
val_files   = read_list("val.txt")
test_files  = read_list("test.txt")

print("Loaded:")
print("Train:", len(train_files))
print("Val:", len(val_files))
print("Test:", len(test_files))

train_labels = pd.DataFrame({
    "file": [os.path.basename(p) for p in train_files],
    "label": [assign_label(p) for p in train_files]
})
val_labels = pd.DataFrame({
    "file": [os.path.basename(p) for p in val_files],
    "label": [assign_label(p) for p in val_files]
})
test_labels = pd.DataFrame({
    "file": [os.path.basename(p) for p in test_files],
    "label": [assign_label(p) for p in test_files]
})

print("Train Real:", sum(train_labels.label == 0), "Fake:", sum(train_labels.label == 1))
print("Val   Real:", sum(val_labels.label == 0),   "Fake:", sum(val_labels.label == 1))
print("Test  Real:", sum(test_labels.label == 0),  "Fake:", sum(test_labels.label == 1))

# --- Define transforms ---
im_size = 112
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

train_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((im_size, im_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

test_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((im_size, im_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

# --- Create Dataset objects ---
train_data = video_dataset(
    video_names=train_files,
    labels=train_labels,
    sequence_length=10,
    transform=train_transforms
)
val_data = video_dataset(
    video_names=val_files,
    labels=val_labels,
    sequence_length=10,
    transform=test_transforms
)
test_data = video_dataset(
    video_names=test_files,
    labels=test_labels,
    sequence_length=10,
    transform=test_transforms
)

# --- DataLoaders ---
train_loader = DataLoader(train_data, batch_size=4, shuffle=True, num_workers=4)
valid_loader = DataLoader(val_data, batch_size=4, shuffle=False, num_workers=4)
test_loader  = DataLoader(test_data,  batch_size=4, shuffle=False, num_workers=4)

# --- PGD attack function ---
def pgd_attack(model, x, target_label, epsilon=0.05, alpha=0.01, steps=5):
    model.eval()
    x_adv = x.clone().detach().requires_grad_(True)
    for _ in range(steps):
        _, logits = model(x_adv)
        loss = F.cross_entropy(logits, target_label)
        model.zero_grad()
        loss.backward()
        with torch.no_grad():
            # Targeted attack: move toward target_label
            x_adv = x_adv - alpha * x_adv.grad.sign()
            x_adv = torch.min(torch.max(x_adv, x - epsilon), x + epsilon)
            x_adv = x_adv.clamp(0, 1)
            x_adv.requires_grad_(True)
    model.train()
    return x_adv.detach()

# --- Learning rate and optimizer ---
lr = 1e-5
num_epochs = 20
model = Model(2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss().to(device)

train_loss_avg = []
train_accuracy = []
test_loss_avg = []
test_accuracy = []
start_epoch, best_acc = 1, 0.0

# --- TRAINING LOOP WITH MADRY STYLE ---
for epoch in range(start_epoch, num_epochs + 1):
    model.train()
    running_loss = 0
    correct = 0
    total = 0
    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Clone inputs for adversarial attack
        inputs_adv = inputs.clone().detach()

        # Only attack fake videos (label==1)
        fake_mask = (labels == 1)
        target_real = torch.zeros_like(labels)  # targeted to REAL class

        if fake_mask.any():
            inputs_fake = inputs_adv[fake_mask]
            targets_fake = target_real[fake_mask]
            adv_fake = pgd_attack(model, inputs_fake, targets_fake, epsilon=0.05, alpha=0.01, steps=5)
            inputs_adv[fake_mask] = adv_fake

        # Forward pass
        _, logits = model(inputs_adv)
        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        preds = logits.argmax(1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    train_loss_avg.append(running_loss / total)
    train_accuracy.append(correct / total)

    # --- VALIDATION ---
    true, pred, probs, val_loss, val_acc = test(epoch, model, valid_loader, criterion)
    test_loss_avg.append(val_loss)
    test_accuracy.append(val_acc)

    # --- SAVE BEST MODEL ---
    if val_acc > best_acc:
        best_acc = val_acc
        save_checkpoint({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "best_acc": best_acc
        })
        print(f"New best model saved with accuracy {best_acc:.2f}%")

# --- Load best model ---
load_checkpoint(model, optimizer, "/content/drive/MyDrive/csc490/code_and_datasets/checkpoints/no_pretraining_madry_style.pth")

# --- Plot results ---
plot_loss(train_loss_avg, test_loss_avg, len(train_loss_avg))
plot_accuracy(train_accuracy, test_accuracy, len(train_accuracy))

# --- Confusion matrix, F1, ROC ---
if true and pred and probs:
    print_confusion_matrix(true, pred)
    f1 = f1_score(true, pred)
    print(f"F1 Score: {f1:.4f}")
    fpr, tpr, thresholds = roc_curve(true, probs)
    roc_auc = auc(fpr, tpr)
    print(f"AUC: {roc_auc:.4f}")
    plt.figure()
    plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (AUC = %0.4f)' % roc_auc)
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC)')
    plt.legend(loc="lower right")
    plt.show()
else:
    print("Skipping confusion matrix, F1 score, and ROC plot due to incomplete validation data. Please run the training loop to completion.")
