In [1]:
!pip install webdataset



In [2]:
import kagglehub
import os
import random
import io
import copy
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import datasets, models
from torchvision.models import ConvNeXt_Tiny_Weights
from torch.utils.data import DataLoader, random_split, Subset
from sklearn.metrics import accuracy_score
from PIL import Image
import torch.nn.functional as F
from datasets import load_dataset
from torch.utils.data import IterableDataset, ConcatDataset
import random
from huggingface_hub import get_token, hf_hub_url, HfFileSystem
import webdataset as wds
from collections import defaultdict
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np

In [3]:
# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [4]:
def openImgaes(images, transform):
    pil_images = []
    for img_bytes in images:
        if isinstance(img_bytes, bytes):
            pil_img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
            pil_images.append(pil_img)
        else:
            pil_images.append(img_bytes)

    # Apply provided transform
    input_tensors = torch.stack([transform(img) for img in pil_images])
    return input_tensors


def labels_to_tensor(labels):
    label_to_idx = {
      "fake": 0,
       "real": 1
    }
    # labels is iterable of strings
    numeric_labels = [label_to_idx[l] for l in labels]
    return torch.tensor(numeric_labels, dtype=torch.long)

def evaluate_test_set(model, test_loader, class_names):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())
    acc = accuracy_score(all_labels, all_preds)
    print(f"Test set accuracy: {acc*100:.2f}%")

def unfreeze_model(model, unfreeze_from_layer=6):
    # ResNet layers: layer1, layer2, layer3, layer4
    # unfreeze_from_layer: number between 1 and 4 to unfreeze from that layer onwards
    layers = [model.layer1, model.layer2, model.layer3, model.layer4]

    for param in model.parameters():
        param.requires_grad = False  # Freeze all first

    for param in model.fc.parameters():
        param.requires_grad = True  # Always train final fc layer

    # Unfreeze from specified layer onwards
    for i in range(unfreeze_from_layer - 1, len(layers)):
        for param in layers[i].parameters():
            param.requires_grad = True

    print(f"Unfroze layers from layer{unfreeze_from_layer} onwards")

def predict_batch(model, images, device, class_names):
    model.eval()
    images = images.to(device)
    with torch.no_grad():
        outputs = model(images)
        probs = F.softmax(outputs, dim=1)
        confidences, preds = torch.max(probs, 1)
    pred_labels = [class_names[i] for i in preds.cpu().numpy()]
    confs = confidences.cpu().numpy()
    return pred_labels, confs


from torch.utils.data import DataLoader

def get_label_from_key(key_str):

    if 'fake' in key_str:
        return 'fake'
    elif 'real' in key_str:
        return 'real'
    else:
        return 'unknown'

def preprocess(sample):
    image = sample.get('png') or sample.get('jpg') or sample.get('tiff')
    key = sample.get('__key__', '')
    label = get_label_from_key(key)
    return image, label

In [5]:
# Replace "your_own_huggingface_token" with your actual Hugging Face access token
# Get one at: https://huggingface.co/settings/tokens
myTtoken = "your_own_huggingface_token"

os.environ["HF_TOKEN"] = myTtoken   
splits = {
    "train_fake": "**/fake_train/*.tar.gz",
    "train_real": "**/real_train/*.tar.gz",
    "test_fake":  "**/fake_test/*.tar.gz",
    "test_real":  "**/real_test/*.tar.gz"
}

def get_urls(split_pattern):
    fs = HfFileSystem()
    files = [fs.resolve_path(path) for path in fs.glob("hf://datasets/xingjunm/WildDeepfake/" + split_pattern)]
    return [hf_hub_url(file.repo_id, file.path_in_repo, repo_type="dataset") for file in files]

def make_ds(urls):
    urls_pipe = f"pipe: curl -s -L -H 'Authorization: Bearer {myTtoken}' {'::'.join(urls)}"
    return wds.WebDataset(urls_pipe, shardshuffle=False).decode()


train_fake_urls = get_urls(splits["train_fake"])
train_real_urls = get_urls(splits["train_real"])
test_fake_urls  = get_urls(splits["test_fake"])
test_real_urls  = get_urls(splits["test_real"])

random.seed(42)  # For reproducibility
random.shuffle(train_fake_urls)
random.shuffle(train_real_urls)
random.shuffle(test_fake_urls)
random.shuffle(test_real_urls)

print(len(train_fake_urls), len(train_real_urls), len(test_fake_urls), len(test_real_urls), sep='\n')

train_fake_urls, val_fake_urls = train_fake_urls[:int(len(train_fake_urls)*0.8)], train_fake_urls[int(len(train_fake_urls)*0.8):]
train_real_urls, val_real_urls = train_real_urls[:int(len(train_real_urls)*0.8)], train_real_urls[int(len(train_real_urls)*0.8):]


max_samples_training =2000

max_smaples_test= 500

max_smaples_val= 500

fake_train = []
max_samples_per_url = 4
for url in train_fake_urls:
  if len(fake_train) >= max_samples_training:
    break
  train_fake =make_ds([url])
  train_fake_processed = list(train_fake.map(preprocess))
  random.shuffle(train_fake_processed)
  for image,label in train_fake_processed:
    if len(fake_train) >= max_samples_training:
      break
    fake_train.append((image,label))

# fake_train, fake_val = fake_data_set[:int(len(fake_data_set)*0.8)], fake_data_set[int(len(fake_data_set)*0.8):]
# random.shuffle(fake_data_set)
random.seed(100)
random.shuffle(fake_train)

real_train = []
max_samples_per_url = 7
for url in train_real_urls:
  if len(real_train) >= max_samples_training:
    break
  train_real =make_ds([url])
  train_real_processed = list(train_real.map(preprocess))
  random.shuffle(train_real_processed)
  for image,label in train_real_processed:
    if len(real_train) >= max_samples_training:
      break
    real_train.append((image,label))

# real_train, real_val = real_data_set[:int(len(real_data_set)*0.8)], real_data_set[int(len(real_data_set)*0.8):]
random.seed(5)
random.shuffle(real_train)
# random.seed(87)
# random.shuffle(real_val)

test_fake_data_set = []
max_samples_per_url = 4
for url in test_fake_urls:
  if len(test_fake_data_set) >= max_smaples_test:
    break
  test_fake =make_ds([url])
  test_fake_processed = list(test_fake.map(preprocess))
  random.shuffle(test_fake_processed)
  for image,label in test_fake_processed:
    if len(test_fake_data_set) >= max_smaples_test:
      break
    test_fake_data_set.append((image,label))
random.seed(13)
random.shuffle(test_fake_data_set)

test_real_data_set = []
max_samples_per_url = 13
for url in test_real_urls:
  if len(test_real_data_set) >= max_smaples_test:
    break
  test_real =make_ds([url])
  test_real_processed = list(test_real.map(preprocess))
  random.shuffle(test_real_processed)
  for image,label in test_real_processed:
    if len(test_real_data_set) >= max_smaples_test:
      break
    test_real_data_set.append((image,label))

random.seed(67)
random.shuffle(test_real_data_set)

fake_val = []
max_samples_per_url = 4
for url in val_fake_urls:
  if len(fake_val) >= max_smaples_val:
    break
  val_fake =make_ds([url])
  val_fake_processed = list(val_fake.map(preprocess))
  random.shuffle(val_fake_processed)
  for image,label in val_fake_processed:
    if len(fake_val) >= max_smaples_val:
      break
    fake_val.append((image,label))

random.seed(100)
random.shuffle(fake_val)

real_val = []
max_samples_per_url = 7
for url in val_real_urls:
  if len(real_val) >= max_smaples_val:
    break
  val_real =make_ds([url])
  val_real_processed = list(val_real.map(preprocess))
  random.shuffle(val_real_processed)
  for image,label in val_real_processed:
    if len(real_val) >= max_smaples_val:
      break
    real_val.append((image,label))


random.seed(87)
random.shuffle(real_val)

print(len(fake_train),len(real_train))
print(len(test_real_data_set),len(test_fake_data_set))
print(len(fake_val),len(real_val))

train_set=fake_train + real_train
test_set= test_real_data_set + test_fake_data_set
val_set = fake_val + real_val

random.seed(58)
random.shuffle(train_set)
random.seed(1)
random.shuffle(test_set)
random.seed(99)
random.shuffle(val_set)

592
371
115
42
2000 2000
500 500
500 500


In [6]:
# Load pretrained ConvNeXt Tiny and transforms
weights = ConvNeXt_Tiny_Weights.IMAGENET1K_V1
train_transform = weights.transforms()
val_transform = weights.transforms()

model = models.convnext_tiny(weights=weights)

# Freeze all params initially
for param in model.parameters():
    param.requires_grad = False

num_classes = 2  
in_features = model.classifier[2].in_features
model.classifier[2] = nn.Linear(in_features, num_classes)

for param in model.classifier.parameters():
    param.requires_grad = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criterion = nn.CrossEntropyLoss()

print(model)

ConvNeXt(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
    )
    (1): Sequential(
      (0): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=96, out_features=384, bias=True)
          (4): GELU(approximate='none')
          (5): Linear(in_features=384, out_features=96, bias=True)
          (6): Permute()
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=

In [7]:
def mixup_data(x, y, alpha=0.4):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


def train_progressive(model, train_dataset, val_dataset, device='cuda',
                      epochs_fc=5, epochs_stage=5, epochs_all=5):

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32)

    def run_epoch(train_loader, val_loader, use_mixup=False):
        model.train()
        total_loss = 0
        for inputs, labels in train_loader:
            inputs = openImgaes(inputs, train_transform).to(device)
            labels_tensor = labels_to_tensor(labels).to(device)

            optimizer.zero_grad()

            if use_mixup:
                inputs, y_a, y_b, lam = mixup_data(inputs, labels_tensor)
                outputs = model(inputs)
                loss = mixup_criterion(criterion, outputs, y_a, y_b, lam)
            else:
                outputs = model(inputs)
                loss = criterion(outputs, labels_tensor)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)

        # Validation
        model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = openImgaes(inputs, val_transform).to(device)
                labels_tensor = labels_to_tensor(labels).to(device)
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels_tensor.cpu().numpy())

        acc = accuracy_score(all_labels, all_preds)
        return avg_loss, acc
        
    optimizer = None
    def phase_training(epochs, description, lr, use_mixup, unfreeze_stages):
        nonlocal optimizer
        # Unfreeze given stages (ConvNeXt has features[0..4] stages)
        for stage_idx in unfreeze_stages:
            for param in model.features[stage_idx].parameters():
                param.requires_grad = True

        print(f"\n[{description}]")
        best_acc = 0.0
        best_model_wts = copy.deepcopy(model.state_dict())

        optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=lr
        )
        scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)

        for epoch in range(epochs):
            loss, acc = run_epoch(train_loader, val_loader, use_mixup=use_mixup)
            print(f"Epoch {epoch+1}/{epochs} - Loss: {loss:.4f}, Val Acc: {acc*100:.2f}%")
            scheduler.step(acc)

            if acc > best_acc:
                best_acc = acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(best_model_wts, f"best_model_{description.replace(' ', '_')}.pth")

        model.load_state_dict(best_model_wts)
        print(f"Restored best model for {description} with Val Acc: {best_acc*100:.2f}%")

    # Phase 1: only FC (classifier)
    phase_training(epochs_fc, "Phase 1 - FC only", lr=1e-3, use_mixup=False, unfreeze_stages=[])

    # Phase 2: unfreeze last ConvNeXt stage = features[4]
    phase_training(epochs_stage, "Phase 2 - FC + Last Stage", lr=1e-4, use_mixup=True, unfreeze_stages=[4])

    # Phase 3: unfreeze all remaining stages (features[0..3])
    phase_training(epochs_all, "Phase 3 - All layers", lr=1e-5, use_mixup=True, unfreeze_stages=[0,1,2,3])

    return model

# Run training
model = train_progressive(model, train_set, val_set, device=device,
                          epochs_fc=15, epochs_stage=15, epochs_all=15)



[Phase 1 - FC only]
Epoch 1/15 - Loss: 0.2178, Val Acc: 70.00%
Epoch 2/15 - Loss: 0.0615, Val Acc: 70.40%
Epoch 3/15 - Loss: 0.0350, Val Acc: 72.10%
Epoch 4/15 - Loss: 0.0239, Val Acc: 73.20%
Epoch 5/15 - Loss: 0.0172, Val Acc: 72.00%
Epoch 6/15 - Loss: 0.0136, Val Acc: 73.30%
Epoch 7/15 - Loss: 0.0107, Val Acc: 72.10%
Epoch 8/15 - Loss: 0.0087, Val Acc: 73.50%
Epoch 9/15 - Loss: 0.0075, Val Acc: 73.30%
Epoch 10/15 - Loss: 0.0063, Val Acc: 74.10%
Epoch 11/15 - Loss: 0.0059, Val Acc: 73.90%
Epoch 12/15 - Loss: 0.0043, Val Acc: 73.60%
Epoch 13/15 - Loss: 0.0039, Val Acc: 71.10%
Epoch 14/15 - Loss: 0.0038, Val Acc: 74.70%
Epoch 15/15 - Loss: 0.0033, Val Acc: 73.00%
Restored best model for Phase 1 - FC only with Val Acc: 74.70%

[Phase 2 - FC + Last Stage]
Epoch 1/15 - Loss: 0.2668, Val Acc: 80.10%
Epoch 2/15 - Loss: 0.2736, Val Acc: 80.20%
Epoch 3/15 - Loss: 0.2355, Val Acc: 82.00%
Epoch 4/15 - Loss: 0.2538, Val Acc: 87.20%
Epoch 5/15 - Loss: 0.2125, Val Acc: 86.80%
Epoch 6/15 - Loss: 0.

In [9]:
model.eval()
all_preds = []
all_labels = []
val_loader = DataLoader(test_set, batch_size=265)
with torch.no_grad():
  for inputs, labels in val_loader:
    newInputs = openImgaes(inputs,val_transform)
    newInputs = newInputs.to(device)
    labels_tensor = labels_to_tensor(labels).to(device)
    outputs = model(newInputs)
    _, preds = torch.max(outputs, 1)
    all_preds.extend(preds.cpu().numpy())
    all_labels.extend(labels_tensor.cpu().numpy())

acc = accuracy_score(all_labels, all_preds)
print(f"test Accuracy: {acc*100:.2f}%\n")

test Accuracy: 94.50%

